import json import logging from datetime import datetime from typing import TypedDict, Optional import orjson from anyio import current_time from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import StateGraph, END, START from graph_chat.gen_sql_chart_agent import gen_history_llm from util.utils import extract_nested_json from service.conversation_service import update_conversation logger = logging.getLogger(__name__) from template.template import get_base_template from decouple import config NEED_MODEL_FEEDBACK_QA=config("NEED_MODEL_FEEDBACK_QA", default=False, cast=bool) class DateReportAgentState(TypedDict): id: str user_id: str sql: str data: Optional[dict] summary: str run_sql_error: str retry_count: int question:str retry_sql:str sql_correct:bool def _run_sql(state: DateReportAgentState) -> dict: from main_service import vn sql = state['sql'] retry_sql = state.get('retry_sql', '') sql_correct = state.get('sql_correct', False) logger.info(f'in _run_sql state={state}') logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _run_sql 节点 ---------------") try: if retry_sql and not sql_correct: sql = retry_sql update_conversation(id=state.get('id'), sql=sql) df = vn.run_sql_2(sql) result = df.to_dict(orient='records') logger.info(f"run sql result {result}") retry = state.get("retry_count", 0) + 1 logger.info(f"old_sql:{sql}, retry sql {retry_sql}") dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql,'run_sql_error':''} #不需要模型反馈,运行成功后则成功 if not NEED_MODEL_FEEDBACK_QA: dd['sql_correct']=True return dd except Exception as e: retry = state.get("retry_count", 0) + 1 logger.error(f"Error running sql: {sql}, error: {e}") return {'retry_count': retry, 'run_sql_error': '运行sql语句失败,请检查语法或联系管理员。'} def _feedback_qa(state: DateReportAgentState) -> dict: from main_service import vn logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _feedback_qa 节点 ---------------") try: user_question = state['question'] sql = state['sql'] sql_result = state['data'] template = get_base_template() feedback_temp = template['template']['result_feedback'] logger.info(f"feedback_temp is {feedback_temp}") ddl_list = vn.get_related_ddl(user_question) qa_list = vn.get_similar_question_sql(user_question) sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result, current_time=datetime.now(),ddl_list=ddl_list,qa_list=qa_list) logger.info(f"system_temp is {sys_promot}") result = gen_history_llm.invoke(sys_promot).text() logger.info(f"feedback result: {result}") result = extract_nested_json(result) result = orjson.loads(result) logger.info(f"提取json成功") logger.info(f"result is {result} type:{type(result)}") if not result["is_result_correct"] and (result["suggested_sql"]): logger.info("开始替换") logger.info(f"suggested_sql is ".format(result["suggested_sql"])) logger.info(f"current state: {state}") return {"retry_sql": result["suggested_sql"]} if result["is_result_correct"]: return {"sql_correct": True,"run_sql_error":""} except Exception as e: logger.error(f"Error feedback: {e}") def handle_no_feedback(state: DateReportAgentState) -> str: logger.info(f"user:{state.get('user_id', '1')} ---------------进入 handle_no_feedback 节点 ---------------") sql_error = state.get('run_sql_error', '') data = state.get('data', {}) sql_retry_count = state.get('retry_count', 0) if sql_error and len(sql_error) > 0: if sql_retry_count < 3: return '_run_sql' else: return END if data and len(data) > 0: return '_gen_report' else: if sql_retry_count < 2: return '_run_sql' else: return END def handle_with_feedback(state: DateReportAgentState) -> str: logger.info(f"user:{state.get('user_id', '1')} ---------------进入 handle_with_feedback 节点 ---------------") sql_error = state.get('run_sql_error', '') # data = state.get('data', {}) # sql_correct = state.get('sql_correct', False) sql_retry_count = state.get('retry_count', 0) logger.info(f"handle with feedback sql_retry_count is {sql_retry_count} sql error is {sql_error}") if sql_retry_count < 3: if sql_error and len(sql_error) > 0: return '_feedback_qa' else: return '_gen_report' # if sql_correct: # return '_gen_report' # else: # if sql_error and len(sql_error) > 0: # return '_feedback_qa' # else: # return END else: if sql_error and len(sql_error)>0: return END else: return '_gen_report' def run_sql_hande(state: DateReportAgentState) -> str: logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _run_sql_hande 节点 ---------------") logger.info(f'in run_sql_handle state={state}') if NEED_MODEL_FEEDBACK_QA: return handle_with_feedback(state) else: return handle_no_feedback(state) def _gen_report(state: DateReportAgentState) -> dict: logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_report 节点 ---------------") sql = state['sql'] question = state['question'] run_sql_error = state.get('run_sql_error', '') data = state.get('data', {}) template = get_base_template() result_summary = template['template']['result_summary']['system'] pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data) txt = gen_history_llm.invoke(pro).text() return {'summary': txt} workflowx = StateGraph(DateReportAgentState) workflowx.add_node("_run_sql", _run_sql) if NEED_MODEL_FEEDBACK_QA: workflowx.add_node("_feedback_qa", _feedback_qa) workflowx.add_node("_gen_report", _gen_report) workflowx.add_edge(START, "_run_sql") workflowx.add_edge("_gen_report", END) if NEED_MODEL_FEEDBACK_QA: workflowx.add_edge("_feedback_qa", "_run_sql") workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa']) else: workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report']) memory = MemorySaver() result_report_agent = workflowx.compile(checkpointer=memory) # png_data=result_report_agent .get_graph().draw_mermaid_png() # with open("D://graph2.png", "wb") as f: # f.write(png_data)