模型反馈修正
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
import orjson
|
||||
from anyio import current_time
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import StateGraph, END, START
|
||||
|
||||
from graph_chat.gen_sql_chart_agent import gen_history_llm
|
||||
from util.utils import extract_nested_json
|
||||
|
||||
from service.conversation_service import update_conversation
|
||||
logger = logging.getLogger(__name__)
|
||||
from template.template import get_base_template
|
||||
|
||||
@@ -20,41 +25,103 @@ class DateReportAgentState(TypedDict):
|
||||
run_sql_error: str
|
||||
retry_count: int
|
||||
question:str
|
||||
retry_sql:str
|
||||
sql_correct:bool
|
||||
|
||||
|
||||
|
||||
def _run_sql(state: DateReportAgentState) -> dict:
|
||||
from main_service import vn
|
||||
sql = state['sql']
|
||||
retry_sql = state.get('retry_sql', '')
|
||||
sql_correct = state.get('sql_correct', False)
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
|
||||
try:
|
||||
if retry_sql and not sql_correct:
|
||||
sql = retry_sql
|
||||
update_conversation(id=state.get('id'), sql=sql)
|
||||
df = vn.run_sql_2(sql)
|
||||
result = df.to_dict(orient='records')
|
||||
logger.info(f"run sql result {result}")
|
||||
retry = state.get("retry_count", 0) + 1
|
||||
return {'data': result, 'retry_count': retry}
|
||||
logger.info(f"old_sql:{sql} retry sql {retry_sql}")
|
||||
return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
|
||||
except Exception as e:
|
||||
retry = state.get("retry_count", 0) + 1
|
||||
logger.error(f"Error running sql: {sql}, error: {e}")
|
||||
return {'retry_count': retry, 'run_sql_error': '运行sql语句失败,请检查语法或联系管理员。'}
|
||||
|
||||
|
||||
def _feedback_qa(state: DateReportAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
|
||||
try:
|
||||
user_question = state['question']
|
||||
sql = state['sql']
|
||||
sql_result = state['data']
|
||||
template = get_base_template()
|
||||
feedback_temp = template['template']['result_feedback']
|
||||
logger.info(f"feedback_temp is {feedback_temp}")
|
||||
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
|
||||
current_time=datetime.now())
|
||||
logger.info(f"system_temp is {sys_promot}")
|
||||
result = gen_history_llm.invoke(sys_promot).text()
|
||||
logger.info(f"feedback result: {result}")
|
||||
result = extract_nested_json(result)
|
||||
result = orjson.loads(result)
|
||||
logger.info(f"提取json成功")
|
||||
logger.info(f"result is {result} type:{type(result)}")
|
||||
print("result is {0}".format(result.keys()))
|
||||
if not result["is_result_correct"] and (result["suggested_sql"]):
|
||||
logger.info("开始替换")
|
||||
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
|
||||
# state["sql"] = result["suggested_sql"]
|
||||
# state["retry_sql"] = True
|
||||
logger.info(f"current state: {state}")
|
||||
return {"retry_sql": result["suggested_sql"]}
|
||||
if result["is_result_correct"]:
|
||||
return {"sql_correct": True}
|
||||
except Exception as e:
|
||||
logger.error(f"Error feedback: {e}")
|
||||
|
||||
|
||||
|
||||
def run_sql_hande(state: DateReportAgentState) -> str:
|
||||
sql_error = state.get('run_sql_error', '')
|
||||
# sql_error = state.get('run_sql_error', '')
|
||||
data = state.get('data', {})
|
||||
sql_correct = state.get('sql_correct', False)
|
||||
sql_retry_count = state.get('retry_count', 0)
|
||||
if sql_error and len(sql_error) > 0:
|
||||
logger.info("sql_retry_count is {0}".format(sql_retry_count))
|
||||
|
||||
if sql_retry_count < 2:
|
||||
return '_run_sql'
|
||||
if sql_correct:
|
||||
return '_gen_report'
|
||||
else:
|
||||
return '_feedback_qa'
|
||||
else:
|
||||
return END
|
||||
if data and len(data) > 0:
|
||||
return '_gen_report'
|
||||
else:
|
||||
if sql_retry_count < 2:
|
||||
return '_run_sql'
|
||||
else:
|
||||
return END
|
||||
END
|
||||
# retry_sql = state.get('retry_sql', False)
|
||||
|
||||
# if sql_error and len(sql_error) > 0:
|
||||
# if sql_retry_count < 2:
|
||||
# return '_feedback_qa'
|
||||
# # return '_run_sql'
|
||||
# else:
|
||||
# return END
|
||||
#
|
||||
# if data and len(data) > 0:
|
||||
# return '_gen_report'
|
||||
# else:
|
||||
# if sql_retry_count < 2:
|
||||
# return '_feedback_qa'
|
||||
# else:
|
||||
# return END
|
||||
|
||||
|
||||
def _gen_report(state: DateReportAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
|
||||
sql = state['sql']
|
||||
question = state['question']
|
||||
run_sql_error = state.get('run_sql_error', '')
|
||||
@@ -62,15 +129,20 @@ def _gen_report(state: DateReportAgentState) -> dict:
|
||||
template = get_base_template()
|
||||
result_summary = template['template']['result_summary']['system']
|
||||
pro = result_summary.format(sql=sql, error_message=run_sql_error, question=question, data=data)
|
||||
txt = gen_history_llm.invoke(pro).text
|
||||
txt = gen_history_llm.invoke(pro).text()
|
||||
return {'summary': txt}
|
||||
|
||||
|
||||
workflowx = StateGraph(DateReportAgentState)
|
||||
workflowx.add_node("_run_sql", _run_sql)
|
||||
workflowx.add_node("_feedback_qa", _feedback_qa)
|
||||
workflowx.add_node("_gen_report", _gen_report)
|
||||
workflowx.add_edge(START, "_run_sql")
|
||||
workflowx.add_edge("_feedback_qa", "_run_sql")
|
||||
workflowx.add_edge("_gen_report", END)
|
||||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report'])
|
||||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
|
||||
memory = MemorySaver()
|
||||
result_report_agent = workflowx.compile(checkpointer=memory)
|
||||
png_data=result_report_agent .get_graph().draw_mermaid_png()
|
||||
with open("D://graph2.png", "wb") as f:
|
||||
f.write(png_data)
|
||||
@@ -65,7 +65,7 @@ def _rewrite_user_question(state: SqlAgentState) -> dict:
|
||||
rewrite_temp = template['template']['rewrite_question']
|
||||
history = state.get('history', [])
|
||||
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
|
||||
new_question = gen_history_llm.invoke(sys_promot).text
|
||||
new_question = gen_history_llm.invoke(sys_promot).text()
|
||||
logger.info("new_question:{0}".format(new_question))
|
||||
result = extract_nested_json(new_question)
|
||||
logger.info(f"result:{result}")
|
||||
@@ -93,7 +93,7 @@ def _gen_sql(state: SqlAgentState) -> dict:
|
||||
user_temp = sql_temp['user'].format(question=question,
|
||||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
rr = gen_sql_llm.invoke(
|
||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text
|
||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text()
|
||||
result = extract_nested_json(rr)
|
||||
logger.info(f"gensql result: {result}")
|
||||
result = orjson.loads(result)
|
||||
@@ -120,7 +120,7 @@ def _gen_chart(state: SqlAgentState) -> dict:
|
||||
lang='中文', sql=sql, chart_type=char_type)
|
||||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||||
rr = gen_history_llm.invoke(
|
||||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text
|
||||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text()
|
||||
retry = state.get("chart_retry_count", 0) + 1
|
||||
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))}
|
||||
except Exception as e:
|
||||
|
||||
@@ -7,7 +7,7 @@ import util.utils
|
||||
from logging_config import LOGGING_CONFIG
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||||
from service.conversation_service import save_conversation,update_conversation,get_sql_by_id
|
||||
from service.conversation_service import save_conversation,update_conversation,get_qa_by_id,get_latest_question
|
||||
from decouple import config
|
||||
import flask
|
||||
from util import load_ddl_doc
|
||||
@@ -133,7 +133,7 @@ def generate_sql_2():
|
||||
data['id'] = id
|
||||
sql = data["resp"]["sql"]
|
||||
logger.info("generate sql is : "+ sql)
|
||||
update_conversation(cvs_id, id, sql)
|
||||
update_conversation(id, sql)
|
||||
save_save_question_async(id, user_id, question, sql)
|
||||
data["type"]="success"
|
||||
return jsonify(data)
|
||||
@@ -241,7 +241,8 @@ def run_sql_2():
|
||||
logger.info("Start to run sql in main")
|
||||
try:
|
||||
id = request.args.get("id")
|
||||
sql = get_sql_by_id(id)
|
||||
qa = get_qa_by_id(id)
|
||||
sql = qa["sql"]
|
||||
logger.info(f"sql is {sql}")
|
||||
if not vn.run_sql_is_set:
|
||||
return jsonify(
|
||||
@@ -297,18 +298,28 @@ def query_present_question():
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
|
||||
def gen_graph_question():
|
||||
try:
|
||||
config = {"configurable": {"thread_id": '1233'}}
|
||||
user_id = request.args.get("user_id")
|
||||
cvs_id = request.args.get("cvs_id")
|
||||
config = {"configurable": {"thread_id": cvs_id}}
|
||||
question = flask.request.args.get("question")
|
||||
question_context = get_latest_question(cvs_id, user_id,limit_count=2)
|
||||
history = []
|
||||
for q in question_context:
|
||||
history.append({"role":"user",'content':q})
|
||||
logger.info(f"history is {history}")
|
||||
initial_state: SqlAgentState = {
|
||||
"user_id": user_id,
|
||||
"user_question": question,
|
||||
"history": [{"role":"user",'content':'谭杰明昨日打卡记录查询'}],
|
||||
"history": history,
|
||||
"sql_retry_count": 0,
|
||||
"chart_retry_count": 0
|
||||
}
|
||||
result = sql_chart_agent.invoke(initial_state, config=config)
|
||||
id = str(uuid.uuid4())
|
||||
cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
|
||||
|
||||
result = sql_chart_agent.invoke(initial_state, config=config)
|
||||
new_question = result.get('rewritten_user_question',question)
|
||||
id = str(uuid.uuid4())
|
||||
save_conversation(id, user_id, cvs_id, new_question)
|
||||
# cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
|
||||
data = {
|
||||
'id': id,
|
||||
'sql': result.get("gen_sql_result", {}),
|
||||
@@ -316,7 +327,9 @@ def gen_graph_question():
|
||||
'gen_sql_error': result.get("gen_sql_error", None),
|
||||
'gen_chart_error': result.get("gen_chart_error", None),
|
||||
}
|
||||
cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
|
||||
# cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
|
||||
sql = data.get('sql', {}).get('sql', '')
|
||||
update_conversation(id, sql)
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@@ -325,9 +338,13 @@ def gen_graph_question():
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_3", methods=["GET"])
|
||||
@session_save
|
||||
@app.requires_cache(["sql",'question'])
|
||||
def run_sql_3(id: str, sql: str, question: str):
|
||||
# @session_save
|
||||
# @app.requires_cache(["sql",'question'])
|
||||
def run_sql_3():
|
||||
id = request.args.get("id")
|
||||
qa = get_qa_by_id(id)
|
||||
sql = qa["sql"]
|
||||
question = qa["question"]
|
||||
logger.info("Start to run sql in main")
|
||||
try:
|
||||
user_id = request.args.get("user_id")
|
||||
@@ -347,6 +364,7 @@ def run_sql_3(id: str, sql: str, question: str):
|
||||
}
|
||||
config = {"configurable": {"thread_id": 'dsds'}}
|
||||
rr = result_report_agent.invoke(initial_state,config)
|
||||
logger.info(f"rr.data is {rr.get('data', {})}")
|
||||
return jsonify(
|
||||
{
|
||||
'data': rr.get('data', {}),
|
||||
|
||||
@@ -41,14 +41,14 @@ def get_conversation(cvs_id: str):
|
||||
# }
|
||||
|
||||
|
||||
def update_conversation(cvs_id: str, id: str, sql=None, meta=None):
|
||||
def update_conversation(id: str, sql=None, meta=None):
|
||||
"""更新sql到对应question"""
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
if sql:
|
||||
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.sql: sql})
|
||||
session.query(Conversation).filter(Conversation.id == id).update({Conversation.sql: sql})
|
||||
if meta:
|
||||
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.meta: meta})
|
||||
session.query(Conversation).filter(Conversation.id == id).update({Conversation.meta: meta})
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
@@ -72,12 +72,12 @@ def get_latest_question(cvs_id, user_id, limit_count):
|
||||
session.close()
|
||||
|
||||
|
||||
def get_sql_by_id(id: str):
|
||||
def get_qa_by_id(id: str):
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
result = session.query(Conversation).filter_by(id=id).first()
|
||||
if result:
|
||||
return result.sql
|
||||
return {"question":result.question, "sql":result.sql}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"get_sql_by_id error {e}")
|
||||
|
||||
@@ -247,7 +247,7 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
new_question = self.generate_rewritten_question(questions,**kwargs)
|
||||
logger.info(f"new_question is {new_question}")
|
||||
question = new_question if new_question else question
|
||||
update_conversation(cvs_id, id, meta=question)
|
||||
update_conversation(id, meta=question)
|
||||
|
||||
# if user_id and cache:
|
||||
# history = cache.get(id=user_id, field="data")
|
||||
|
||||
@@ -624,4 +624,46 @@ template:
|
||||
# 最终输出
|
||||
请根据以上规则,直接输出最终生成的回复话术,不要包含任何其他解释或格式。
|
||||
|
||||
result_feedback:
|
||||
system: |
|
||||
# 角色
|
||||
你是一位经验丰富的数据分析师和SQL专家。你的核心任务是对SQL查询和其执行结果进行批判性审查。
|
||||
# 你的任务
|
||||
我(用户)向你提供了一系列信息,请你进行分析和反思。
|
||||
# 输入信息
|
||||
[原始问题]: <question>{question}</question>
|
||||
[生成的SQL]: <sql>{sql}</sql>。
|
||||
[执行结果]: <sql_result>{sql_result}</sql_result>。
|
||||
[当前时间]: <current_time>{current_time}</current_time>
|
||||
# 核心规则
|
||||
请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤:
|
||||
1. **核对问题理解**:
|
||||
回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。
|
||||
我生成的SQL是否准确捕捉了问题的所有关键要素?有没有遗漏或误解?
|
||||
2. **通用业务定义**
|
||||
上周:指完整的**上周一到上周日**
|
||||
3. **审查SQL逻辑**:
|
||||
逐行分析“生成的SQL”,检查其语法结构是否正确。
|
||||
关键子句(如 `WHERE`, `GROUP BY`, `HAVING`, `JOIN`)是否准确地反映了查询意图?
|
||||
聚合函数(如 `SUM`, `COUNT`, `AVG`)的使用是否恰当?分组维度是否正确?
|
||||
4. **评估结果合理性**:
|
||||
观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)?
|
||||
结果的列名和内容是否与问题期望的输出一致?
|
||||
如果结果为空或数据量异常,推测可能的原因。
|
||||
5. **最终判断和建议**:
|
||||
**结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。
|
||||
**原因** : 简练地解释你做出此判断的核心理由。
|
||||
**改进建议** (如果结果不正确):
|
||||
* 具体指出SQL中可能存在的逻辑错误。
|
||||
* 给出一个或多个修改后的SQL版本。
|
||||
* 解释你为什么这样修改。
|
||||
# 输出格式
|
||||
请严格按照以下JSON格式输出你的分析结果:
|
||||
```json
|
||||
{{
|
||||
"conclusion": "结果正确 OR 结果可能不正确 OR 无法完全确定",
|
||||
"reasoning": "对问题理解、SQL逻辑和结果合理性的综合分析说明。",
|
||||
"is_result_correct": true OR false,
|
||||
"suggested_sql": "如果结论为不正确,请在此处提供修改后的SQL。如果正确,则为null。"
|
||||
}}
|
||||
Resources:
|
||||
Reference in New Issue
Block a user