148 lines
5.6 KiB
Python
148 lines
5.6 KiB
Python
import json
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import TypedDict, Optional
|
||
|
||
import orjson
|
||
from anyio import current_time
|
||
from langgraph.checkpoint.memory import MemorySaver
|
||
from langgraph.graph import StateGraph, END, START
|
||
|
||
from graph_chat.gen_sql_chart_agent import gen_history_llm
|
||
from util.utils import extract_nested_json
|
||
|
||
from service.conversation_service import update_conversation
|
||
logger = logging.getLogger(__name__)
|
||
from template.template import get_base_template
|
||
from decouple import config
|
||
NEED_MODEL_FEEDBACK_QA=config("NEED_MODEL_FEEDBACK_QA", default=False, cast=bool)
|
||
|
||
class DateReportAgentState(TypedDict):
|
||
id: str
|
||
user_id: str
|
||
sql: str
|
||
data: Optional[dict]
|
||
summary: str
|
||
run_sql_error: str
|
||
retry_count: int
|
||
question:str
|
||
retry_sql:str
|
||
sql_correct:bool
|
||
|
||
|
||
|
||
def _run_sql(state: DateReportAgentState) -> dict:
|
||
from main_service import vn
|
||
sql = state['sql']
|
||
retry_sql = state.get('retry_sql', '')
|
||
sql_correct = state.get('sql_correct', False)
|
||
logger.info(f'in _run_sql state={state}')
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
|
||
try:
|
||
if retry_sql and not sql_correct:
|
||
sql = retry_sql
|
||
update_conversation(id=state.get('id'), sql=sql)
|
||
df = vn.run_sql_2(sql)
|
||
result = df.to_dict(orient='records')
|
||
logger.info(f"run sql result {result}")
|
||
retry = state.get("retry_count", 0) + 1
|
||
logger.info(f"old_sql:{sql}, retry sql {retry_sql}")
|
||
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
|
||
#不需要模型反馈,运行成功后则成功
|
||
if not NEED_MODEL_FEEDBACK_QA:
|
||
dd['sql_correct']=True
|
||
return dd
|
||
except Exception as e:
|
||
retry = state.get("retry_count", 0) + 1
|
||
logger.error(f"Error running sql: {sql}, error: {e}")
|
||
return {'retry_count': retry, 'run_sql_error': '运行sql语句失败,请检查语法或联系管理员。'}
|
||
|
||
|
||
def _feedback_qa(state: DateReportAgentState) -> dict:
|
||
from main_service import vn
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
|
||
try:
|
||
user_question = state['question']
|
||
sql = state['sql']
|
||
sql_result = state['data']
|
||
template = get_base_template()
|
||
feedback_temp = template['template']['result_feedback']
|
||
logger.info(f"feedback_temp is {feedback_temp}")
|
||
ddl_list = vn.get_related_ddl(user_question)
|
||
qa_list = vn.get_similar_question_sql(user_question)
|
||
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
|
||
current_time=datetime.now(),ddl_list=ddl_list,qa_list=qa_list)
|
||
logger.info(f"system_temp is {sys_promot}")
|
||
result = gen_history_llm.invoke(sys_promot).text()
|
||
logger.info(f"feedback result: {result}")
|
||
result = extract_nested_json(result)
|
||
result = orjson.loads(result)
|
||
logger.info(f"提取json成功")
|
||
logger.info(f"result is {result} type:{type(result)}")
|
||
if not result["is_result_correct"] and (result["suggested_sql"]):
|
||
logger.info("开始替换")
|
||
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
|
||
logger.info(f"current state: {state}")
|
||
return {"retry_sql": result["suggested_sql"]}
|
||
if result["is_result_correct"]:
|
||
return {"sql_correct": True}
|
||
except Exception as e:
|
||
logger.error(f"Error feedback: {e}")
|
||
|
||
|
||
|
||
def run_sql_hande(state: DateReportAgentState) -> str:
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
|
||
logger.info(f'in run_sql_handle state={state}')
|
||
# sql_error = state.get('run_sql_error', '')
|
||
data = state.get('data', {})
|
||
sql_correct = state.get('sql_correct', False)
|
||
sql_retry_count = state.get('retry_count', 0)
|
||
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
|
||
|
||
if sql_retry_count < 3:
|
||
if sql_correct:
|
||
return '_gen_report'
|
||
else:
|
||
if NEED_MODEL_FEEDBACK_QA:
|
||
return '_feedback_qa'
|
||
else:
|
||
return '_run_sql'
|
||
else:
|
||
if data and len(data) > 0:
|
||
return '_gen_report'
|
||
else:
|
||
END
|
||
|
||
|
||
|
||
def _gen_report(state: DateReportAgentState) -> dict:
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
|
||
sql = state['sql']
|
||
question = state['question']
|
||
run_sql_error = state.get('run_sql_error', '')
|
||
data = state.get('data', {})
|
||
template = get_base_template()
|
||
result_summary = template['template']['result_summary']['system']
|
||
pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data)
|
||
txt = gen_history_llm.invoke(pro).text()
|
||
return {'summary': txt}
|
||
|
||
|
||
workflowx = StateGraph(DateReportAgentState)
|
||
workflowx.add_node("_run_sql", _run_sql)
|
||
if NEED_MODEL_FEEDBACK_QA:
|
||
workflowx.add_node("_feedback_qa", _feedback_qa)
|
||
workflowx.add_node("_gen_report", _gen_report)
|
||
workflowx.add_edge(START, "_run_sql")
|
||
workflowx.add_edge("_gen_report", END)
|
||
if NEED_MODEL_FEEDBACK_QA:
|
||
workflowx.add_edge("_feedback_qa", "_run_sql")
|
||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
|
||
else:
|
||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report'])
|
||
memory = MemorySaver()
|
||
result_report_agent = workflowx.compile(checkpointer=memory)
|
||
# png_data=result_report_agent .get_graph().draw_mermaid_png()
|
||
# with open("D://graph2.png", "wb") as f:
|
||
# f.write(png_data) |