为memory添加ttl,启动定时清理

This commit is contained in:
yujj128
2025-10-10 16:39:59 +08:00
parent 3285f3bca7
commit 73cbc55d74
2 changed files with 135 additions and 23 deletions

View File

@@ -3,7 +3,7 @@ import logging
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from decouple import config
import flask
from util import load_ddl_doc
@@ -67,6 +67,7 @@ def init_vn(vn):
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn,chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int))
init_vn(vn)
cache = app.cache
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
@@ -119,7 +120,7 @@ def generate_sql_2():
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
def run_sql_2(id: str, sql: str, page_num=None, page_size=None):
"""
Run SQL
---
@@ -155,6 +156,9 @@ def run_sql_2(id: str, sql: str):
}
)
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
# df_count = vn.run_sql(count_sql)
# total_count = df_count[0]["total_count"] if df_count is not None else 0
df = vn.run_sql(sql=sql)
logger.info("")
app.cache.set(id=id, field="df", value=df)
@@ -162,6 +166,7 @@ def run_sql_2(id: str, sql: str):
logger.info("df ---------------{0} {1}".format(result,type(result)))
# result = util.utils.deal_result(data=result)
return jsonify(
{
"type": "success",
@@ -174,6 +179,8 @@ def run_sql_2(id: str, sql: str):
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)

View File

@@ -1,5 +1,8 @@
from email.policy import default
from typing import List, Union
from typing import List, Union, Any, Optional
import time
import threading
from vanna.flask import Cache, MemoryCache
import dmPython
import orjson
import pandas as pd
@@ -65,37 +68,55 @@ class OpenAICompatibleLLM(VannaBase):
port: int = None,
**kwargs
):
conn = None
self.conn = None
try:
conn = dmPython.connect(user=user, password=password, server=host, port=port)
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"Failed to connect to dameng database: {e}")
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
# conn.ping(reconnect=True)
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
logger.info(f"start to run_sql_damengsql")
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
try:
# conn.ping(reconnect=True)
cs = self.conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except Exception as e:
conn.rollback()
raise e
return df
except Exception as e:
self.conn.rollback()
logger.error(f"Failed to execute sql query: {e}")
raise e
return None
def reconnect():
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"reconnect failed: {e}")
def is_connection_alive(conn) -> bool:
if conn is None:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM DUAL")
cursor.close()
return True
except Exception as e:
return False
self.run_sql_is_set = True
self.run_sql = run_sql_damengsql
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
@@ -291,3 +312,87 @@ class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
class TTLCacheWrapper:
"为MemoryCache()添加带ttl的包装器防治内存泄漏"
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
self.cache = cache or MemoryCache()
self.ttl = ttl
self._expiry_times = {}
self._cleanup_thread = None
self._start_cleanup()
def _start_cleanup(self):
"""启动后台清理线程"""
def cleanup():
while True:
current_time = time.time()
expired_keys = []
# 找出所有过期的key
for key_info, expiry in self._expiry_times.items():
if expiry <= current_time:
expired_keys.append(key_info)
# 清理过期数据
for key_info in expired_keys:
id, field = key_info
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
time.sleep(180) # 每3分钟清理一次
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
self._cleanup_thread.start()
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
"""设置缓存值支持TTL"""
# 使用提供的TTL或默认TTL
actual_ttl = ttl if ttl is not None else self.ttl
# 调用原始cache的set方法
self.cache.set(id=id, field=field, value=value)
# 记录过期时间
key_info = (id, field)
self._expiry_times[key_info] = time.time() + actual_ttl
def get(self, id: str, field: str) -> Any:
"""获取缓存值,自动处理过期"""
key_info = (id, field)
# 检查是否过期
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期删除并返回None
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
return None
# 返回缓存值
return self.cache.get(id=id, field=field)
def delete(self, id: str, field: str):
"""删除缓存值"""
key_info = (id, field)
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
if key_info in self._expiry_times:
del self._expiry_times[key_info]
def exists(self, id: str, field: str) -> bool:
"""检查缓存是否存在且未过期"""
key_info = (id, field)
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期清理并返回False
self.delete(id=id, field=field)
return False
return self.get(id=id, field=field) is not None
# 代理其他方法到原始cache
def __getattr__(self, name):
return getattr(self.cache, name)