Merge branch 'dev_graph' of http://106.13.42.156:33077/lei_y601/sqlbot_agent into dev_graph
# Conflicts: # graph_chat/gen_data_report_agent.py # main_service.py
This commit is contained in:
2
.env
2
.env
@@ -54,5 +54,5 @@ DAMENG_DATABASE_USER=ai_view
|
||||
TTL_CACHE=43200
|
||||
SESSION_LENGTH=2
|
||||
|
||||
|
||||
NEED_MODEL_FEEDBACK_QA=False
|
||||
ALLOWED_USERS=ea17e939-10f1-492d-adf1-e6ac42f89d81,5fba57a1-0d5f-4988-bcc9-a2b340f8adf1,25bb2ddb-d26d-486f-9aab-5497a9f1e10e,793b49e9-98e6-4f15-8b56-148a691d3022,c76720b5-15f7-4d23-a213-19da38b64d89,103bdc74-a88c-4e1e-a379-794937443c77,c4b2e802-58ba-4927-82fe-f938deb41cf0,1cf59aed-bbf0-4f1c-9246-1d96cc5e2719,6b28c87e-061d-4aca-9a12-71c52e677e73,4c105cfa-ce70-4eca-8d36-c518eda89005,cc93a83b-b145-43f4-80af-fe4ce8a55a04
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text
|
||||
from sqlalchemy import Column, DateTime,Integer, String, create_engine, Boolean, Text
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
# 申明基类对象
|
||||
@@ -23,10 +23,12 @@ class QuestionFeedBack(Base):
|
||||
user_id = Column(String(100), nullable=False, default='1')
|
||||
# 用户意见反馈
|
||||
user_comment = Column(Text, nullable=True)
|
||||
# 用户点赞,点踩
|
||||
user_praise = Column(Boolean, nullable=False, default=False)
|
||||
# 用户点赞,点踩 0代表未点赞,1代表点赞,-1代表点踩
|
||||
user_praise = Column(Integer, nullable=False, default=0)
|
||||
# 该数据是否被认为梳理过
|
||||
is_process = Column(Boolean, nullable=False, default=False)
|
||||
def to_dict(self):
|
||||
return {"id":self.id, "question":self.question, "user_id":self.user_id, "user_praise":self.user_praise}
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
|
||||
@@ -14,7 +14,8 @@ from util.utils import extract_nested_json
|
||||
from service.conversation_service import update_conversation
|
||||
logger = logging.getLogger(__name__)
|
||||
from template.template import get_base_template
|
||||
|
||||
from decouple import config
|
||||
NEED_MODEL_FEEDBACK_QA=config("NEED_MODEL_FEEDBACK_QA", default=False, cast=bool)
|
||||
|
||||
class DateReportAgentState(TypedDict):
|
||||
id: str
|
||||
@@ -45,8 +46,12 @@ def _run_sql(state: DateReportAgentState) -> dict:
|
||||
result = df.to_dict(orient='records')
|
||||
logger.info(f"run sql result {result}")
|
||||
retry = state.get("retry_count", 0) + 1
|
||||
logger.info(f"old_sql:{sql} retry sql {retry_sql}")
|
||||
return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
|
||||
logger.info(f"old_sql:{sql}, retry sql {retry_sql}")
|
||||
dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql}
|
||||
#不需要模型反馈,运行成功后则成功
|
||||
if not NEED_MODEL_FEEDBACK_QA:
|
||||
dd['sql_correct']=True
|
||||
return dd
|
||||
except Exception as e:
|
||||
retry = state.get("retry_count", 0) + 1
|
||||
logger.error(f"Error running sql: {sql}, error: {e}")
|
||||
@@ -88,39 +93,27 @@ def _feedback_qa(state: DateReportAgentState) -> dict:
|
||||
|
||||
def run_sql_hande(state: DateReportAgentState) -> str:
|
||||
logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点")
|
||||
logger.info(f'in _run_sql_handle state={state}')
|
||||
logger.info(f'in run_sql_handle state={state}')
|
||||
# sql_error = state.get('run_sql_error', '')
|
||||
data = state.get('data', {})
|
||||
sql_correct = state.get('sql_correct', False)
|
||||
sql_retry_count = state.get('retry_count', 0)
|
||||
logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}")
|
||||
|
||||
if sql_retry_count < 2:
|
||||
if sql_retry_count < 3:
|
||||
if sql_correct:
|
||||
return '_gen_report'
|
||||
else:
|
||||
return '_feedback_qa'
|
||||
if NEED_MODEL_FEEDBACK_QA:
|
||||
return '_feedback_qa'
|
||||
else:
|
||||
return '_run_sql'
|
||||
else:
|
||||
if data and len(data) > 0:
|
||||
return '_gen_report'
|
||||
else:
|
||||
END
|
||||
# retry_sql = state.get('retry_sql', False)
|
||||
|
||||
# if sql_error and len(sql_error) > 0:
|
||||
# if sql_retry_count < 2:
|
||||
# return '_feedback_qa'
|
||||
# # return '_run_sql'
|
||||
# else:
|
||||
# return END
|
||||
#
|
||||
# if data and len(data) > 0:
|
||||
# return '_gen_report'
|
||||
# else:
|
||||
# if sql_retry_count < 2:
|
||||
# return '_feedback_qa'
|
||||
# else:
|
||||
# return END
|
||||
|
||||
|
||||
def _gen_report(state: DateReportAgentState) -> dict:
|
||||
@@ -138,13 +131,16 @@ def _gen_report(state: DateReportAgentState) -> dict:
|
||||
|
||||
workflowx = StateGraph(DateReportAgentState)
|
||||
workflowx.add_node("_run_sql", _run_sql)
|
||||
workflowx.add_node("_feedback_qa", _feedback_qa)
|
||||
if NEED_MODEL_FEEDBACK_QA:
|
||||
workflowx.add_node("_feedback_qa", _feedback_qa)
|
||||
workflowx.add_node("_gen_report", _gen_report)
|
||||
workflowx.add_edge(START, "_run_sql")
|
||||
workflowx.add_edge("_feedback_qa", "_run_sql")
|
||||
workflowx.add_edge("_gen_report", END)
|
||||
|
||||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
|
||||
if NEED_MODEL_FEEDBACK_QA:
|
||||
workflowx.add_edge("_feedback_qa", "_run_sql")
|
||||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa'])
|
||||
else:
|
||||
workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report'])
|
||||
memory = MemorySaver()
|
||||
result_report_agent = workflowx.compile(checkpointer=memory)
|
||||
# png_data=result_report_agent .get_graph().draw_mermaid_png()
|
||||
|
||||
@@ -6,8 +6,9 @@ from functools import wraps
|
||||
import util.utils
|
||||
from logging_config import LOGGING_CONFIG
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||||
from service.conversation_service import save_conversation,update_conversation,get_qa_by_id,get_latest_question
|
||||
from service.question_feedback_service import save_save_question_async, query_predefined_question_list, \
|
||||
update_user_feedBack, query_feedBack_question_list
|
||||
from service.conversation_service import save_conversation, update_conversation, get_qa_by_id, get_latest_question
|
||||
from decouple import config
|
||||
import flask
|
||||
from util import load_ddl_doc
|
||||
@@ -125,17 +126,17 @@ def generate_sql_2():
|
||||
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
|
||||
id = generate_timestamp_id()
|
||||
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
|
||||
save_conversation(id,user_id,cvs_id,question)
|
||||
save_conversation(id, user_id, cvs_id, question)
|
||||
try:
|
||||
logger.info(f"Generate sql for {question}")
|
||||
data = vn.generate_sql_2(user_id,cvs_id,question,id,need_context)
|
||||
data = vn.generate_sql_2(user_id, cvs_id, question, id, need_context)
|
||||
logger.info("Generate sql result is {0}".format(data))
|
||||
data['id'] = id
|
||||
sql = data["resp"]["sql"]
|
||||
logger.info("generate sql is : "+ sql)
|
||||
logger.info("generate sql is : " + sql)
|
||||
update_conversation(id, sql)
|
||||
save_save_question_async(id, user_id, question, sql)
|
||||
data["type"]="success"
|
||||
data["type"] = "success"
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
logger.error(f"generate sql failed:{e}")
|
||||
@@ -295,17 +296,47 @@ def query_present_question():
|
||||
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/query_feedback_question", methods=["POST"])
|
||||
def query_feedback_question():
|
||||
id_list = request.json.get("id_list", [])
|
||||
try:
|
||||
data = query_feedBack_question_list(id_list)
|
||||
return jsonify({"type": "success", "data": data})
|
||||
except Exception as e:
|
||||
logger.error(f"查询用户反馈问题失败 failed:{e}")
|
||||
return jsonify({"type": "error", "error": f'查询用户反馈问题失败:{str(e)}'})
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/question_feed_back", methods=["PUT"])
|
||||
def update_question_feed_back():
|
||||
id = request.json.get("id")
|
||||
user_feedback = request.json.get("user_feedback")
|
||||
if not id or not user_feedback:
|
||||
return jsonify({"type": "error", "error": "id 或者用户反馈为空"})
|
||||
try:
|
||||
update_user_feedBack(id, '', user_feedback)
|
||||
return jsonify({"type": "success"})
|
||||
except Exception as e:
|
||||
logger.error(f"查询预制问题失败 failed:{e}")
|
||||
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
|
||||
def gen_graph_question():
|
||||
try:
|
||||
user_id = request.args.get("user_id")
|
||||
cvs_id = request.args.get("cvs_id")
|
||||
config = {"configurable": {"thread_id": uuid.uuid4()}}
|
||||
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
|
||||
question = flask.request.args.get("question")
|
||||
question_context = get_latest_question(cvs_id, user_id,limit_count=2)
|
||||
question_context = get_latest_question(cvs_id, user_id, limit_count=2)
|
||||
history = []
|
||||
i = 0
|
||||
for q in question_context:
|
||||
history.append({"role":"user",'content':q})
|
||||
is_latest=False
|
||||
if i==0:
|
||||
is_latest=True
|
||||
history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest})
|
||||
i +=1
|
||||
logger.info(f"history is {history}")
|
||||
initial_state: SqlAgentState = {
|
||||
"user_id": user_id,
|
||||
@@ -316,9 +347,10 @@ def gen_graph_question():
|
||||
}
|
||||
|
||||
result = sql_chart_agent.invoke(initial_state, config=config)
|
||||
new_question = result.get('rewritten_user_question',question)
|
||||
new_question = result.get('rewritten_user_question', question)
|
||||
id = str(uuid.uuid4())
|
||||
save_conversation(id, user_id, cvs_id, new_question)
|
||||
|
||||
# cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
|
||||
data = {
|
||||
'id': id,
|
||||
@@ -330,6 +362,7 @@ def gen_graph_question():
|
||||
# cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
|
||||
sql = data.get('sql', {}).get('sql', '')
|
||||
update_conversation(id, sql)
|
||||
save_save_question_async(id, user_id, new_question, sql)
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@@ -362,11 +395,9 @@ def run_sql_3():
|
||||
"sql": sql,
|
||||
"question": question,
|
||||
"retry_count": 0,
|
||||
"need_feedback": False,
|
||||
|
||||
}
|
||||
config = {"configurable": {"thread_id": uuid.uuid4()}}
|
||||
rr = result_report_agent.invoke(initial_state,config)
|
||||
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
|
||||
rr = result_report_agent.invoke(initial_state, config)
|
||||
logger.info(f"rr.data is {rr.get('data', {})}")
|
||||
return jsonify(
|
||||
{
|
||||
|
||||
@@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
pool = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
|
||||
def save_question(id, user_id, question, sql):
|
||||
logger.info(f"开始保存用户用户==>:{question}")
|
||||
user_id = user_id if user_id else '1'
|
||||
@@ -27,7 +28,7 @@ def save_save_question_async(id, user_id, question, sql):
|
||||
pool.submit(save_question, id, user_id, question, sql)
|
||||
|
||||
|
||||
def update_user_feedBack(id, user_comment, user_praise: bool):
|
||||
def update_user_feedBack(id, user_comment, user_praise: int):
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update(
|
||||
@@ -39,12 +40,25 @@ def update_user_feedBack(id, user_comment, user_praise: bool):
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
'''
|
||||
查询所有的预制问题
|
||||
'''
|
||||
|
||||
|
||||
def query_predefined_question_list() -> list:
|
||||
session = SqliteSqlalchemy().session
|
||||
all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all()
|
||||
all = [a.to_dict() for a in all] if all else []
|
||||
session.close()
|
||||
return all
|
||||
|
||||
|
||||
def query_feedBack_question_list(id_list: list) -> list:
|
||||
if not id_list or len(id_list) == 0:
|
||||
return []
|
||||
session = SqliteSqlalchemy().session
|
||||
all = session.query(QuestionFeedBack).filter(QuestionFeedBack.id.in_(id_list)).all()
|
||||
all = [a.to_dict() for a in all] if all else []
|
||||
session.close()
|
||||
return all
|
||||
|
||||
@@ -72,19 +72,7 @@ template:
|
||||
</rule>
|
||||
<rule>
|
||||
<rule-title>术语标准化规则</rule-title>
|
||||
<rule-detail>
|
||||
[重要指令]解析用户问题,识别并替换所有已知的等价短语,将前面的短语换成后面的等价短语,:
|
||||
** "数信部" -> "数字信息部" (必须替换!)
|
||||
** "安质部" -> "安全质量部" (必须替换!)
|
||||
例如:用户:查询数信部 → SQL: LIKE '%数字信息部%'
|
||||
用户:查询安质部 → SQL: LIKE '%安全质量部%'
|
||||
例如:查询数信部有多少人->查询数字信息部有多少人
|
||||
</rule-detail>
|
||||
<rule-detail>
|
||||
禁止自动联想或替换:
|
||||
- "数信部" 不等于 "数信中心"
|
||||
- "数字信息部" 不等于 "数字信息中心"
|
||||
</rule-detail>
|
||||
|
||||
<rule-detail>
|
||||
数信中心和数信部都是部门,而非单位
|
||||
</rule-detail>
|
||||
@@ -220,10 +208,6 @@ template:
|
||||
<words><word>数信中心</word><word>建设处</word><word>规划发展部</word><word>综合处</word></words>
|
||||
<description>这些都可能是单位的名称,属于内部部门</description>
|
||||
</terminology>
|
||||
<terminology>
|
||||
"数信部"务必替换成"数字信息部"
|
||||
"安质部"务必替换成"安全质量部"
|
||||
</terminology>
|
||||
</terminologies>
|
||||
<!-- [RAG 集成区] -->
|
||||
<!-- 将从向量数据库/知识库中检索到的最相关的N个问答对放在这里 -->
|
||||
@@ -552,6 +536,14 @@ template:
|
||||
2. **生成最终问题**:
|
||||
* **如果判定为【关联】**:你必须将历史对话中的相关上下文信息**融合**到当前问题中,形成一个**完整清晰、无任何指代或歧义的新问题**。
|
||||
* **如果判定为【不关联】**:你只需**原样输出**当前问题。
|
||||
3. 短语指标库替换,如果问题出现短语指标,则替换为指标库中的指标。
|
||||
<rule-detail>
|
||||
3.1 :数信部 -> 数字信息部 (必须替换!)
|
||||
3.2 :安质部 -> 安全质量部 (必须替换!)
|
||||
3.3 :数信中心 -> 数信中心
|
||||
例如:用户:查询数信部 → SQL: LIKE '%数字信息部%'
|
||||
用户:查询安质部 → SQL: LIKE '%安全质量部%'
|
||||
</rule-detail>
|
||||
# 输出格式要求
|
||||
- **你的唯一输出必须是一个JSON对象**。
|
||||
- **严禁**在JSON前后添加任何解释性文字、代码块标记(如 ```json)或任何其他内容。
|
||||
@@ -576,6 +568,7 @@ template:
|
||||
</examples>
|
||||
####
|
||||
# 上下文信息
|
||||
注意:[历史对话中order越小,代表消息越新,优先基于最新消息进行关联性分析]
|
||||
**对话历史:**
|
||||
{history}
|
||||
**当前用户问题:**
|
||||
|
||||
@@ -1156,6 +1156,34 @@ question_and_answer = [
|
||||
AND p."dr" = 0 LIMIT 1
|
||||
'''
|
||||
}
|
||||
,
|
||||
{
|
||||
"question": "9月在林芝工作中有多少天迟到?",
|
||||
"answer": '''
|
||||
|
||||
SELECT COUNT(*) as '迟到天数'
|
||||
FROM "YJOA_APPSERVICE_DB"."t_yj_person_status" ps
|
||||
JOIN "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p ON ps."person_id" = p."code"
|
||||
WHERE p."name" = '张三'
|
||||
AND ps."status" IN ('1006', '1009', '6002', '6004')
|
||||
AND ps."date_value" in (SELECT distinct (TO_CHAR(a."attendance_time", 'yyyy-MM-dd'))
|
||||
FROM "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a
|
||||
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" b
|
||||
ON a."access_control_point" = b."ac_point"
|
||||
|
||||
WHERE a."person_name" = '张三'
|
||||
and b.region = 5
|
||||
AND a."attendance_time" >= '2025-09-01'
|
||||
AND a."attendance_time"
|
||||
< '2025-10-01'
|
||||
AND a."dr" = 0
|
||||
LIMIT 1000
|
||||
)
|
||||
AND ps."dr" = 0
|
||||
AND p."dr" = 0
|
||||
''',
|
||||
"tags": ["员工", "个人", "考勤", "工作地", "区域", "工作天数"],
|
||||
"category": "工作地考勤统计分析"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user