模型反馈修正

This commit is contained in:
yujj128
2025-11-07 18:06:37 +08:00
parent d2bf6d71ad
commit 53905324fe
4 changed files with 18 additions and 12 deletions

View File

@@ -35,6 +35,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
sql = state['sql']
retry_sql = state.get('retry_sql', '')
sql_correct = state.get('sql_correct', False)
logger.info(f'in _run_sql state={state}')
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
try:
if retry_sql and not sql_correct:
@@ -63,8 +64,9 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
feedback_temp = template['template']['result_feedback']
logger.info(f"feedback_temp is {feedback_temp}")
ddl_list = vn.get_related_ddl(user_question)
qa_list = vn.get_similar_question_sql(user_question)
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
current_time=datetime.now(),ddl_list=ddl_list)
current_time=datetime.now(),ddl_list=ddl_list,qa_list=qa_list)
logger.info(f"system_temp is {sys_promot}")
result = gen_history_llm.invoke(sys_promot).text()
logger.info(f"feedback result: {result}")
@@ -72,12 +74,9 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
result = orjson.loads(result)
logger.info(f"提取json成功")
logger.info(f"result is {result} type:{type(result)}")
print("result is {0}".format(result.keys()))
if not result["is_result_correct"] and (result["suggested_sql"]):
logger.info("开始替换")
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
# state["sql"] = result["suggested_sql"]
# state["retry_sql"] = True
logger.info(f"current state: {state}")
return {"retry_sql": result["suggested_sql"]}
if result["is_result_correct"]:
@@ -88,11 +87,13 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
def run_sql_hande(state: DateReportAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
logger.info(f'in _run_sql_handle state={state}')
# sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
sql_correct = state.get('sql_correct', False)
sql_retry_count = state.get('retry_count', 0)
logger.info("sql_retry_count is {0}".format(sql_retry_count))
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
if sql_retry_count < 2:
if sql_correct:
@@ -142,9 +143,10 @@ workflowx.add_node("_gen_report", _gen_report)
workflowx.add_edge(START, "_run_sql")
workflowx.add_edge("_feedback_qa", "_run_sql")
workflowx.add_edge("_gen_report", END)
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
memory = MemorySaver()
result_report_agent = workflowx.compile(checkpointer=memory)
png_data=result_report_agent .get_graph().draw_mermaid_png()
with open("D://graph2.png", "wb") as f:
f.write(png_data)
# png_data=result_report_agent .get_graph().draw_mermaid_png()
# with open("D://graph2.png", "wb") as f:
# f.write(png_data)

View File

@@ -179,6 +179,6 @@ workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_c
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, ['_gen_chart', END])
memory = MemorySaver()
sql_chart_agent = workflow.compile(checkpointer=memory)
png_data=sql_chart_agent .get_graph().draw_mermaid_png()
# png_data=sql_chart_agent .get_graph().draw_mermaid_png()
# with open("D://graph.png", "wb") as f:
# f.write(png_data)

View File

@@ -300,7 +300,7 @@ def gen_graph_question():
try:
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
config = {"configurable": {"thread_id": cvs_id}}
config = {"configurable": {"thread_id": uuid.uuid4()}}
question = flask.request.args.get("question")
question_context = get_latest_question(cvs_id, user_id,limit_count=2)
history = []
@@ -345,6 +345,7 @@ def run_sql_3():
qa = get_qa_by_id(id)
sql = qa["sql"]
question = qa["question"]
logger.info(f"in main sql {sql} question {question}")
logger.info("Start to run sql in main")
try:
user_id = request.args.get("user_id")
@@ -361,8 +362,10 @@ def run_sql_3():
"sql": sql,
"question": question,
"retry_count": 0,
"need_feedback": False,
}
config = {"configurable": {"thread_id": 'dsds'}}
config = {"configurable": {"thread_id": uuid.uuid4()}}
rr = result_report_agent.invoke(initial_state,config)
logger.info(f"rr.data is {rr.get('data', {})}")
return jsonify(

View File

@@ -636,11 +636,13 @@ template:
[执行结果]: <sql_result>{sql_result}</sql_result>。
[当前时间]: <current_time>{current_time}</current_time>
[表结构信息]:<schema>{ddl_list}</schema>
[问答参考]:<question_answer>{qa_list}</question_answer>
# 核心规则
请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤:
1. **核对问题理解**
回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。
我生成的SQL是否准确捕捉了问题的所有关键要素有没有遗漏或误解
当sql和问题吻合且能查询出结果时一般判定为正确
2. **通用业务定义**
上周:指完整的**上周一到上周日**
3. **审查SQL逻辑**
@@ -650,7 +652,6 @@ template:
4. **评估结果合理性**
观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)?
结果的列名和内容是否与问题期望的输出一致?
如果结果为空或数据量异常,推测可能的原因。
5. **最终判断和建议**
**结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。
**原因** : 简练地解释你做出此判断的核心理由。