在藏表添加,feedback优化, 每轮错误状态清空

This commit is contained in:
yujj128
2025-11-08 15:24:28 +08:00
parent 703e0e2591
commit 896cc689f1
5 changed files with 153 additions and 32 deletions

View File

@@ -37,7 +37,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
retry_sql = state.get('retry_sql', '')
sql_correct = state.get('sql_correct', False)
logger.info(f'in _run_sql state={state}')
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _run_sql 节点 ---------------")
try:
if retry_sql and not sql_correct:
sql = retry_sql
@@ -47,7 +47,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
logger.info(f"run sql result {result}")
retry = state.get("retry_count", 0) + 1
logger.info(f"old_sql:{sql}, retry sql {retry_sql}")
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql,'run_sql_error':''}
#不需要模型反馈,运行成功后则成功
if not NEED_MODEL_FEEDBACK_QA:
dd['sql_correct']=True
@@ -60,7 +60,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
def _feedback_qa(state: DateReportAgentState) -> dict:
from main_service import vn
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _feedback_qa 节点 ---------------")
try:
user_question = state['question']
sql = state['sql']
@@ -85,39 +85,65 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
logger.info(f"current state: {state}")
return {"retry_sql": result["suggested_sql"]}
if result["is_result_correct"]:
return {"sql_correct": True}
return {"sql_correct": True,"run_sql_error":""}
except Exception as e:
logger.error(f"Error feedback: {e}")
def handle_no_feedback(state: DateReportAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 handle_no_feedback 节点 ---------------")
sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
sql_retry_count = state.get('retry_count', 0)
if sql_error and len(sql_error) > 0:
if sql_retry_count < 3:
return '_run_sql'
else:
return END
if data and len(data) > 0:
return '_gen_report'
else:
if sql_retry_count < 2:
return '_run_sql'
else:
return END
def handle_with_feedback(state: DateReportAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 handle_with_feedback 节点 ---------------")
sql_error = state.get('run_sql_error', '')
# data = state.get('data', {})
# sql_correct = state.get('sql_correct', False)
sql_retry_count = state.get('retry_count', 0)
logger.info(f"handle with feedback sql_retry_count is {sql_retry_count} sql error is {sql_error}")
if sql_retry_count < 3:
if sql_error and len(sql_error) > 0:
return '_feedback_qa'
else:
return '_gen_report'
# if sql_correct:
# return '_gen_report'
# else:
# if sql_error and len(sql_error) > 0:
# return '_feedback_qa'
# else:
# return END
else:
if sql_error and len(sql_error)>0:
return END
else:
return '_gen_report'
def run_sql_hande(state: DateReportAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _run_sql_hande 节点 ---------------")
logger.info(f'in run_sql_handle state={state}')
# sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
sql_correct = state.get('sql_correct', False)
sql_retry_count = state.get('retry_count', 0)
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
if sql_retry_count < 3:
if sql_correct:
return '_gen_report'
else:
if NEED_MODEL_FEEDBACK_QA:
return '_feedback_qa'
else:
return '_run_sql'
if NEED_MODEL_FEEDBACK_QA:
return handle_with_feedback(state)
else:
if data and len(data) > 0:
return '_gen_report'
else:
END
return handle_no_feedback(state)
def _gen_report(state: DateReportAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_report 节点 ---------------")
sql = state['sql']
question = state['question']
run_sql_error = state.get('run_sql_error', '')

View File

@@ -59,7 +59,7 @@ class SqlAgentState(TypedDict):
def _rewrite_user_question(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _rewrite_user_question 节点---------------")
user_question = state['user_question']
template = get_base_template()
rewrite_temp = template['template']['rewrite_question']
@@ -75,7 +75,7 @@ def _rewrite_user_question(state: SqlAgentState) -> dict:
def _gen_sql(state: SqlAgentState) -> dict:
from main_service import vn
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_sql 节点---------------")
question = state.get('rewritten_user_question', state['user_question'])
service = vn
question_sql_list = service.get_similar_question_sql(question)
@@ -98,7 +98,7 @@ def _gen_sql(state: SqlAgentState) -> dict:
logger.info(f"gensql result: {result}")
result = orjson.loads(result)
retry = state.get("sql_retry_count", 0) + 1
return {'gen_sql_result': result, "sql_retry_count": retry}
return {'gen_sql_result': result, "sql_retry_count": retry,"gen_sql_error":""}
except Exception as e:
import traceback
traceback.print_exc()
@@ -108,7 +108,7 @@ def _gen_sql(state: SqlAgentState) -> dict:
def _gen_chart(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_chart 节点---------------")
template = get_base_template()
gen_sql_result = state.get('gen_sql_result', {})
sql = gen_sql_result.get('sql', '')
@@ -122,7 +122,7 @@ def _gen_chart(state: SqlAgentState) -> dict:
rr = gen_history_llm.invoke(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text()
retry = state.get("chart_retry_count", 0) + 1
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))}
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr)),"gen_chart_error":""}
except Exception as e:
import traceback
traceback.print_exc()
@@ -132,6 +132,7 @@ def _gen_chart(state: SqlAgentState) -> dict:
# 如果生成sql失败则重试2次如果仍然失败则返回错误信息
def gen_sql_handler(state: SqlAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 gen_sql_handler ---------------")
sql_error = state.get('gen_sql_error', '')
sql_result = state.get('gen_sql_result', {})
sql_retry_count = state.get('sql_retry_count', 0)
@@ -151,6 +152,7 @@ def gen_sql_handler(state: SqlAgentState) -> str:
def gen_chart_handler(state: SqlAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_chart ---------------")
chart_error = state.get('gen_chart_error', '')
chart_result = state.get('gen_chart_result', {})
sql_retry_count = state.get('chart_retry_count', 0)

View File

@@ -7,7 +7,8 @@ table_ddls = [
train_ddl.person_database_ddl,train_ddl.person_status_ddl,
train_ddl.person_attendance_ddl,train_ddl.person_ac_area,
train_ddl.person_ac_position,
train_ddl.org_orgs_ddl
train_ddl.org_orgs_ddl,
train_ddl.person_in_tibat
]
list_documentions = [

View File

@@ -948,6 +948,26 @@ question_and_answer = [
"tags": ["员工", "个人", "考勤", "工作地", "区域", "最早在藏时间"],
"category": "工作地考勤统计分析"
},
{
"question": "XX中心在藏最长时间的人是谁",
"answer": '''
SELECT p."name" AS "姓名", p."code" AS "工号", COUNT(ps."id") AS "在藏天数"
FROM YJOA_APPSERVICE_DB."t_yj_person_status" ps
JOIN YJOA_APPSERVICE_DB."t_pr3rl2oj_yj_person_database" p ON ps."person_id" = p."code"
WHERE ps."is_in_tibet" = 1
AND ps."dr" = 0
AND p."dr" = 0
and p.internal_dept in (SELECT "id"
FROM "IUAP_APDOC_BASEDOC"."org_orgs" START
WITH "name"||"shortname" LIKE '%xx中心%' AND "dr"=0 AND "enable"=1 AND "code" LIKE '%CYJ%'
CONNECT BY PRIOR "id" = "parentid"
)
GROUP BY p."name", p."code"
ORDER BY COUNT (ps."id") DESC LIMIT 1
''',
"tags": ["员工", "个人", "考勤", "工作地", "区域", "累计在藏统计"],
"category": "工作地考勤统计分析"
},
{
"question": "张三从5月到10月每个月分别在藏多长时间",
"answer": '''

View File

@@ -11,6 +11,7 @@ train_document='''
internal_dept和internal_unit是部门编号不是名称注意区分
查询部门信息时尽量使用internal_dept而非internal_unit<UNK>
数信部不是数信中心,两者不能等价
数信中心就叫数信中心,没有数字信息中心这个部门,请勿胡乱替换
'''
person_database_ddl = """
@@ -935,8 +936,7 @@ person_ac_position = '''
}
'''
person_ac_area = '''
{
person_ac_area = '''{
"db_name":"YJOA_APPSERVICE_DB",
"table_name": "t_yj_person_ac_area",
"table_comment": "门禁与区域关系表",
@@ -980,8 +980,80 @@ person_ac_area = '''
"tags": ["门禁详情","门禁与区域位置关联信息","门禁地区信息","枚举"]
}
'''
person_in_tibat = '''
{
"db_name": "YJOA_APPSERVICE_DB",
"table_name": "person_in_tibat",
"table_comment": "人员在藏情况表",
"columns": [
{
"name": "id",
"type": "VARCHAR(50)",
"comment": "主键ID",
"role": "dimension",
"tags": ["唯一标识"]
},
{
"name": "person_id",
"type": "VARCHAR(50)",
"comment": "人员ID",
"role": "dimension",
"tags": ["人员标识"]
},
{
"name": "year",
"type": "VARCHAR(20)",
"comment": "年份",
"role": "dimension",
"tags": ["时间维度"]
},
{
"name": "create_time",
"type": "DATETIME",
"comment": "创建时间",
"role": "dimension",
"tags": ["时间信息"]
},
{
"name": "update_time",
"type": "DATETIME",
"comment": "更新时间",
"role": "dimension",
"tags": ["时间信息"]
},
{
"name": "dr",
"type": "INT",
"comment": "删除标记",
"role": "dimension",
"tags": ["数据状态"]
},
{
"name": "continuous_in_tibet_days",
"type": "INT",
"comment": "连续在藏天数",
"role": "metric",
"tags": ["连续在藏天数"]
}
],
"relationships": [
{
"from": "person_id",
"to_table": "t_yj_person_database",
"to_field": "code",
"type": "foreign_key",
"comment": "关联人员基本信息表"
}
],
"tags": ["人员信息", "在藏情况", "时间统计", "人员轨迹"]
}
'''
org_orgs_ddl = '''
{
"db_name":"IUAP_APDOC_BASEDOC",