from langchain_openai import ChatOpenAI from typing import TypedDict, Annotated, List, Optional, Union import pandas as pd from langgraph.types import Command, interrupt from langgraph.graph import StateGraph, END, START from langgraph.checkpoint.memory import MemorySaver import orjson, logging logger = logging.getLogger(__name__) from decouple import config from template.template import get_base_template from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer from datetime import datetime from util import train_ddl gen_history_llm = ChatOpenAI( model=config('CHAT_MODEL_NAME', default=''), base_url=config('CHAT_MODEL_BASE_URL', default=''), temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float), max_tokens=8000, max_retries=2, api_key=config('CHAT_MODEL_API_KEY', default='empty'), ) gen_sql_llm = ChatOpenAI( model=config('SQL_MODEL_NAME', default=''), base_url=config('SQL_MODEL_BASE_URL', default=''), temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float), max_tokens=8000, max_retries=2, api_key=config('SQL_MODEL_API_KEY', default='empty'), ) ''' 用于生成sql,生成图表的agent,上下文 ''' class SqlAgentState(TypedDict): user_question: str rewritten_user_question: Optional[str] session_id: str user_id: str history: Optional[list] gen_sql_result: Optional[dict] gen_chart_result: Optional[dict] gen_sql_error: Optional[str] gen_chart_error: Optional[str] final_error_msg: Optional[str] sql_retry_count: int chart_retry_count: int ''' 基于上下文,历史消息,重写用户问题 ''' def _rewrite_user_question(state: SqlAgentState) -> dict: logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点") user_question = state['user_question'] template = get_base_template() rewrite_temp = template['template']['rewrite_question'] history = state.get('history', []) sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history) new_question = gen_history_llm.invoke(sys_promot).text() logger.info("new_question:{0}".format(new_question)) result = extract_nested_json(new_question) logger.info(f"result:{result}") result = orjson.loads(result) return {"rewritten_user_question": result.get('question', user_question)} def _gen_sql(state: SqlAgentState) -> dict: from main_service import vn logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点") question = state.get('rewritten_user_question', state['user_question']) service = vn question_sql_list = service.get_similar_question_sql(question) if question_sql_list and len(question_sql_list) > 3: question_sql_list = question_sql_list[:3] ddl_list = service.get_related_ddl(question) template = get_base_template() sql_temp = template['template']['sql'] history = [] try: sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', schema=ddl_list, documentation=[train_ddl.train_document], history=history, retrieved_examples_data=question_sql_list, data_training=question_sql_list, ) user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) rr = gen_sql_llm.invoke( [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text() result = extract_nested_json(rr) logger.info(f"gensql result: {result}") result = orjson.loads(result) retry = state.get("sql_retry_count", 0) + 1 return {'gen_sql_result': result, "sql_retry_count": retry} except Exception as e: import traceback traceback.print_exc() retry = state.get("sql_retry_count", 0) + 1 logger.info("cus_vanna_srevice failed-------------------: ") return {'gen_sql_error': str(e), "sql_retry_count": retry} def _gen_chart(state: SqlAgentState) -> dict: logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点") template = get_base_template() gen_sql_result = state.get('gen_sql_result', {}) sql = gen_sql_result.get('sql', '') char_type = gen_sql_result.get('char_type', '') question = state.get('rewritten_user_question', state['user_question']) char_temp = template['template']['chart'] try: sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', sql=sql, chart_type=char_type) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) rr = gen_history_llm.invoke( [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text() retry = state.get("chart_retry_count", 0) + 1 return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))} except Exception as e: import traceback traceback.print_exc() retry = state.get("chart_retry_count", 0) + 1 return {'gen_chart_error': str(e), "chart_retry_count": retry} # 如果生成sql失败,则重试2次,如果仍然失败,则返回错误信息 def gen_sql_handler(state: SqlAgentState) -> str: sql_error = state.get('gen_sql_error', '') sql_result = state.get('gen_sql_result', {}) sql_retry_count = state.get('sql_retry_count', 0) if sql_error and len(sql_error) > 0: if sql_retry_count < 2: return '_gen_sql' else: return END sql = sql_result.get('sql', '') if len(sql) > 0: return '_gen_chart' else: if sql_retry_count < 2: return '_gen_sql' else: return END def gen_chart_handler(state: SqlAgentState) -> str: chart_error = state.get('gen_chart_error', '') chart_result = state.get('gen_chart_result', {}) sql_retry_count = state.get('chart_retry_count', 0) if chart_error and len(chart_error) > 0: if sql_retry_count < 2: return '_gen_chart' else: return END title = chart_result.get('title', '') if len(title) > 0: return END else: if sql_retry_count < 2: return '_gen_chart' else: return END workflow = StateGraph(SqlAgentState) workflow.add_node("_rewrite_user_question", _rewrite_user_question) workflow.add_node("_gen_sql", _gen_sql) workflow.add_node("_gen_chart", _gen_chart) workflow.add_edge(START, "_rewrite_user_question") workflow.add_edge("_rewrite_user_question", "_gen_sql") workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_chart', END]) workflow.add_conditional_edges('_gen_chart', gen_chart_handler, ['_gen_chart', END]) memory = MemorySaver() sql_chart_agent = workflow.compile(checkpointer=memory) # png_data=sql_chart_agent .get_graph().draw_mermaid_png() # with open("D://graph.png", "wb") as f: # f.write(png_data)