Files
sqlbot_agent/graph_chat/gen_data_report_agent.py
yujj128 703e0e2591 Merge branch 'dev_graph' of http://106.13.42.156:33077/lei_y601/sqlbot_agent into dev_graph
# Conflicts:
#	graph_chat/gen_data_report_agent.py
#	main_service.py
2025-11-07 18:12:44 +08:00

148 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
from datetime import datetime
from typing import TypedDict, Optional
import orjson
from anyio import current_time
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, END, START
from graph_chat.gen_sql_chart_agent import gen_history_llm
from util.utils import extract_nested_json
from service.conversation_service import update_conversation
logger = logging.getLogger(__name__)
from template.template import get_base_template
from decouple import config
NEED_MODEL_FEEDBACK_QA=config("NEED_MODEL_FEEDBACK_QA", default=False, cast=bool)
class DateReportAgentState(TypedDict):
id: str
user_id: str
sql: str
data: Optional[dict]
summary: str
run_sql_error: str
retry_count: int
question:str
retry_sql:str
sql_correct:bool
def _run_sql(state: DateReportAgentState) -> dict:
from main_service import vn
sql = state['sql']
retry_sql = state.get('retry_sql', '')
sql_correct = state.get('sql_correct', False)
logger.info(f'in _run_sql state={state}')
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
try:
if retry_sql and not sql_correct:
sql = retry_sql
update_conversation(id=state.get('id'), sql=sql)
df = vn.run_sql_2(sql)
result = df.to_dict(orient='records')
logger.info(f"run sql result {result}")
retry = state.get("retry_count", 0) + 1
logger.info(f"old_sql:{sql}, retry sql {retry_sql}")
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
#不需要模型反馈,运行成功后则成功
if not NEED_MODEL_FEEDBACK_QA:
dd['sql_correct']=True
return dd
except Exception as e:
retry = state.get("retry_count", 0) + 1
logger.error(f"Error running sql: {sql}, error: {e}")
return {'retry_count': retry, 'run_sql_error': '运行sql语句失败请检查语法或联系管理员。'}
def _feedback_qa(state: DateReportAgentState) -> dict:
from main_service import vn
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
try:
user_question = state['question']
sql = state['sql']
sql_result = state['data']
template = get_base_template()
feedback_temp = template['template']['result_feedback']
logger.info(f"feedback_temp is {feedback_temp}")
ddl_list = vn.get_related_ddl(user_question)
qa_list = vn.get_similar_question_sql(user_question)
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
current_time=datetime.now(),ddl_list=ddl_list,qa_list=qa_list)
logger.info(f"system_temp is {sys_promot}")
result = gen_history_llm.invoke(sys_promot).text()
logger.info(f"feedback result: {result}")
result = extract_nested_json(result)
result = orjson.loads(result)
logger.info(f"提取json成功")
logger.info(f"result is {result} type:{type(result)}")
if not result["is_result_correct"] and (result["suggested_sql"]):
logger.info("开始替换")
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
logger.info(f"current state: {state}")
return {"retry_sql": result["suggested_sql"]}
if result["is_result_correct"]:
return {"sql_correct": True}
except Exception as e:
logger.error(f"Error feedback: {e}")
def run_sql_hande(state: DateReportAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
logger.info(f'in run_sql_handle state={state}')
# sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
sql_correct = state.get('sql_correct', False)
sql_retry_count = state.get('retry_count', 0)
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
if sql_retry_count < 3:
if sql_correct:
return '_gen_report'
else:
if NEED_MODEL_FEEDBACK_QA:
return '_feedback_qa'
else:
return '_run_sql'
else:
if data and len(data) > 0:
return '_gen_report'
else:
END
def _gen_report(state: DateReportAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
sql = state['sql']
question = state['question']
run_sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
template = get_base_template()
result_summary = template['template']['result_summary']['system']
pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data)
txt = gen_history_llm.invoke(pro).text()
return {'summary': txt}
workflowx = StateGraph(DateReportAgentState)
workflowx.add_node("_run_sql", _run_sql)
if NEED_MODEL_FEEDBACK_QA:
workflowx.add_node("_feedback_qa", _feedback_qa)
workflowx.add_node("_gen_report", _gen_report)
workflowx.add_edge(START, "_run_sql")
workflowx.add_edge("_gen_report", END)
if NEED_MODEL_FEEDBACK_QA:
workflowx.add_edge("_feedback_qa", "_run_sql")
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
else:
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report'])
memory = MemorySaver()
result_report_agent = workflowx.compile(checkpointer=memory)
# png_data=result_report_agent .get_graph().draw_mermaid_png()
# with open("D://graph2.png", "wb") as f:
# f.write(png_data)