418 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			418 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from dataclasses import field
 | ||
| from email.policy import default
 | ||
| from typing import List, Union, Any, Optional
 | ||
| import time
 | ||
| import threading
 | ||
| from vanna.flask import Cache, MemoryCache
 | ||
| import  dmPython
 | ||
| import orjson
 | ||
| import pandas as pd
 | ||
| from vanna.base import VannaBase
 | ||
| from vanna.flask import MemoryCache
 | ||
| from vanna.qdrant import Qdrant_VectorStore
 | ||
| from qdrant_client import QdrantClient
 | ||
| from openai import OpenAI
 | ||
| import requests
 | ||
| from decouple import config
 | ||
| from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
 | ||
| import json
 | ||
| from template.template import get_base_template
 | ||
| from datetime import datetime
 | ||
| import logging
 | ||
| from util import  train_ddl
 | ||
| logger = logging.getLogger(__name__)
 | ||
| import traceback
 | ||
| class OpenAICompatibleLLM(VannaBase):
 | ||
|     def __init__(self, client=None, config_file=None):
 | ||
|         VannaBase.__init__(self, config=config_file)
 | ||
|         # default parameters - can be overrided using config
 | ||
|         self.temperature = 0.6
 | ||
|         self.max_tokens = 5000
 | ||
| 
 | ||
|         if "temperature" in config_file:
 | ||
|             self.temperature = config_file["temperature"]
 | ||
| 
 | ||
|         if "max_tokens" in config_file:
 | ||
|             self.max_tokens = config_file["max_tokens"]
 | ||
| 
 | ||
|         if "api_type" in config_file:
 | ||
|             raise Exception(
 | ||
|                 "Passing api_type is now deprecated. Please pass an OpenAI client instead."
 | ||
|             )
 | ||
| 
 | ||
|         if "api_version" in config_file:
 | ||
|             raise Exception(
 | ||
|                 "Passing api_version is now deprecated. Please pass an OpenAI client instead."
 | ||
|             )
 | ||
| 
 | ||
|         if client is not None:
 | ||
|             self.client = client
 | ||
|             return
 | ||
| 
 | ||
|         if "api_base" not in config_file:
 | ||
|             raise Exception("Please passing api_base")
 | ||
| 
 | ||
|         if "api_key" not in config_file:
 | ||
|             raise Exception("Please passing api_key")
 | ||
| 
 | ||
|         self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
 | ||
| 
 | ||
|     def system_message(self, message: str) -> any:
 | ||
|         return {"role": "system", "content": message}
 | ||
| 
 | ||
|     def connect_to_dameng(
 | ||
|             self,
 | ||
|             host: str = None,
 | ||
|             dbname: str = None,
 | ||
|             user: str = None,
 | ||
|             password: str = None,
 | ||
|             port: int = None,
 | ||
|             **kwargs
 | ||
|     ):
 | ||
|         self.conn = None
 | ||
|         try:
 | ||
|             self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
 | ||
|         except Exception as e:
 | ||
|             raise Exception(f"Failed to connect to dameng database: {e}")
 | ||
| 
 | ||
|         def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
 | ||
|             logger.info(f"start to run_sql_damengsql")
 | ||
|             try:
 | ||
|                 if not is_connection_alive(conn=self.conn):
 | ||
|                     logger.info("connection is not alive, reconnecting..........")
 | ||
|                     reconnect()
 | ||
|                 # conn.ping(reconnect=True)
 | ||
|                 cs = self.conn.cursor()
 | ||
|                 cs.execute(sql)
 | ||
|                 results = cs.fetchall()
 | ||
| 
 | ||
|                 # Create a pandas dataframe from the results
 | ||
|                 df = pd.DataFrame(
 | ||
|                     results, columns=[desc[0] for desc in cs.description]
 | ||
|                 )
 | ||
| 
 | ||
|                 return df
 | ||
|             except Exception as e:
 | ||
|                 self.conn.rollback()
 | ||
|                 logger.error(f"Failed to execute sql query: {e}")
 | ||
|                 raise e
 | ||
|             return None
 | ||
| 
 | ||
|         def reconnect():
 | ||
