Compare commits
2 Commits
2bd4b934ce
...
703e0e2591
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
703e0e2591 | ||
|
|
53905324fe |
@@ -36,6 +36,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
|
|||||||
sql = state['sql']
|
sql = state['sql']
|
||||||
retry_sql = state.get('retry_sql', '')
|
retry_sql = state.get('retry_sql', '')
|
||||||
sql_correct = state.get('sql_correct', False)
|
sql_correct = state.get('sql_correct', False)
|
||||||
|
logger.info(f'in _run_sql state={state}')
|
||||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
|
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
|
||||||
try:
|
try:
|
||||||
if retry_sql and not sql_correct:
|
if retry_sql and not sql_correct:
|
||||||
@@ -68,8 +69,9 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
|
|||||||
feedback_temp = template['template']['result_feedback']
|
feedback_temp = template['template']['result_feedback']
|
||||||
logger.info(f"feedback_temp is {feedback_temp}")
|
logger.info(f"feedback_temp is {feedback_temp}")
|
||||||
ddl_list = vn.get_related_ddl(user_question)
|
ddl_list = vn.get_related_ddl(user_question)
|
||||||
|
qa_list = vn.get_similar_question_sql(user_question)
|
||||||
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
|
sys_promot = feedback_temp['system'].format(question=user_question, sql=sql, sql_result=sql_result,
|
||||||
current_time=datetime.now(),ddl_list=ddl_list)
|
current_time=datetime.now(),ddl_list=ddl_list,qa_list=qa_list)
|
||||||
logger.info(f"system_temp is {sys_promot}")
|
logger.info(f"system_temp is {sys_promot}")
|
||||||
result = gen_history_llm.invoke(sys_promot).text()
|
result = gen_history_llm.invoke(sys_promot).text()
|
||||||
logger.info(f"feedback result: {result}")
|
logger.info(f"feedback result: {result}")
|
||||||
@@ -77,12 +79,9 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
|
|||||||
result = orjson.loads(result)
|
result = orjson.loads(result)
|
||||||
logger.info(f"提取json成功")
|
logger.info(f"提取json成功")
|
||||||
logger.info(f"result is {result} type:{type(result)}")
|
logger.info(f"result is {result} type:{type(result)}")
|
||||||
print("result is {0}".format(result.keys()))
|
|
||||||
if not result["is_result_correct"] and (result["suggested_sql"]):
|
if not result["is_result_correct"] and (result["suggested_sql"]):
|
||||||
logger.info("开始替换")
|
logger.info("开始替换")
|
||||||
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
|
logger.info(f"suggested_sql is ".format(result["suggested_sql"]))
|
||||||
# state["sql"] = result["suggested_sql"]
|
|
||||||
# state["retry_sql"] = True
|
|
||||||
logger.info(f"current state: {state}")
|
logger.info(f"current state: {state}")
|
||||||
return {"retry_sql": result["suggested_sql"]}
|
return {"retry_sql": result["suggested_sql"]}
|
||||||
if result["is_result_correct"]:
|
if result["is_result_correct"]:
|
||||||
@@ -93,11 +92,13 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def run_sql_hande(state: DateReportAgentState) -> str:
|
def run_sql_hande(state: DateReportAgentState) -> str:
|
||||||
|
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
|
||||||
|
logger.info(f'in run_sql_handle state={state}')
|
||||||
# sql_error = state.get('run_sql_error', '')
|
# sql_error = state.get('run_sql_error', '')
|
||||||
data = state.get('data', {})
|
data = state.get('data', {})
|
||||||
sql_correct = state.get('sql_correct', False)
|
sql_correct = state.get('sql_correct', False)
|
||||||
sql_retry_count = state.get('retry_count', 0)
|
sql_retry_count = state.get('retry_count', 0)
|
||||||
logger.info("sql_retry_count is {0}".format(sql_retry_count))
|
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
|
||||||
|
|
||||||
if sql_retry_count < 3:
|
if sql_retry_count < 3:
|
||||||
if sql_correct:
|
if sql_correct:
|
||||||
@@ -142,6 +143,6 @@ else:
|
|||||||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report'])
|
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report'])
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
result_report_agent = workflowx.compile(checkpointer=memory)
|
result_report_agent = workflowx.compile(checkpointer=memory)
|
||||||
png_data=result_report_agent .get_graph().draw_mermaid_png()
|
# png_data=result_report_agent .get_graph().draw_mermaid_png()
|
||||||
with open("D://graph2.png", "wb") as f:
|
# with open("D://graph2.png", "wb") as f:
|
||||||
f.write(png_data)
|
# f.write(png_data)
|
||||||
@@ -179,6 +179,6 @@ workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_c
|
|||||||
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, ['_gen_chart', END])
|
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, ['_gen_chart', END])
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
sql_chart_agent = workflow.compile(checkpointer=memory)
|
sql_chart_agent = workflow.compile(checkpointer=memory)
|
||||||
png_data=sql_chart_agent .get_graph().draw_mermaid_png()
|
# png_data=sql_chart_agent .get_graph().draw_mermaid_png()
|
||||||
# with open("D://graph.png", "wb") as f:
|
# with open("D://graph.png", "wb") as f:
|
||||||
# f.write(png_data)
|
# f.write(png_data)
|
||||||
@@ -19,7 +19,6 @@ import traceback
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def generate_timestamp_id():
|
def generate_timestamp_id():
|
||||||
"""生成基于时间戳的ID"""
|
"""生成基于时间戳的ID"""
|
||||||
# 获取当前时间戳(秒级)
|
# 获取当前时间戳(秒级)
|
||||||
@@ -27,6 +26,7 @@ def generate_timestamp_id():
|
|||||||
return f"Q{timestamp}"
|
return f"Q{timestamp}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def connect_database(vn):
|
def connect_database(vn):
|
||||||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||||||
if db_type == 'sqlite':
|
if db_type == 'sqlite':
|
||||||
@@ -180,6 +180,8 @@ def generate_sql_2():
|
|||||||
# return decorator
|
# return decorator
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def session_save(func):
|
# def session_save(func):
|
||||||
# @wraps(func)
|
# @wraps(func)
|
||||||
# def wrapper(*args, **kwargs):
|
# def wrapper(*args, **kwargs):
|
||||||
@@ -206,6 +208,7 @@ def generate_sql_2():
|
|||||||
# return wrapper
|
# return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
||||||
# @session_save
|
# @session_save
|
||||||
# @requires_cache_2(required_keys=["sql"])
|
# @requires_cache_2(required_keys=["sql"])
|
||||||
@@ -375,6 +378,7 @@ def run_sql_3():
|
|||||||
qa = get_qa_by_id(id)
|
qa = get_qa_by_id(id)
|
||||||
sql = qa["sql"]
|
sql = qa["sql"]
|
||||||
question = qa["question"]
|
question = qa["question"]
|
||||||
|
logger.info(f"in main sql {sql} question {question}")
|
||||||
logger.info("Start to run sql in main")
|
logger.info("Start to run sql in main")
|
||||||
try:
|
try:
|
||||||
user_id = request.args.get("user_id")
|
user_id = request.args.get("user_id")
|
||||||
|
|||||||
@@ -629,11 +629,13 @@ template:
|
|||||||
[执行结果]: <sql_result>{sql_result}</sql_result>。
|
[执行结果]: <sql_result>{sql_result}</sql_result>。
|
||||||
[当前时间]: <current_time>{current_time}</current_time>
|
[当前时间]: <current_time>{current_time}</current_time>
|
||||||
[表结构信息]:<schema>{ddl_list}</schema>
|
[表结构信息]:<schema>{ddl_list}</schema>
|
||||||
|
[问答参考]:<question_answer>{qa_list}</question_answer>
|
||||||
# 核心规则
|
# 核心规则
|
||||||
请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤:
|
请根据以上信息,进行全面、细致的反思,并最终判断这个结果是否能正确回答原始问题。你的反思需要包含以下几个步骤:
|
||||||
1. **核对问题理解**:
|
1. **核对问题理解**:
|
||||||
回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。
|
回顾“原始问题”,确认其核心意图、时间范围、筛选条件和所需字段。
|
||||||
我生成的SQL是否准确捕捉了问题的所有关键要素?有没有遗漏或误解?
|
我生成的SQL是否准确捕捉了问题的所有关键要素?有没有遗漏或误解?
|
||||||
|
当sql和问题吻合,且能查询出结果时,一般判定为正确
|
||||||
2. **通用业务定义**
|
2. **通用业务定义**
|
||||||
上周:指完整的**上周一到上周日**
|
上周:指完整的**上周一到上周日**
|
||||||
3. **审查SQL逻辑**:
|
3. **审查SQL逻辑**:
|
||||||
@@ -643,7 +645,6 @@ template:
|
|||||||
4. **评估结果合理性**:
|
4. **评估结果合理性**:
|
||||||
观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)?
|
观察“执行结果”,思考这个结果是否符合业务常识或数据的基本特征(例如,数量级、正负值、范围是否合理)?
|
||||||
结果的列名和内容是否与问题期望的输出一致?
|
结果的列名和内容是否与问题期望的输出一致?
|
||||||
如果结果为空或数据量异常,推测可能的原因。
|
|
||||||
5. **最终判断和建议**:
|
5. **最终判断和建议**:
|
||||||
**结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。
|
**结论** : 用一句话明确指出:“结果正确”、“结果可能不正确”或“无法完全确定”。
|
||||||
**原因** : 简练地解释你做出此判断的核心理由。
|
**原因** : 简练地解释你做出此判断的核心理由。
|
||||||
|
|||||||
Reference in New Issue
Block a user