模型反馈修正

This commit is contained in:
yujj128
2025-11-07 15:30:27 +08:00
parent 3f7cd23bdc
commit 6fe29a51ce
6 changed files with 166 additions and 34 deletions

View File

@@ -1,12 +1,17 @@
import json
import logging
from datetime import datetime
from typing import TypedDict, Optional
import orjson
from anyio import current_time
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, END, START
from graph_chat.gen_sql_chart_agent import gen_history_llm
from util.utils import extract_nested_json
from service.conversation_service import update_conversation
logger = logging.getLogger(__name__)
from template.template import get_base_template
@@ -20,41 +25,103 @@ class DateReportAgentState(TypedDict):
run_sql_error: str
retry_count: int
question:str
retry_sql:str
sql_correct:bool
def _run_sql(state: DateReportAgentState) -> dict:
from main_service import vn
sql = state['sql']
retry_sql = state.get('retry_sql', '')
sql_correct = state.get('sql_correct', False)
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
try:
if retry_sql and not sql_correct:
sql = retry_sql
update_conversation(id=state.get('id'), sql=sql)
df = vn.run_sql_2(sql)
result = df.to_dict(orient='records')
logger.info(f"run sql result {result}")
retry = state.get("retry_count", 0) + 1
return {'data': result, 'retry_count': retry}
logger.info(f"old_sql:{sql} retry sql {retry_sql}")
return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
except Exception as e:
retry = state.get("retry_count", 0) + 1
logger.error(f"Error running sql: {sql}, error: {e}")
return {'retry_count': retry, 'run_sql_error': '运行sql语句失败请检查语法或联系管理员。'}
def _feedback_qa(state: DateReportAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
try:
user_question = state['question']
sql = state['sql']
sql_result = state['data']
template = get_base_template()
feedback_temp = template['template']['result_feedback']
logger.info(f"feedback_temp is {feedback_temp}")
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
current_time=datetime.now())
logger.info(f"system_temp is {sys_promot}")
result = gen_history_llm.invoke(sys_promot).text()
logger.info(f"feedback result: {result}")
result = extract_nested_json(result)
result = orjson.loads(result)
logger.info(f"提取json成功")
logger.info(f"result is {result} type:{type(result)}")
print("result is {0}".format(result.keys()))
if not result["is_result_correct"] and (result["suggested_sql"]):
logger.info("开始替换")
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
# state["sql"] = result["suggested_sql"]
# state["retry_sql"] = True
logger.info(f"current state: {state}")
return {"retry_sql": result["suggested_sql"]}
if result["is_result_correct"]:
return {"sql_correct": True}
except Exception as e:
logger.error(f"Error feedback: {e}")
def run_sql_hande(state: DateReportAgentState) -> str:
sql_error = state.get('run_sql_error', '')
# sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
sql_correct = state.get('sql_correct', False)
sql_retry_count = state.get('retry_count', 0)
if sql_error and len(sql_error) > 0:
if sql_retry_count < 2:
return '_run_sql'
logger.info("sql_retry_count is {0}".format(sql_retry_count))
if sql_retry_count < 2:
if sql_correct:
return '_gen_report'
else:
return END
if data and len(data) > 0:
return '_gen_report'
return '_feedback_qa'
else:
if sql_retry_count < 2:
return '_run_sql'
if data and len(data) > 0:
return '_gen_report'
else:
return END
END
# retry_sql = state.get('retry_sql', False)
# if sql_error and len(sql_error) > 0:
# if sql_retry_count < 2:
# return '_feedback_qa'
# # return '_run_sql'
# else:
# return END
#
# if data and len(data) > 0:
# return '_gen_report'
# else:
# if sql_retry_count < 2:
# return '_feedback_qa'
# else:
# return END
def _gen_report(state: DateReportAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
sql = state['sql']
question = state['question']
run_sql_error = state.get('run_sql_error', '')
@@ -62,15 +129,20 @@ def _gen_report(state: DateReportAgentState) -> dict:
template = get_base_template()
result_summary = template['template']['result_summary']['system']
pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data)
txt = gen_history_llm.invoke(pro).text
txt = gen_history_llm.invoke(pro).text()
return {'summary': txt}
workflowx = StateGraph(DateReportAgentState)
workflowx.add_node("_run_sql", _run_sql)
workflowx.add_node("_feedback_qa", _feedback_qa)
workflowx.add_node("_gen_report", _gen_report)
workflowx.add_edge(START, "_run_sql")
workflowx.add_edge("_feedback_qa", "_run_sql")
workflowx.add_edge("_gen_report", END)
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report'])
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
memory = MemorySaver()
result_report_agent = workflowx.compile(checkpointer=memory)
png_data=result_report_agent .get_graph().draw_mermaid_png()
with open("D://graph2.png", "wb") as f:
f.write(png_data)

