diff --git a/main_service.py b/main_service.py index 4414ae3..6038fa2 100644 --- a/main_service.py +++ b/main_service.py @@ -1,5 +1,9 @@ +import copy from email.policy import default import logging +from functools import wraps + +from Demos.mmapfile_demo import page_size import util.utils from logging_config import LOGGING_CONFIG @@ -102,8 +106,9 @@ def generate_sql_2(): return jsonify({"type": "error", "error": "No question provided"}) try: id = cache.generate_id(question=question) + user_id = request.args.get("user_id") logger.info(f"Generate sql for {question}") - data = vn.generate_sql_2(question=question) + data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id) logger.info("Generate sql result is {0}".format(data)) data['id'] = id sql = data["resp"]["sql"] @@ -116,21 +121,56 @@ def generate_sql_2(): logger.error(f"generate sql failed:{e}") return jsonify({"type": "error", "error": str(e)}) +def session_save(func): + @wraps(func) + def wrapper(*args, **kwargs): + id=request.args.get("id") + user_id = request.args.get("user_id") + logger.info(f" id: {id},user_id: {user_id}") + result = func(*args, **kwargs) + + datas=[] + session_len = int(config("SESSION_LENGTH", default=2)) + if cache.exists(id=user_id, field="data"): + datas = copy.deepcopy(cache.get(id=user_id, field="data")) + data = { + "id": id, + "question":cache.get(id=id, field="question"), + "sql":cache.get(id=id, field="sql") + } + datas.append(data) + logger.info("datas is {0}".format(datas)) + if len(datas) > session_len and session_len > 0: + datas=datas[-session_len:] + # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文 + cache.delete(id=id, field="question") + cache.set(id=user_id, field="data", value=copy.deepcopy(datas)) + logger.info(f" user data {cache.get(user_id, field='data')}") + return result + + return wrapper + @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) +@session_save @app.requires_cache(["sql"]) -def run_sql_2(id: str, sql: str, page_num=None, page_size=None): +def run_sql_2(id: str, sql: str): """ Run SQL --- parameters: - - name: user + - name: user_id in: query + required: true - name: id in: query|body type: string required: true + - name: page_size + in: query + -name: page_num + in: query responses: 200: schema: @@ -158,15 +198,12 @@ def run_sql_2(id: str, sql: str, page_num=None, page_size=None): # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery" # df_count = vn.run_sql(count_sql) - # total_count = df_count[0]["total_count"] if df_count is not None else 0 + # print(df_count,"is type",type(df_count)) + # total_count = df_count.to_dict(orient="records")[0]["total_count"] + # logger.info("Total count is {0}".format(total_count)) df = vn.run_sql(sql=sql) - logger.info("") - app.cache.set(id=id, field="df", value=df) result = df.to_dict(orient='records') logger.info("df ---------------{0} {1}".format(result,type(result))) - # result = util.utils.deal_result(data=result) - - return jsonify( { "type": "success", @@ -180,7 +217,6 @@ def run_sql_2(id: str, sql: str, page_num=None, page_size=None): return jsonify({"type": "sql_error", "error": str(e)}) - if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False) diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py index a99963e..4a59999 100644 --- a/service/cus_vanna_srevice.py +++ b/service/cus_vanna_srevice.py @@ -1,3 +1,4 @@ +from dataclasses import field from email.policy import default from typing import List, Union, Any, Optional import time @@ -76,10 +77,10 @@ class OpenAICompatibleLLM(VannaBase): def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: logger.info(f"start to run_sql_damengsql") - if not is_connection_alive(conn=self.conn): - logger.info("connection is not alive, reconnecting..........") - reconnect() try: + if not is_connection_alive(conn=self.conn): + logger.info("connection is not alive, reconnecting..........") + reconnect() # conn.ping(reconnect=True) cs = self.conn.cursor() cs.execute(sql) @@ -203,7 +204,7 @@ class OpenAICompatibleLLM(VannaBase): return response.choices[0].message.content - def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: + def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict: try: logger.info("Start to generate_sql_2 in cus_vanna_srevice") question_sql_list = self.get_similar_question_sql(question, **kwargs) @@ -215,15 +216,19 @@ class OpenAICompatibleLLM(VannaBase): template = get_base_template() sql_temp = template['template']['sql'] char_temp = template['template']['chart'] + history = None + if user_id and cache: + history = cache.get(id=user_id, field="data") # --------基于提示词,生成sql以及图表类型 sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', schema=ddl_list, documentation=[train_ddl.train_document], - retrieved_examples_data=question_sql_list, + history=history,retrieved_examples_data=question_sql_list, data_training=question_sql_list,) - logger.info(f"sys_temp:{sys_temp}") + user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logger.info(f"user_temp:{user_temp}") + logger.info(f"sys_temp:{sys_temp}") llm_response = self.submit_prompt( [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) logger.info(f"llm_response:{llm_response}") diff --git a/template.yaml b/template.yaml index 29e193c..378bf46 100644 --- a/template.yaml +++ b/template.yaml @@ -16,6 +16,7 @@ template: :[RAG核心区] 通过检索与当前问题最相关的历史问答对。**这是最高优先级的s参考**,优先从中寻找与用户问题意图或表述最相似的案例来指导你生成SQL。 :通用SQL示例库。当中没有足够参考时,可在此处寻找相似的用法、函数模板或Join思路作为补充参考。 :数据库或业务相关的补充文档。 + :上下文历史,可以通过上下文历史,丰富问题背景 :[可选] 上一次生成的SQL执行失败时的错误信息,用于修正和优化你的输出。 :[可选] 背景信息,如当前提问时间等。 用户的提问位于 块内。 @@ -156,6 +157,7 @@ template: {engine} {schema} {documentation} + {history} 电网雅江联通 diff --git a/util/train_ddl.py b/util/train_ddl.py index 6dba9ee..c413f21 100644 --- a/util/train_ddl.py +++ b/util/train_ddl.py @@ -443,12 +443,26 @@ person_ddl_sql = """ ], "relationships": [ { - "from": "ytenant_id", - "to_table": "租户表", + "from": "input_dept", + "to_table": "IUAP_APDOC_BASEDOC.org_orgs", "to_field": "id", "type": "foreign_key", - "comment": "关联租户信息" - } + "comment": "关联部门表" + }, + { + "from": "internal_dept", + "to_table": "IUAP_APDOC_BASEDOC.org_orgs", + "to_field": "id", + "type": "foreign_key", + "comment": "关联部门表" + }, + { + "from": "internal_unit", + "to_table": "IUAP_APDOC_BASEDOC.org_orgs", + "to_field": "id", + "type": "foreign_key", + "comment": "关联部门表" + }, ], "tags": ["人员管理", "人力资源", "审批流程", "基本信息", "工作信息"], @@ -509,7 +523,7 @@ rule_ddl=''' { "name": "region", "type": "VARCHAR(50)", - "comment": "区域", + "comment": "地区", "value":{ "1":"北京", "2":"成都", @@ -518,14 +532,14 @@ rule_ddl=''' "5": "林芝" }, "role": "dimension", - "tags": [ "考勤的位置","非办公区域不要混淆","枚举"] + "tags": [ "考勤的地区位置","非办公区域,不要混淆","枚举"] }, ], "relationships": [ { "from": "region", - "to_table": "区域配置表", - "to_field": "region_code", + "to_table": "t_yj_person_ac_area", + "to_field": "region", "type": "foreign_key", "comment": "关联区域配置信息" } @@ -617,4 +631,294 @@ user_status_ddl=''' "tags": ["人员状态", "状态记录", "地区管理", "西藏标识", "每日状态"] } +''' + +user_attendance_ddl = ''' +{ + "db_name": "YJOA_APPSERVICE_DB", + "table_name": "t_person_attendance_records", + "table_comment": "人员考勤记录表,存储员工的打卡记录、考勤状态和位置信息", + "columns": [ + { + "name": "id", + "type": "VARCHAR(200)", + "comment": "主键ID", + "role": "dimension", + "tags": ["主键", "ID标识"] + }, + { + "name": "person_name", + "type": "VARCHAR(50)", + "comment": "人员姓名", + "role": "dimension", + "tags": ["人员信息", "姓名"] + }, + { + "name": "person_id", + "type": "VARCHAR(200)", + "comment": "人员ID", + "role": "dimension", + "tags": ["人员标识", "关联字段"] + }, + { + "name": "phone_number", + "type": "VARCHAR(50)", + "comment": "手机号码", + "role": "dimension", + "tags": ["联系方式", "人员信息"] + }, + { + "name": "attendance_time", + "type": "DATETIME", + "comment": "考勤时间", + "role": "dimension", + "tags": ["时间戳", "打卡时间", "关键时间"] + }, + { + "name": "attendance_address", + "type": "VARCHAR(200)", + "comment": "考勤地址", + "role": "dimension", + "tags": ["位置信息", "打卡地点"] + }, + { + "name": "status", + "type": "INT", + "comment": "状态", + "value": { + "0": "在岗", + "1": "出差", + "2": "休假" + }, + "role": "dimension", + "tags": ["状态标识", "人员在岗状态"] + }, + { + "name": "original_id", + "type": "VARCHAR(200)", + "comment": "原始ID", + "role": "dimension", + "tags": ["原数据ID"] + }, + { + "name": "source", + "type": "VARCHAR(50)", + "comment": "数据来源", + "value": { + "APP": "手机应用", + "DEVICE": "考勤设备", + "SYSTEM": "系统导入" + }, + "role": "dimension", + "tags": ["来源系统", "数据渠道"] + }, + { + "name": "dr", + "type": "INT", + "comment": "删除标志", + "value": { + "0": "正常", + "1": "已删除" + }, + "role": "dimension", + "tags": ["软删除", "数据状态"] + }, + { + "name": "create_time", + "type": "DATETIME", + "comment": "创建时间", + "role": "dimension", + "tags": ["时间戳", "记录创建时间"] + }, + { + "name": "enter_or_exit", + "type": "INT", + "comment": "进出类型", + "value": { + "0": "进", + "1": "出" + }, + "role": "dimension", + "tags": ["进出标识", "打卡方向"] + }, + { + "name": "access_control_point", + "type": "VARCHAR(50)", + "comment": "门禁点", + "role": "dimension", + "tags": ["门禁位置", "打卡设备点"] + }, + { + "name": "by_st", + "type": "VARCHAR(20)", + "comment": "上午打卡时间", + "role": "dimension", + "tags": ["时间范围", "开始时间"] + }, + { + "name": "by_et", + "type": "VARCHAR(20)", + "comment": "下午打卡时间", + "role": "dimension", + "tags": ["时间范围", "结束时间"] + }, + { + "name": "by_st_field", + "type": "VARCHAR(50)", + "comment": "午休前打卡时间", + "role": "dimension", + "tags": ["中间打卡","时间配置"] + }, + { + "name": "by_et_field", + "type": "VARCHAR(50)", + "comment": "午休后打卡时间", + "role": "dimension", + "tags": ["中间打卡", "时间配置"] + }, + { + "name": "by_go_type", + "type": "VARCHAR(8)", + "comment": "打卡类型", + "role": "dimension", + "tags": ["类型标识", "打卡类型"] + } + ], + "relationships": [ + { + "from": "person_id", + "to_table": "t_pr3rl2oj_yj_person_database", + "to_field": "code", + "type": "foreign_key", + "comment": "关联人员基本信息" + }, + { + "from": "access_control_point", + "to_table": "t_yj_person_ac_position", + "to_field": "ac_point", + "type": "foreign_key", + "comment": "关联门禁点配置信息" + } + ], + "tags": ["考勤记录", "打卡数据", "人员考勤", "时间记录", "位置信息", "门禁系统"] +} +''' + +person_ac_position = ''' +{ + "db_name":"YJOA_APPSERVICE_DB", + "table_name": "t_yj_person_ac_position", + "table_comment": "门禁控制点位置记录", + "columns": [ + { + "name": "ac_point", + "type": "VARCHAR(50)", + "comment": "门禁点", + "role": "dimension", + "tags": ["门禁点", "门禁点标识"] + }, + { + "name": "position", + "type": "VARCHAR(50)", + "comment": "位置编号", + "role": "dimension", + "tags": ["门禁位置"] + }, + ], + "relationships": [ + { + "from": "ac_point", + "to_table": "t_yj_person_ac_area", + "to_field": "ac_point", + "type": "foreign_key", + "comment": "关联门禁区域关系表" + }, + ], + + "tags": ["门禁控制点","门禁位置"] +} +''' + +person_ac_area = ''' +{ + "db_name":"YJOA_APPSERVICE_DB", + "table_name": "t_yj_person_ac_area", + "table_comment": "门禁区域关系表", + "columns": [ + { + "name": "ac_point", + "type": "VARCHAR(50)", + "comment": "门禁点", + "role": "dimension", + "tags": ["门禁点", "门禁点标识"] + }, + { + "name": "area", + "type": "Int", + "comment": "区域位置", + "role": "dimension", + "tags": ["门禁所属区域"] + }, + { + "name": "region", + "type": "Int", + "comment": "地区位置", + "value":{ + "1":"北京", + "2":"成都", + "3":"秭归", + "4":"林芝市区", + "5":"拉萨", + "6":"米林", + "7":"派镇", + "8":"墨脱", + }, + "role": "dimension", + "tags": ["门禁所属地区"] + }, + ], + + "tags": ["门禁详情","门禁区域位置","门禁地区信息"] +} +''' + + +org_orgs_ddl = ''' +{ + "db_name":"IUAP_APDOC_BASEDOC", + "table_name": "org_orgs", + "table_comment": "人员状态记录表,记录人员每日考勤状态信息包括西藏地区标识", + "columns": [ + { + "name": "id", + "type": "VARCHAR(36)", + "comment": "主键ID", + "role": "dimension", + "tags": ["主键", "id标识"] + }, + { + "name": "code", + "type": "VARCHAR(50)", + "comment": "编号", + "role": "dimension", + "tags": ["部门编号"] + }, + { + "name": "name", + "type": "VARCHAR(50)", + "comment": "部门名称", + "role": "dimension", + "tags": ["部门名称","单位名称"] + }, + { + "name": "shortname", + "type": "VARCHAR(1152)", + "comment": "部门简称", + "role": "dimension", + "tags": ["部门名称","部门简称","部门缩写"] + }, + ], + + "tags": ["部门id","部门信息","部门名称"] +} ''' \ No newline at end of file