Compare commits

...

2 Commits

Author SHA1 Message Date
yujj128
be0bc661e2 缓存上下文,表结构添加 2025-10-13 18:18:58 +08:00
yujj128
73cbc55d74 为memory添加ttl,启动定时清理 2025-10-10 16:39:59 +08:00
4 changed files with 498 additions and 39 deletions

View File

@@ -1,9 +1,13 @@
import copy
from email.policy import default
import logging
from functools import wraps
from Demos.mmapfile_demo import page_size
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from decouple import config
import flask
from util import load_ddl_doc
@@ -67,6 +71,7 @@ def init_vn(vn):
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn,chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int))
init_vn(vn)
cache = app.cache
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
@@ -101,8 +106,9 @@ def generate_sql_2():
return jsonify({"type": "error", "error": "No question provided"})
try:
id = cache.generate_id(question=question)
user_id = request.args.get("user_id")
logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(question=question)
data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id)
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
@@ -115,21 +121,56 @@ def generate_sql_2():
logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)})
def session_save(func):
@wraps(func)
def wrapper(*args, **kwargs):
id=request.args.get("id")
user_id = request.args.get("user_id")
logger.info(f" id: {id},user_id: {user_id}")
result = func(*args, **kwargs)
datas=[]
session_len = int(config("SESSION_LENGTH", default=2))
if cache.exists(id=user_id, field="data"):
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
data = {
"id": id,
"question":cache.get(id=id, field="question"),
"sql":cache.get(id=id, field="sql")
}
datas.append(data)
logger.info("datas is {0}".format(datas))
if len(datas) > session_len and session_len > 0:
datas=datas[-session_len:]
# 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
cache.delete(id=id, field="question")
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
logger.info(f" user data {cache.get(user_id, field='data')}")
return result
return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
@session_save
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
"""
Run SQL
---
parameters:
- name: user
- name: user_id
in: query
required: true
- name: id
in: query|body
type: string
required: true
- name: page_size
in: query
-name: page_num
in: query
responses:
200:
schema:
@@ -155,13 +196,14 @@ def run_sql_2(id: str, sql: str):
}
)
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
# df_count = vn.run_sql(count_sql)
# print(df_count,"is type",type(df_count))
# total_count = df_count.to_dict(orient="records")[0]["total_count"]
# logger.info("Total count is {0}".format(total_count))
df = vn.run_sql(sql=sql)
logger.info("")
app.cache.set(id=id, field="df", value=df)
result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result,type(result)))
# result = util.utils.deal_result(data=result)
return jsonify(
{
"type": "success",
@@ -174,6 +216,7 @@ def run_sql_2(id: str, sql: str):
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)

View File

@@ -1,5 +1,9 @@
from dataclasses import field
from email.policy import default
from typing import List, Union
from typing import List, Union, Any, Optional
import time
import threading
from vanna.flask import Cache, MemoryCache
import dmPython
import orjson
import pandas as pd
@@ -65,17 +69,20 @@ class OpenAICompatibleLLM(VannaBase):
port: int = None,
**kwargs
):
conn = None
self.conn = None
try:
conn = dmPython.connect(user=user, password=password, server=host, port=port)
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"Failed to connect to dameng database: {e}")
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
if conn:
logger.info(f"start to run_sql_damengsql")
try:
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
# conn.ping(reconnect=True)
cs = conn.cursor()
cs = self.conn.cursor()
cs.execute(sql)
results = cs.fetchall()
@@ -85,17 +92,32 @@ class OpenAICompatibleLLM(VannaBase):
)
return df
except Exception as e:
conn.rollback()
self.conn.rollback()
logger.error(f"Failed to execute sql query: {e}")
raise e
return None
def reconnect():
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"reconnect failed: {e}")
def is_connection_alive(conn) -> bool:
if conn is None:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM DUAL")
cursor.close()
return True
except Exception as e:
return False
self.run_sql_is_set = True
self.run_sql = run_sql_damengsql
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
@@ -182,7 +204,7 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
@@ -194,15 +216,19 @@ class OpenAICompatibleLLM(VannaBase):
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
history = None
if user_id and cache:
history = cache.get(id=user_id, field="data")
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
retrieved_examples_data=question_sql_list,
history=history,retrieved_examples_data=question_sql_list,
data_training=question_sql_list,)
logger.info(f"sys_temp:{sys_temp}")
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
logger.info(f"user_temp:{user_temp}")
logger.info(f"sys_temp:{sys_temp}")
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
logger.info(f"llm_response:{llm_response}")
@@ -291,3 +317,87 @@ class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
class TTLCacheWrapper:
"为MemoryCache()添加带ttl的包装器防治内存泄漏"
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
self.cache = cache or MemoryCache()
self.ttl = ttl
self._expiry_times = {}
self._cleanup_thread = None
self._start_cleanup()
def _start_cleanup(self):
"""启动后台清理线程"""
def cleanup():
while True:
current_time = time.time()
expired_keys = []
# 找出所有过期的key
for key_info, expiry in self._expiry_times.items():
if expiry <= current_time:
expired_keys.append(key_info)
# 清理过期数据
for key_info in expired_keys:
id, field = key_info
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
time.sleep(180) # 每3分钟清理一次
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
self._cleanup_thread.start()
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
"""设置缓存值支持TTL"""
# 使用提供的TTL或默认TTL
actual_ttl = ttl if ttl is not None else self.ttl
# 调用原始cache的set方法
self.cache.set(id=id, field=field, value=value)
# 记录过期时间
key_info = (id, field)
self._expiry_times[key_info] = time.time() + actual_ttl
def get(self, id: str, field: str) -> Any:
"""获取缓存值,自动处理过期"""
key_info = (id, field)
# 检查是否过期
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期删除并返回None
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
return None
# 返回缓存值
return self.cache.get(id=id, field=field)
def delete(self, id: str, field: str):
"""删除缓存值"""
key_info = (id, field)
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
if key_info in self._expiry_times:
del self._expiry_times[key_info]
def exists(self, id: str, field: str) -> bool:
"""检查缓存是否存在且未过期"""
key_info = (id, field)
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期清理并返回False
self.delete(id=id, field=field)
return False
return self.get(id=id, field=field) is not None
# 代理其他方法到原始cache
def __getattr__(self, name):
return getattr(self.cache, name)

View File

@@ -16,6 +16,7 @@ template:
<retrieved-examples>[RAG核心区] 通过检索与当前问题最相关的历史问答对。**这是最高优先级的s参考**优先从中寻找与用户问题意图或表述最相似的案例来指导你生成SQL。
<sql-examples>通用SQL示例库。当<retrieved-examples>中没有足够参考时可在此处寻找相似的用法、函数模板或Join思路作为补充参考。
<documentation>:数据库或业务相关的补充文档。
<history>:上下文历史,可以通过上下文历史,丰富问题背景
<error-msg>[可选] 上一次生成的SQL执行失败时的错误信息用于修正和优化你的输出。
<background-infos>[可选] 背景信息,如当前提问时间<current-time>等。
用户的提问位于 <user-question> 块内。
@@ -156,6 +157,7 @@ template:
<db-engine>{engine}</db-engine>
<m-schema>{schema}</m-schema>
<documentation>{documentation}</documentation>
<history>{history}</history>
<terminologies>
<terminology>
<words><word国网</word><word>电网</word><word>雅江</word><word>联通</word></words>

View File

@@ -443,12 +443,26 @@ person_ddl_sql = """
],
"relationships": [
{
"from": "ytenant_id",
"to_table": "租户表",
"from": "input_dept",
"to_table": "IUAP_APDOC_BASEDOC.org_orgs",
"to_field": "id",
"type": "foreign_key",
"comment": "关联租户信息"
}
"comment": "关联部门表"
},
{
"from": "internal_dept",
"to_table": "IUAP_APDOC_BASEDOC.org_orgs",
"to_field": "id",
"type": "foreign_key",
"comment": "关联部门表"
},
{
"from": "internal_unit",
"to_table": "IUAP_APDOC_BASEDOC.org_orgs",
"to_field": "id",
"type": "foreign_key",
"comment": "关联部门表"
},
],
"tags": ["人员管理", "人力资源", "审批流程", "基本信息", "工作信息"],
@@ -509,7 +523,7 @@ rule_ddl='''
{
"name": "region",
"type": "VARCHAR(50)",
"comment": "",
"comment": "",
"value":{
"1":"北京",
"2":"成都",
@@ -518,14 +532,14 @@ rule_ddl='''
"5": "林芝"
},
"role": "dimension",
"tags": [ "考勤的位置","非办公区域不要混淆","枚举"]
"tags": [ "考勤的地区位置","非办公区域不要混淆","枚举"]
},
],
"relationships": [
{
"from": "region",
"to_table": "区域配置表",
"to_field": "region_code",
"to_table": "t_yj_person_ac_area",
"to_field": "region",
"type": "foreign_key",
"comment": "关联区域配置信息"
}
@@ -618,3 +632,293 @@ user_status_ddl='''
"tags": ["人员状态", "状态记录", "地区管理", "西藏标识", "每日状态"]
}
'''
user_attendance_ddl = '''
{
"db_name": "YJOA_APPSERVICE_DB",
"table_name": "t_person_attendance_records",
"table_comment": "人员考勤记录表,存储员工的打卡记录、考勤状态和位置信息",
"columns": [
{
"name": "id",
"type": "VARCHAR(200)",
"comment": "主键ID",
"role": "dimension",
"tags": ["主键", "ID标识"]
},
{
"name": "person_name",
"type": "VARCHAR(50)",
"comment": "人员姓名",
"role": "dimension",
"tags": ["人员信息", "姓名"]
},
{
"name": "person_id",
"type": "VARCHAR(200)",
"comment": "人员ID",
"role": "dimension",
"tags": ["人员标识", "关联字段"]
},
{
"name": "phone_number",
"type": "VARCHAR(50)",
"comment": "手机号码",
"role": "dimension",
"tags": ["联系方式", "人员信息"]
},
{
"name": "attendance_time",
"type": "DATETIME",
"comment": "考勤时间",
"role": "dimension",
"tags": ["时间戳", "打卡时间", "关键时间"]
},
{
"name": "attendance_address",
"type": "VARCHAR(200)",
"comment": "考勤地址",
"role": "dimension",
"tags": ["位置信息", "打卡地点"]
},
{
"name": "status",
"type": "INT",
"comment": "状态",
"value": {
"0": "在岗",
"1": "出差",
"2": "休假"
},
"role": "dimension",
"tags": ["状态标识", "人员在岗状态"]
},
{
"name": "original_id",
"type": "VARCHAR(200)",
"comment": "原始ID",
"role": "dimension",
"tags": ["原数据ID"]
},
{
"name": "source",
"type": "VARCHAR(50)",
"comment": "数据来源",
"value": {
"APP": "手机应用",
"DEVICE": "考勤设备",
"SYSTEM": "系统导入"
},
"role": "dimension",
"tags": ["来源系统", "数据渠道"]
},
{
"name": "dr",
"type": "INT",
"comment": "删除标志",
"value": {
"0": "正常",
"1": "已删除"
},
"role": "dimension",
"tags": ["软删除", "数据状态"]
},
{
"name": "create_time",
"type": "DATETIME",
"comment": "创建时间",
"role": "dimension",
"tags": ["时间戳", "记录创建时间"]
},
{
"name": "enter_or_exit",
"type": "INT",
"comment": "进出类型",
"value": {
"0": "",
"1": ""
},
"role": "dimension",
"tags": ["进出标识", "打卡方向"]
},
{
"name": "access_control_point",
"type": "VARCHAR(50)",
"comment": "门禁点",
"role": "dimension",
"tags": ["门禁位置", "打卡设备点"]
},
{
"name": "by_st",
"type": "VARCHAR(20)",
"comment": "上午打卡时间",
"role": "dimension",
"tags": ["时间范围", "开始时间"]
},
{
"name": "by_et",
"type": "VARCHAR(20)",
"comment": "下午打卡时间",
"role": "dimension",
"tags": ["时间范围", "结束时间"]
},
{
"name": "by_st_field",
"type": "VARCHAR(50)",
"comment": "午休前打卡时间",
"role": "dimension",
"tags": ["中间打卡","时间配置"]
},
{
"name": "by_et_field",
"type": "VARCHAR(50)",
"comment": "午休后打卡时间",
"role": "dimension",
"tags": ["中间打卡", "时间配置"]
},
{
"name": "by_go_type",
"type": "VARCHAR(8)",
"comment": "打卡类型",
"role": "dimension",
"tags": ["类型标识", "打卡类型"]
}
],
"relationships": [
{
"from": "person_id",
"to_table": "t_pr3rl2oj_yj_person_database",
"to_field": "code",
"type": "foreign_key",
"comment": "关联人员基本信息"
},
{
"from": "access_control_point",
"to_table": "t_yj_person_ac_position",
"to_field": "ac_point",
"type": "foreign_key",
"comment": "关联门禁点配置信息"
}
],
"tags": ["考勤记录", "打卡数据", "人员考勤", "时间记录", "位置信息", "门禁系统"]
}
'''
person_ac_position = '''
{
"db_name":"YJOA_APPSERVICE_DB",
"table_name": "t_yj_person_ac_position",
"table_comment": "门禁控制点位置记录",
"columns": [
{
"name": "ac_point",
"type": "VARCHAR(50)",
"comment": "门禁点",
"role": "dimension",
"tags": ["门禁点", "门禁点标识"]
},
{
"name": "position",
"type": "VARCHAR(50)",
"comment": "位置编号",
"role": "dimension",
"tags": ["门禁位置"]
},
],
"relationships": [
{
"from": "ac_point",
"to_table": "t_yj_person_ac_area",
"to_field": "ac_point",
"type": "foreign_key",
"comment": "关联门禁区域关系表"
},
],
"tags": ["门禁控制点","门禁位置"]
}
'''
person_ac_area = '''
{
"db_name":"YJOA_APPSERVICE_DB",
"table_name": "t_yj_person_ac_area",
"table_comment": "门禁区域关系表",
"columns": [
{
"name": "ac_point",
"type": "VARCHAR(50)",
"comment": "门禁点",
"role": "dimension",
"tags": ["门禁点", "门禁点标识"]
},
{
"name": "area",
"type": "Int",
"comment": "区域位置",
"role": "dimension",
"tags": ["门禁所属区域"]
},
{
"name": "region",
"type": "Int",
"comment": "地区位置",
"value":{
"1":"北京",
"2":"成都",
"3":"秭归",
"4":"林芝市区",
"5":"拉萨",
"6":"米林",
"7":"派镇",
"8":"墨脱",
},
"role": "dimension",
"tags": ["门禁所属地区"]
},
],
"tags": ["门禁详情","门禁区域位置","门禁地区信息"]
}
'''
org_orgs_ddl = '''
{
"db_name":"IUAP_APDOC_BASEDOC",
"table_name": "org_orgs",
"table_comment": "人员状态记录表,记录人员每日考勤状态信息包括西藏地区标识",
"columns": [
{
"name": "id",
"type": "VARCHAR(36)",
"comment": "主键ID",
"role": "dimension",
"tags": ["主键", "id标识"]
},
{
"name": "code",
"type": "VARCHAR(50)",
"comment": "编号",
"role": "dimension",
"tags": ["部门编号"]
},
{
"name": "name",
"type": "VARCHAR(50)",
"comment": "部门名称",
"role": "dimension",
"tags": ["部门名称""单位名称"]
},
{
"name": "shortname",
"type": "VARCHAR(1152)",
"comment": "部门简称",
"role": "dimension",
"tags": ["部门名称","部门简称","部门缩写"]
},
],
"tags": ["部门id","部门信息","部门名称"]
}
'''