import json import logging from datetime import datetime from typing import TypedDict, Optional import orjson from anyio import current_time from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import StateGraph, END, START from graph_chat.gen_sql_chart_agent import gen_history_llm from util.utils import extract_nested_json from service.conversation_service import update_conversation logger = logging.getLogger(__name__) from template.template import get_base_template class DateReportAgentState(TypedDict): id: str user_id: str sql: str data: Optional[dict] summary: str run_sql_error: str retry_count: int question:str retry_sql:str sql_correct:bool def _run_sql(state: DateReportAgentState) -> dict: from main_service import vn sql = state['sql'] retry_sql = state.get('retry_sql', '') sql_correct = state.get('sql_correct', False) logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点") try: if retry_sql and not sql_correct: sql = retry_sql update_conversation(id=state.get('id'), sql=sql) df = vn.run_sql_2(sql) result = df.to_dict(orient='records') logger.info(f"run sql result {result}") retry = state.get("retry_count", 0) + 1 logger.info(f"old_sql:{sql} retry sql {retry_sql}") return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql} except Exception as e: retry = state.get("retry_count", 0) + 1 logger.error(f"Error running sql: {sql}, error: {e}") return {'retry_count': retry, 'run_sql_error': '运行sql语句失败,请检查语法或联系管理员。'} def _feedback_qa(state: DateReportAgentState) -> dict: logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点") try: user_question = state['question'] sql = state['sql'] sql_result = state['data'] template = get_base_template() feedback_temp = template['template']['result_feedback'] logger.info(f"feedback_temp is {feedback_temp}") sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result, current_time=datetime.now()) logger.info(f"system_temp is {sys_promot}") result = gen_history_llm.invoke(sys_promot).text() logger.info(f"feedback result: {result}") result = extract_nested_json(result) result = orjson.loads(result) logger.info(f"提取json成功") logger.info(f"result is {result} type:{type(result)}") print("result is {0}".format(result.keys())) if not result["is_result_correct"] and (result["suggested_sql"]): logger.info("开始替换") logger.info(f"suggested_sql is ".format(result["suggested_sql"])) # state["sql"] = result["suggested_sql"] # state["retry_sql"] = True logger.info(f"current state: {state}") return {"retry_sql": result["suggested_sql"]} if result["is_result_correct"]: return {"sql_correct": True} except Exception as e: logger.error(f"Error feedback: {e}") def run_sql_hande(state: DateReportAgentState) -> str: # sql_error = state.get('run_sql_error', '') data = state.get('data', {}) sql_correct = state.get('sql_correct', False) sql_retry_count = state.get('retry_count', 0) logger.info("sql_retry_count is {0}".format(sql_retry_count)) if sql_retry_count < 2: if sql_correct: return '_gen_report' else: return '_feedback_qa' else: if data and len(data) > 0: return '_gen_report' else: END # retry_sql = state.get('retry_sql', False) # if sql_error and len(sql_error) > 0: # if sql_retry_count < 2: # return '_feedback_qa' # # return '_run_sql' # else: # return END # # if data and len(data) > 0: # return '_gen_report' # else: # if sql_retry_count < 2: # return '_feedback_qa' # else: # return END def _gen_report(state: DateReportAgentState) -> dict: logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点") sql = state['sql'] question = state['question'] run_sql_error = state.get('run_sql_error', '') data = state.get('data', {}) template = get_base_template() result_summary = template['template']['result_summary']['system'] pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data) txt = gen_history_llm.invoke(pro).text() return {'summary': txt} workflowx = StateGraph(DateReportAgentState) workflowx.add_node("_run_sql", _run_sql) workflowx.add_node("_feedback_qa", _feedback_qa) workflowx.add_node("_gen_report", _gen_report) workflowx.add_edge(START, "_run_sql") workflowx.add_edge("_feedback_qa", "_run_sql") workflowx.add_edge("_gen_report", END) workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa']) memory = MemorySaver() result_report_agent = workflowx.compile(checkpointer=memory) png_data=result_report_agent .get_graph().draw_mermaid_png() with open("D://graph2.png", "wb") as f: f.write(png_data)