|             try:
 | ||
|                 self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
 | ||
|             except Exception as e:
 | ||
|                 raise Exception(f"reconnect failed: {e}")
 | ||
|         def is_connection_alive(conn) -> bool:
 | ||
|             if conn is None:
 | ||
|                 return False
 | ||
|             try:
 | ||
|                 cursor = conn.cursor()
 | ||
|                 cursor.execute("SELECT 1 FROM DUAL")
 | ||
|                 cursor.close()
 | ||
|                 return True
 | ||
|             except Exception as e:
 | ||
|                 return False
 | ||
| 
 | ||
|         self.run_sql_is_set = True
 | ||
|         self.run_sql = run_sql_damengsql
 | ||
| 
 | ||
| 
 | ||
|     def user_message(self, message: str) -> any:
 | ||
|         return {"role": "user", "content": message}
 | ||
| 
 | ||
|     def assistant_message(self, message: str) -> any:
 | ||
|         return {"role": "assistant", "content": message}
 | ||
| 
 | ||
|     def submit_prompt(self, prompt, **kwargs) -> str:
 | ||
|         if prompt is None:
 | ||
|             raise Exception("Prompt is None")
 | ||
| 
 | ||
|         if len(prompt) == 0:
 | ||
|             raise Exception("Prompt is empty")
 | ||
|         print(prompt)
 | ||
| 
 | ||
|         num_tokens = 0
 | ||
|         for message in prompt:
 | ||
|             num_tokens += len(message["content"]) / 4
 | ||
| 
 | ||
|         if kwargs.get("model", None) is not None:
 | ||
|             model = kwargs.get("model", None)
 | ||
