feat:图式节点-支持问数

This commit is contained in:
雷雨
2025-11-06 12:23:07 +08:00
parent 5edc62e9f7
commit e34a8b839f
7 changed files with 280 additions and 25 deletions

21
.env
View File

@@ -1,4 +1,4 @@
IS_FIRST_LOAD=True
IS_FIRST_LOAD=False
#CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
#CHAT_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
@@ -9,19 +9,26 @@ IS_FIRST_LOAD=True
#CHAT_MODEL_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
#CHAT_MODEL_API_KEY=sk-72575159d3ec43a68c6e222a15719bed
#CHAT_MODEL_NAME=qwen-plus
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
CHAT_MODEL_API_KEY=sk-tnmkzvzbipohjfbqxhictewzdgrrxoghbmicrfjgxbgdkjfq
CHAT_MODEL_NAME=Qwen/Qwen3-32B
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn/v1
CHAT_MODEL_API_KEY=sk-jixyuwdltfawojgwywogkkdoqwpxblabprxybltlacbpqnip
CHAT_MODEL_NAME=zai-org/GLM-4.6
SQL_MODEL_BASE_URL=https://api.siliconflow.cn/v1
SQL_MODEL_API_KEY=sk-jixyuwdltfawojgwywogkkdoqwpxblabprxybltlacbpqnip
SQL_MODEL_NAME=Qwen/Qwen3-Coder-30B-A3B-Instruct
SQL_MODEL_TEMPERATURE=0.5
#使用ai中台的模型
EMBEDDING_MODEL_BASE_URL=http://10.225.128.2:13206/member1/small-model/bge/encode
EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn
EMBEDDING_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
EMBEDDING_MODEL_NAME=BAAI/bge-m3
#向量数据库
#type:memory/remote,如果设置为remote将IS_FIRST_LOAD 设置成false
QDRANT_TYPE=memory
QDRANT_TYPE=remote
QDRANT_DB_HOST=106.13.42.156
QDRANT_DB_PORT=16000
QDRANT_DB_PORT=33088
#mysql ,sqlite,pg等
DATA_SOURCE_TYPE=dameng

0
graph_chat/__init__.py Normal file
View File

View File

