缓存上下文,表结构添加

This commit is contained in:
yujj128
2025-10-13 18:18:58 +08:00
parent 73cbc55d74
commit be0bc661e2
4 changed files with 371 additions and 24 deletions

View File

@@ -1,5 +1,9 @@
import copy
from email.policy import default from email.policy import default
import logging import logging
from functools import wraps
from Demos.mmapfile_demo import page_size
import util.utils import util.utils
from logging_config import LOGGING_CONFIG from logging_config import LOGGING_CONFIG
@@ -102,8 +106,9 @@ def generate_sql_2():
return jsonify({"type": "error", "error": "No question provided"}) return jsonify({"type": "error", "error": "No question provided"})
try: try:
id = cache.generate_id(question=question) id = cache.generate_id(question=question)
user_id = request.args.get("user_id")
logger.info(f"Generate sql for {question}") logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(question=question) data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id)
logger.info("Generate sql result is {0}".format(data)) logger.info("Generate sql result is {0}".format(data))
data['id'] = id data['id'] = id
sql = data["resp"]["sql"] sql = data["resp"]["sql"]
@@ -116,21 +121,56 @@ def generate_sql_2():
logger.error(f"generate sql failed:{e}") logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)}) return jsonify({"type": "error", "error": str(e)})
def session_save(func):
@wraps(func)
def wrapper(*args, **kwargs):
id=request.args.get("id")
user_id = request.args.get("user_id")
logger.info(f" id: {id},user_id: {user_id}")
result = func(*args, **kwargs)
datas=[]
session_len = int(config("SESSION_LENGTH", default=2))
if cache.exists(id=user_id, field="data"):
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
data = {
"id": id,
"question":cache.get(id=id, field="question"),
"sql":cache.get(id=id, field="sql")
}
datas.append(data)
logger.info("datas is {0}".format(datas))
if len(datas) > session_len and session_len > 0:
datas=datas[-session_len:]
# 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
cache.delete(id=id, field="question")
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
logger.info(f" user data {cache.get(user_id, field='data')}")
return result
return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
@session_save
@app.requires_cache(["sql"]) @app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str, page_num=None, page_size=None): def run_sql_2(id: str, sql: str):
""" """
Run SQL Run SQL
--- ---
parameters: parameters:
- name: user - name: user_id
in: query in: query
required: true
- name: id - name: id
in: query|body in: query|body
type: string type: string
required: true required: true
- name: page_size
in: query
-name: page_num
in: query
responses: responses:
200: 200:
schema: schema:
@@ -158,15 +198,12 @@ def run_sql_2(id: str, sql: str, page_num=None, page_size=None):
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery" # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
# df_count = vn.run_sql(count_sql) # df_count = vn.run_sql(count_sql)
# total_count = df_count[0]["total_count"] if df_count is not None else 0 # print(df_count,"is type",type(df_count))
# total_count = df_count.to_dict(orient="records")[0]["total_count"]
# logger.info("Total count is {0}".format(total_count))
df = vn.run_sql(sql=sql) df = vn.run_sql(sql=sql)
logger.info("")
app.cache.set(id=id, field="df", value=df)
result = df.to_dict(orient='records') result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result,type(result))) logger.info("df ---------------{0} {1}".format(result,type(result)))
# result = util.utils.deal_result(data=result)
return jsonify( return jsonify(
{ {
"type": "success", "type": "success",
@@ -180,7 +217,6 @@ def run_sql_2(id: str, sql: str, page_num=None, page_size=None):
return jsonify({"type": "sql_error", "error": str(e)}) return jsonify({"type": "sql_error", "error": str(e)})
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False) app.run(host='0.0.0.0', port=8084, debug=False)

View File

@@ -1,3 +1,4 @@
from dataclasses import field
from email.policy import default from email.policy import default
from typing import List, Union, Any, Optional from typing import List, Union, Any, Optional
import time import time
@@ -76,10 +77,10 @@ class OpenAICompatibleLLM(VannaBase):
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
logger.info(f"start to run_sql_damengsql") logger.info(f"start to run_sql_damengsql")
try:
if not is_connection_alive(conn=self.conn): if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........") logger.info("connection is not alive, reconnecting..........")
reconnect() reconnect()
try:
# conn.ping(reconnect=True) # conn.ping(reconnect=True)
cs = self.conn.cursor() cs = self.conn.cursor()
cs.execute(sql) cs.execute(sql)
@@ -203,7 +204,7 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content return response.choices[0].message.content
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
try: try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice") logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs) question_sql_list = self.get_similar_question_sql(question, **kwargs)
@@ -215,15 +216,19 @@ class OpenAICompatibleLLM(VannaBase):
template = get_base_template() template = get_base_template()
sql_temp = template['template']['sql'] sql_temp = template['template']['sql']
char_temp = template['template']['chart'] char_temp = template['template']['chart']
history = None
if user_id and cache:
history = cache.get(id=user_id, field="data")
# --------基于提示词生成sql以及图表类型 # --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document], schema=ddl_list, documentation=[train_ddl.train_document],
retrieved_examples_data=question_sql_list, history=history,retrieved_examples_data=question_sql_list,
data_training=question_sql_list,) data_training=question_sql_list,)
logger.info(f"sys_temp:{sys_temp}")
user_temp = sql_temp['user'].format(question=question, user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
logger.info(f"user_temp:{user_temp}") logger.info(f"user_temp:{user_temp}")
logger.info(f"sys_temp:{sys_temp}")
llm_response = self.submit_prompt( llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
logger.info(f"llm_response:{llm_response}") logger.info(f"llm_response:{llm_response}")

View File

@@ -16,6 +16,7 @@ template:
<retrieved-examples>[RAG核心区] 通过检索与当前问题最相关的历史问答对。**这是最高优先级的s参考**优先从中寻找与用户问题意图或表述最相似的案例来指导你生成SQL。 <retrieved-examples>[RAG核心区] 通过检索与当前问题最相关的历史问答对。**这是最高优先级的s参考**优先从中寻找与用户问题意图或表述最相似的案例来指导你生成SQL。
<sql-examples>通用SQL示例库。当<retrieved-examples>中没有足够参考时可在此处寻找相似的用法、函数模板或Join思路作为补充参考。 <sql-examples>通用SQL示例库。当<retrieved-examples>中没有足够参考时可在此处寻找相似的用法、函数模板或Join思路作为补充参考。
<documentation>:数据库或业务相关的补充文档。 <documentation>:数据库或业务相关的补充文档。
<history>:上下文历史,可以通过上下文历史,丰富问题背景
<error-msg>[可选] 上一次生成的SQL执行失败时的错误信息用于修正和优化你的输出。 <error-msg>[可选] 上一次生成的SQL执行失败时的错误信息用于修正和优化你的输出。
<background-infos>[可选] 背景信息,如当前提问时间<current-time>等。 <background-infos>[可选] 背景信息,如当前提问时间<current-time>等。
用户的提问位于 <user-question> 块内。 用户的提问位于 <user-question> 块内。
@@ -156,6 +157,7 @@ template:
<db-engine>{engine}</db-engine> <db-engine>{engine}</db-engine>
<m-schema>{schema}</m-schema> <m-schema>{schema}</m-schema>
<documentation>{documentation}</documentation> <documentation>{documentation}</documentation>
<history>{history}</history>
<terminologies> <terminologies>
<terminology> <terminology>
<words><word国网</word><word>电网</word><word>雅江</word><word>联通</word></words> <words><word国网</word><word>电网</word><word>雅江</word><word>联通</word></words>

View File

@@ -443,12 +443,26 @@ person_ddl_sql = """
], ],
"relationships": [ "relationships": [
{ {
"from": "ytenant_id", "from": "input_dept",
"to_table": "租户表", "to_table": "IUAP_APDOC_BASEDOC.org_orgs",
"to_field": "id", "to_field": "id",
"type": "foreign_key", "type": "foreign_key",
"comment": "关联租户信息" "comment": "关联部门表"
} },
{
"from": "internal_dept",
"to_table": "IUAP_APDOC_BASEDOC.org_orgs",
"to_field": "id",
"type": "foreign_key",
"comment": "关联部门表"
},
{
"from": "internal_unit",
"to_table": "IUAP_APDOC_BASEDOC.org_orgs",
"to_field": "id",
"type": "foreign_key",
"comment": "关联部门表"
},
], ],
"tags": ["人员管理", "人力资源", "审批流程", "基本信息", "工作信息"], "tags": ["人员管理", "人力资源", "审批流程", "基本信息", "工作信息"],
@@ -509,7 +523,7 @@ rule_ddl='''
{ {
"name": "region", "name": "region",
"type": "VARCHAR(50)", "type": "VARCHAR(50)",
"comment": "", "comment": "",
"value":{ "value":{
"1":"北京", "1":"北京",
"2":"成都", "2":"成都",
@@ -518,14 +532,14 @@ rule_ddl='''
"5": "林芝" "5": "林芝"
}, },
"role": "dimension", "role": "dimension",
"tags": [ "考勤的位置","非办公区域不要混淆","枚举"] "tags": [ "考勤的地区位置","非办公区域不要混淆","枚举"]
}, },
], ],
"relationships": [ "relationships": [
{ {
"from": "region", "from": "region",
"to_table": "区域配置表", "to_table": "t_yj_person_ac_area",
"to_field": "region_code", "to_field": "region",
"type": "foreign_key", "type": "foreign_key",
"comment": "关联区域配置信息" "comment": "关联区域配置信息"
} }
@@ -618,3 +632,293 @@ user_status_ddl='''
"tags": ["人员状态", "状态记录", "地区管理", "西藏标识", "每日状态"] "tags": ["人员状态", "状态记录", "地区管理", "西藏标识", "每日状态"]
} }
''' '''
user_attendance_ddl = '''
{
"db_name": "YJOA_APPSERVICE_DB",
"table_name": "t_person_attendance_records",
"table_comment": "人员考勤记录表,存储员工的打卡记录、考勤状态和位置信息",
"columns": [
{
"name": "id",
"type": "VARCHAR(200)",
"comment": "主键ID",
"role": "dimension",
"tags": ["主键", "ID标识"]
},
{
"name": "person_name",
"type": "VARCHAR(50)",
"comment": "人员姓名",
"role": "dimension",
"tags": ["人员信息", "姓名"]
},
{
"name": "person_id",
"type": "VARCHAR(200)",
"comment": "人员ID",
"role": "dimension",
"tags": ["人员标识", "关联字段"]
},
{
"name": "phone_number",
"type": "VARCHAR(50)",
"comment": "手机号码",
"role": "dimension",
"tags": ["联系方式", "人员信息"]
},
{
"name": "attendance_time",
"type": "DATETIME",
"comment": "考勤时间",
"role": "dimension",
"tags": ["时间戳", "打卡时间", "关键时间"]
},
{
"name": "attendance_address",
"type": "VARCHAR(200)",
"comment": "考勤地址",
"role": "dimension",
"tags": ["位置信息", "打卡地点"]
},
{
"name": "status",
"type": "INT",
"comment": "状态",
"value": {
"0": "在岗",
"1": "出差",
"2": "休假"
},
"role": "dimension",
"tags": ["状态标识", "人员在岗状态"]
},
{
"name": "original_id",
"type": "VARCHAR(200)",
"comment": "原始ID",
"role": "dimension",
"tags": ["原数据ID"]
},
{
"name": "source",
"type": "VARCHAR(50)",
"comment": "数据来源",
"value": {
"APP": "手机应用",
"DEVICE": "考勤设备",
"SYSTEM": "系统导入"
},
"role": "dimension",
"tags": ["来源系统", "数据渠道"]
},
{
"name": "dr",
"type": "INT",
"comment": "删除标志",
"value": {
"0": "正常",
"1": "已删除"
},
"role": "dimension",
"tags": ["软删除", "数据状态"]
},
{
"name": "create_time",
"type": "DATETIME",
"comment": "创建时间",
"role": "dimension",
"tags": ["时间戳", "记录创建时间"]
},
{
"name": "enter_or_exit",
"type": "INT",
"comment": "进出类型",
"value": {
"0": "",
"1": ""
},
"role": "dimension",
"tags": ["进出标识", "打卡方向"]
},
{
"name": "access_control_point",
"type": "VARCHAR(50)",
"comment": "门禁点",
"role": "dimension",
"tags": ["门禁位置", "打卡设备点"]
},
{
"name": "by_st",
"type": "VARCHAR(20)",
"comment": "上午打卡时间",
"role": "dimension",
"tags": ["时间范围", "开始时间"]
},
{
"name": "by_et",
"type": "VARCHAR(20)",
"comment": "下午打卡时间",
"role": "dimension",
"tags": ["时间范围", "结束时间"]
},
{
"name": "by_st_field",
"type": "VARCHAR(50)",
"comment": "午休前打卡时间",
"role": "dimension",
"tags": ["中间打卡","时间配置"]
},
{
"name": "by_et_field",
"type": "VARCHAR(50)",
"comment": "午休后打卡时间",
"role": "dimension",
"tags": ["中间打卡", "时间配置"]
},
{
"name": "by_go_type",
"type": "VARCHAR(8)",
"comment": "打卡类型",
"role": "dimension",
"tags": ["类型标识", "打卡类型"]
}
],
"relationships": [
{
"from": "person_id",
"to_table": "t_pr3rl2oj_yj_person_database",
"to_field": "code",
"type": "foreign_key",
"comment": "关联人员基本信息"
},
{
"from": "access_control_point",
"to_table": "t_yj_person_ac_position",
"to_field": "ac_point",
"type": "foreign_key",
"comment": "关联门禁点配置信息"
}
],
"tags": ["考勤记录", "打卡数据", "人员考勤", "时间记录", "位置信息", "门禁系统"]
}
'''
person_ac_position = '''
{
"db_name":"YJOA_APPSERVICE_DB",
"table_name": "t_yj_person_ac_position",
"table_comment": "门禁控制点位置记录",
"columns": [
{
"name": "ac_point",
"type": "VARCHAR(50)",
"comment": "门禁点",
"role": "dimension",
"tags": ["门禁点", "门禁点标识"]
},
{
"name": "position",
"type": "VARCHAR(50)",
"comment": "位置编号",
"role": "dimension",
"tags": ["门禁位置"]
},
],
"relationships": [
{
"from": "ac_point",
"to_table": "t_yj_person_ac_area",
"to_field": "ac_point",
"type": "foreign_key",
"comment": "关联门禁区域关系表"
},
],
"tags": ["门禁控制点","门禁位置"]
}
'''
person_ac_area = '''
{
"db_name":"YJOA_APPSERVICE_DB",
"table_name": "t_yj_person_ac_area",
"table_comment": "门禁区域关系表",
"columns": [
{
"name": "ac_point",
"type": "VARCHAR(50)",
"comment": "门禁点",
"role": "dimension",
"tags": ["门禁点", "门禁点标识"]
},
{
"name": "area",
"type": "Int",
"comment": "区域位置",
"role": "dimension",
"tags": ["门禁所属区域"]
},
{
"name": "region",
"type": "Int",
"comment": "地区位置",
"value":{
"1":"北京",
"2":"成都",
"3":"秭归",
"4":"林芝市区",
"5":"拉萨",
"6":"米林",
"7":"派镇",
"8":"墨脱",
},
"role": "dimension",
"tags": ["门禁所属地区"]
},
],
"tags": ["门禁详情","门禁区域位置","门禁地区信息"]
}
'''
org_orgs_ddl = '''
{
"db_name":"IUAP_APDOC_BASEDOC",
"table_name": "org_orgs",
"table_comment": "人员状态记录表,记录人员每日考勤状态信息包括西藏地区标识",
"columns": [
{
"name": "id",
"type": "VARCHAR(36)",
"comment": "主键ID",
"role": "dimension",
"tags": ["主键", "id标识"]
},
{
"name": "code",
"type": "VARCHAR(50)",
"comment": "编号",
"role": "dimension",
"tags": ["部门编号"]
},
{
"name": "name",
"type": "VARCHAR(50)",
"comment": "部门名称",
"role": "dimension",
"tags": ["部门名称""单位名称"]
},
{
"name": "shortname",
"type": "VARCHAR(1152)",
"comment": "部门简称",
"role": "dimension",
"tags": ["部门名称","部门简称","部门缩写"]
},
],
"tags": ["部门id","部门信息","部门名称"]
}
'''