diff --git a/main_service.py b/main_service.py index 7e014bd..4414ae3 100644 --- a/main_service.py +++ b/main_service.py @@ -3,7 +3,7 @@ import logging import util.utils from logging_config import LOGGING_CONFIG -from service.cus_vanna_srevice import CustomVanna, QdrantClient +from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper from decouple import config import flask from util import load_ddl_doc @@ -67,6 +67,7 @@ def init_vn(vn): from vanna.flask import VannaFlaskApp vn = create_vana() app = VannaFlaskApp(vn,chart=False) +app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int)) init_vn(vn) cache = app.cache @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) @@ -119,7 +120,7 @@ def generate_sql_2(): @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) @app.requires_cache(["sql"]) -def run_sql_2(id: str, sql: str): +def run_sql_2(id: str, sql: str, page_num=None, page_size=None): """ Run SQL --- @@ -155,6 +156,9 @@ def run_sql_2(id: str, sql: str): } ) + # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery" + # df_count = vn.run_sql(count_sql) + # total_count = df_count[0]["total_count"] if df_count is not None else 0 df = vn.run_sql(sql=sql) logger.info("") app.cache.set(id=id, field="df", value=df) @@ -162,6 +166,7 @@ def run_sql_2(id: str, sql: str): logger.info("df ---------------{0} {1}".format(result,type(result))) # result = util.utils.deal_result(data=result) + return jsonify( { "type": "success", @@ -174,6 +179,8 @@ def run_sql_2(id: str, sql: str): logger.error(f"run sql failed:{e}") return jsonify({"type": "sql_error", "error": str(e)}) + + if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False) diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py index f4da618..a99963e 100644 --- a/service/cus_vanna_srevice.py +++ b/service/cus_vanna_srevice.py @@ -1,5 +1,8 @@ from email.policy import default -from typing import List, Union +from typing import List, Union, Any, Optional +import time +import threading +from vanna.flask import Cache, MemoryCache import dmPython import orjson import pandas as pd @@ -65,37 +68,55 @@ class OpenAICompatibleLLM(VannaBase): port: int = None, **kwargs ): - conn = None + self.conn = None try: - conn = dmPython.connect(user=user, password=password, server=host, port=port) + self.conn = dmPython.connect(user=user, password=password, server=host, port=port) except Exception as e: raise Exception(f"Failed to connect to dameng database: {e}") def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - # conn.ping(reconnect=True) - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + logger.info(f"start to run_sql_damengsql") + if not is_connection_alive(conn=self.conn): + logger.info("connection is not alive, reconnecting..........") + reconnect() + try: + # conn.ping(reconnect=True) + cs = self.conn.cursor() + cs.execute(sql) + results = cs.fetchall() - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] - ) + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) - return df - - - - except Exception as e: - conn.rollback() - raise e + return df + except Exception as e: + self.conn.rollback() + logger.error(f"Failed to execute sql query: {e}") + raise e return None + def reconnect(): + try: + self.conn = dmPython.connect(user=user, password=password, server=host, port=port) + except Exception as e: + raise Exception(f"reconnect failed: {e}") + def is_connection_alive(conn) -> bool: + if conn is None: + return False + try: + cursor = conn.cursor() + cursor.execute("SELECT 1 FROM DUAL") + cursor.close() + return True + except Exception as e: + return False + self.run_sql_is_set = True self.run_sql = run_sql_damengsql + def user_message(self, message: str) -> any: return {"role": "user", "content": message} @@ -290,4 +311,88 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore): class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): def __init__(self, llm_config=None, vector_store_config=None): CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) - OpenAICompatibleLLM.__init__(self, config_file=llm_config) \ No newline at end of file + OpenAICompatibleLLM.__init__(self, config_file=llm_config) + +class TTLCacheWrapper: + "为MemoryCache()添加带ttl的包装器,防治内存泄漏" + def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600): + self.cache = cache or MemoryCache() + self.ttl = ttl + self._expiry_times = {} + self._cleanup_thread = None + self._start_cleanup() + + def _start_cleanup(self): + """启动后台清理线程""" + + def cleanup(): + while True: + current_time = time.time() + expired_keys = [] + + # 找出所有过期的key + for key_info, expiry in self._expiry_times.items(): + if expiry <= current_time: + expired_keys.append(key_info) + + # 清理过期数据 + for key_info in expired_keys: + id, field = key_info + if hasattr(self.cache, 'delete'): + self.cache.delete(id=id) + del self._expiry_times[key_info] + + time.sleep(180) # 每3分钟清理一次 + + self._cleanup_thread = threading.Thread(target=cleanup, daemon=True) + self._cleanup_thread.start() + + def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None): + """设置缓存值,支持TTL""" + # 使用提供的TTL或默认TTL + actual_ttl = ttl if ttl is not None else self.ttl + + # 调用原始cache的set方法 + self.cache.set(id=id, field=field, value=value) + + # 记录过期时间 + key_info = (id, field) + self._expiry_times[key_info] = time.time() + actual_ttl + + def get(self, id: str, field: str) -> Any: + """获取缓存值,自动处理过期""" + key_info = (id, field) + + # 检查是否过期 + if key_info in self._expiry_times: + if time.time() > self._expiry_times[key_info]: + # 已过期,删除并返回None + if hasattr(self.cache, 'delete'): + self.cache.delete(id=id) + del self._expiry_times[key_info] + return None + + # 返回缓存值 + return self.cache.get(id=id, field=field) + + def delete(self, id: str, field: str): + """删除缓存值""" + key_info = (id, field) + if hasattr(self.cache, 'delete'): + self.cache.delete(id=id) + if key_info in self._expiry_times: + del self._expiry_times[key_info] + + def exists(self, id: str, field: str) -> bool: + """检查缓存是否存在且未过期""" + key_info = (id, field) + if key_info in self._expiry_times: + if time.time() > self._expiry_times[key_info]: + # 已过期,清理并返回False + self.delete(id=id, field=field) + return False + return self.get(id=id, field=field) is not None + + # 代理其他方法到原始cache + def __getattr__(self, name): + return getattr(self.cache, name) \ No newline at end of file