@@ -0,0 +1,178 @@
from langchain_openai import ChatOpenAI
from typing import TypedDict, Annotated, List, Optional, Union
import pandas as pd
from langgraph.types import Command, interrupt
from langgraph.graph import StateGraph, END, START
from langgraph.checkpoint.memory import MemorySaver
import orjson, logging
logger = logging.getLogger(__name__)
from decouple import config
from template.template import get_base_template
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
from datetime import datetime
from util import train_ddl
gen_history_llm = ChatOpenAI(
model=config('CHAT_MODEL_NAME', default=''),
base_url=config('CHAT_MODEL_BASE_URL', default=''),
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
max_tokens=8000,
max_retries=2,
api_key=config('CHAT_MODEL_API_KEY', default='empty'),
)
gen_sql_llm = ChatOpenAI(
model=config('SQL_MODEL_NAME', default=''),
base_url=config('SQL_MODEL_BASE_URL', default=''),
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
max_tokens=8000,
max_retries=2,
api_key=config('SQL_MODEL_API_KEY', default='empty'),
)
'''
用于生成sql,生成图表的agent上下文
'''
class SqlAgentState(TypedDict):
user_question: str
rewritten_user_question: Optional[str]
session_id: str
user_id: str
history: Optional[list]
gen_sql_result: Optional[dict]
gen_chart_result: Optional[dict]
gen_sql_error: Optional[str]
gen_chart_error: Optional[str]
final_error_msg: Optional[str]
sql_retry_count: int
chart_retry_count: int
'''
基于上下文,历史消息,重写用户问题
'''
def _rewrite_user_question(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
user_question = state['user_question']
template = get_base_template()
rewrite_temp = template['template']['rewrite_question']
history = state.get('history', [])
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
new_question = gen_history_llm.invoke(sys_promot).text
result = extract_nested_json(new_question)
logger.info(f"result:{result}")
result = orjson.loads(result)
return {"rewritten_user_question": result.get('question', user_question)}
def _gen_sql(state: SqlAgentState) -> dict:
from main_service import vn
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
question = state.get('rewritten_user_question', state['user_question'])
service = vn
question_sql_list = service.get_similar_question_sql(question)
if question_sql_list and len(question_sql_list) > 3:
question_sql_list = question_sql_list[:3]
ddl_list = service.get_related_ddl(question)
template = get_base_template()
sql_temp = template['template']['sql']
history = []
try:
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
history=history, retrieved_examples_data=question_sql_list,
data_training=question_sql_list, )
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
rr = gen_sql_llm.invoke(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text
result = extract_nested_json(rr)
logger.info(f"gensql result: {result}")
result = orjson.loads(result)
retry = state.get("sql_retry_count", 0) + 1
return {'gen_sql_result': result,"sql_retry_count": retry}
except Exception as e:
import traceback
traceback.print_exc()
retry = state.get("sql_retry_count", 0) + 1
logger.info("cus_vanna_srevice failed-------------------: ")
return {'gen_sql_error': str(e), "sql_retry_count": retry}
def _gen_chart(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
template = get_base_template()
gen_sql_result=state.get('gen_sql_result', {})
sql = gen_sql_result.get('sql', '')
char_type = gen_sql_result.get('char_type', '')
question = state.get('rewritten_user_question', state['user_question'])
char_temp = template['template']['chart']
try:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
rr=gen_history_llm.invoke([{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text
retry = state.get("chart_retry_count", 0) + 1
return {"chart_retry_count": retry,'gen_chart_result': orjson.loads(extract_nested_json(rr))}
except Exception as e:
import traceback
traceback.print_exc()
retry = state.get("chart_retry_count", 0) + 1
return {'gen_chart_error': str(e), "chart_retry_count": retry}
# 如果生成sql失败则重试2次如果仍然失败则返回错误信息
def gen_sql_handler(state: SqlAgentState) -> str:
sql_error = state.get('gen_sql_error', '')
sql_result = state.get('gen_sql_result', {})
sql_retry_count = state.get('sql_retry_count', 0)
if sql_error and len(sql_error) > 0:
if sql_retry_count < 2:
return '_gen_sql'
else:
return END
sql = sql_result.get('sql', '')
if len(sql) > 0:
return '_gen_chart'
else:
if sql_retry_count < 2:
return '_gen_sql'
else:
return END
def gen_chart_handler(state: SqlAgentState) -> str:
chart_error = state.get('gen_chart_error', '')
chart_result = state.get('gen_chart_result', {})
sql_retry_count = state.get('chart_retry_count', 0)
if chart_error and len(chart_error) > 0:
if sql_retry_count < 2:
return '_gen_chart'
else:
return END
title = chart_result.get('title', '')
if len(title) > 0:
return END
else:
if sql_retry_count < 2:
return '_gen_chart'
else:
return END
workflow = StateGraph(SqlAgentState)
workflow.add_node("_rewrite_user_question", _rewrite_user_question)
workflow.add_node("_gen_sql", _gen_sql)
workflow.add_node("_gen_chart",_gen_chart)
workflow.add_edge(START, "_rewrite_user_question")
workflow.add_edge("_rewrite_user_question", "_gen_sql")
workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_chart',END])
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, [ '_gen_chart',END])
memory = MemorySaver()
sql_chart_agent = workflow.compile(checkpointer=memory)

View File

@@ -4,15 +4,15 @@ from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
from service.question_feedback_service import save_save_question_async, query_predefined_question_list
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request
from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent
import traceback
logger = logging.getLogger(__name__)
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
@@ -59,7 +59,7 @@ def create_vana():
def init_vn(vn):
logger.info("--------------init vana-----connect to datasouce db----")
connect_database(vn)
#connect_database(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
@@ -240,11 +240,36 @@ def verify_user():
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
def query_present_question():
try:
data=query_predefined_question_list()
return jsonify({"type": "success", "data": data})
data = query_predefined_question_list()
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'})
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
def gen_graph_question():
try:
config = {"configurable": {"thread_id": '1233'}}
question = flask.request.args.get("question")
initial_state: SqlAgentState = {
"user_question": question,
"history": [{"role":"user","content":"宋亚澜9月在林芝工作多少天"},{"role":"user","content":"余佳佳9月在林芝工作多少天"}],
"sql_retry_count": 0,
"chart_retry_count": 0
}
result=sql_chart_agent.invoke(initial_state, config=config)
data={
'sql':result.get("gen_sql_result",{}),
'chart':result.get("gen_chart_result",{}),
'gen_sql_error':result.get("gen_sql_error",None),
'gen_chart_error':result.get("gen_chart_error",None),
}
return jsonify(data)
except Exception as e:
traceback.print_exc()
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
if __name__ == '__main__':

View File

@@ -2,4 +2,6 @@ vanna ==0.7.9
vanna[openai]
vanna[qdrant]
python-decouple==3.8
dmPython==2.5.22
dmPython==2.5.22
langgraph
langchain-openai

View File

@@ -331,15 +331,15 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
"model": self.embedding_model_name,
"sentences": [data],
}
# request_body = {
# "model": self.embedding_model_name,
# "encoding_format": "float",
# "input": [data],
# }
request_body = {
"model": self.embedding_model_name,
"encoding_format": "float",
"input": [data],
}
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}",
url=f"{self.embedding_api_base}/v1/embeddings",
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
)
@@ -348,10 +348,10 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
embeddings = result['embeddings']
# embeddings = result['data'][0]['embedding']
return embeddings[0]
# return embeddings
#embeddings = result['embeddings']
embeddings = result['data'][0]['embedding']
#return embeddings[0]
return embeddings
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):

