diff --git a/.env b/.env index c949e91..aa24d31 100644 --- a/.env +++ b/.env @@ -54,5 +54,5 @@ DAMENG_DATABASE_USER=ai_view TTL_CACHE=43200 SESSION_LENGTH=2 - +NEED_MODEL_FEEDBACK_QA=False ALLOWED_USERS=ea17e939-10f1-492d-adf1-e6ac42f89d81,5fba57a1-0d5f-4988-bcc9-a2b340f8adf1,25bb2ddb-d26d-486f-9aab-5497a9f1e10e,793b49e9-98e6-4f15-8b56-148a691d3022,c76720b5-15f7-4d23-a213-19da38b64d89,103bdc74-a88c-4e1e-a379-794937443c77,c4b2e802-58ba-4927-82fe-f938deb41cf0,1cf59aed-bbf0-4f1c-9246-1d96cc5e2719,6b28c87e-061d-4aca-9a12-71c52e677e73,4c105cfa-ce70-4eca-8d36-c518eda89005,cc93a83b-b145-43f4-80af-fe4ce8a55a04 \ No newline at end of file diff --git a/db_util/db_main.py b/db_util/db_main.py index 30d86d0..74da186 100644 --- a/db_util/db_main.py +++ b/db_util/db_main.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text +from sqlalchemy import Column, DateTime,Integer, String, create_engine, Boolean, Text from sqlalchemy.orm import declarative_base, sessionmaker # 申明基类对象 @@ -23,10 +23,12 @@ class QuestionFeedBack(Base): user_id = Column(String(100), nullable=False, default='1') # 用户意见反馈 user_comment = Column(Text, nullable=True) - # 用户点赞,点踩 - user_praise = Column(Boolean, nullable=False, default=False) + # 用户点赞,点踩 0代表未点赞,1代表点赞,-1代表点踩 + user_praise = Column(Integer, nullable=False, default=0) # 该数据是否被认为梳理过 is_process = Column(Boolean, nullable=False, default=False) + def to_dict(self): + return {"id":self.id, "question":self.question, "user_id":self.user_id, "user_praise":self.user_praise} class Conversation(Base): diff --git a/graph_chat/gen_data_report_agent.py b/graph_chat/gen_data_report_agent.py index 5785046..4e1b13b 100644 --- a/graph_chat/gen_data_report_agent.py +++ b/graph_chat/gen_data_report_agent.py @@ -14,7 +14,8 @@ from util.utils import extract_nested_json from service.conversation_service import update_conversation logger = logging.getLogger(__name__) from template.template import get_base_template - +from decouple import config +NEED_MODEL_FEEDBACK_QA=config("NEED_MODEL_FEEDBACK_QA", default=False, cast=bool) class DateReportAgentState(TypedDict): id: str @@ -45,8 +46,12 @@ def _run_sql(state: DateReportAgentState) -> dict: result = df.to_dict(orient='records') logger.info(f"run sql result {result}") retry = state.get("retry_count", 0) + 1 - logger.info(f"old_sql:{sql} retry sql {retry_sql}") - return {'data': result, 'retry_count': retry, 'retry_sql': retry_sql} + logger.info(f"old_sql:{sql}, retry sql {retry_sql}") + dd= {'data': result, 'retry_count': retry, 'retry_sql': retry_sql} + #不需要模型反馈,运行成功后则成功 + if not NEED_MODEL_FEEDBACK_QA: + dd['sql_correct']=True + return dd except Exception as e: retry = state.get("retry_count", 0) + 1 logger.error(f"Error running sql: {sql}, error: {e}") @@ -88,39 +93,27 @@ def _feedback_qa(state: DateReportAgentState) -> dict: def run_sql_hande(state: DateReportAgentState) -> str: logger.info(f"user:{state.get('user_id', '1')} 进入 _run_sql_hande 节点") - logger.info(f'in _run_sql_handle state={state}') + logger.info(f'in run_sql_handle state={state}') # sql_error = state.get('run_sql_error', '') data = state.get('data', {}) sql_correct = state.get('sql_correct', False) sql_retry_count = state.get('retry_count', 0) logger.info(f"sql_retry_count is {sql_retry_count} sql_correct is {sql_correct}") - if sql_retry_count < 2: + if sql_retry_count < 3: if sql_correct: return '_gen_report' else: - return '_feedback_qa' + if NEED_MODEL_FEEDBACK_QA: + return '_feedback_qa' + else: + return '_run_sql' else: if data and len(data) > 0: return '_gen_report' else: END - # retry_sql = state.get('retry_sql', False) - # if sql_error and len(sql_error) > 0: - # if sql_retry_count < 2: - # return '_feedback_qa' - # # return '_run_sql' - # else: - # return END - # - # if data and len(data) > 0: - # return '_gen_report' - # else: - # if sql_retry_count < 2: - # return '_feedback_qa' - # else: - # return END def _gen_report(state: DateReportAgentState) -> dict: @@ -138,13 +131,16 @@ def _gen_report(state: DateReportAgentState) -> dict: workflowx = StateGraph(DateReportAgentState) workflowx.add_node("_run_sql", _run_sql) -workflowx.add_node("_feedback_qa", _feedback_qa) +if NEED_MODEL_FEEDBACK_QA: + workflowx.add_node("_feedback_qa", _feedback_qa) workflowx.add_node("_gen_report", _gen_report) workflowx.add_edge(START, "_run_sql") -workflowx.add_edge("_feedback_qa", "_run_sql") workflowx.add_edge("_gen_report", END) - -workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa']) +if NEED_MODEL_FEEDBACK_QA: + workflowx.add_edge("_feedback_qa", "_run_sql") + workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, '_gen_report','_feedback_qa']) +else: + workflowx.add_conditional_edges("_run_sql", run_sql_hande, [END, "_run_sql",'_gen_report']) memory = MemorySaver() result_report_agent = workflowx.compile(checkpointer=memory) # png_data=result_report_agent .get_graph().draw_mermaid_png() diff --git a/main_service.py b/main_service.py index e83c9ea..8336704 100644 --- a/main_service.py +++ b/main_service.py @@ -6,8 +6,9 @@ from functools import wraps import util.utils from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper -from service.question_feedback_service import save_save_question_async,query_predefined_question_list -from service.conversation_service import save_conversation,update_conversation,get_qa_by_id,get_latest_question +from service.question_feedback_service import save_save_question_async, query_predefined_question_list, \ + update_user_feedBack, query_feedBack_question_list +from service.conversation_service import save_conversation, update_conversation, get_qa_by_id, get_latest_question from decouple import config import flask from util import load_ddl_doc @@ -125,17 +126,17 @@ def generate_sql_2(): return jsonify({"type": "error", "error": "No user_id or cvs_id provided"}) id = generate_timestamp_id() logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}") - save_conversation(id,user_id,cvs_id,question) + save_conversation(id, user_id, cvs_id, question) try: logger.info(f"Generate sql for {question}") - data = vn.generate_sql_2(user_id,cvs_id,question,id,need_context) + data = vn.generate_sql_2(user_id, cvs_id, question, id, need_context) logger.info("Generate sql result is {0}".format(data)) data['id'] = id sql = data["resp"]["sql"] - logger.info("generate sql is : "+ sql) + logger.info("generate sql is : " + sql) update_conversation(id, sql) save_save_question_async(id, user_id, question, sql) - data["type"]="success" + data["type"] = "success" return jsonify(data) except Exception as e: logger.error(f"generate sql failed:{e}") @@ -295,17 +296,47 @@ def query_present_question(): return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'}) +@app.flask_app.route("/yj_sqlbot/api/v0/query_feedback_question", methods=["POST"]) +def query_feedback_question(): + id_list = request.json.get("id_list", []) + try: + data = query_feedBack_question_list(id_list) + return jsonify({"type": "success", "data": data}) + except Exception as e: + logger.error(f"查询用户反馈问题失败 failed:{e}") + return jsonify({"type": "error", "error": f'查询用户反馈问题失败:{str(e)}'}) + + +@app.flask_app.route("/yj_sqlbot/api/v0/question_feed_back", methods=["PUT"]) +def update_question_feed_back(): + id = request.json.get("id") + user_feedback = request.json.get("user_feedback") + if not id or not user_feedback: + return jsonify({"type": "error", "error": "id 或者用户反馈为空"}) + try: + update_user_feedBack(id, '', user_feedback) + return jsonify({"type": "success"}) + except Exception as e: + logger.error(f"查询预制问题失败 failed:{e}") + return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'}) + + @app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"]) def gen_graph_question(): try: user_id = request.args.get("user_id") cvs_id = request.args.get("cvs_id") - config = {"configurable": {"thread_id": uuid.uuid4()}} + config = {"configurable": {"thread_id": str(uuid.uuid4())}} question = flask.request.args.get("question") - question_context = get_latest_question(cvs_id, user_id,limit_count=2) + question_context = get_latest_question(cvs_id, user_id, limit_count=2) history = [] + i = 0 for q in question_context: - history.append({"role":"user",'content':q}) + is_latest=False + if i==0: + is_latest=True + history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest}) + i +=1 logger.info(f"history is {history}") initial_state: SqlAgentState = { "user_id": user_id, @@ -316,9 +347,10 @@ def gen_graph_question(): } result = sql_chart_agent.invoke(initial_state, config=config) - new_question = result.get('rewritten_user_question',question) + new_question = result.get('rewritten_user_question', question) id = str(uuid.uuid4()) save_conversation(id, user_id, cvs_id, new_question) + # cache.set(id=id, field="question", value=result.get('rewritten_user_question',question)) data = { 'id': id, @@ -330,6 +362,7 @@ def gen_graph_question(): # cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', '')) sql = data.get('sql', {}).get('sql', '') update_conversation(id, sql) + save_save_question_async(id, user_id, new_question, sql) return jsonify(data) except Exception as e: traceback.print_exc() @@ -362,11 +395,9 @@ def run_sql_3(): "sql": sql, "question": question, "retry_count": 0, - "need_feedback": False, - } - config = {"configurable": {"thread_id": uuid.uuid4()}} - rr = result_report_agent.invoke(initial_state,config) + config = {"configurable": {"thread_id": str(uuid.uuid4())}} + rr = result_report_agent.invoke(initial_state, config) logger.info(f"rr.data is {rr.get('data', {})}") return jsonify( { diff --git a/service/question_feedback_service.py b/service/question_feedback_service.py index 45d8d1f..e2dc6d6 100644 --- a/service/question_feedback_service.py +++ b/service/question_feedback_service.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor pool = ThreadPoolExecutor(max_workers=10) + def save_question(id, user_id, question, sql): logger.info(f"开始保存用户用户==>:{question}") user_id = user_id if user_id else '1' @@ -27,7 +28,7 @@ def save_save_question_async(id, user_id, question, sql): pool.submit(save_question, id, user_id, question, sql) -def update_user_feedBack(id, user_comment, user_praise: bool): +def update_user_feedBack(id, user_comment, user_praise: int): session = SqliteSqlalchemy().session try: session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update( @@ -39,12 +40,25 @@ def update_user_feedBack(id, user_comment, user_praise: bool): finally: session.close() + ''' 查询所有的预制问题 ''' + + def query_predefined_question_list() -> list: session = SqliteSqlalchemy().session all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all() all = [a.to_dict() for a in all] if all else [] session.close() return all + + +def query_feedBack_question_list(id_list: list) -> list: + if not id_list or len(id_list) == 0: + return [] + session = SqliteSqlalchemy().session + all = session.query(QuestionFeedBack).filter(QuestionFeedBack.id.in_(id_list)).all() + all = [a.to_dict() for a in all] if all else [] + session.close() + return all diff --git a/template.yaml b/template.yaml index 0983aec..f330174 100644 --- a/template.yaml +++ b/template.yaml @@ -72,19 +72,7 @@ template: 术语标准化规则 - - [重要指令]解析用户问题,识别并替换所有已知的等价短语,将前面的短语换成后面的等价短语,: - ** "数信部" -> "数字信息部" (必须替换!) - ** "安质部" -> "安全质量部" (必须替换!) - 例如:用户:查询数信部 → SQL: LIKE '%数字信息部%' - 用户:查询安质部 → SQL: LIKE '%安全质量部%' - 例如:查询数信部有多少人->查询数字信息部有多少人 - - - 禁止自动联想或替换: - - "数信部" 不等于 "数信中心" - - "数字信息部" 不等于 "数字信息中心" - + 数信中心和数信部都是部门,而非单位 @@ -220,10 +208,6 @@ template: 数信中心建设处规划发展部综合处 这些都可能是单位的名称,属于内部部门 - - "数信部"务必替换成"数字信息部" - "安质部"务必替换成"安全质量部" - @@ -552,6 +536,14 @@ template: 2. **生成最终问题**: * **如果判定为【关联】**:你必须将历史对话中的相关上下文信息**融合**到当前问题中,形成一个**完整清晰、无任何指代或歧义的新问题**。 * **如果判定为【不关联】**:你只需**原样输出**当前问题。 + 3. 短语指标库替换,如果问题出现短语指标,则替换为指标库中的指标。 + + 3.1 :数信部 -> 数字信息部 (必须替换!) + 3.2 :安质部 -> 安全质量部 (必须替换!) + 3.3 :数信中心 -> 数信中心 + 例如:用户:查询数信部 → SQL: LIKE '%数字信息部%' + 用户:查询安质部 → SQL: LIKE '%安全质量部%' + # 输出格式要求 - **你的唯一输出必须是一个JSON对象**。 - **严禁**在JSON前后添加任何解释性文字、代码块标记(如 ```json)或任何其他内容。 @@ -576,6 +568,7 @@ template: #### # 上下文信息 + 注意:[历史对话中order越小,代表消息越新,优先基于最新消息进行关联性分析] **对话历史:** {history} **当前用户问题:** diff --git a/util/q_and_a_test1.py b/util/q_and_a_test1.py index e9820cc..5ed0805 100644 --- a/util/q_and_a_test1.py +++ b/util/q_and_a_test1.py @@ -1156,6 +1156,34 @@ question_and_answer = [ AND p."dr" = 0 LIMIT 1 ''' } +, + { + "question": "9月在林芝工作中有多少天迟到?", + "answer": ''' + SELECT COUNT(*) as '迟到天数' + FROM "YJOA_APPSERVICE_DB"."t_yj_person_status" ps + JOIN "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p ON ps."person_id" = p."code" + WHERE p."name" = '张三' + AND ps."status" IN ('1006', '1009', '6002', '6004') + AND ps."date_value" in (SELECT distinct (TO_CHAR(a."attendance_time", 'yyyy-MM-dd')) + FROM "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a + LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" b + ON a."access_control_point" = b."ac_point" + + WHERE a."person_name" = '张三' + and b.region = 5 + AND a."attendance_time" >= '2025-09-01' + AND a."attendance_time" + < '2025-10-01' + AND a."dr" = 0 + LIMIT 1000 + ) + AND ps."dr" = 0 + AND p."dr" = 0 + ''', + "tags": ["员工", "个人", "考勤", "工作地", "区域", "工作天数"], + "category": "工作地考勤统计分析" + } ]