Merge branch 'dev_graph' of http://106.13.42.156:33077/lei_y601/sqlbot_agent into dev_graph

# Conflicts:
#	graph_chat/gen_data_report_agent.py
#	main_service.py
This commit is contained in:
yujj128
2025-11-07 18:12:44 +08:00
7 changed files with 125 additions and 61 deletions

2
.env
View File

@@ -54,5 +54,5 @@ DAMENG_DATABASE_USER=ai_view
TTL_CACHE=43200
SESSION_LENGTH=2
NEED_MODEL_FEEDBACK_QA=False
ALLOWED_USERS=ea17e939-10f1-492d-adf1-e6ac42f89d81,5fba57a1-0d5f-4988-bcc9-a2b340f8adf1,25bb2ddb-d26d-486f-9aab-5497a9f1e10e,793b49e9-98e6-4f15-8b56-148a691d3022,c76720b5-15f7-4d23-a213-19da38b64d89,103bdc74-a88c-4e1e-a379-794937443c77,c4b2e802-58ba-4927-82fe-f938deb41cf0,1cf59aed-bbf0-4f1c-9246-1d96cc5e2719,6b28c87e-061d-4aca-9a12-71c52e677e73,4c105cfa-ce70-4eca-8d36-c518eda89005,cc93a83b-b145-43f4-80af-fe4ce8a55a04

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text
from sqlalchemy import Column, DateTime,Integer, String, create_engine, Boolean, Text
from sqlalchemy.orm import declarative_base, sessionmaker
# 申明基类对象
@@ -23,10 +23,12 @@ class QuestionFeedBack(Base):
user_id = Column(String(100), nullable=False, default='1')
# 用户意见反馈
user_comment = Column(Text, nullable=True)
# 用户点赞,点踩
user_praise = Column(Boolean, nullable=False, default=False)
# 用户点赞,点踩 0代表未点赞1代表点赞-1代表点踩
user_praise = Column(Integer, nullable=False, default=0)
# 该数据是否被认为梳理过
is_process = Column(Boolean, nullable=False, default=False)
def to_dict(self):
return {"id":self.id, "question":self.question, "user_id":self.user_id, "user_praise":self.user_praise}
class Conversation(Base):

View File

@@ -14,7 +14,8 @@ from util.utils import extract_nested_json
from service.conversation_service import update_conversation
logger = logging.getLogger(__name__)
from template.template import get_base_template
from decouple import config
NEED_MODEL_FEEDBACK_QA=config("NEED_MODEL_FEEDBACK_QA", default=False, cast=bool)
class DateReportAgentState(TypedDict):
id: str
@@ -45,8 +46,12 @@ def _run_sql(state: DateReportAgentState) -> dict:
result = df.to_dict(orient='records')
logger.info(f"run sql result {result}")
retry = state.get("retry_count", 0) + 1
logger.info(f"old_sql:{sql} retry sql {retry_sql}")
return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
logger.info(f"old_sql:{sql}, retry sql {retry_sql}")
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
#不需要模型反馈,运行成功后则成功
if not NEED_MODEL_FEEDBACK_QA:
dd['sql_correct']=True
return dd
except Exception as e:
retry = state.get("retry_count", 0) + 1
logger.error(f"Error running sql: {sql}, error: {e}")
@@ -88,39 +93,27 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
def run_sql_hande(state: DateReportAgentState) -> str:
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
logger.info(f'in _run_sql_handle state={state}')
logger.info(f'in run_sql_handle state={state}')
# sql_error = state.get('run_sql_error', '')
data = state.get('data', {})
sql_correct = state.get('sql_correct', False)
sql_retry_count = state.get('retry_count', 0)
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
if sql_retry_count < 2:
if sql_retry_count < 3:
if sql_correct:
return '_gen_report'
else:
if NEED_MODEL_FEEDBACK_QA:
return '_feedback_qa'
else:
return '_run_sql'
else:
if data and len(data) > 0:
return '_gen_report'
else:
END
# retry_sql = state.get('retry_sql', False)
# if sql_error and len(sql_error) > 0:
# if sql_retry_count < 2:
# return '_feedback_qa'
# # return '_run_sql'
# else:
# return END
#
# if data and len(data) > 0:
# return '_gen_report'
# else:
# if sql_retry_count < 2:
# return '_feedback_qa'
# else:
# return END
def _gen_report(state: DateReportAgentState) -> dict:
@@ -138,13 +131,16 @@ def _gen_report(state: DateReportAgentState) -> dict:
workflowx = StateGraph(DateReportAgentState)
workflowx.add_node("_run_sql", _run_sql)
workflowx.add_node("_feedback_qa", _feedback_qa)
if NEED_MODEL_FEEDBACK_QA:
workflowx.add_node("_feedback_qa", _feedback_qa)
workflowx.add_node("_gen_report", _gen_report)
workflowx.add_edge(START, "_run_sql")
workflowx.add_edge("_feedback_qa", "_run_sql")
workflowx.add_edge("_gen_report", END)
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
if NEED_MODEL_FEEDBACK_QA:
workflowx.add_edge("_feedback_qa", "_run_sql")
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
else:
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report'])
memory = MemorySaver()
result_report_agent = workflowx.compile(checkpointer=memory)
# png_data=result_report_agent .get_graph().draw_mermaid_png()

