feat:图式节点-支持问数
This commit is contained in:
21
.env
21
.env
@@ -1,4 +1,4 @@
|
||||
IS_FIRST_LOAD=True
|
||||
IS_FIRST_LOAD=False
|
||||
|
||||
#CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
|
||||
#CHAT_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
|
||||
@@ -9,19 +9,26 @@ IS_FIRST_LOAD=True
|
||||
#CHAT_MODEL_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
#CHAT_MODEL_API_KEY=sk-72575159d3ec43a68c6e222a15719bed
|
||||
#CHAT_MODEL_NAME=qwen-plus
|
||||
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
|
||||
CHAT_MODEL_API_KEY=sk-tnmkzvzbipohjfbqxhictewzdgrrxoghbmicrfjgxbgdkjfq
|
||||
CHAT_MODEL_NAME=Qwen/Qwen3-32B
|
||||
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn/v1
|
||||
CHAT_MODEL_API_KEY=sk-jixyuwdltfawojgwywogkkdoqwpxblabprxybltlacbpqnip
|
||||
CHAT_MODEL_NAME=zai-org/GLM-4.6
|
||||
|
||||
SQL_MODEL_BASE_URL=https://api.siliconflow.cn/v1
|
||||
SQL_MODEL_API_KEY=sk-jixyuwdltfawojgwywogkkdoqwpxblabprxybltlacbpqnip
|
||||
SQL_MODEL_NAME=Qwen/Qwen3-Coder-30B-A3B-Instruct
|
||||
SQL_MODEL_TEMPERATURE=0.5
|
||||
|
||||
|
||||
#使用ai中台的模型
|
||||
EMBEDDING_MODEL_BASE_URL=http://10.225.128.2:13206/member1/small-model/bge/encode
|
||||
EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn
|
||||
EMBEDDING_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
|
||||
EMBEDDING_MODEL_NAME=BAAI/bge-m3
|
||||
|
||||
#向量数据库
|
||||
#type:memory/remote,如果设置为remote,将IS_FIRST_LOAD 设置成false
|
||||
QDRANT_TYPE=memory
|
||||
QDRANT_TYPE=remote
|
||||
QDRANT_DB_HOST=106.13.42.156
|
||||
QDRANT_DB_PORT=16000
|
||||
QDRANT_DB_PORT=33088
|
||||
|
||||
#mysql ,sqlite,pg等
|
||||
DATA_SOURCE_TYPE=dameng
|
||||
|
||||
0
graph_chat/__init__.py
Normal file
0
graph_chat/__init__.py
Normal file
178
graph_chat/gen_sql_chart_agent.py
Normal file
178
graph_chat/gen_sql_chart_agent.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from typing import TypedDict, Annotated, List, Optional, Union
|
||||
import pandas as pd
|
||||
from langgraph.types import Command, interrupt
|
||||
from langgraph.graph import StateGraph, END, START
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
import orjson, logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from decouple import config
|
||||
from template.template import get_base_template
|
||||
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
|
||||
from datetime import datetime
|
||||
from util import train_ddl
|
||||
|
||||
gen_history_llm = ChatOpenAI(
|
||||
model=config('CHAT_MODEL_NAME', default=''),
|
||||
base_url=config('CHAT_MODEL_BASE_URL', default=''),
|
||||
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
|
||||
max_tokens=8000,
|
||||
max_retries=2,
|
||||
api_key=config('CHAT_MODEL_API_KEY', default='empty'),
|
||||
)
|
||||
gen_sql_llm = ChatOpenAI(
|
||||
model=config('SQL_MODEL_NAME', default=''),
|
||||
base_url=config('SQL_MODEL_BASE_URL', default=''),
|
||||
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
|
||||
max_tokens=8000,
|
||||
max_retries=2,
|
||||
api_key=config('SQL_MODEL_API_KEY', default='empty'),
|
||||
)
|
||||
|
||||
'''
|
||||
用于生成sql,生成图表的agent,上下文
|
||||
'''
|
||||
|
||||
|
||||
class SqlAgentState(TypedDict):
|
||||
user_question: str
|
||||
rewritten_user_question: Optional[str]
|
||||
session_id: str
|
||||
user_id: str
|
||||
history: Optional[list]
|
||||
gen_sql_result: Optional[dict]
|
||||
gen_chart_result: Optional[dict]
|
||||
|
||||
gen_sql_error: Optional[str]
|
||||
gen_chart_error: Optional[str]
|
||||
final_error_msg: Optional[str]
|
||||
|
||||
sql_retry_count: int
|
||||
chart_retry_count: int
|
||||
|
||||
|
||||
'''
|
||||
基于上下文,历史消息,重写用户问题
|
||||
'''
|
||||
|
||||
|
||||
def _rewrite_user_question(state: SqlAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
|
||||
user_question = state['user_question']
|
||||
template = get_base_template()
|
||||
rewrite_temp = template['template']['rewrite_question']
|
||||
history = state.get('history', [])
|
||||
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
|
||||
new_question = gen_history_llm.invoke(sys_promot).text
|
||||
result = extract_nested_json(new_question)
|
||||
logger.info(f"result:{result}")
|
||||
result = orjson.loads(result)
|
||||
return {"rewritten_user_question": result.get('question', user_question)}
|
||||
|
||||
|
||||
def _gen_sql(state: SqlAgentState) -> dict:
|
||||
from main_service import vn
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
|
||||
question = state.get('rewritten_user_question', state['user_question'])
|
||||
service = vn
|
||||
question_sql_list = service.get_similar_question_sql(question)
|
||||
if question_sql_list and len(question_sql_list) > 3:
|
||||
question_sql_list = question_sql_list[:3]
|
||||
ddl_list = service.get_related_ddl(question)
|
||||
template = get_base_template()
|
||||
sql_temp = template['template']['sql']
|
||||
history = []
|
||||
try:
|
||||
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
|
||||
schema=ddl_list, documentation=[train_ddl.train_document],
|
||||
history=history, retrieved_examples_data=question_sql_list,
|
||||
data_training=question_sql_list, )
|
||||
user_temp = sql_temp['user'].format(question=question,
|
||||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
rr = gen_sql_llm.invoke(
|
||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text
|
||||
result = extract_nested_json(rr)
|
||||
logger.info(f"gensql result: {result}")
|
||||
result = orjson.loads(result)
|
||||
retry = state.get("sql_retry_count", 0) + 1
|
||||
return {'gen_sql_result': result,"sql_retry_count": retry}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
retry = state.get("sql_retry_count", 0) + 1
|
||||
logger.info("cus_vanna_srevice failed-------------------: ")
|
||||
return {'gen_sql_error': str(e), "sql_retry_count": retry}
|
||||
|
||||
|
||||
def _gen_chart(state: SqlAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
|
||||
template = get_base_template()
|
||||
gen_sql_result=state.get('gen_sql_result', {})
|
||||
sql = gen_sql_result.get('sql', '')
|
||||
char_type = gen_sql_result.get('char_type', '')
|
||||
question = state.get('rewritten_user_question', state['user_question'])
|
||||
char_temp = template['template']['chart']
|
||||
try:
|
||||
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
||||
lang='中文', sql=sql, chart_type=char_type)
|
||||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||||
rr=gen_history_llm.invoke([{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text
|
||||
retry = state.get("chart_retry_count", 0) + 1
|
||||
return {"chart_retry_count": retry,'gen_chart_result': orjson.loads(extract_nested_json(rr))}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
retry = state.get("chart_retry_count", 0) + 1
|
||||
return {'gen_chart_error': str(e), "chart_retry_count": retry}
|
||||
|
||||
|
||||
# 如果生成sql失败,则重试2次,如果仍然失败,则返回错误信息
|
||||
def gen_sql_handler(state: SqlAgentState) -> str:
|
||||
sql_error = state.get('gen_sql_error', '')
|
||||
sql_result = state.get('gen_sql_result', {})
|
||||
sql_retry_count = state.get('sql_retry_count', 0)
|
||||
if sql_error and len(sql_error) > 0:
|
||||
if sql_retry_count < 2:
|
||||
return '_gen_sql'
|
||||
else:
|
||||
return END
|
||||
sql = sql_result.get('sql', '')
|
||||
if len(sql) > 0:
|
||||
return '_gen_chart'
|
||||
else:
|
||||
if sql_retry_count < 2:
|
||||
return '_gen_sql'
|
||||
else:
|
||||
return END
|
||||
|
||||
def gen_chart_handler(state: SqlAgentState) -> str:
|
||||
chart_error = state.get('gen_chart_error', '')
|
||||
chart_result = state.get('gen_chart_result', {})
|
||||
sql_retry_count = state.get('chart_retry_count', 0)
|
||||
if chart_error and len(chart_error) > 0:
|
||||
if sql_retry_count < 2:
|
||||
return '_gen_chart'
|
||||
else:
|
||||
return END
|
||||
title = chart_result.get('title', '')
|
||||
if len(title) > 0:
|
||||
return END
|
||||
else:
|
||||
if sql_retry_count < 2:
|
||||
return '_gen_chart'
|
||||
else:
|
||||
return END
|
||||
|
||||
|
||||
workflow = StateGraph(SqlAgentState)
|
||||
workflow.add_node("_rewrite_user_question", _rewrite_user_question)
|
||||
workflow.add_node("_gen_sql", _gen_sql)
|
||||
workflow.add_node("_gen_chart",_gen_chart)
|
||||
workflow.add_edge(START, "_rewrite_user_question")
|
||||
workflow.add_edge("_rewrite_user_question", "_gen_sql")
|
||||
workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_chart',END])
|
||||
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, [ '_gen_chart',END])
|
||||
memory = MemorySaver()
|
||||
sql_chart_agent = workflow.compile(checkpointer=memory)
|
||||
@@ -4,15 +4,15 @@ from functools import wraps
|
||||
import util.utils
|
||||
from logging_config import LOGGING_CONFIG
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||||
from service.question_feedback_service import save_save_question_async, query_predefined_question_list
|
||||
from decouple import config
|
||||
import flask
|
||||
from util import load_ddl_doc
|
||||
from flask import Flask, Response, jsonify, request
|
||||
|
||||
from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent
|
||||
import traceback
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def connect_database(vn):
|
||||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||||
if db_type == 'sqlite':
|
||||
@@ -59,7 +59,7 @@ def create_vana():
|
||||
|
||||
def init_vn(vn):
|
||||
logger.info("--------------init vana-----connect to datasouce db----")
|
||||
connect_database(vn)
|
||||
#connect_database(vn)
|
||||
if config('IS_FIRST_LOAD', default=False, cast=bool):
|
||||
load_ddl_doc.add_ddl(vn)
|
||||
load_ddl_doc.add_documentation(vn)
|
||||
@@ -240,11 +240,36 @@ def verify_user():
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
|
||||
def query_present_question():
|
||||
try:
|
||||
data=query_predefined_question_list()
|
||||
return jsonify({"type": "success", "data": data})
|
||||
data = query_predefined_question_list()
|
||||
return jsonify({"type": "success", "data": data})
|
||||
except Exception as e:
|
||||
logger.error(f"查询预制问题失败 failed:{e}")
|
||||
return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'})
|
||||
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
|
||||
def gen_graph_question():
|
||||
try:
|
||||
config = {"configurable": {"thread_id": '1233'}}
|
||||
question = flask.request.args.get("question")
|
||||
initial_state: SqlAgentState = {
|
||||
"user_question": question,
|
||||
"history": [{"role":"user","content":"宋亚澜9月在林芝工作多少天"},{"role":"user","content":"余佳佳9月在林芝工作多少天"}],
|
||||
"sql_retry_count": 0,
|
||||
"chart_retry_count": 0
|
||||
}
|
||||
result=sql_chart_agent.invoke(initial_state, config=config)
|
||||
data={
|
||||
'sql':result.get("gen_sql_result",{}),
|
||||
'chart':result.get("gen_chart_result",{}),
|
||||
'gen_sql_error':result.get("gen_sql_error",None),
|
||||
'gen_chart_error':result.get("gen_chart_error",None),
|
||||
}
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"查询预制问题失败 failed:{e}")
|
||||
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -2,4 +2,6 @@ vanna ==0.7.9
|
||||
vanna[openai]
|
||||
vanna[qdrant]
|
||||
python-decouple==3.8
|
||||
dmPython==2.5.22
|
||||
dmPython==2.5.22
|
||||
langgraph
|
||||
langchain-openai
|
||||
|
||||
@@ -331,15 +331,15 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||
"model": self.embedding_model_name,
|
||||
"sentences": [data],
|
||||
}
|
||||
# request_body = {
|
||||
# "model": self.embedding_model_name,
|
||||
# "encoding_format": "float",
|
||||
# "input": [data],
|
||||
# }
|
||||
request_body = {
|
||||
"model": self.embedding_model_name,
|
||||
"encoding_format": "float",
|
||||
"input": [data],
|
||||
}
|
||||
request_body.update(kwargs)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{self.embedding_api_base}",
|
||||
url=f"{self.embedding_api_base}/v1/embeddings",
|
||||
json=request_body,
|
||||
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
|
||||
)
|
||||
@@ -348,10 +348,10 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
|
||||
)
|
||||
result = response.json()
|
||||
embeddings = result['embeddings']
|
||||
# embeddings = result['data'][0]['embedding']
|
||||
return embeddings[0]
|
||||
# return embeddings
|
||||
#embeddings = result['embeddings']
|
||||
embeddings = result['data'][0]['embedding']
|
||||
#return embeddings[0]
|
||||
return embeddings
|
||||
|
||||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
|
||||
def __init__(self, llm_config=None, vector_store_config=None):
|
||||
|
||||
@@ -538,4 +538,47 @@ template:
|
||||
{sub_query}
|
||||
|
||||
|
||||
rewrite_question:
|
||||
system: |
|
||||
# 角色
|
||||
你是一个专业的智能数据分析助手的对话上下文理解模块。
|
||||
# 任务
|
||||
你的核心任务是判断用户的当前问题是否依赖于之前的对话历史,并据此生成一个独立的、可供数据分析系统直接执行的问题。
|
||||
# 核心规则
|
||||
1. **识别关联性**:
|
||||
* 分析当前问题是否包含指代词(如“它”、“这个”、“那个”)、简称(如“去年同期”、“此产品”)、或省略的主语/宾语。
|
||||
* 如果存在上述特征,且这些信息能在历史对话中找到明确的对应项,则判定为**关联**。
|
||||
* 如果当前问题是一个全新的、独立的查询,与历史话题无关,则判定为**不关联**。
|
||||
2. **生成最终问题**:
|
||||
* **如果判定为【关联】**:你必须将历史对话中的相关上下文信息**融合**到当前问题中,形成一个**完整清晰、无任何指代或歧义的新问题**。
|
||||
* **如果判定为【不关联】**:你只需**原样输出**当前问题。
|
||||
# 输出格式要求
|
||||
- **你的唯一输出必须是一个JSON对象**。
|
||||
- **严禁**在JSON前后添加任何解释性文字、代码块标记(如 ```json)或任何其他内容。
|
||||
- JSON的键固定为 `question`,值为处理后的最终问题字符串。
|
||||
**最终输出格式示例**:
|
||||
{{"question": "A产品上个月的成本是多少?"}}
|
||||
<examples>
|
||||
## 示例1
|
||||
**对话历史:**
|
||||
Q: B产品上月的销量是? A: 200件
|
||||
**当前用户问题:**
|
||||
那利润率呢?
|
||||
**输出:**
|
||||
{{"question": "B产品上月的利润率是多少?"}}
|
||||
## 示例2
|
||||
**对话历史:**
|
||||
Q: 列出所有华东区的销售员 A: 张三、李四、王五
|
||||
**当前用户问题:**
|
||||
帮我预定一张明天去上海的机票
|
||||
**输出:**
|
||||
{{"question": "帮我预定一张明天去上海的机票"}}
|
||||
</examples>
|
||||
####
|
||||
# 上下文信息
|
||||
**对话历史:**
|
||||
{history}
|
||||
**当前用户问题:**
|
||||
{current_question}
|
||||
|
||||
Resources:
|
||||
Reference in New Issue
Block a user