View File

@@ -65,7 +65,7 @@ def _rewrite_user_question(state: SqlAgentState) -> dict:
rewrite_temp = template['template']['rewrite_question']
history = state.get('history', [])
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
new_question = gen_history_llm.invoke(sys_promot).text
new_question = gen_history_llm.invoke(sys_promot).text()
logger.info("new_question:{0}".format(new_question))
result = extract_nested_json(new_question)
logger.info(f"result:{result}")
@@ -93,7 +93,7 @@ def _gen_sql(state: SqlAgentState) -> dict:
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
rr = gen_sql_llm.invoke(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text()
result = extract_nested_json(rr)
logger.info(f"gensql result: {result}")
result = orjson.loads(result)
@@ -120,7 +120,7 @@ def _gen_chart(state: SqlAgentState) -> dict:
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
rr = gen_history_llm.invoke(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text()
retry = state.get("chart_retry_count", 0) + 1
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))}
except Exception as e:

View File

@@ -7,7 +7,7 @@ import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
from service.conversation_service import save_conversation,update_conversation,get_sql_by_id
from service.conversation_service import save_conversation,update_conversation,get_qa_by_id,get_latest_question
from decouple import config
import flask
from util import load_ddl_doc
@@ -133,7 +133,7 @@ def generate_sql_2():
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : "+ sql)
update_conversation(cvs_id, id, sql)
update_conversation(id, sql)
save_save_question_async(id, user_id, question, sql)
data["type"]="success"
return jsonify(data)
@@ -241,7 +241,8 @@ def run_sql_2():
logger.info("Start to run sql in main")
try:
id = request.args.get("id")
sql = get_sql_by_id(id)
qa = get_qa_by_id(id)
sql = qa["sql"]
logger.info(f"sql is {sql}")
if not vn.run_sql_is_set:
return jsonify(
@@ -297,18 +298,28 @@ def query_present_question():
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
def gen_graph_question():
try:
config = {"configurable": {"thread_id": '1233'}}
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
config = {"configurable": {"thread_id": cvs_id}}
question = flask.request.args.get("question")
question_context = get_latest_question(cvs_id, user_id,limit_count=2)
history = []
for q in question_context:
history.append({"role":"user",'content':q})
logger.info(f"history is {history}")
initial_state: SqlAgentState = {
"user_id": user_id,
"user_question": question,
"history": [{"role":"user",'content':'谭杰明昨日打卡记录查询'}],
"history": history,
"sql_retry_count": 0,
"chart_retry_count": 0
}
result = sql_chart_agent.invoke(initial_state, config=config)
id = str(uuid.uuid4())
cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
result = sql_chart_agent.invoke(initial_state, config=config)
new_question = result.get('rewritten_user_question',question)
id = str(uuid.uuid4())
save_conversation(id, user_id, cvs_id, new_question)
# cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
data = {
'id': id,
'sql': result.get("gen_sql_result", {}),
@@ -316,7 +327,9 @@ def gen_graph_question():
'gen_sql_error': result.get("gen_sql_error", None),
'gen_chart_error': result.get("gen_chart_error", None),
}
cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
# cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
sql = data.get('sql', {}).get('sql', '')
update_conversation(id, sql)
return jsonify(data)
except Exception as e:
traceback.print_exc()
@@ -325,9 +338,13 @@ def gen_graph_question():
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_3", methods=["GET"])
@session_save
@app.requires_cache(["sql",'question'])
def run_sql_3(id: str, sql: str, question: str):
# @session_save
# @app.requires_cache(["sql",'question'])
def run_sql_3():
id = request.args.get("id")
qa = get_qa_by_id(id)
sql = qa["sql"]
question = qa["question"]
logger.info("Start to run sql in main")
try:
user_id = request.args.get("user_id")
@@ -347,6 +364,7 @@ def run_sql_3(id: str, sql: str, question: str):
}
config = {"configurable": {"thread_id": 'dsds'}}
rr = result_report_agent.invoke(initial_state,config)
logger.info(f"rr.data is {rr.get('data', {})}")
return jsonify(
{
'data': rr.get('data', {}),

View File

@@ -41,14 +41,14 @@ def get_conversation(cvs_id: str):
# }
def update_conversation(cvs_id: str, id: str, sql=None, meta=None):
def update_conversation(id: str, sql=None, meta=None):
"""更新sql到对应question"""
session = SqliteSqlalchemy().session
try:
if sql:
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.sql: sql})
session.query(Conversation).filter(Conversation.id == id).update({Conversation.sql: sql})
if meta:
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.meta: meta})
session.query(Conversation).filter(Conversation.id == id).update({Conversation.meta: meta})
session.commit()
except Exception as e:
session.rollback()
@@ -72,12 +72,12 @@ def get_latest_question(cvs_id, user_id, limit_count):
session.close()
def get_sql_by_id(id: str):
def get_qa_by_id(id: str):
session = SqliteSqlalchemy().session
try:
result = session.query(Conversation).filter_by(id=id).first()
if result:
return result.sql
return {"question":result.question, "sql":result.sql}
return None
except Exception as e:
logger.error(f"get_sql_by_id error {e}")

