from typing import List, Union, Any, Optional import time import threading from vanna.flask import Cache, MemoryCache import dmPython import orjson import pandas as pd from vanna.base import VannaBase from vanna.flask import MemoryCache from vanna.qdrant import Qdrant_VectorStore from qdrant_client import QdrantClient from openai import OpenAI import requests from decouple import config from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer import json from template.template import get_base_template from datetime import datetime import logging from util import train_ddl logger = logging.getLogger(__name__) import traceback class OpenAICompatibleLLM(VannaBase): def __init__(self, client=None, config_file=None): VannaBase.__init__(self, config=config_file) # default parameters - can be overrided using config self.temperature = 0.6 self.max_tokens = 10000 if "temperature" in config_file: self.temperature = config_file["temperature"] if "max_tokens" in config_file: self.max_tokens = config_file["max_tokens"] if "api_type" in config_file: raise Exception( "Passing api_type is now deprecated. Please pass an OpenAI client instead." ) if "api_version" in config_file: raise Exception( "Passing api_version is now deprecated. Please pass an OpenAI client instead." ) if client is not None: self.client = client return if "api_base" not in config_file: raise Exception("Please passing api_base") if "api_key" not in config_file: raise Exception("Please passing api_key") self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"]) def system_message(self, message: str) -> any: return {"role": "system", "content": message} def connect_to_dameng( self, host: str = None, dbname: str = None, user: str = None, password: str = None, port: int = None, **kwargs ): self.conn = None try: self.conn = dmPython.connect(user=user, password=password, server=host, port=port) except Exception as e: raise Exception(f"Failed to connect to dameng database: {e}") def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: logger.info(f"start to run_sql_damengsql") try: if not is_connection_alive(conn=self.conn): logger.info("connection is not alive, reconnecting..........") reconnect() # conn.ping(reconnect=True) cs = self.conn.cursor() cs.execute(sql) results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame( results, columns=[desc[0] for desc in cs.description] ) return df except Exception as e: self.conn.rollback() logger.error(f"Failed to execute sql query: {e}") raise e return None def reconnect(): try: self.conn = dmPython.connect(user=user, password=password, server=host, port=port) except Exception as e: raise Exception(f"reconnect failed: {e}") def is_connection_alive(conn) -> bool: if conn is None: return False try: cursor = conn.cursor() cursor.execute("SELECT 1 FROM DUAL") cursor.close() return True except Exception as e: return False self.run_sql_is_set = True self.run_sql = run_sql_damengsql def user_message(self, message: str) -> any: return {"role": "user", "content": message} def assistant_message(self, message: str) -> any: return {"role": "assistant", "content": message} def submit_prompt(self, prompt, **kwargs) -> str: if prompt is None: print("test1") raise Exception("Prompt is None") if len(prompt) == 0: print("test2") raise Exception("Prompt is empty") print(prompt) num_tokens = 0 for message in prompt: num_tokens += len(message["content"]) / 4 print("test3 {0}".format(num_tokens)) if kwargs.get("model", None) is not None: print("test4") model = kwargs.get("model", None) print( f"Using model {model} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif kwargs.get("engine", None) is not None: print("test5") engine = kwargs.get("engine", None) print( f"Using model {engine} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=engine, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "engine" in self.config: print("test6") print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=self.config["engine"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "model" in self.config: print("test7") print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) print("config is ",self.config) response = self.client.chat.completions.create( model=self.config["model"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) # data = { # "model": self.model, # "prompt": prompt, # "max_tokens": self.max_tokens, # "temperature": self.temperature, # } # response = requests.post( # url=f"{self.api_base}/completions", # headers=self.headers, # json=data # ) else: print("test8") if num_tokens > 3500: model = "kimi" else: model = "doubao" print(f"5.Using model {model} for {num_tokens} tokens (approx)") response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) for choice in response.choices: if "text" in choice: return choice.text return response.choices[0].message.content def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict: try: logger.info("Start to generate_sql_2 in cus_vanna_srevice") question_sql_list = self.get_similar_question_sql(question, **kwargs) if question_sql_list and len(question_sql_list)>2: question_sql_list=question_sql_list[:2] ddl_list = self.get_related_ddl(question, **kwargs) #doc_list = self.get_related_documentation(question, **kwargs) template = get_base_template() sql_temp = template['template']['sql'] char_temp = template['template']['chart'] history = None if user_id and cache: history = cache.get(id=user_id, field="data") # --------基于提示词,生成sql以及图表类型 sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', schema=ddl_list, documentation=[train_ddl.train_document], history=history,retrieved_examples_data=question_sql_list, data_training=question_sql_list,) user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logger.info(f"user_temp:{user_temp}") logger.info(f"sys_temp:{sys_temp}") llm_response = self.submit_prompt( [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) llm_response = str(llm_response.strip()) logger.info(f"llm_response:{llm_response}") #优化中....... result = extract_nested_json(llm_response) logger.info(f"result:{result}") result = {"resp": orjson.loads(result)} logger.info(f"llm_response:{llm_response}") sql = check_and_get_sql(llm_response) logger.info(f"sql:{sql}") # ---------------生成图表 char_type = get_chart_type_from_sql_answer(llm_response) logger.info(f"chart type:{char_type}") if char_type: sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', sql=sql, chart_type=char_type) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) llm_response2 = self.submit_prompt( [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) print(llm_response2) result['chart'] = orjson.loads(extract_nested_json(llm_response2)) logger.info(f"chart_response:{result}") logger.info("Finish to generate_sql_2 in cus_vanna_srevice") return result except Exception as e: logger.info("cus_vanna_srevice failed-------------------: ") traceback.print_exc() raise e def run_sql_2(self,sql): try: return self.run_sql(sql) except Exception as e: logger.error("run_sql failed {0}".format(sql)) raise e def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: logger.info(f"generate_rewritten_question---------------{new_question}") if last_question is None: return new_question print("last question {0}".format(last_question)) print("new question {0}".format(new_question)) prompt = [ self.system_message( "你的目标是将一系列相关问题合并成一个单一的问题。" "合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。" "合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。" "合并准则三、理论上,合并后的问题应该能够通过单个SQL语句来回答"), self.user_message("First question: " + last_question + "\nSecond question: " + new_question), ] return self.submit_prompt(prompt=prompt, **kwargs) class CustomQdrant_VectorStore(Qdrant_VectorStore): def __init__( self, config_file={} ): self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='') self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='') self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='') super().__init__(config_file) def generate_embedding(self, data: str, **kwargs) -> List[float]: def _get_error_string(response: requests.Response) -> str: try: if response.content: return response.json()["detail"] except Exception: pass try: response.raise_for_status() except requests.HTTPError as e: return str(e) return "Unknown error" request_body = { "model": self.embedding_model_name, "sentences": [data], } # request_body = { # "model": self.embedding_model_name, # "encoding_format": "float", # "input": [data], # } request_body.update(kwargs) response = requests.post( url=f"{self.embedding_api_base}", json=request_body, headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'}, ) if response.status_code != 200: raise RuntimeError( f"Failed to create the embeddings, detail: {_get_error_string(response)}" ) result = response.json() embeddings = result['embeddings'] # embeddings = result['data'][0]['embedding'] return embeddings[0] # return embeddings class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): def __init__(self, llm_config=None, vector_store_config=None): CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) OpenAICompatibleLLM.__init__(self, config_file=llm_config) class TTLCacheWrapper: "为MemoryCache()添加带ttl的包装器,防治内存泄漏" def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600): self.cache = cache or MemoryCache() self.ttl = ttl self._expiry_times = {} self._cleanup_thread = None self._start_cleanup() def _start_cleanup(self): """启动后台清理线程""" def cleanup(): while True: current_time = time.time() expired_keys = [] # 找出所有过期的key for key_info, expiry in self._expiry_times.items(): if expiry <= current_time: expired_keys.append(key_info) # 清理过期数据 for key_info in expired_keys: id, field = key_info if hasattr(self.cache, 'delete'): self.cache.delete(id=id) del self._expiry_times[key_info] time.sleep(180) # 每3分钟清理一次 self._cleanup_thread = threading.Thread(target=cleanup, daemon=True) self._cleanup_thread.start() def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None): """设置缓存值,支持TTL""" # 使用提供的TTL或默认TTL actual_ttl = ttl if ttl is not None else self.ttl # 调用原始cache的set方法 self.cache.set(id=id, field=field, value=value) # 记录过期时间 key_info = (id, field) self._expiry_times[key_info] = time.time() + actual_ttl def get(self, id: str, field: str) -> Any: """获取缓存值,自动处理过期""" key_info = (id, field) # 检查是否过期 if key_info in self._expiry_times: if time.time() > self._expiry_times[key_info]: # 已过期,删除并返回None if hasattr(self.cache, 'delete'): self.cache.delete(id=id) del self._expiry_times[key_info] return None # 返回缓存值 return self.cache.get(id=id, field=field) def delete(self, id: str, field: str): """删除缓存值""" key_info = (id, field) if hasattr(self.cache, 'delete'): self.cache.delete(id=id) if key_info in self._expiry_times: del self._expiry_times[key_info] def exists(self, id: str, field: str) -> bool: """检查缓存是否存在且未过期""" key_info = (id, field) if key_info in self._expiry_times: if time.time() > self._expiry_times[key_info]: # 已过期,清理并返回False self.delete(id=id, field=field) return False return self.get(id=id, field=field) is not None def items(self): """遍历所有未过期的缓存键值对""" current_time = time.time() items = [] for (id, field), expiry in self._expiry_times.items(): # 检查是否过期 if current_time <= expiry: value = self.get(id=id, field=field) if value is not None: items.append({ 'id': id, 'field': field, 'value': value, 'expires_at': expiry, 'time_left': expiry - current_time }) return items def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None): """ 获取指定ID下时间最近的缓存项 Args: id: 要查询的ID limit: 返回最近几条记录,默认1条 field_filter: 可选的字段过滤,如只获取特定字段 Returns: 按时间倒序排列的缓存项列表 """ current_time = time.time() matched_items = [] # 找出该ID下所有未过期的缓存项 for (cache_id, field), expiry in self._expiry_times.items(): if cache_id == id and current_time <= expiry: # 字段过滤 if field_filter and field != field_filter: continue value = self.get(id=id, field=field) if value is not None: matched_items.append({ 'id': id, 'field': field, 'value': value, 'expires_at': expiry, 'created_time': expiry - self.ttl, # 估算创建时间 'time_left': expiry - current_time }) # 按过期时间倒序排列(最近创建的排在前面) matched_items.sort(key=lambda x: x['expires_at'], reverse=True) return matched_items[:limit] # 代理其他方法到原始cache def __getattr__(self, name): return getattr(self.cache, name)