为memory添加ttl,启动定时清理
This commit is contained in:
		| @@ -3,7 +3,7 @@ import logging | ||||
|  | ||||
| import util.utils | ||||
| from logging_config import LOGGING_CONFIG | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper | ||||
| from decouple import config | ||||
| import flask | ||||
| from util import load_ddl_doc | ||||
| @@ -67,6 +67,7 @@ def init_vn(vn): | ||||
| from vanna.flask import VannaFlaskApp | ||||
| vn = create_vana() | ||||
| app = VannaFlaskApp(vn,chart=False) | ||||
| app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int)) | ||||
| init_vn(vn) | ||||
| cache = app.cache | ||||
| @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) | ||||
| @@ -119,7 +120,7 @@ def generate_sql_2(): | ||||
|  | ||||
| @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) | ||||
| @app.requires_cache(["sql"]) | ||||
| def run_sql_2(id: str, sql: str): | ||||
| def run_sql_2(id: str, sql: str, page_num=None, page_size=None): | ||||
|     """ | ||||
|     Run SQL | ||||
|     --- | ||||
| @@ -155,6 +156,9 @@ def run_sql_2(id: str, sql: str): | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery" | ||||
|         # df_count = vn.run_sql(count_sql) | ||||
|         # total_count = df_count[0]["total_count"] if df_count is not None else 0 | ||||
|         df = vn.run_sql(sql=sql) | ||||
|         logger.info("") | ||||
|         app.cache.set(id=id, field="df", value=df) | ||||
| @@ -162,6 +166,7 @@ def run_sql_2(id: str, sql: str): | ||||
|         logger.info("df ---------------{0}   {1}".format(result,type(result))) | ||||
|         # result = util.utils.deal_result(data=result) | ||||
|  | ||||
|  | ||||
|         return jsonify( | ||||
|             { | ||||
|                 "type": "success", | ||||
| @@ -174,6 +179,8 @@ def run_sql_2(id: str, sql: str): | ||||
|         logger.error(f"run sql failed:{e}") | ||||
|         return jsonify({"type": "sql_error", "error": str(e)}) | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     app.run(host='0.0.0.0', port=8084, debug=False) | ||||
|   | ||||
| @@ -1,5 +1,8 @@ | ||||
| from email.policy import default | ||||
| from typing import List, Union | ||||
| from typing import List, Union, Any, Optional | ||||
| import time | ||||
| import threading | ||||
| from vanna.flask import Cache, MemoryCache | ||||
| import  dmPython | ||||
| import orjson | ||||
| import pandas as pd | ||||
| @@ -65,37 +68,55 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             port: int = None, | ||||
|             **kwargs | ||||
|     ): | ||||
|         conn = None | ||||
|         self.conn = None | ||||
|         try: | ||||
|             conn = dmPython.connect(user=user, password=password, server=host, port=port) | ||||
|             self.conn = dmPython.connect(user=user, password=password, server=host, port=port) | ||||
|         except Exception as e: | ||||
|             raise Exception(f"Failed to connect to dameng database: {e}") | ||||
|  | ||||
|         def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: | ||||
|             if conn: | ||||
|                 try: | ||||
|                     # conn.ping(reconnect=True) | ||||
|                     cs = conn.cursor() | ||||
|                     cs.execute(sql) | ||||
|                     results = cs.fetchall() | ||||
|             logger.info(f"start to run_sql_damengsql") | ||||
|             if not is_connection_alive(conn=self.conn): | ||||
|                 logger.info("connection is not alive, reconnecting..........") | ||||
|                 reconnect() | ||||
|             try: | ||||
|                 # conn.ping(reconnect=True) | ||||
|                 cs = self.conn.cursor() | ||||
|                 cs.execute(sql) | ||||
|                 results = cs.fetchall() | ||||
|  | ||||
|                     # Create a pandas dataframe from the results | ||||
|                     df = pd.DataFrame( | ||||
|                         results, columns=[desc[0] for desc in cs.description] | ||||
|                     ) | ||||
|                 # Create a pandas dataframe from the results | ||||
|                 df = pd.DataFrame( | ||||
|                     results, columns=[desc[0] for desc in cs.description] | ||||
|                 ) | ||||
|  | ||||
|                     return df | ||||
|  | ||||
|  | ||||
|  | ||||
|                 except Exception as e: | ||||
|                     conn.rollback() | ||||
|                     raise e | ||||
|                 return df | ||||
|             except Exception as e: | ||||
|                 self.conn.rollback() | ||||
|                 logger.error(f"Failed to execute sql query: {e}") | ||||
|                 raise e | ||||
|             return None | ||||
|  | ||||
|         def reconnect(): | ||||
|             try: | ||||
|                 self.conn = dmPython.connect(user=user, password=password, server=host, port=port) | ||||
|             except Exception as e: | ||||
|                 raise Exception(f"reconnect failed: {e}") | ||||
|         def is_connection_alive(conn) -> bool: | ||||
|             if conn is None: | ||||
|                 return False | ||||
|             try: | ||||
|                 cursor = conn.cursor() | ||||
|                 cursor.execute("SELECT 1 FROM DUAL") | ||||
|                 cursor.close() | ||||
|                 return True | ||||
|             except Exception as e: | ||||
|                 return False | ||||
|  | ||||
|         self.run_sql_is_set = True | ||||
|         self.run_sql = run_sql_damengsql | ||||
|  | ||||
|  | ||||
|     def user_message(self, message: str) -> any: | ||||
|         return {"role": "user", "content": message} | ||||
|  | ||||
| @@ -291,3 +312,87 @@ class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): | ||||
|     def __init__(self, llm_config=None, vector_store_config=None): | ||||
|         CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) | ||||
|         OpenAICompatibleLLM.__init__(self, config_file=llm_config) | ||||
|  | ||||
| class TTLCacheWrapper: | ||||
|     "为MemoryCache()添加带ttl的包装器,防治内存泄漏" | ||||
|     def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600): | ||||
|         self.cache = cache or MemoryCache() | ||||
|         self.ttl = ttl | ||||
|         self._expiry_times = {} | ||||
|         self._cleanup_thread = None | ||||
|         self._start_cleanup() | ||||
|  | ||||
|     def _start_cleanup(self): | ||||
|         """启动后台清理线程""" | ||||
|  | ||||
|         def cleanup(): | ||||
|             while True: | ||||
|                 current_time = time.time() | ||||
|                 expired_keys = [] | ||||
|  | ||||
|                 # 找出所有过期的key | ||||
|                 for key_info, expiry in self._expiry_times.items(): | ||||
|                     if expiry <= current_time: | ||||
|                         expired_keys.append(key_info) | ||||
|  | ||||
|                 # 清理过期数据 | ||||
|                 for key_info in expired_keys: | ||||
|                     id, field = key_info | ||||
|                     if hasattr(self.cache, 'delete'): | ||||
|                         self.cache.delete(id=id) | ||||
|                     del self._expiry_times[key_info] | ||||
|  | ||||
|                 time.sleep(180)  # 每3分钟清理一次 | ||||
|  | ||||
|         self._cleanup_thread = threading.Thread(target=cleanup, daemon=True) | ||||
|         self._cleanup_thread.start() | ||||
|  | ||||
|     def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None): | ||||
|         """设置缓存值,支持TTL""" | ||||
|         # 使用提供的TTL或默认TTL | ||||
|         actual_ttl = ttl if ttl is not None else self.ttl | ||||
|  | ||||
|         # 调用原始cache的set方法 | ||||
|         self.cache.set(id=id, field=field, value=value) | ||||
|  | ||||
|         # 记录过期时间 | ||||
|         key_info = (id, field) | ||||
|         self._expiry_times[key_info] = time.time() + actual_ttl | ||||
|  | ||||
|     def get(self, id: str, field: str) -> Any: | ||||
|         """获取缓存值,自动处理过期""" | ||||
|         key_info = (id, field) | ||||
|  | ||||
|         # 检查是否过期 | ||||
|         if key_info in self._expiry_times: | ||||
|             if time.time() > self._expiry_times[key_info]: | ||||
|                 # 已过期,删除并返回None | ||||
|                 if hasattr(self.cache, 'delete'): | ||||
|                     self.cache.delete(id=id) | ||||
|                 del self._expiry_times[key_info] | ||||
|                 return None | ||||
|  | ||||
|         # 返回缓存值 | ||||
|         return self.cache.get(id=id, field=field) | ||||
|  | ||||
|     def delete(self, id: str, field: str): | ||||
|         """删除缓存值""" | ||||
|         key_info = (id, field) | ||||
|         if hasattr(self.cache, 'delete'): | ||||
|             self.cache.delete(id=id) | ||||
|         if key_info in self._expiry_times: | ||||
|             del self._expiry_times[key_info] | ||||
|  | ||||
|     def exists(self, id: str, field: str) -> bool: | ||||
|         """检查缓存是否存在且未过期""" | ||||
|         key_info = (id, field) | ||||
|         if key_info in self._expiry_times: | ||||
|             if time.time() > self._expiry_times[key_info]: | ||||
|                 # 已过期,清理并返回False | ||||
|                 self.delete(id=id, field=field) | ||||
|                 return False | ||||
|         return self.get(id=id, field=field) is not None | ||||
|  | ||||
|     # 代理其他方法到原始cache | ||||
|     def __getattr__(self, name): | ||||
|         return getattr(self.cache, name) | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128