View File

@@ -247,7 +247,7 @@ class OpenAICompatibleLLM(VannaBase):
new_question = self.generate_rewritten_question(questions,**kwargs)
logger.info(f"new_question is {new_question}")
question = new_question if new_question else question
update_conversation(cvs_id, id, meta=question)
update_conversation(id, meta=question)
# if user_id and cache:
# history = cache.get(id=user_id, field="data")

View File

@@ -624,4 +624,46 @@ template:
# 最终输出
请根据以上规则,直接输出最终生成的回复话术,不要包含任何其他解释或格式。
result_feedback:
system: |
# 角色
你是一位经验丰富的数据分析师和SQL专家。你的核心任务是对SQL查询和其执行结果进行批判性审查。
# 你的任务
我(用户)向你提供了一系列信息,请你进行分析和反思。
# 输入信息
[原始问题]: <question>{question}</question>
[生成的SQL]: <sql>{sql}</sql>。
[执行结果]: <sql_result>{sql_result}</sql_result>。
[当前时间]: <current_time>{current_time}</current_time>
# 核心规则
请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤:
1. **核对问题理解**
回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。
我生成的SQL是否准确捕捉了问题的所有关键要素有没有遗漏或误解
2. **通用业务定义**
上周:指完整的**上周一到上周日**
3. **审查SQL逻辑**
逐行分析“生成的SQL”检查其语法结构是否正确。
关键子句(如 `WHERE`, `GROUP BY`, `HAVING`, `JOIN`)是否准确地反映了查询意图?
聚合函数(如 `SUM`, `COUNT`, `AVG`)的使用是否恰当?分组维度是否正确?
4. **评估结果合理性**
观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)?
结果的列名和内容是否与问题期望的输出一致?
如果结果为空或数据量异常,推测可能的原因。
5. **最终判断和建议**
**结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。
**原因** : 简练地解释你做出此判断的核心理由。
**改进建议** (如果结果不正确):
* 具体指出SQL中可能存在的逻辑错误。
* 给出一个或多个修改后的SQL版本。
* 解释你为什么这样修改。
# 输出格式
请严格按照以下JSON格式输出你的分析结果
```json
{{
"conclusion": "结果正确 OR 结果可能不正确 OR 无法完全确定",
"reasoning": "对问题理解、SQL逻辑和结果合理性的综合分析说明。",
"is_result_correct": true OR false,
"suggested_sql": "如果结论为不正确请在此处提供修改后的SQL。如果正确则为null。"
}}
Resources: