from typing import List, Union, Any, Optional import time import threading from vanna.flask import Cache, MemoryCache import dmPython import orjson import pandas as pd from vanna.base import VannaBase from vanna.flask import MemoryCache from vanna.qdrant import Qdrant_VectorStore from qdrant_client import QdrantClient from openai import OpenAI import requests from decouple import config from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer import json from template.template import get_base_template from datetime import datetime import logging from util import train_ddl logger = logging.getLogger(__name__) import traceback class OpenAICompatibleLLM(VannaBase): def __init__(self, client=None, config_file=None): VannaBase.__init__(self, config=config_file) # default parameters - can be overrided using config self.temperature = 0.6 self.max_tokens = 5000 if "temperature" in config_file: self.temperature = config_file["temperature"] if "max_tokens" in config_file: self.max_tokens = config_file["max_tokens"] if "api_type" in config_file: raise Exception( "Passing api_type is now deprecated. Please pass an OpenAI client instead." ) if "api_version" in config_file: raise Exception( "Passing api_version is now deprecated. Please pass an OpenAI client instead." ) if client is not None: self.client = client return if "api_base" not in config_file: raise Exception("Please passing api_base") if "api_key" not in config_file: raise Exception("Please passing api_key") self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"]) def system_message(self, message: str) -> any: return {"role": "system", "content": message} def connect_to_dameng( self, host: str = None, dbname: str = None, user: str = None, password: str = None, port: int = None, **kwargs ): self.conn = None try: self.conn = dmPython.connect(user=user, password=password, server=host, port=port) except Exception as e: raise Exception(f"Failed to connect to dameng database: {e}") def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: logger.info(f"start to run_sql_damengsql") try: if not is_connection_alive(conn=self.conn): logger.info("connection is not alive, reconnecting..........") reconnect() # conn.ping(reconnect=True) cs = self.conn.cursor() cs.execute(sql) results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame( results, columns=[desc[0] for desc in cs.description] ) return df except Exception as e: self.conn.rollback() logger.error(f"Failed to execute sql query: {e}") raise e return None def reconnect(): try: self.conn = dmPython.connect(user=user, password=password, server=host, port=port) except Exception as e: raise Exception(f"reconnect failed: {e}") def is_connection_alive(conn) -> bool: if conn is None: return False try: cursor = conn.cursor() cursor.execute("SELECT 1 FROM DUAL") cursor.close() return True except Exception as e: return False self.run_sql_is_set = True self.run_sql = run_sql_damengsql def user_message(self, message: str) -> any: return {"role": "user", "content": message} def assistant_message(self, message: str) -> any: return {"role": "assistant", "content": message} def submit_prompt(self, prompt, **kwargs) -> str: if prompt is None: raise Exception("Prompt is None") if len(prompt) == 0: raise Exception("Prompt is empty") print(prompt) num_tokens = 0 for message in prompt: num_tokens += len(message["content"]) / 4 if kwargs.get("model", None) is not None: model = kwargs.get("model", None) print( f"Using model {model} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif kwargs.get("engine", None) is not None: engine = kwargs.get("engine", None) print( f"Using model {engine} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=engine, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "engine" in self.config: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=self.config["engine"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "model" in self.config: print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) print(self.config) response = self.client.chat.completions.create( model=self.config["model"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) # data = { # "model": self.model, # "prompt": prompt, # "max_tokens": self.max_tokens, # "temperature": self.temperature, # } # response = requests.post( # url=f"{self.api_base}/completions", # headers=self.headers, # json=data # ) else: if num_tokens > 3500: model = "kimi" else: model = "doubao" print(f"5.Using model {model} for {num_tokens} tokens (approx)") response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) for choice in response.choices: if "text" in choice: return choice.text return response.choices[0].message.content def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict: try: logger.info("Start to generate_sql_2 in cus_vanna_srevice") question_sql_list = self.get_similar_question_sql(question, **kwargs) if question_sql_list and len(question_sql_list)>2: question_sql_list=question_sql_list[:2] ddl_list = self.get_related_ddl(question, **kwargs) #doc_list = self.get_related_documentation(question, **kwargs) template = get_base_template() sql_temp = template['template']['sql'] char_temp = template['template']['chart'] history = None if user_id and cache: history = cache.get(id=user_id, field="data") # --------基于提示词,生成sql以及图表类型 sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', schema=ddl_list, documentation=[train_ddl.train_document], history=history,retrieved_examples_data=question_sql_list, data_training=question_sql_list,) user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logger.info(f"user_temp:{user_temp}") logger.info(f"sys_temp:{sys_temp}") llm_response = self.submit_prompt( [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) llm_response = str(llm_response.strip()) logger.info(f"llm_response:{llm_response}") #优化中....... result = extract_nested_json(llm_response) logger.info(f"result:{result}") result = {"resp": orjson.loads(result)} logger.info(f"llm_response:{llm_response}") sql = check_and_get_sql(llm_response) logger.info(f"sql:{sql}") # ---------------生成图表 char_type = get_chart_type_from_sql_answer(llm_response) logger.info(f"chart type:{char_type}") if char_type: sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', sql=sql, chart_type=char_type) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) llm_response2 = self.submit_prompt( [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) print(llm_response2) result['chart'] = orjson.loads(extract_nested_json(llm_response2)) logger.info(f"chart_response:{result}") logger.info("Finish to generate_sql_2 in cus_vanna_srevice") return result except Exception as e: logger.info("cus_vanna_srevice failed-------------------: ") traceback.print_exc() raise e def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: logger.info(f"generate_rewritten_question---------------{new_question}") if last_question is None: return new_question prompt = [ self.system_message( "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."), self.user_message(new_question), ] return self.submit_prompt(prompt=prompt, **kwargs) class CustomQdrant_VectorStore(Qdrant_VectorStore): def __init__( self, config_file={} ): self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='') self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='') self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='') super().__init__(config_file) def generate_embedding(self, data: str, **kwargs) -> List[float]: def _get_error_string(response: requests.Response) -> str: try: if response.content: return response.json()["detail"] except Exception: pass try: response.raise_for_status() except requests.HTTPError as e: return str(e) return "Unknown error" request_body = { "model": self.embedding_model_name, "sentences": [data], } request_body.update(kwargs) response = requests.post( url=f"{self.embedding_api_base}", json=request_body, headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'}, ) if response.status_code != 200: raise RuntimeError( f"Failed to create the embeddings, detail: {_get_error_string(response)}" ) result = response.json() embeddings = result['embeddings'] return embeddings[0] class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): def __init__(self, llm_config=None, vector_store_config=None): CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) OpenAICompatibleLLM.__init__(self, config_file=llm_config) class TTLCacheWrapper: "为MemoryCache()添加带ttl的包装器,防治内存泄漏" def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600): self.cache = cache or MemoryCache() self.ttl = ttl self._expiry_times = {} self._cleanup_thread = None self._start_cleanup() def _start_cleanup(self): """启动后台清理线程""" def cleanup(): while True: current_time = time.time() expired_keys = [] # 找出所有过期的key for key_info, expiry in self._expiry_times.items(): if expiry <= current_time: expired_keys.append(key_info) # 清理过期数据 for key_info in expired_keys: id, field = key_info if hasattr(self.cache, 'delete'): self.cache.delete(id=id) del self._expiry_times[key_info] time.sleep(180) # 每3分钟清理一次 self._cleanup_thread = threading.Thread(target=cleanup, daemon=True) self._cleanup_thread.start() def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None): """设置缓存值,支持TTL""" # 使用提供的TTL或默认TTL actual_ttl = ttl if ttl is not None else self.ttl # 调用原始cache的set方法 self.cache.set(id=id, field=field, value=value) # 记录过期时间 key_info = (id, field) self._expiry_times[key_info] = time.time() + actual_ttl def get(self, id: str, field: str) -> Any: """获取缓存值,自动处理过期""" key_info = (id, field) # 检查是否过期 if key_info in self._expiry_times: if time.time() > self._expiry_times[key_info]: # 已过期,删除并返回None if hasattr(self.cache, 'delete'): self.cache.delete(id=id) del self._expiry_times[key_info] return None # 返回缓存值 return self.cache.get(id=id, field=field) def delete(self, id: str, field: str): """删除缓存值""" key_info = (id, field) if hasattr(self.cache, 'delete'): self.cache.delete(id=id) if key_info in self._expiry_times: del self._expiry_times[key_info] def exists(self, id: str, field: str) -> bool: """检查缓存是否存在且未过期""" key_info = (id, field) if key_info in self._expiry_times: if time.time() > self._expiry_times[key_info]: # 已过期,清理并返回False self.delete(id=id, field=field) return False return self.get(id=id, field=field) is not None # 代理其他方法到原始cache def __getattr__(self, name): return getattr(self.cache, name)