为memory添加ttl,启动定时清理
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
|||||||
|
|
||||||
import util.utils
|
import util.utils
|
||||||
from logging_config import LOGGING_CONFIG
|
from logging_config import LOGGING_CONFIG
|
||||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||||
from decouple import config
|
from decouple import config
|
||||||
import flask
|
import flask
|
||||||
from util import load_ddl_doc
|
from util import load_ddl_doc
|
||||||
@@ -67,6 +67,7 @@ def init_vn(vn):
|
|||||||
from vanna.flask import VannaFlaskApp
|
from vanna.flask import VannaFlaskApp
|
||||||
vn = create_vana()
|
vn = create_vana()
|
||||||
app = VannaFlaskApp(vn,chart=False)
|
app = VannaFlaskApp(vn,chart=False)
|
||||||
|
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int))
|
||||||
init_vn(vn)
|
init_vn(vn)
|
||||||
cache = app.cache
|
cache = app.cache
|
||||||
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
|
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
|
||||||
@@ -119,7 +120,7 @@ def generate_sql_2():
|
|||||||
|
|
||||||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
||||||
@app.requires_cache(["sql"])
|
@app.requires_cache(["sql"])
|
||||||
def run_sql_2(id: str, sql: str):
|
def run_sql_2(id: str, sql: str, page_num=None, page_size=None):
|
||||||
"""
|
"""
|
||||||
Run SQL
|
Run SQL
|
||||||
---
|
---
|
||||||
@@ -155,6 +156,9 @@ def run_sql_2(id: str, sql: str):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
|
||||||
|
# df_count = vn.run_sql(count_sql)
|
||||||
|
# total_count = df_count[0]["total_count"] if df_count is not None else 0
|
||||||
df = vn.run_sql(sql=sql)
|
df = vn.run_sql(sql=sql)
|
||||||
logger.info("")
|
logger.info("")
|
||||||
app.cache.set(id=id, field="df", value=df)
|
app.cache.set(id=id, field="df", value=df)
|
||||||
@@ -162,6 +166,7 @@ def run_sql_2(id: str, sql: str):
|
|||||||
logger.info("df ---------------{0} {1}".format(result,type(result)))
|
logger.info("df ---------------{0} {1}".format(result,type(result)))
|
||||||
# result = util.utils.deal_result(data=result)
|
# result = util.utils.deal_result(data=result)
|
||||||
|
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"type": "success",
|
"type": "success",
|
||||||
@@ -174,6 +179,8 @@ def run_sql_2(id: str, sql: str):
|
|||||||
logger.error(f"run sql failed:{e}")
|
logger.error(f"run sql failed:{e}")
|
||||||
return jsonify({"type": "sql_error", "error": str(e)})
|
return jsonify({"type": "sql_error", "error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
app.run(host='0.0.0.0', port=8084, debug=False)
|
app.run(host='0.0.0.0', port=8084, debug=False)
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from email.policy import default
|
from email.policy import default
|
||||||
from typing import List, Union
|
from typing import List, Union, Any, Optional
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from vanna.flask import Cache, MemoryCache
|
||||||
import dmPython
|
import dmPython
|
||||||
import orjson
|
import orjson
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -65,17 +68,20 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
port: int = None,
|
port: int = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
conn = None
|
self.conn = None
|
||||||
try:
|
try:
|
||||||
conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to connect to dameng database: {e}")
|
raise Exception(f"Failed to connect to dameng database: {e}")
|
||||||
|
|
||||||
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
|
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
|
||||||
if conn:
|
logger.info(f"start to run_sql_damengsql")
|
||||||
|
if not is_connection_alive(conn=self.conn):
|
||||||
|
logger.info("connection is not alive, reconnecting..........")
|
||||||
|
reconnect()
|
||||||
try:
|
try:
|
||||||
# conn.ping(reconnect=True)
|
# conn.ping(reconnect=True)
|
||||||
cs = conn.cursor()
|
cs = self.conn.cursor()
|
||||||
cs.execute(sql)
|
cs.execute(sql)
|
||||||
results = cs.fetchall()
|
results = cs.fetchall()
|
||||||
|
|
||||||
@@ -85,17 +91,32 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
conn.rollback()
|
self.conn.rollback()
|
||||||
|
logger.error(f"Failed to execute sql query: {e}")
|
||||||
raise e
|
raise e
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def reconnect():
|
||||||
|
try:
|
||||||
|
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"reconnect failed: {e}")
|
||||||
|
def is_connection_alive(conn) -> bool:
|
||||||
|
if conn is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT 1 FROM DUAL")
|
||||||
|
cursor.close()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
self.run_sql_is_set = True
|
self.run_sql_is_set = True
|
||||||
self.run_sql = run_sql_damengsql
|
self.run_sql = run_sql_damengsql
|
||||||
|
|
||||||
|
|
||||||
def user_message(self, message: str) -> any:
|
def user_message(self, message: str) -> any:
|
||||||
return {"role": "user", "content": message}
|
return {"role": "user", "content": message}
|
||||||
|
|
||||||
@@ -291,3 +312,87 @@ class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
|
|||||||
def __init__(self, llm_config=None, vector_store_config=None):
|
def __init__(self, llm_config=None, vector_store_config=None):
|
||||||
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
|
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
|
||||||
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
|
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
|
||||||
|
|
||||||
|
class TTLCacheWrapper:
|
||||||
|
"为MemoryCache()添加带ttl的包装器,防治内存泄漏"
|
||||||
|
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
|
||||||
|
self.cache = cache or MemoryCache()
|
||||||
|
self.ttl = ttl
|
||||||
|
self._expiry_times = {}
|
||||||
|
self._cleanup_thread = None
|
||||||
|
self._start_cleanup()
|
||||||
|
|
||||||
|
def _start_cleanup(self):
|
||||||
|
"""启动后台清理线程"""
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
while True:
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = []
|
||||||
|
|
||||||
|
# 找出所有过期的key
|
||||||
|
for key_info, expiry in self._expiry_times.items():
|
||||||
|
if expiry <= current_time:
|
||||||
|
expired_keys.append(key_info)
|
||||||
|
|
||||||
|
# 清理过期数据
|
||||||
|
for key_info in expired_keys:
|
||||||
|
id, field = key_info
|
||||||
|
if hasattr(self.cache, 'delete'):
|
||||||
|
self.cache.delete(id=id)
|
||||||
|
del self._expiry_times[key_info]
|
||||||
|
|
||||||
|
time.sleep(180) # 每3分钟清理一次
|
||||||
|
|
||||||
|
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
|
||||||
|
self._cleanup_thread.start()
|
||||||
|
|
||||||
|
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
|
||||||
|
"""设置缓存值,支持TTL"""
|
||||||
|
# 使用提供的TTL或默认TTL
|
||||||
|
actual_ttl = ttl if ttl is not None else self.ttl
|
||||||
|
|
||||||
|
# 调用原始cache的set方法
|
||||||
|
self.cache.set(id=id, field=field, value=value)
|
||||||
|
|
||||||
|
# 记录过期时间
|
||||||
|
key_info = (id, field)
|
||||||
|
self._expiry_times[key_info] = time.time() + actual_ttl
|
||||||
|
|
||||||
|
def get(self, id: str, field: str) -> Any:
|
||||||
|
"""获取缓存值,自动处理过期"""
|
||||||
|
key_info = (id, field)
|
||||||
|
|
||||||
|
# 检查是否过期
|
||||||
|
if key_info in self._expiry_times:
|
||||||
|
if time.time() > self._expiry_times[key_info]:
|
||||||
|
# 已过期,删除并返回None
|
||||||
|
if hasattr(self.cache, 'delete'):
|
||||||
|
self.cache.delete(id=id)
|
||||||
|
del self._expiry_times[key_info]
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 返回缓存值
|
||||||
|
return self.cache.get(id=id, field=field)
|
||||||
|
|
||||||
|
def delete(self, id: str, field: str):
|
||||||
|
"""删除缓存值"""
|
||||||
|
key_info = (id, field)
|
||||||
|
if hasattr(self.cache, 'delete'):
|
||||||
|
self.cache.delete(id=id)
|
||||||
|
if key_info in self._expiry_times:
|
||||||
|
del self._expiry_times[key_info]
|
||||||
|
|
||||||
|
def exists(self, id: str, field: str) -> bool:
|
||||||
|
"""检查缓存是否存在且未过期"""
|
||||||
|
key_info = (id, field)
|
||||||
|
if key_info in self._expiry_times:
|
||||||
|
if time.time() > self._expiry_times[key_info]:
|
||||||
|
# 已过期,清理并返回False
|
||||||
|
self.delete(id=id, field=field)
|
||||||
|
return False
|
||||||
|
return self.get(id=id, field=field) is not None
|
||||||
|
|
||||||
|
# 代理其他方法到原始cache
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.cache, name)
|
||||||
Reference in New Issue
Block a user