View File

@@ -6,8 +6,9 @@ from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
from service.conversation_service import save_conversation,update_conversation,get_qa_by_id,get_latest_question
from service.question_feedback_service import save_save_question_async, query_predefined_question_list, \
update_user_feedBack, query_feedBack_question_list
from service.conversation_service import save_conversation, update_conversation, get_qa_by_id, get_latest_question
from decouple import config
import flask
from util import load_ddl_doc
@@ -125,17 +126,17 @@ def generate_sql_2():
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
id = generate_timestamp_id()
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
save_conversation(id,user_id,cvs_id,question)
save_conversation(id, user_id, cvs_id, question)
try:
logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(user_id,cvs_id,question,id,need_context)
data = vn.generate_sql_2(user_id, cvs_id, question, id, need_context)
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : "+ sql)
logger.info("generate sql is : " + sql)
update_conversation(id, sql)
save_save_question_async(id, user_id, question, sql)
data["type"]="success"
data["type"] = "success"
return jsonify(data)
except Exception as e:
logger.error(f"generate sql failed:{e}")
@@ -295,17 +296,47 @@ def query_present_question():
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/query_feedback_question", methods=["POST"])
def query_feedback_question():
id_list = request.json.get("id_list", [])
try:
data = query_feedBack_question_list(id_list)
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询用户反馈问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询用户反馈问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/question_feed_back", methods=["PUT"])
def update_question_feed_back():
id = request.json.get("id")
user_feedback = request.json.get("user_feedback")
if not id or not user_feedback:
return jsonify({"type": "error", "error": "id 或者用户反馈为空"})
try:
update_user_feedBack(id, '', user_feedback)
return jsonify({"type": "success"})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
def gen_graph_question():
try:
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
config = {"configurable": {"thread_id": uuid.uuid4()}}
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
question = flask.request.args.get("question")
question_context = get_latest_question(cvs_id, user_id,limit_count=2)
question_context = get_latest_question(cvs_id, user_id, limit_count=2)
history = []
i = 0
for q in question_context:
history.append({"role":"user",'content':q})
is_latest=False
if i==0:
is_latest=True
history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest})
i +=1
logger.info(f"history is {history}")
initial_state: SqlAgentState = {
"user_id": user_id,
@@ -316,9 +347,10 @@ def gen_graph_question():
}
result = sql_chart_agent.invoke(initial_state, config=config)
new_question = result.get('rewritten_user_question',question)
new_question = result.get('rewritten_user_question', question)
id = str(uuid.uuid4())
save_conversation(id, user_id, cvs_id, new_question)
# cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
data = {
'id': id,
@@ -330,6 +362,7 @@ def gen_graph_question():
# cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
sql = data.get('sql', {}).get('sql', '')
update_conversation(id, sql)
save_save_question_async(id, user_id, new_question, sql)
return jsonify(data)
except Exception as e:
traceback.print_exc()
@@ -362,11 +395,9 @@ def run_sql_3():
"sql": sql,
"question": question,
"retry_count": 0,
"need_feedback": False,
}
config = {"configurable": {"thread_id": uuid.uuid4()}}
rr = result_report_agent.invoke(initial_state,config)
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
rr = result_report_agent.invoke(initial_state, config)
logger.info(f"rr.data is {rr.get('data', {})}")
return jsonify(
{

View File

@@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
pool = ThreadPoolExecutor(max_workers=10)
def save_question(id, user_id, question, sql):
logger.info(f"开始保存用户用户==>:{question}")
user_id = user_id if user_id else '1'
@@ -27,7 +28,7 @@ def save_save_question_async(id, user_id, question, sql):
pool.submit(save_question, id, user_id, question, sql)
def update_user_feedBack(id, user_comment, user_praise: bool):
def update_user_feedBack(id, user_comment, user_praise: int):
session = SqliteSqlalchemy().session
try:
session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update(
@@ -39,12 +40,25 @@ def update_user_feedBack(id, user_comment, user_praise: bool):
finally:
session.close()
'''
查询所有的预制问题
'''
def query_predefined_question_list() -> list:
session = SqliteSqlalchemy().session
all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all()
all = [a.to_dict() for a in all] if all else []
session.close()
return all
def query_feedBack_question_list(id_list: list) -> list:
if not id_list or len(id_list) == 0:
return []
session = SqliteSqlalchemy().session
all = session.query(QuestionFeedBack).filter(QuestionFeedBack.id.in_(id_list)).all()
all = [a.to_dict() for a in all] if all else []
session.close()
return all

View File

@@ -72,19 +72,7 @@ template:
</rule>
<rule>
<rule-title>术语标准化规则</rule-title>
<rule-detail>
[重要指令]解析用户问题,识别并替换所有已知的等价短语,将前面的短语换成后面的等价短语,:
** "数信部" -> "数字信息部" (必须替换!)
** "安质部" -> "安全质量部" (必须替换!)
例如:用户:查询数信部 → SQL: LIKE '%数字信息部%'
用户:查询安质部 → SQL: LIKE '%安全质量部%'
例如:查询数信部有多少人->查询数字信息部有多少人
</rule-detail>
<rule-detail>
禁止自动联想或替换:
- "数信部" 不等于 "数信中心"
- "数字信息部" 不等于 "数字信息中心"
</rule-detail>
<rule-detail>
数信中心和数信部都是部门,而非单位
</rule-detail>
@@ -220,10 +208,6 @@ template:
<words><word>数信中心</word><word>建设处</word><word>规划发展部</word><word>综合处</word></words>
<description>这些都可能是单位的名称,属于内部部门</description>
</terminology>
<terminology>
"数信部"务必替换成"数字信息部"
"安质部"务必替换成"安全质量部"
</terminology>
</terminologies>
<!-- [RAG 集成区] -->
<!-- 将从向量数据库/知识库中检索到的最相关的N个问答对放在这里 -->
@@ -552,6 +536,14 @@ template:
2. **生成最终问题**
* **如果判定为【关联】**:你必须将历史对话中的相关上下文信息**融合**到当前问题中,形成一个**完整清晰、无任何指代或歧义的新问题**。
* **如果判定为【不关联】**:你只需**原样输出**当前问题。
3. 短语指标库替换,如果问题出现短语指标,则替换为指标库中的指标。
<rule-detail>
3.1 :数信部 -> 数字信息部 (必须替换!)
3.2 :安质部 -> 安全质量部 (必须替换!)
3.3 :数信中心 -> 数信中心
例如:用户:查询数信部 → SQL: LIKE '%数字信息部%'
用户:查询安质部 → SQL: LIKE '%安全质量部%'
</rule-detail>
# 输出格式要求
- **你的唯一输出必须是一个JSON对象**。
- **严禁**在JSON前后添加任何解释性文字、代码块标记如 ```json或任何其他内容。
@@ -576,6 +568,7 @@ template:
</examples>
####
# 上下文信息
注意:[历史对话中order越小代表消息越新优先基于最新消息进行关联性分析]
**对话历史:**
{history}
**当前用户问题:**

View File

@@ -1156,6 +1156,34 @@ question_and_answer = [
AND p."dr" = 0 LIMIT 1
'''
}
,
{
"question": "9月在林芝工作中有多少天迟到",
"answer": '''
SELECT COUNT(*) as '迟到天数'
FROM "YJOA_APPSERVICE_DB"."t_yj_person_status" ps
JOIN "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p ON ps."person_id" = p."code"
WHERE p."name" = '张三'
AND ps."status" IN ('1006', '1009', '6002', '6004')
AND ps."date_value" in (SELECT distinct (TO_CHAR(a."attendance_time", 'yyyy-MM-dd'))
FROM "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" b
ON a."access_control_point" = b."ac_point"
WHERE a."person_name" = '张三'
and b.region = 5
AND a."attendance_time" >= '2025-09-01'
AND a."attendance_time"
< '2025-10-01'
AND a."dr" = 0
LIMIT 1000
)
AND ps."dr" = 0
AND p."dr" = 0
''',
"tags": ["员工", "个人", "考勤", "工作地", "区域", "工作天数"],
"category": "工作地考勤统计分析"
}
]