From 53905324feb4ad0452f9f85735a88b1526bc6a00 Mon Sep 17 00:00:00 2001 From: yujj128 Date: Fri, 7 Nov 2025 18:06:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8F=8D=E9=A6=88=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- graph_chat/gen_data_report_agent.py | 18 ++++++++++-------- graph_chat/gen_sql_chart_agent.py | 2 +- main_service.py | 7 +++++-- template.yaml | 3 ++- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/graph_chat/gen_data_report_agent.py b/graph_chat/gen_data_report_agent.py index def383f..5785046 100644 --- a/graph_chat/gen_data_report_agent.py +++ b/graph_chat/gen_data_report_agent.py @@ -35,6 +35,7 @@ def _run_sql(state: DateReportAgentState) -> dict: sql = state['sql'] retry_sql = state.get('retry_sql', '') sql_correct = state.get('sql_correct', False) + logger.info(f'in _run_sql state={state}') logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点") try: if retry_sql and not sql_correct: @@ -63,8 +64,9 @@ def _feedback_qa(state: DateReportAgentState) -> dict: feedback_temp = template['template']['result_feedback'] logger.info(f"feedback_temp is {feedback_temp}") ddl_list = vn.get_related_ddl(user_question) + qa_list = vn.get_similar_question_sql(user_question) sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result, - current_time=datetime.now(),ddl_list=ddl_list) + current_time=datetime.now(),ddl_list=ddl_list,qa_list=qa_list) logger.info(f"system_temp is {sys_promot}") result = gen_history_llm.invoke(sys_promot).text() logger.info(f"feedback result: {result}") @@ -72,12 +74,9 @@ def _feedback_qa(state: DateReportAgentState) -> dict: result = orjson.loads(result) logger.info(f"提取json成功") logger.info(f"result is {result} type:{type(result)}") - print("result is {0}".format(result.keys())) if not result["is_result_correct"] and (result["suggested_sql"]): logger.info("开始替换") logger.info(f"suggested_sql is ".format(result["suggested_sql"])) - # state["sql"] = result["suggested_sql"] - # state["retry_sql"] = True logger.info(f"current state: {state}") return {"retry_sql": result["suggested_sql"]} if result["is_result_correct"]: @@ -88,11 +87,13 @@ def _feedback_qa(state: DateReportAgentState) -> dict: def run_sql_hande(state: DateReportAgentState) -> str: + logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点") + logger.info(f'in _run_sql_handle state={state}') # sql_error = state.get('run_sql_error', '') data = state.get('data', {}) sql_correct = state.get('sql_correct', False) sql_retry_count = state.get('retry_count', 0) - logger.info("sql_retry_count is {0}".format(sql_retry_count)) + logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}") if sql_retry_count < 2: if sql_correct: @@ -142,9 +143,10 @@ workflowx.add_node("_gen_report", _gen_report) workflowx.add_edge(START, "_run_sql") workflowx.add_edge("_feedback_qa", "_run_sql") workflowx.add_edge("_gen_report", END) + workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa']) memory = MemorySaver() result_report_agent = workflowx.compile(checkpointer=memory) -png_data=result_report_agent .get_graph().draw_mermaid_png() -with open("D://graph2.png", "wb") as f: - f.write(png_data) \ No newline at end of file +# png_data=result_report_agent .get_graph().draw_mermaid_png() +# with open("D://graph2.png", "wb") as f: +# f.write(png_data) \ No newline at end of file diff --git a/graph_chat/gen_sql_chart_agent.py b/graph_chat/gen_sql_chart_agent.py index 880ac21..2a7ab7b 100644 --- a/graph_chat/gen_sql_chart_agent.py +++ b/graph_chat/gen_sql_chart_agent.py @@ -179,6 +179,6 @@ workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_c workflow.add_conditional_edges('_gen_chart', gen_chart_handler, ['_gen_chart', END]) memory = MemorySaver() sql_chart_agent = workflow.compile(checkpointer=memory) -png_data=sql_chart_agent .get_graph().draw_mermaid_png() +# png_data=sql_chart_agent .get_graph().draw_mermaid_png() # with open("D://graph.png", "wb") as f: # f.write(png_data) \ No newline at end of file diff --git a/main_service.py b/main_service.py index a968029..e83c9ea 100644 --- a/main_service.py +++ b/main_service.py @@ -300,7 +300,7 @@ def gen_graph_question(): try: user_id = request.args.get("user_id") cvs_id = request.args.get("cvs_id") - config = {"configurable": {"thread_id": cvs_id}} + config = {"configurable": {"thread_id": uuid.uuid4()}} question = flask.request.args.get("question") question_context = get_latest_question(cvs_id, user_id,limit_count=2) history = [] @@ -345,6 +345,7 @@ def run_sql_3(): qa = get_qa_by_id(id) sql = qa["sql"] question = qa["question"] + logger.info(f"in main sql {sql} question {question}") logger.info("Start to run sql in main") try: user_id = request.args.get("user_id") @@ -361,8 +362,10 @@ def run_sql_3(): "sql": sql, "question": question, "retry_count": 0, + "need_feedback": False, + } - config = {"configurable": {"thread_id": 'dsds'}} + config = {"configurable": {"thread_id": uuid.uuid4()}} rr = result_report_agent.invoke(initial_state,config) logger.info(f"rr.data is {rr.get('data', {})}") return jsonify( diff --git a/template.yaml b/template.yaml index b5d09d1..0983aec 100644 --- a/template.yaml +++ b/template.yaml @@ -636,11 +636,13 @@ template: [执行结果]: {sql_result}。 [当前时间]: {current_time} [表结构信息]:{ddl_list} + [问答参考]:{qa_list} # 核心规则 请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤: 1. **核对问题理解**: 回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。 我生成的SQL是否准确捕捉了问题的所有关键要素?有没有遗漏或误解? + 当sql和问题吻合,且能查询出结果时,一般判定为正确 2. **通用业务定义** 上周:指完整的**上周一到上周日** 3. **审查SQL逻辑**: @@ -650,7 +652,6 @@ template: 4. **评估结果合理性**: 观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)? 结果的列名和内容是否与问题期望的输出一致? - 如果结果为空或数据量异常,推测可能的原因。 5. **最终判断和建议**: **结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。 **原因** : 简练地解释你做出此判断的核心理由。