148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
import json
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import TypedDict, Optional
|
||
|
||
import orjson
|
||
from anyio import current_time
|
||
from langgraph.checkpoint.memory import MemorySaver
|
||
from langgraph.graph import StateGraph, END, START
|
||
|
||
from graph_chat.gen_sql_chart_agent import gen_history_llm
|
||
from util.utils import extract_nested_json
|
||
|
||
from service.conversation_service import update_conversation
|
||
logger = logging.getLogger(__name__)
|
||
from template.template import get_base_template
|
||
|
||
|
||
class DateReportAgentState(TypedDict):
|
||
id: str
|
||
user_id: str
|
||
sql: str
|
||
data: Optional[dict]
|
||
summary: str
|
||
run_sql_error: str
|
||
retry_count: int
|
||
question:str
|
||
retry_sql:str
|
||
sql_correct:bool
|
||
|
||
|
||
|
||
def _run_sql(state: DateReportAgentState) -> dict:
|
||
from main_service import vn
|
||
sql = state['sql']
|
||
retry_sql = state.get('retry_sql', '')
|
||
sql_correct = state.get('sql_correct', False)
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
|
||
try:
|
||
if retry_sql and not sql_correct:
|
||
sql = retry_sql
|
||
update_conversation(id=state.get('id'), sql=sql)
|
||
df = vn.run_sql_2(sql)
|
||
result = df.to_dict(orient='records')
|
||
logger.info(f"run sql result {result}")
|
||
retry = state.get("retry_count", 0) + 1
|
||
logger.info(f"old_sql:{sql} retry sql {retry_sql}")
|
||
return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
|
||
except Exception as e:
|
||
retry = state.get("retry_count", 0) + 1
|
||
logger.error(f"Error running sql: {sql}, error: {e}")
|
||
return {'retry_count': retry, 'run_sql_error': '运行sql语句失败,请检查语法或联系管理员。'}
|
||
|
||
|
||
def _feedback_qa(state: DateReportAgentState) -> dict:
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
|
||
try:
|
||
user_question = state['question']
|
||
sql = state['sql']
|
||
sql_result = state['data']
|
||
template = get_base_template()
|
||
feedback_temp = template['template']['result_feedback']
|
||
logger.info(f"feedback_temp is {feedback_temp}")
|
||
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
|
||
current_time=datetime.now())
|
||
logger.info(f"system_temp is {sys_promot}")
|
||
result = gen_history_llm.invoke(sys_promot).text()
|
||
logger.info(f"feedback result: {result}")
|
||
result = extract_nested_json(result)
|
||
result = orjson.loads(result)
|
||
logger.info(f"提取json成功")
|
||
logger.info(f"result is {result} type:{type(result)}")
|
||
print("result is {0}".format(result.keys()))
|
||
if not result["is_result_correct"] and (result["suggested_sql"]):
|
||
logger.info("开始替换")
|
||
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
|
||
# state["sql"] = result["suggested_sql"]
|
||
# state["retry_sql"] = True
|
||
logger.info(f"current state: {state}")
|
||
return {"retry_sql": result["suggested_sql"]}
|
||
if result["is_result_correct"]:
|
||
return {"sql_correct": True}
|
||
except Exception as e:
|
||
logger.error(f"Error feedback: {e}")
|
||
|
||
|
||
|
||
def run_sql_hande(state: DateReportAgentState) -> str:
|
||
# sql_error = state.get('run_sql_error', '')
|
||
data = state.get('data', {})
|
||
sql_correct = state.get('sql_correct', False)
|
||
sql_retry_count = state.get('retry_count', 0)
|
||
logger.info("sql_retry_count is {0}".format(sql_retry_count))
|
||
|
||
if sql_retry_count < 2:
|
||
if sql_correct:
|
||
return '_gen_report'
|
||
else:
|
||
return '_feedback_qa'
|
||
else:
|
||
if data and len(data) > 0:
|
||
return '_gen_report'
|
||
else:
|
||
END
|
||
# retry_sql = state.get('retry_sql', False)
|
||
|
||
# if sql_error and len(sql_error) > 0:
|
||
# if sql_retry_count < 2:
|
||
# return '_feedback_qa'
|
||
# # return '_run_sql'
|
||
# else:
|
||
# return END
|
||
#
|
||
# if data and len(data) > 0:
|
||
# return '_gen_report'
|
||
# else:
|
||
# if sql_retry_count < 2:
|
||
# return '_feedback_qa'
|
||
# else:
|
||
# return END
|
||
|
||
|
||
def _gen_report(state: DateReportAgentState) -> dict:
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
|
||
sql = state['sql']
|
||
question = state['question']
|
||
run_sql_error = state.get('run_sql_error', '')
|
||
data = state.get('data', {})
|
||
template = get_base_template()
|
||
result_summary = template['template']['result_summary']['system']
|
||
pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data)
|
||
txt = gen_history_llm.invoke(pro).text()
|
||
return {'summary': txt}
|
||
|
||
|
||
workflowx = StateGraph(DateReportAgentState)
|
||
workflowx.add_node("_run_sql", _run_sql)
|
||
workflowx.add_node("_feedback_qa", _feedback_qa)
|
||
workflowx.add_node("_gen_report", _gen_report)
|
||
workflowx.add_edge(START, "_run_sql")
|
||
workflowx.add_edge("_feedback_qa", "_run_sql")
|
||
workflowx.add_edge("_gen_report", END)
|
||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
|
||
memory = MemorySaver()
|
||
result_report_agent = workflowx.compile(checkpointer=memory)
|
||
png_data=result_report_agent .get_graph().draw_mermaid_png()
|
||
with open("D://graph2.png", "wb") as f:
|
||
f.write(png_data) |