|             print(
 | ||
|                 f"Using model {model} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 model=model,
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         elif kwargs.get("engine", None) is not None:
 | ||
|             engine = kwargs.get("engine", None)
 | ||
|             print(
 | ||
|                 f"Using model {engine} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 engine=engine,
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         elif self.config is not None and "engine" in self.config:
 | ||
|             print(
 | ||
|                 f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 engine=self.config["engine"],
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         elif self.config is not None and "model" in self.config:
 | ||
|             print(
 | ||
|                 f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             print(self.config)
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 model=self.config["model"],
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|             # data = {
 | ||
|             #     "model": self.model,
 | ||
|             #     "prompt": prompt,
 | ||
|             #     "max_tokens": self.max_tokens,
 | ||
|             #     "temperature": self.temperature,
 | ||
|             # }
 | ||
|             # response = requests.post(
 | ||
|             #     url=f"{self.api_base}/completions",
 | ||
|             #     headers=self.headers,
 | ||
|             #     json=data
 | ||
|             # )
 | ||
|         else:
 | ||
|             if num_tokens > 3500:
 | ||
|                 model = "kimi"
 | ||
|             else:
 | ||
|                 model = "doubao"
 | ||
| 
 | ||
|             print(f"5.Using model {model} for {num_tokens} tokens (approx)")
 | ||
| 
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 model=model,
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         for choice in response.choices:
 | ||
|             if "text" in choice:
 | ||
|                 return choice.text
 | ||
| 
 | ||
|         return response.choices[0].message.content
 | ||
| 
 | ||
|     def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
 | ||
|         try:
 | ||
|             logger.info("Start to generate_sql_2 in cus_vanna_srevice")
 | ||
|             question_sql_list = self.get_similar_question_sql(question, **kwargs)
 | ||
|             if question_sql_list and len(question_sql_list)>2:
 | ||
|                 question_sql_list=question_sql_list[:2]
 | ||
| 
 | ||
|             ddl_list = self.get_related_ddl(question, **kwargs)
 | ||
|             #doc_list = self.get_related_documentation(question, **kwargs)
 | ||
|             template = get_base_template()
 | ||
|             sql_temp = template['template']['sql']
 | ||
|             char_temp = template['template']['chart']
 | ||
|             history = None
 | ||
|             if user_id and cache:
 | ||
|                 history = cache.get(id=user_id, field="data")
 | ||
|             # --------基于提示词,生成sql以及图表类型
 | ||
|             sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
 | ||
|                                                  schema=ddl_list, documentation=[train_ddl.train_document],
 | ||
|                                                  history=history,retrieved_examples_data=question_sql_list,
 | ||
|                                                  data_training=question_sql_list,)
 | ||
| 
 | ||
|             user_temp = sql_temp['user'].format(question=question,
 | ||
|                                                 current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
 | ||
|             logger.info(f"user_temp:{user_temp}")
 | ||
|             logger.info(f"sys_temp:{sys_temp}")
 | ||
|             llm_response = self.submit_prompt(
 | ||
|                 [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
 | ||
|             llm_response = str(llm_response.strip())
 | ||
|             logger.info(f"llm_response:{llm_response}")
 | ||
|             #优化中.......
 | ||
|             result = extract_nested_json(llm_response)
 | ||
|             logger.info(f"result:{result}")
 | ||
|             result = {"resp": orjson.loads(result)}
 | ||
|             logger.info(f"llm_response:{llm_response}")
 | ||
|             sql = check_and_get_sql(llm_response)
 | ||
|             logger.info(f"sql:{sql}")
 | ||
|             # ---------------生成图表
 | ||
|             char_type = get_chart_type_from_sql_answer(llm_response)
 | ||
|             logger.info(f"chart type:{char_type}")
 | ||
|             if char_type:
 | ||
|                 sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
 | ||
|                                                            lang='中文', sql=sql, chart_type=char_type)
 | ||
|                 user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
 | ||
|                 llm_response2 = self.submit_prompt(
 | ||
|                     [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
 | ||
|                     **kwargs)
 | ||
|                 print(llm_response2)
 | ||
|                 result['chart'] = orjson.loads(extract_nested_json(llm_response2))
 | ||
|                 logger.info(f"chart_response:{result}")
 | ||
| 
 | ||
|             logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
 | ||
|             return result
 | ||
|         except Exception as e:
 | ||
|             logger.info("cus_vanna_srevice failed-------------------: ")
 | ||
|             traceback.print_exc()
 | ||
|             raise e
 | ||
| 
 | ||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
 | ||
|         logger.info(f"generate_rewritten_question---------------{new_question}")
 | ||
|         if last_question is None:
 | ||
|             return new_question
 | ||
| 
 | ||
|         prompt = [
 | ||
|             self.system_message(
 | ||
|                 "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
 | ||
|             self.user_message(new_question),
 | ||
|         ]
 | ||
| 
 | ||
|         return self.submit_prompt(prompt=prompt, **kwargs)
 | ||
| 
 | ||
| 
 | ||
| class CustomQdrant_VectorStore(Qdrant_VectorStore):
 | ||
|     def __init__(
 | ||
|             self,
 | ||
|             config_file={}
 | ||
|     ):
 | ||
|         self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
 | ||
|         self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
 | ||
|         self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
 | ||
|         super().__init__(config_file)
 | ||
| 
 | ||
|     def generate_embedding(self, data: str, **kwargs) -> List[float]:
 | ||
|         def _get_error_string(response: requests.Response) -> str:
 | ||
|             try:
 | ||
|                 if response.content:
 | ||
|                     return response.json()["detail"]
 | ||
|             except Exception:
 | ||
|                 pass
 | ||
|             try:
 | ||
|                 response.raise_for_status()
 | ||
|             except requests.HTTPError as e:
 | ||
|                 return str(e)
 | ||
|             return "Unknown error"
 | ||
| 
 | ||
|         request_body = {
 | ||
|             "model": self.embedding_model_name,
 | ||
|             "sentences": [data],
 | ||
|         }
 | ||
|         request_body.update(kwargs)
 | ||
| 
 | ||
|         response = requests.post(
 | ||
|             url=f"{self.embedding_api_base}",
 | ||
|             json=request_body,
 | ||
|             headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
 | ||
|         )
 | ||
|         if response.status_code != 200:
 | ||
|             raise RuntimeError(
 | ||
|                 f"Failed to create the embeddings, detail: {_get_error_string(response)}"
 | ||
|             )
 | ||
|         result = response.json()
 | ||
|         embeddings = result['embeddings']
 | ||
|         return embeddings[0]
 | ||
| 
 | ||
| class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
 | ||
|     def __init__(self, llm_config=None, vector_store_config=None):
 | ||
|         CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
 | ||
|         OpenAICompatibleLLM.__init__(self, config_file=llm_config)
 | ||
| 
 | ||
| class TTLCacheWrapper:
 | ||
|     "为MemoryCache()添加带ttl的包装器,防治内存泄漏"
 | ||
|     def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
 | ||
|         self.cache = cache or MemoryCache()
 | ||
|         self.ttl = ttl
 | ||
|         self._expiry_times = {}
 | ||
|         self._cleanup_thread = None
 | ||
|         self._start_cleanup()
 | ||
| 
 | ||
|     def _start_cleanup(self):
 | ||
|         """启动后台清理线程"""
 | ||
| 
 | ||
|         def cleanup():
 | ||
|             while True:
 | ||
|                 current_time = time.time()
 | ||
|                 expired_keys = []
 | ||
| 
 | ||
|                 # 找出所有过期的key
 | ||
|                 for key_info, expiry in self._expiry_times.items():
 | ||
|                     if expiry <= current_time:
 | ||
|                         expired_keys.append(key_info)
 | ||
| 
 | ||
|                 # 清理过期数据
 | ||
|                 for key_info in expired_keys:
 | ||
|                     id, field = key_info
 | ||
|                     if hasattr(self.cache, 'delete'):
 | ||
|                         self.cache.delete(id=id)
 | ||
|                     del self._expiry_times[key_info]
 | ||
| 
 | ||
|                 time.sleep(180)  # 每3分钟清理一次
 | ||
| 
 | ||
|         self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
 | ||
|         self._cleanup_thread.start()
 | ||
| 
 | ||
|     def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
 | ||
|         """设置缓存值,支持TTL"""
 | ||
|         # 使用提供的TTL或默认TTL
 | ||
|         actual_ttl = ttl if ttl is not None else self.ttl
 | ||
| 
 | ||
|         # 调用原始cache的set方法
 | ||
|         self.cache.set(id=id, field=field, value=value)
 | ||
| 
 | ||
|         # 记录过期时间
 | ||
|         key_info = (id, field)
 | ||
|         self._expiry_times[key_info] = time.time() + actual_ttl
 | ||
| 
 | ||
|     def get(self, id: str, field: str) -> Any:
 | ||
|         """获取缓存值,自动处理过期"""
 | ||
|         key_info = (id, field)
 | ||
| 
 | ||
|         # 检查是否过期
 | ||
|         if key_info in self._expiry_times:
 | ||
|             if time.time() > self._expiry_times[key_info]:
 | ||
|                 # 已过期,删除并返回None
 | ||
|                 if hasattr(self.cache, 'delete'):
 | ||
|                     self.cache.delete(id=id)
 | ||
|                 del self._expiry_times[key_info]
 | ||
|                 return None
 | ||
| 
 | ||
|         # 返回缓存值
 | ||
|         return self.cache.get(id=id, field=field)
 | ||
| 
 | ||
|     def delete(self, id: str, field: str):
 | ||
|         """删除缓存值"""
 | ||
|         key_info = (id, field)
 | ||
|         if hasattr(self.cache, 'delete'):
 | ||
|             self.cache.delete(id=id)
 | ||
|         if key_info in self._expiry_times:
 | ||
|             del self._expiry_times[key_info]
 | ||
| 
 | ||
|     def exists(self, id: str, field: str) -> bool:
 | ||
|         """检查缓存是否存在且未过期"""
 | ||
|         key_info = (id, field)
 | ||
|         if key_info in self._expiry_times:
 | ||
|             if time.time() > self._expiry_times[key_info]:
 | ||
|                 # 已过期,清理并返回False
 | ||
|                 self.delete(id=id, field=field)
 | ||
|                 return False
 | ||
|         return self.get(id=id, field=field) is not None
 | ||
| 
 | ||
|     # 代理其他方法到原始cache
 | ||
|     def __getattr__(self, name):
 | ||
|         return getattr(self.cache, name) | 
