403 lines
15 KiB
Python
403 lines
15 KiB
Python
from dataclasses import field
|
||
from email.policy import default
|
||
from typing import List, Union, Any, Optional
|
||
import time
|
||
import threading
|
||
from vanna.flask import Cache, MemoryCache
|
||
import dmPython
|
||
import orjson
|
||
import pandas as pd
|
||
from vanna.base import VannaBase
|
||
from vanna.flask import MemoryCache
|
||
from vanna.qdrant import Qdrant_VectorStore
|
||
from qdrant_client import QdrantClient
|
||
from openai import OpenAI
|
||
import requests
|
||
from decouple import config
|
||
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
|
||
import json
|
||
from template.template import get_base_template
|
||
from datetime import datetime
|
||
import logging
|
||
from util import train_ddl
|
||
logger = logging.getLogger(__name__)
|
||
import traceback
|
||
class OpenAICompatibleLLM(VannaBase):
|
||
def __init__(self, client=None, config_file=None):
|
||
VannaBase.__init__(self, config=config_file)
|
||
# default parameters - can be overrided using config
|
||
self.temperature = 0.5
|
||
self.max_tokens = 5000
|
||
|
||
if "temperature" in config_file:
|
||
self.temperature = config_file["temperature"]
|
||
|
||
if "max_tokens" in config_file:
|
||
self.max_tokens = config_file["max_tokens"]
|
||
|
||
if "api_type" in config_file:
|
||
raise Exception(
|
||
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
||
)
|
||
|
||
if "api_version" in config_file:
|
||
raise Exception(
|
||
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
||
)
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
return
|
||
|
||
if "api_base" not in config_file:
|
||
raise Exception("Please passing api_base")
|
||
|
||
if "api_key" not in config_file:
|
||
raise Exception("Please passing api_key")
|
||
|
||
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
|
||
|
||
def system_message(self, message: str) -> any:
|
||
return {"role": "system", "content": message}
|
||
|
||
def connect_to_dameng(
|
||
self,
|
||
host: str = None,
|
||
dbname: str = None,
|
||
user: str = None,
|
||
password: str = None,
|
||
port: int = None,
|
||
**kwargs
|
||
):
|
||
self.conn = None
|
||
try:
|
||
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
||
except Exception as e:
|
||
raise Exception(f"Failed to connect to dameng database: {e}")
|
||
|
||
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
|
||
logger.info(f"start to run_sql_damengsql")
|
||
try:
|
||
if not is_connection_alive(conn=self.conn):
|
||
logger.info("connection is not alive, reconnecting..........")
|
||
reconnect()
|
||
# conn.ping(reconnect=True)
|
||
cs = self.conn.cursor()
|
||
cs.execute(sql)
|
||
results = cs.fetchall()
|
||
|
||
# Create a pandas dataframe from the results
|
||
df = pd.DataFrame(
|
||
results, columns=[desc[0] for desc in cs.description]
|
||
)
|
||
|
||
return df
|
||
except Exception as e:
|
||
self.conn.rollback()
|
||
logger.error(f"Failed to execute sql query: {e}")
|
||
raise e
|
||
return None
|
||
|
||
def reconnect():
|
||
try:
|
||
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
||
except Exception as e:
|
||
raise Exception(f"reconnect failed: {e}")
|
||
def is_connection_alive(conn) -> bool:
|
||
if conn is None:
|
||
return False
|
||
try:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT 1 FROM DUAL")
|
||
cursor.close()
|
||
return True
|
||
except Exception as e:
|
||
return False
|
||
|
||
self.run_sql_is_set = True
|
||
self.run_sql = run_sql_damengsql
|
||
|
||
|
||
def user_message(self, message: str) -> any:
|
||
return {"role": "user", "content": message}
|
||
|
||
def assistant_message(self, message: str) -> any:
|
||
return {"role": "assistant", "content": message}
|
||
|
||
def submit_prompt(self, prompt, **kwargs) -> str:
|
||
if prompt is None:
|
||
raise Exception("Prompt is None")
|
||
|
||
if len(prompt) == 0:
|
||
raise Exception("Prompt is empty")
|
||
print(prompt)
|
||
|
||
num_tokens = 0
|
||
for message in prompt:
|
||
num_tokens += len(message["content"]) / 4
|
||
|
||
if kwargs.get("model", None) is not None:
|
||
model = kwargs.get("model", None)
|
||
print(
|
||
f"Using model {model} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
elif kwargs.get("engine", None) is not None:
|
||
engine = kwargs.get("engine", None)
|
||
print(
|
||
f"Using model {engine} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
engine=engine,
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
elif self.config is not None and "engine" in self.config:
|
||
print(
|
||
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
engine=self.config["engine"],
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
elif self.config is not None and "model" in self.config:
|
||
print(
|
||
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
model=self.config["model"],
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
else:
|
||
if num_tokens > 3500:
|
||
model = "kimi"
|
||
else:
|
||
model = "doubao"
|
||
|
||
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
|
||
for choice in response.choices:
|
||
if "text" in choice:
|
||
return choice.text
|
||
|
||
return response.choices[0].message.content
|
||
|
||
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||
try:
|
||
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
|
||
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
||
if question_sql_list and len(question_sql_list)>2:
|
||
question_sql_list=question_sql_list[:1]
|
||
|
||
ddl_list = self.get_related_ddl(question, **kwargs)
|
||
#doc_list = self.get_related_documentation(question, **kwargs)
|
||
template = get_base_template()
|
||
sql_temp = template['template']['sql']
|
||
char_temp = template['template']['chart']
|
||
history = None
|
||
if user_id and cache:
|
||
history = cache.get(id=user_id, field="data")
|
||
# --------基于提示词,生成sql以及图表类型
|
||
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
|
||
schema=ddl_list, documentation=[train_ddl.train_document],
|
||
history=history,retrieved_examples_data=question_sql_list,
|
||
data_training=question_sql_list,)
|
||
|
||
user_temp = sql_temp['user'].format(question=question,
|
||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||
logger.info(f"user_temp:{user_temp}")
|
||
logger.info(f"sys_temp:{sys_temp}")
|
||
llm_response = self.submit_prompt(
|
||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||
logger.info(f"llm_response:{llm_response}")
|
||
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
||
logger.info(f"llm_response:{llm_response}")
|
||
sql = check_and_get_sql(llm_response)
|
||
logger.info(f"sql:{sql}")
|
||
# ---------------生成图表
|
||
char_type = get_chart_type_from_sql_answer(llm_response)
|
||
logger.info(f"chart type:{char_type}")
|
||
if char_type:
|
||
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
||
lang='中文', sql=sql, chart_type=char_type)
|
||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||
llm_response2 = self.submit_prompt(
|
||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
|
||
**kwargs)
|
||
print(llm_response2)
|
||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||
logger.info(f"chart_response:{result}")
|
||
|
||
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
|
||
return result
|
||
except Exception as e:
|
||
logger.info("cus_vanna_srevice failed-------------------: ")
|
||
traceback.print_exc()
|
||
raise e
|
||
|
||
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
||
logger.info(f"generate_rewritten_question---------------{new_question}")
|
||
if last_question is None:
|
||
return new_question
|
||
|
||
prompt = [
|
||
self.system_message(
|
||
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
|
||
self.user_message(new_question),
|
||
]
|
||
|
||
return self.submit_prompt(prompt=prompt, **kwargs)
|
||
|
||
|
||
class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||
def __init__(
|
||
self,
|
||
config_file={}
|
||
):
|
||
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
|
||
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
|
||
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
|
||
super().__init__(config_file)
|
||
|
||
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
||
def _get_error_string(response: requests.Response) -> str:
|
||
try:
|
||
if response.content:
|
||
return response.json()["detail"]
|
||
except Exception:
|
||
pass
|
||
try:
|
||
response.raise_for_status()
|
||
except requests.HTTPError as e:
|
||
return str(e)
|
||
return "Unknown error"
|
||
|
||
request_body = {
|
||
"model": self.embedding_model_name,
|
||
"sentences": [data],
|
||
}
|
||
request_body.update(kwargs)
|
||
|
||
response = requests.post(
|
||
url=f"{self.embedding_api_base}",
|
||
json=request_body,
|
||
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
|
||
)
|
||
if response.status_code != 200:
|
||
raise RuntimeError(
|
||
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
|
||
)
|
||
result = response.json()
|
||
embeddings = result['embeddings']
|
||
return embeddings[0]
|
||
|
||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
|
||
def __init__(self, llm_config=None, vector_store_config=None):
|
||
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
|
||
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
|
||
|
||
class TTLCacheWrapper:
|
||
"为MemoryCache()添加带ttl的包装器,防治内存泄漏"
|
||
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
|
||
self.cache = cache or MemoryCache()
|
||
self.ttl = ttl
|
||
self._expiry_times = {}
|
||
self._cleanup_thread = None
|
||
self._start_cleanup()
|
||
|
||
def _start_cleanup(self):
|
||
"""启动后台清理线程"""
|
||
|
||
def cleanup():
|
||
while True:
|
||
current_time = time.time()
|
||
expired_keys = []
|
||
|
||
# 找出所有过期的key
|
||
for key_info, expiry in self._expiry_times.items():
|
||
if expiry <= current_time:
|
||
expired_keys.append(key_info)
|
||
|
||
# 清理过期数据
|
||
for key_info in expired_keys:
|
||
id, field = key_info
|
||
if hasattr(self.cache, 'delete'):
|
||
self.cache.delete(id=id)
|
||
del self._expiry_times[key_info]
|
||
|
||
time.sleep(180) # 每3分钟清理一次
|
||
|
||
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
|
||
self._cleanup_thread.start()
|
||
|
||
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
|
||
"""设置缓存值,支持TTL"""
|
||
# 使用提供的TTL或默认TTL
|
||
actual_ttl = ttl if ttl is not None else self.ttl
|
||
|
||
# 调用原始cache的set方法
|
||
self.cache.set(id=id, field=field, value=value)
|
||
|
||
# 记录过期时间
|
||
key_info = (id, field)
|
||
self._expiry_times[key_info] = time.time() + actual_ttl
|
||
|
||
def get(self, id: str, field: str) -> Any:
|
||
"""获取缓存值,自动处理过期"""
|
||
key_info = (id, field)
|
||
|
||
# 检查是否过期
|
||
if key_info in self._expiry_times:
|
||
if time.time() > self._expiry_times[key_info]:
|
||
# 已过期,删除并返回None
|
||
if hasattr(self.cache, 'delete'):
|
||
self.cache.delete(id=id)
|
||
del self._expiry_times[key_info]
|
||
return None
|
||
|
||
# 返回缓存值
|
||
return self.cache.get(id=id, field=field)
|
||
|
||
def delete(self, id: str, field: str):
|
||
"""删除缓存值"""
|
||
key_info = (id, field)
|
||
if hasattr(self.cache, 'delete'):
|
||
self.cache.delete(id=id)
|
||
if key_info in self._expiry_times:
|
||
del self._expiry_times[key_info]
|
||
|
||
def exists(self, id: str, field: str) -> bool:
|
||
"""检查缓存是否存在且未过期"""
|
||
key_info = (id, field)
|
||
if key_info in self._expiry_times:
|
||
if time.time() > self._expiry_times[key_info]:
|
||
# 已过期,清理并返回False
|
||
self.delete(id=id, field=field)
|
||
return False
|
||
return self.get(id=id, field=field) is not None
|
||
|
||
# 代理其他方法到原始cache
|
||
def __getattr__(self, name):
|
||
return getattr(self.cache, name) |