在藏表添加,feedback优化, 每轮错误状态清空
This commit is contained in:
@@ -37,7 +37,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
|
||||
retry_sql = state.get('retry_sql', '')
|
||||
sql_correct = state.get('sql_correct', False)
|
||||
logger.info(f'in _run_sql state={state}')
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _run_sql 节点 ---------------")
|
||||
try:
|
||||
if retry_sql and not sql_correct:
|
||||
sql = retry_sql
|
||||
@@ -47,7 +47,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
|
||||
logger.info(f"run sql result {result}")
|
||||
retry = state.get("retry_count", 0) + 1
|
||||
logger.info(f"old_sql:{sql}, retry sql {retry_sql}")
|
||||
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
|
||||
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql,'run_sql_error':''}
|
||||
#不需要模型反馈,运行成功后则成功
|
||||
if not NEED_MODEL_FEEDBACK_QA:
|
||||
dd['sql_correct']=True
|
||||
@@ -60,7 +60,7 @@ def _run_sql(state: DateReportAgentState) -> dict:
|
||||
|
||||
def _feedback_qa(state: DateReportAgentState) -> dict:
|
||||
from main_service import vn
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _feedback_qa 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _feedback_qa 节点 ---------------")
|
||||
try:
|
||||
user_question = state['question']
|
||||
sql = state['sql']
|
||||
@@ -85,39 +85,65 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
|
||||
logger.info(f"current state: {state}")
|
||||
return {"retry_sql": result["suggested_sql"]}
|
||||
if result["is_result_correct"]:
|
||||
return {"sql_correct": True}
|
||||
return {"sql_correct": True,"run_sql_error":""}
|
||||
except Exception as e:
|
||||
logger.error(f"Error feedback: {e}")
|
||||
|
||||
def handle_no_feedback(state: DateReportAgentState) -> str:
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 handle_no_feedback 节点 ---------------")
|
||||
sql_error = state.get('run_sql_error', '')
|
||||
data = state.get('data', {})
|
||||
sql_retry_count = state.get('retry_count', 0)
|
||||
if sql_error and len(sql_error) > 0:
|
||||
if sql_retry_count < 3:
|
||||
return '_run_sql'
|
||||
else:
|
||||
return END
|
||||
if data and len(data) > 0:
|
||||
return '_gen_report'
|
||||
else:
|
||||
if sql_retry_count < 2:
|
||||
return '_run_sql'
|
||||
else:
|
||||
return END
|
||||
|
||||
def handle_with_feedback(state: DateReportAgentState) -> str:
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 handle_with_feedback 节点 ---------------")
|
||||
sql_error = state.get('run_sql_error', '')
|
||||
# data = state.get('data', {})
|
||||
# sql_correct = state.get('sql_correct', False)
|
||||
sql_retry_count = state.get('retry_count', 0)
|
||||
logger.info(f"handle with feedback sql_retry_count is {sql_retry_count} sql error is {sql_error}")
|
||||
if sql_retry_count < 3:
|
||||
if sql_error and len(sql_error) > 0:
|
||||
return '_feedback_qa'
|
||||
else:
|
||||
return '_gen_report'
|
||||
# if sql_correct:
|
||||
# return '_gen_report'
|
||||
# else:
|
||||
# if sql_error and len(sql_error) > 0:
|
||||
# return '_feedback_qa'
|
||||
# else:
|
||||
# return END
|
||||
else:
|
||||
if sql_error and len(sql_error)>0:
|
||||
return END
|
||||
else:
|
||||
return '_gen_report'
|
||||
|
||||
def run_sql_hande(state: DateReportAgentState) -> str:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _run_sql_hande 节点 ---------------")
|
||||
logger.info(f'in run_sql_handle state={state}')
|
||||
# sql_error = state.get('run_sql_error', '')
|
||||
data = state.get('data', {})
|
||||
sql_correct = state.get('sql_correct', False)
|
||||
sql_retry_count = state.get('retry_count', 0)
|
||||
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
|
||||
|
||||
if sql_retry_count < 3:
|
||||
if sql_correct:
|
||||
return '_gen_report'
|
||||
else:
|
||||
if NEED_MODEL_FEEDBACK_QA:
|
||||
return '_feedback_qa'
|
||||
else:
|
||||
return '_run_sql'
|
||||
if NEED_MODEL_FEEDBACK_QA:
|
||||
return handle_with_feedback(state)
|
||||
else:
|
||||
if data and len(data) > 0:
|
||||
return '_gen_report'
|
||||
else:
|
||||
END
|
||||
return handle_no_feedback(state)
|
||||
|
||||
|
||||
|
||||
def _gen_report(state: DateReportAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_report 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_report 节点 ---------------")
|
||||
sql = state['sql']
|
||||
question = state['question']
|
||||
run_sql_error = state.get('run_sql_error', '')
|
||||
|
||||
@@ -59,7 +59,7 @@ class SqlAgentState(TypedDict):
|
||||
|
||||
|
||||
def _rewrite_user_question(state: SqlAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _rewrite_user_question 节点---------------")
|
||||
user_question = state['user_question']
|
||||
template = get_base_template()
|
||||
rewrite_temp = template['template']['rewrite_question']
|
||||
@@ -75,7 +75,7 @@ def _rewrite_user_question(state: SqlAgentState) -> dict:
|
||||
|
||||
def _gen_sql(state: SqlAgentState) -> dict:
|
||||
from main_service import vn
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_sql 节点---------------")
|
||||
question = state.get('rewritten_user_question', state['user_question'])
|
||||
service = vn
|
||||
question_sql_list = service.get_similar_question_sql(question)
|
||||
@@ -98,7 +98,7 @@ def _gen_sql(state: SqlAgentState) -> dict:
|
||||
logger.info(f"gensql result: {result}")
|
||||
result = orjson.loads(result)
|
||||
retry = state.get("sql_retry_count", 0) + 1
|
||||
return {'gen_sql_result': result, "sql_retry_count": retry}
|
||||
return {'gen_sql_result': result, "sql_retry_count": retry,"gen_sql_error":""}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
@@ -108,7 +108,7 @@ def _gen_sql(state: SqlAgentState) -> dict:
|
||||
|
||||
|
||||
def _gen_chart(state: SqlAgentState) -> dict:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_chart 节点---------------")
|
||||
template = get_base_template()
|
||||
gen_sql_result = state.get('gen_sql_result', {})
|
||||
sql = gen_sql_result.get('sql', '')
|
||||
@@ -122,7 +122,7 @@ def _gen_chart(state: SqlAgentState) -> dict:
|
||||
rr = gen_history_llm.invoke(
|
||||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text()
|
||||
retry = state.get("chart_retry_count", 0) + 1
|
||||
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr))}
|
||||
return {"chart_retry_count": retry, 'gen_chart_result': orjson.loads(extract_nested_json(rr)),"gen_chart_error":""}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
@@ -132,6 +132,7 @@ def _gen_chart(state: SqlAgentState) -> dict:
|
||||
|
||||
# 如果生成sql失败,则重试2次,如果仍然失败,则返回错误信息
|
||||
def gen_sql_handler(state: SqlAgentState) -> str:
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 gen_sql_handler ---------------")
|
||||
sql_error = state.get('gen_sql_error', '')
|
||||
sql_result = state.get('gen_sql_result', {})
|
||||
sql_retry_count = state.get('sql_retry_count', 0)
|
||||
@@ -151,6 +152,7 @@ def gen_sql_handler(state: SqlAgentState) -> str:
|
||||
|
||||
|
||||
def gen_chart_handler(state: SqlAgentState) -> str:
|
||||
logger.info(f"user:{state.get('user_id', '1')} ---------------进入 _gen_chart ---------------")
|
||||
chart_error = state.get('gen_chart_error', '')
|
||||
chart_result = state.get('gen_chart_result', {})
|
||||
sql_retry_count = state.get('chart_retry_count', 0)
|
||||
|
||||
@@ -7,7 +7,8 @@ table_ddls = [
|
||||
train_ddl.person_database_ddl,train_ddl.person_status_ddl,
|
||||
train_ddl.person_attendance_ddl,train_ddl.person_ac_area,
|
||||
train_ddl.person_ac_position,
|
||||
train_ddl.org_orgs_ddl
|
||||
train_ddl.org_orgs_ddl,
|
||||
train_ddl.person_in_tibat
|
||||
|
||||
]
|
||||
list_documentions = [
|
||||
|
||||
@@ -948,6 +948,26 @@ question_and_answer = [
|
||||
"tags": ["员工", "个人", "考勤", "工作地", "区域", "最早在藏时间"],
|
||||
"category": "工作地考勤统计分析"
|
||||
},
|
||||
{
|
||||
"question": "XX中心在藏最长时间的人是谁",
|
||||
"answer": '''
|
||||
SELECT p."name" AS "姓名", p."code" AS "工号", COUNT(ps."id") AS "在藏天数"
|
||||
FROM YJOA_APPSERVICE_DB."t_yj_person_status" ps
|
||||
JOIN YJOA_APPSERVICE_DB."t_pr3rl2oj_yj_person_database" p ON ps."person_id" = p."code"
|
||||
WHERE ps."is_in_tibet" = 1
|
||||
AND ps."dr" = 0
|
||||
AND p."dr" = 0
|
||||
and p.internal_dept in (SELECT "id"
|
||||
FROM "IUAP_APDOC_BASEDOC"."org_orgs" START
|
||||
WITH "name"||"shortname" LIKE '%xx中心%' AND "dr"=0 AND "enable"=1 AND "code" LIKE '%CYJ%'
|
||||
CONNECT BY PRIOR "id" = "parentid"
|
||||
)
|
||||
GROUP BY p."name", p."code"
|
||||
ORDER BY COUNT (ps."id") DESC LIMIT 1
|
||||
''',
|
||||
"tags": ["员工", "个人", "考勤", "工作地", "区域", "累计在藏统计"],
|
||||
"category": "工作地考勤统计分析"
|
||||
},
|
||||
{
|
||||
"question": "张三从5月到10月每个月分别在藏多长时间",
|
||||
"answer": '''
|
||||
|
||||
@@ -11,6 +11,7 @@ train_document='''
|
||||
internal_dept和internal_unit是部门编号不是名称,注意区分
|
||||
查询部门信息时尽量使用internal_dept而非internal_unit<UNK>
|
||||
数信部不是数信中心,两者不能等价
|
||||
数信中心就叫数信中心,没有数字信息中心这个部门,请勿胡乱替换
|
||||
'''
|
||||
|
||||
person_database_ddl = """
|
||||
@@ -935,8 +936,7 @@ person_ac_position = '''
|
||||
}
|
||||
'''
|
||||
|
||||
person_ac_area = '''
|
||||
{
|
||||
person_ac_area = '''{
|
||||
"db_name":"YJOA_APPSERVICE_DB",
|
||||
"table_name": "t_yj_person_ac_area",
|
||||
"table_comment": "门禁与区域关系表",
|
||||
@@ -980,8 +980,80 @@ person_ac_area = '''
|
||||
|
||||
"tags": ["门禁详情","门禁与区域位置关联信息","门禁地区信息","枚举"]
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
person_in_tibat = '''
|
||||
{
|
||||
"db_name": "YJOA_APPSERVICE_DB",
|
||||
"table_name": "person_in_tibat",
|
||||
"table_comment": "人员在藏情况表",
|
||||
"columns": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "VARCHAR(50)",
|
||||
"comment": "主键ID",
|
||||
"role": "dimension",
|
||||
"tags": ["唯一标识"]
|
||||
},
|
||||
{
|
||||
"name": "person_id",
|
||||
"type": "VARCHAR(50)",
|
||||
"comment": "人员ID",
|
||||
"role": "dimension",
|
||||
"tags": ["人员标识"]
|
||||
},
|
||||
{
|
||||
"name": "year",
|
||||
"type": "VARCHAR(20)",
|
||||
"comment": "年份",
|
||||
"role": "dimension",
|
||||
"tags": ["时间维度"]
|
||||
},
|
||||
{
|
||||
"name": "create_time",
|
||||
"type": "DATETIME",
|
||||
"comment": "创建时间",
|
||||
"role": "dimension",
|
||||
"tags": ["时间信息"]
|
||||
},
|
||||
{
|
||||
"name": "update_time",
|
||||
"type": "DATETIME",
|
||||
"comment": "更新时间",
|
||||
"role": "dimension",
|
||||
"tags": ["时间信息"]
|
||||
},
|
||||
{
|
||||
"name": "dr",
|
||||
"type": "INT",
|
||||
"comment": "删除标记",
|
||||
"role": "dimension",
|
||||
"tags": ["数据状态"]
|
||||
},
|
||||
{
|
||||
"name": "continuous_in_tibet_days",
|
||||
"type": "INT",
|
||||
"comment": "连续在藏天数",
|
||||
"role": "metric",
|
||||
"tags": ["连续在藏天数"]
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"from": "person_id",
|
||||
"to_table": "t_yj_person_database",
|
||||
"to_field": "code",
|
||||
"type": "foreign_key",
|
||||
"comment": "关联人员基本信息表"
|
||||
}
|
||||
],
|
||||
"tags": ["人员信息", "在藏情况", "时间统计", "人员轨迹"]
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
|
||||
org_orgs_ddl = '''
|
||||
{
|
||||
"db_name":"IUAP_APDOC_BASEDOC",
|
||||
|
||||
Reference in New Issue
Block a user