View File

@@ -538,4 +538,47 @@ template:
{sub_query}
rewrite_question:
system: |
# 角色
你是一个专业的智能数据分析助手的对话上下文理解模块。
# 任务
你的核心任务是判断用户的当前问题是否依赖于之前的对话历史,并据此生成一个独立的、可供数据分析系统直接执行的问题。
# 核心规则
1. **识别关联性**
* 分析当前问题是否包含指代词(如“它”、“这个”、“那个”)、简称(如“去年同期”、“此产品”)、或省略的主语/宾语。
* 如果存在上述特征,且这些信息能在历史对话中找到明确的对应项,则判定为**关联**。
* 如果当前问题是一个全新的、独立的查询,与历史话题无关,则判定为**不关联**。
2. **生成最终问题**
* **如果判定为【关联】**:你必须将历史对话中的相关上下文信息**融合**到当前问题中,形成一个**完整清晰、无任何指代或歧义的新问题**。
* **如果判定为【不关联】**:你只需**原样输出**当前问题。
# 输出格式要求
- **你的唯一输出必须是一个JSON对象**。
- **严禁**在JSON前后添加任何解释性文字、代码块标记如 ```json或任何其他内容。
- JSON的键固定为 `question`,值为处理后的最终问题字符串。
**最终输出格式示例**
{{"question": "A产品上个月的成本是多少"}}
<examples>
## 示例1
**对话历史:**
Q: B产品上月的销量是 A: 200件
**当前用户问题:**
那利润率呢?
**输出:**
{{"question": "B产品上月的利润率是多少"}}
## 示例2
**对话历史:**
Q: 列出所有华东区的销售员 A: 张三、李四、王五
**当前用户问题:**
帮我预定一张明天去上海的机票
**输出:**
{{"question": "帮我预定一张明天去上海的机票"}}
</examples>
####
# 上下文信息
**对话历史:**
{history}
**当前用户问题:**
{current_question}
Resources: