Files
sqlbot_agent/graph_chat/gen_sql_chart_agent.py
2025-11-06 16:23:51 +08:00

182 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain_openai import ChatOpenAI
from typing import TypedDict, Annotated, List, Optional, Union
import pandas as pd
from langgraph.types import Command, interrupt
from langgraph.graph import StateGraph, END, START
from langgraph.checkpoint.memory import MemorySaver
import orjson, logging
logger = logging.getLogger(__name__)
from decouple import config
from template.template import get_base_template
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
from datetime import datetime
from util import train_ddl
gen_history_llm = ChatOpenAI(
model=config('CHAT_MODEL_NAME', default=''),
base_url=config('CHAT_MODEL_BASE_URL', default=''),
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
max_tokens=8000,
max_retries=2,
api_key=config('CHAT_MODEL_API_KEY', default='empty'),
)
gen_sql_llm = ChatOpenAI(
model=config('SQL_MODEL_NAME', default=''),
base_url=config('SQL_MODEL_BASE_URL', default=''),
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
max_tokens=8000,
max_retries=2,
api_key=config('SQL_MODEL_API_KEY', default='empty'),
)
'''
用于生成sql,生成图表的agent上下文
'''
class SqlAgentState(TypedDict):
user_question: str
rewritten_user_question: Optional[str]
session_id: str
user_id: str
history: Optional[list]
gen_sql_result: Optional[dict]
gen_chart_result: Optional[dict]
gen_sql_error: Optional[str]
gen_chart_error: Optional[str]
final_error_msg: Optional[str]
sql_retry_count: int
chart_retry_count: int
'''
基于上下文,历史消息,重写用户问题
'''
def _rewrite_user_question(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
user_question = state['user_question']
template = get_base_template()
rewrite_temp = template['template']['rewrite_question']
history = state.get('history', [])
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
new_question = gen_history_llm.invoke(sys_promot).text
logger.info("new_question:{0}".format(new_question))
result = extract_nested_json(new_question)
logger.info(f"result:{result}")
result = orjson.loads(result)
return {"rewritten_user_question": result.get('question', user_question)}
def _gen_sql(state: SqlAgentState) -> dict:
from main_service import vn
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
question = state.get('rewritten_user_question', state['user_question'])
service = vn
question_sql_list = service.get_similar_question_sql(question)
if question_sql_list and len(question_sql_list) > 3:
question_sql_list = question_sql_list[:3]
ddl_list = service.get_related_ddl(question)
template = get_base_template()
sql_temp = template['template']['sql']
history = []
try:
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
history=history, retrieved_examples_data=question_sql_list,
data_training=question_sql_list, )
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
rr = gen_sql_llm.invoke(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text
result = extract_nested_json(rr)
logger.info(f"gensql result: {result}")
result = orjson.loads(result)
retry = state.get("sql_retry_count", 0) + 1
return {'gen_sql_result': result,"sql_retry_count": retry}
except Exception as e:
import traceback
traceback.print_exc()
retry = state.get("sql_retry_count", 0) + 1
logger.info("cus_vanna_srevice failed-------------------: ")
return {'gen_sql_error': str(e), "sql_retry_count": retry}
def _gen_chart(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
template = get_base_template()
gen_sql_result=state.get('gen_sql_result', {})
sql = gen_sql_result.get('sql', '')
char_type = gen_sql_result.get('char_type', '')
question = state.get('rewritten_user_question', state['user_question'])
char_temp = template['template']['chart']
try:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
rr=gen_history_llm.invoke([{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text
retry = state.get("chart_retry_count", 0) + 1
return {"chart_retry_count": retry,'gen_chart_result': orjson.loads(extract_nested_json(rr))}
except Exception as e:
import traceback
traceback.print_exc()
retry = state.get("chart_retry_count", 0) + 1
return {'gen_chart_error': str(e), "chart_retry_count": retry}
# 如果生成sql失败则重试2次如果仍然失败则返回错误信息
def gen_sql_handler(state: SqlAgentState) -> str:
sql_error = state.get('gen_sql_error', '')
sql_result = state.get('gen_sql_result', {})
sql_retry_count = state.get('sql_retry_count', 0)
if sql_error and len(sql_error) > 0:
if sql_retry_count < 2:
return '_gen_sql'
else:
return END
sql = sql_result.get('sql', '')
if len(sql) > 0:
return '_gen_chart'
else:
if sql_retry_count < 2:
return '_gen_sql'
else:
return END
def gen_chart_handler(state: SqlAgentState) -> str:
chart_error = state.get('gen_chart_error', '')
chart_result = state.get('gen_chart_result', {})
sql_retry_count = state.get('chart_retry_count', 0)
if chart_error and len(chart_error) > 0:
if sql_retry_count < 2:
return '_gen_chart'
else:
return END
title = chart_result.get('title', '')
if len(title) > 0:
return END
else:
if sql_retry_count < 2:
return '_gen_chart'
else:
return END
workflow = StateGraph(SqlAgentState)
workflow.add_node("_rewrite_user_question", _rewrite_user_question)
workflow.add_node("_gen_sql", _gen_sql)
workflow.add_node("_gen_chart",_gen_chart)
workflow.add_edge(START, "_rewrite_user_question")
workflow.add_edge("_rewrite_user_question", "_gen_sql")
workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_chart',END])
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, [ '_gen_chart',END])
memory = MemorySaver()
sql_chart_agent = workflow.compile(checkpointer=memory)
png_data=sql_chart_agent .get_graph().draw_mermaid_png()
# with open("D://graph.png", "wb") as f:
# f.write(png_data)