184 lines
7.1 KiB
Python
184 lines
7.1 KiB
Python
from langchain_openai import ChatOpenAI
|
||
from typing import TypedDict, Annotated, List, Optional, Union
|
||
import pandas as pd
|
||
from langgraph.types import Command, interrupt
|
||
from langgraph.graph import StateGraph, END, START
|
||
from langgraph.checkpoint.memory import MemorySaver
|
||
import orjson, logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
from decouple import config
|
||
from template.template import get_base_template
|
||
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
|
||
from datetime import datetime
|
||
from util import train_ddl
|
||
|
||
gen_history_llm = ChatOpenAI(
|
||
model=config('CHAT_MODEL_NAME', default=''),
|
||
base_url=config('CHAT_MODEL_BASE_URL', default=''),
|
||
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
|
||
max_tokens=8000,
|
||
max_retries=2,
|
||
api_key=config('CHAT_MODEL_API_KEY', default='empty'),
|
||
)
|
||
gen_sql_llm = ChatOpenAI(
|
||
model=config('SQL_MODEL_NAME', default=''),
|
||
base_url=config('SQL_MODEL_BASE_URL', default=''),
|
||
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
|
||
max_tokens=8000,
|
||
max_retries=2,
|
||
api_key=config('SQL_MODEL_API_KEY', default='empty'),
|
||
)
|
||
|
||
'''
|
||
用于生成sql,生成图表的agent,上下文
|
||
'''
|
||
|
||
|
||
class SqlAgentState(TypedDict):
|
||
user_question: str
|
||
rewritten_user_question: Optional[str]
|
||
session_id: str
|
||
user_id: str
|
||
history: Optional[list]
|
||
gen_sql_result: Optional[dict]
|
||
gen_chart_result: Optional[dict]
|
||
|
||
gen_sql_error: Optional[str]
|
||
gen_chart_error: Optional[str]
|
||
final_error_msg: Optional[str]
|
||
|
||
sql_retry_count: int
|
||
chart_retry_count: int
|
||
|
||
|
||
'''
|
||
基于上下文,历史消息,重写用户问题
|
||
'''
|
||
|
||
|
||
def _rewrite_user_question(state: SqlAgentState) -> dict:
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
|
||
user_question = state['user_question']
|
||
template = get_base_template()
|
||
rewrite_temp = template['template']['rewrite_question']
|
||
history = state.get('history', [])
|
||
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
|
||
new_question = gen_history_llm.invoke(sys_promot).text()
|
||
logger.info("new_question:{0}".format(new_question))
|
||
result = extract_nested_json(new_question)
|
||
logger.info(f"result:{result}")
|
||
result = orjson.loads(result)
|
||
return {"rewritten_user_question": result.get('question', user_question)}
|
||
|
||
|
||
def _gen_sql(state: SqlAgentState) -> dict:
|
||
from main_service import vn
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
|
||
question = state.get('rewritten_user_question', state['user_question'])
|
||
service = vn
|
||
question_sql_list = service.get_similar_question_sql(question)
|
||
if question_sql_list and len(question_sql_list) > 3:
|
||
question_sql_list = question_sql_list[:3]
|
||
ddl_list = service.get_related_ddl(question)
|
||
template = get_base_template()
|
||
sql_temp = template['template']['sql']
|
||
history = []
|
||
try:
|
||
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
|
||
schema=ddl_list, documentation=[train_ddl.train_document],
|
||
history=history, retrieved_examples_data=question_sql_list,
|
||
data_training=question_sql_list, )
|
||
user_temp = sql_temp['user'].format(question=question,
|
||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||
rr = gen_sql_llm.invoke(
|
||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text()
|
||
result = extract_nested_json(rr)
|
||
logger.info(f"gensql result: {result}")
|
||
result = orjson.loads(result)
|
||
retry = state.get("sql_retry_count", 0) + 1
|
||
return {'gen_sql_result': result, "sql_retry_count": retry}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
retry = state.get("sql_retry_count", 0) + 1
|
||
logger.info("cus_vanna_srevice failed-------------------: ")
|
||
return {'gen_sql_error': str(e), "sql_retry_count": retry}
|
||
|
||
|
||
def _gen_chart(state: SqlAgentState) -> dict:
|
||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
|
||
template = get_base_template()
|
||
gen_sql_result = state.get('gen_sql_result', {})
|
||
sql = gen_sql_result.get('sql', '')
|
||
char_type = gen_sql_result.get('char_type', '')
|
||
question = state.get('rewritten_user_question', state['user_question'])
|
||
char_temp = template['template']['chart']
|
||
try:
|
||
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
||
lang='中文', sql=sql, chart_type=char_type)
|
||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||
rr = gen_history_llm.invoke(
|
||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text()
|
||
retry = state.get("chart_retry_count", 0) + 1
|
||
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))}
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
retry = state.get("chart_retry_count", 0) + 1
|
||
return {'gen_chart_error': str(e), "chart_retry_count": retry}
|
||
|
||
|
||
# 如果生成sql失败,则重试2次,如果仍然失败,则返回错误信息
|
||
def gen_sql_handler(state: SqlAgentState) -> str:
|
||
sql_error = state.get('gen_sql_error', '')
|
||
sql_result = state.get('gen_sql_result', {})
|
||
sql_retry_count = state.get('sql_retry_count', 0)
|
||
if sql_error and len(sql_error) > 0:
|
||
if sql_retry_count < 2:
|
||
return '_gen_sql'
|
||
else:
|
||
return END
|
||
sql = sql_result.get('sql', '')
|
||
if len(sql) > 0:
|
||
return '_gen_chart'
|
||
else:
|
||
if sql_retry_count < 2:
|
||
return '_gen_sql'
|
||
else:
|
||
return END
|
||
|
||
|
||
def gen_chart_handler(state: SqlAgentState) -> str:
|
||
chart_error = state.get('gen_chart_error', '')
|
||
chart_result = state.get('gen_chart_result', {})
|
||
sql_retry_count = state.get('chart_retry_count', 0)
|
||
if chart_error and len(chart_error) > 0:
|
||
if sql_retry_count < 2:
|
||
return '_gen_chart'
|
||
else:
|
||
return END
|
||
title = chart_result.get('title', '')
|
||
if len(title) > 0:
|
||
return END
|
||
else:
|
||
if sql_retry_count < 2:
|
||
return '_gen_chart'
|
||
else:
|
||
return END
|
||
|
||
|
||
workflow = StateGraph(SqlAgentState)
|
||
workflow.add_node("_rewrite_user_question", _rewrite_user_question)
|
||
workflow.add_node("_gen_sql", _gen_sql)
|
||
workflow.add_node("_gen_chart", _gen_chart)
|
||
workflow.add_edge(START, "_rewrite_user_question")
|
||
workflow.add_edge("_rewrite_user_question", "_gen_sql")
|
||
workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_chart', END])
|
||
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, ['_gen_chart', END])
|
||
memory = MemorySaver()
|
||
sql_chart_agent = workflow.compile(checkpointer=memory)
|
||
# png_data=sql_chart_agent .get_graph().draw_mermaid_png()
|
||
# with open("D://graph.png", "wb") as f:
|
||
# f.write(png_data) |