From 6fe29a51ce76b14a4c7ff762786f9cf275035f1d Mon Sep 17 00:00:00 2001 From: yujj128 Date: Fri, 7 Nov 2025 15:30:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8F=8D=E9=A6=88=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- graph_chat/gen_data_report_agent.py | 98 +++++++++++++++++++++++++---- graph_chat/gen_sql_chart_agent.py | 6 +- main_service.py | 42 +++++++++---- service/conversation_service.py | 10 +-- service/cus_vanna_srevice.py | 2 +- template.yaml | 42 +++++++++++++ 6 files changed, 166 insertions(+), 34 deletions(-) diff --git a/graph_chat/gen_data_report_agent.py b/graph_chat/gen_data_report_agent.py index 4c6b7d9..0993cb8 100644 --- a/graph_chat/gen_data_report_agent.py +++ b/graph_chat/gen_data_report_agent.py @@ -1,12 +1,17 @@ import json import logging +from datetime import datetime from typing import TypedDict, Optional +import orjson +from anyio import current_time from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import StateGraph, END, START from graph_chat.gen_sql_chart_agent import gen_history_llm +from util.utils import extract_nested_json +from service.conversation_service import update_conversation logger = logging.getLogger(__name__) from template.template import get_base_template @@ -20,41 +25,103 @@ class DateReportAgentState(TypedDict): run_sql_error: str retry_count: int question:str + retry_sql:str + sql_correct:bool + def _run_sql(state: DateReportAgentState) -> dict: from main_service import vn sql = state['sql'] + retry_sql = state.get('retry_sql', '') + sql_correct = state.get('sql_correct', False) + logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点") try: + if retry_sql and not sql_correct: + sql = retry_sql + update_conversation(id=state.get('id'), sql=sql) df = vn.run_sql_2(sql) result = df.to_dict(orient='records') + logger.info(f"run sql result {result}") retry = state.get("retry_count", 0) + 1 - return {'data': result, 'retry_count': retry} + logger.info(f"old_sql:{sql} retry sql {retry_sql}") + return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql} except Exception as e: retry = state.get("retry_count", 0) + 1 logger.error(f"Error running sql: {sql}, error: {e}") return {'retry_count': retry, 'run_sql_error': '运行sql语句失败,请检查语法或联系管理员。'} +def _feedback_qa(state: DateReportAgentState) -> dict: + logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点") + try: + user_question = state['question'] + sql = state['sql'] + sql_result = state['data'] + template = get_base_template() + feedback_temp = template['template']['result_feedback'] + logger.info(f"feedback_temp is {feedback_temp}") + sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result, + current_time=datetime.now()) + logger.info(f"system_temp is {sys_promot}") + result = gen_history_llm.invoke(sys_promot).text() + logger.info(f"feedback result: {result}") + result = extract_nested_json(result) + result = orjson.loads(result) + logger.info(f"提取json成功") + logger.info(f"result is {result} type:{type(result)}") + print("result is {0}".format(result.keys())) + if not result["is_result_correct"] and (result["suggested_sql"]): + logger.info("开始替换") + logger.info(f"suggested_sql is ".format(result["suggested_sql"])) + # state["sql"] = result["suggested_sql"] + # state["retry_sql"] = True + logger.info(f"current state: {state}") + return {"retry_sql": result["suggested_sql"]} + if result["is_result_correct"]: + return {"sql_correct": True} + except Exception as e: + logger.error(f"Error feedback: {e}") + + + def run_sql_hande(state: DateReportAgentState) -> str: - sql_error = state.get('run_sql_error', '') + # sql_error = state.get('run_sql_error', '') data = state.get('data', {}) + sql_correct = state.get('sql_correct', False) sql_retry_count = state.get('retry_count', 0) - if sql_error and len(sql_error) > 0: - if sql_retry_count < 2: - return '_run_sql' + logger.info("sql_retry_count is {0}".format(sql_retry_count)) + + if sql_retry_count < 2: + if sql_correct: + return '_gen_report' else: - return END - if data and len(data) > 0: - return '_gen_report' + return '_feedback_qa' else: - if sql_retry_count < 2: - return '_run_sql' + if data and len(data) > 0: + return '_gen_report' else: - return END + END + # retry_sql = state.get('retry_sql', False) + + # if sql_error and len(sql_error) > 0: + # if sql_retry_count < 2: + # return '_feedback_qa' + # # return '_run_sql' + # else: + # return END + # + # if data and len(data) > 0: + # return '_gen_report' + # else: + # if sql_retry_count < 2: + # return '_feedback_qa' + # else: + # return END def _gen_report(state: DateReportAgentState) -> dict: + logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点") sql = state['sql'] question = state['question'] run_sql_error = state.get('run_sql_error', '') @@ -62,15 +129,20 @@ def _gen_report(state: DateReportAgentState) -> dict: template = get_base_template() result_summary = template['template']['result_summary']['system'] pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data) - txt = gen_history_llm.invoke(pro).text + txt = gen_history_llm.invoke(pro).text() return {'summary': txt} workflowx = StateGraph(DateReportAgentState) workflowx.add_node("_run_sql", _run_sql) +workflowx.add_node("_feedback_qa", _feedback_qa) workflowx.add_node("_gen_report", _gen_report) workflowx.add_edge(START, "_run_sql") +workflowx.add_edge("_feedback_qa", "_run_sql") workflowx.add_edge("_gen_report", END) -workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report']) +workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa']) memory = MemorySaver() result_report_agent = workflowx.compile(checkpointer=memory) +png_data=result_report_agent .get_graph().draw_mermaid_png() +with open("D://graph2.png", "wb") as f: + f.write(png_data) \ No newline at end of file diff --git a/graph_chat/gen_sql_chart_agent.py b/graph_chat/gen_sql_chart_agent.py index c587b84..880ac21 100644 --- a/graph_chat/gen_sql_chart_agent.py +++ b/graph_chat/gen_sql_chart_agent.py @@ -65,7 +65,7 @@ def _rewrite_user_question(state: SqlAgentState) -> dict: rewrite_temp = template['template']['rewrite_question'] history = state.get('history', []) sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history) - new_question = gen_history_llm.invoke(sys_promot).text + new_question = gen_history_llm.invoke(sys_promot).text() logger.info("new_question:{0}".format(new_question)) result = extract_nested_json(new_question) logger.info(f"result:{result}") @@ -93,7 +93,7 @@ def _gen_sql(state: SqlAgentState) -> dict: user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) rr = gen_sql_llm.invoke( - [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text + [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text() result = extract_nested_json(rr) logger.info(f"gensql result: {result}") result = orjson.loads(result) @@ -120,7 +120,7 @@ def _gen_chart(state: SqlAgentState) -> dict: lang='中文', sql=sql, chart_type=char_type) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) rr = gen_history_llm.invoke( - [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text + [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text() retry = state.get("chart_retry_count", 0) + 1 return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))} except Exception as e: diff --git a/main_service.py b/main_service.py index 9dc24d0..a968029 100644 --- a/main_service.py +++ b/main_service.py @@ -7,7 +7,7 @@ import util.utils from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper from service.question_feedback_service import save_save_question_async,query_predefined_question_list -from service.conversation_service import save_conversation,update_conversation,get_sql_by_id +from service.conversation_service import save_conversation,update_conversation,get_qa_by_id,get_latest_question from decouple import config import flask from util import load_ddl_doc @@ -133,7 +133,7 @@ def generate_sql_2(): data['id'] = id sql = data["resp"]["sql"] logger.info("generate sql is : "+ sql) - update_conversation(cvs_id, id, sql) + update_conversation(id, sql) save_save_question_async(id, user_id, question, sql) data["type"]="success" return jsonify(data) @@ -241,7 +241,8 @@ def run_sql_2(): logger.info("Start to run sql in main") try: id = request.args.get("id") - sql = get_sql_by_id(id) + qa = get_qa_by_id(id) + sql = qa["sql"] logger.info(f"sql is {sql}") if not vn.run_sql_is_set: return jsonify( @@ -297,18 +298,28 @@ def query_present_question(): @app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"]) def gen_graph_question(): try: - config = {"configurable": {"thread_id": '1233'}} + user_id = request.args.get("user_id") + cvs_id = request.args.get("cvs_id") + config = {"configurable": {"thread_id": cvs_id}} question = flask.request.args.get("question") + question_context = get_latest_question(cvs_id, user_id,limit_count=2) + history = [] + for q in question_context: + history.append({"role":"user",'content':q}) + logger.info(f"history is {history}") initial_state: SqlAgentState = { + "user_id": user_id, "user_question": question, - "history": [{"role":"user",'content':'谭杰明昨日打卡记录查询'}], + "history": history, "sql_retry_count": 0, "chart_retry_count": 0 } - result = sql_chart_agent.invoke(initial_state, config=config) - id = str(uuid.uuid4()) - cache.set(id=id, field="question", value=result.get('rewritten_user_question',question)) + result = sql_chart_agent.invoke(initial_state, config=config) + new_question = result.get('rewritten_user_question',question) + id = str(uuid.uuid4()) + save_conversation(id, user_id, cvs_id, new_question) + # cache.set(id=id, field="question", value=result.get('rewritten_user_question',question)) data = { 'id': id, 'sql': result.get("gen_sql_result", {}), @@ -316,7 +327,9 @@ def gen_graph_question(): 'gen_sql_error': result.get("gen_sql_error", None), 'gen_chart_error': result.get("gen_chart_error", None), } - cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', '')) + # cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', '')) + sql = data.get('sql', {}).get('sql', '') + update_conversation(id, sql) return jsonify(data) except Exception as e: traceback.print_exc() @@ -325,9 +338,13 @@ def gen_graph_question(): @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_3", methods=["GET"]) -@session_save -@app.requires_cache(["sql",'question']) -def run_sql_3(id: str, sql: str, question: str): +# @session_save +# @app.requires_cache(["sql",'question']) +def run_sql_3(): + id = request.args.get("id") + qa = get_qa_by_id(id) + sql = qa["sql"] + question = qa["question"] logger.info("Start to run sql in main") try: user_id = request.args.get("user_id") @@ -347,6 +364,7 @@ def run_sql_3(id: str, sql: str, question: str): } config = {"configurable": {"thread_id": 'dsds'}} rr = result_report_agent.invoke(initial_state,config) + logger.info(f"rr.data is {rr.get('data', {})}") return jsonify( { 'data': rr.get('data', {}), diff --git a/service/conversation_service.py b/service/conversation_service.py index 6c3dcdd..b4ee9fc 100644 --- a/service/conversation_service.py +++ b/service/conversation_service.py @@ -41,14 +41,14 @@ def get_conversation(cvs_id: str): # } -def update_conversation(cvs_id: str, id: str, sql=None, meta=None): +def update_conversation(id: str, sql=None, meta=None): """更新sql到对应question""" session = SqliteSqlalchemy().session try: if sql: - session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.sql: sql}) + session.query(Conversation).filter(Conversation.id == id).update({Conversation.sql: sql}) if meta: - session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.meta: meta}) + session.query(Conversation).filter(Conversation.id == id).update({Conversation.meta: meta}) session.commit() except Exception as e: session.rollback() @@ -72,12 +72,12 @@ def get_latest_question(cvs_id, user_id, limit_count): session.close() -def get_sql_by_id(id: str): +def get_qa_by_id(id: str): session = SqliteSqlalchemy().session try: result = session.query(Conversation).filter_by(id=id).first() if result: - return result.sql + return {"question":result.question, "sql":result.sql} return None except Exception as e: logger.error(f"get_sql_by_id error {e}") diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py index 7b5a218..669a8de 100644 --- a/service/cus_vanna_srevice.py +++ b/service/cus_vanna_srevice.py @@ -247,7 +247,7 @@ class OpenAICompatibleLLM(VannaBase): new_question = self.generate_rewritten_question(questions,**kwargs) logger.info(f"new_question is {new_question}") question = new_question if new_question else question - update_conversation(cvs_id, id, meta=question) + update_conversation(id, meta=question) # if user_id and cache: # history = cache.get(id=user_id, field="data") diff --git a/template.yaml b/template.yaml index d02cc3f..45a8e9c 100644 --- a/template.yaml +++ b/template.yaml @@ -624,4 +624,46 @@ template: # 最终输出 请根据以上规则,直接输出最终生成的回复话术,不要包含任何其他解释或格式。 + result_feedback: + system: | + # 角色 + 你是一位经验丰富的数据分析师和SQL专家。你的核心任务是对SQL查询和其执行结果进行批判性审查。 + # 你的任务 + 我(用户)向你提供了一系列信息,请你进行分析和反思。 + # 输入信息 + [原始问题]: {question} + [生成的SQL]: {sql}。 + [执行结果]: {sql_result}。 + [当前时间]: {current_time} + # 核心规则 + 请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤: + 1. **核对问题理解**: + 回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。 + 我生成的SQL是否准确捕捉了问题的所有关键要素?有没有遗漏或误解? + 2. **通用业务定义** + 上周:指完整的**上周一到上周日** + 3. **审查SQL逻辑**: + 逐行分析“生成的SQL”,检查其语法结构是否正确。 + 关键子句(如 `WHERE`, `GROUP BY`, `HAVING`, `JOIN`)是否准确地反映了查询意图? + 聚合函数(如 `SUM`, `COUNT`, `AVG`)的使用是否恰当?分组维度是否正确? + 4. **评估结果合理性**: + 观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)? + 结果的列名和内容是否与问题期望的输出一致? + 如果结果为空或数据量异常,推测可能的原因。 + 5. **最终判断和建议**: + **结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。 + **原因** : 简练地解释你做出此判断的核心理由。 + **改进建议** (如果结果不正确): + * 具体指出SQL中可能存在的逻辑错误。 + * 给出一个或多个修改后的SQL版本。 + * 解释你为什么这样修改。 + # 输出格式 + 请严格按照以下JSON格式输出你的分析结果: + ```json + {{ + "conclusion": "结果正确 OR 结果可能不正确 OR 无法完全确定", + "reasoning": "对问题理解、SQL逻辑和结果合理性的综合分析说明。", + "is_result_correct": true OR false, + "suggested_sql": "如果结论为不正确,请在此处提供修改后的SQL。如果正确,则为null。" + }} Resources: \ No newline at end of file