Files
sqlbot_agent/service/cus_vanna_srevice.py
2025-11-01 10:16:06 +08:00

501 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Union, Any, Optional
import time
import threading
from vanna.flask import Cache, MemoryCache
import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
from vanna.flask import MemoryCache
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
import json
from template.template import get_base_template
from datetime import datetime
import logging
from util import train_ddl
logger = logging.getLogger(__name__)
import traceback
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
# default parameters - can be overrided using config
self.temperature = 0.6
self.max_tokens = 10000
if "temperature" in config_file:
self.temperature = config_file["temperature"]
if "max_tokens" in config_file:
self.max_tokens = config_file["max_tokens"]
if "api_type" in config_file:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)
if "api_version" in config_file:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)
if client is not None:
self.client = client
return
if "api_base" not in config_file:
raise Exception("Please passing api_base")
if "api_key" not in config_file:
raise Exception("Please passing api_key")
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def connect_to_dameng(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):
self.conn = None
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"Failed to connect to dameng database: {e}")
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
logger.info(f"start to run_sql_damengsql")
try:
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
# conn.ping(reconnect=True)
cs = self.conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except Exception as e:
self.conn.rollback()
logger.error(f"Failed to execute sql query: {e}")
raise e
return None
def reconnect():
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"reconnect failed: {e}")
def is_connection_alive(conn) -> bool:
if conn is None:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM DUAL")
cursor.close()
return True
except Exception as e:
return False
self.run_sql_is_set = True
self.run_sql = run_sql_damengsql
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
print("test1")
raise Exception("Prompt is None")
if len(prompt) == 0:
print("test2")
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
print("test3 {0}".format(num_tokens))
if kwargs.get("model", None) is not None:
print("test4")
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
print("test5")
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print("test6")
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print("test7")
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
print("config is ",self.config)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
# data = {
# "model": self.model,
# "prompt": prompt,
# "max_tokens": self.max_tokens,
# "temperature": self.temperature,
# }
# response = requests.post(
# url=f"{self.api_base}/completions",
# headers=self.headers,
# json=data
# )
else:
print("test8")
if num_tokens > 3500:
model = "kimi"
else:
model = "doubao"
print(f"5.Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
for choice in response.choices:
if "text" in choice:
return choice.text
return response.choices[0].message.content
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
if question_sql_list and len(question_sql_list)>2:
question_sql_list=question_sql_list[:2]
ddl_list = self.get_related_ddl(question, **kwargs)
#doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
history = None
if user_id and cache:
history = cache.get(id=user_id, field="data")
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
history=history,retrieved_examples_data=question_sql_list,
data_training=question_sql_list,)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
logger.info(f"user_temp:{user_temp}")
logger.info(f"sys_temp:{sys_temp}")
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
llm_response = str(llm_response.strip())
logger.info(f"llm_response:{llm_response}")
#优化中.......
result = extract_nested_json(llm_response)
logger.info(f"result:{result}")
result = {"resp": orjson.loads(result)}
logger.info(f"llm_response:{llm_response}")
sql = check_and_get_sql(llm_response)
logger.info(f"sql:{sql}")
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
logger.info(f"chart type:{char_type}")
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
logger.info(f"chart_response:{result}")
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
return result
except Exception as e:
logger.info("cus_vanna_srevice failed-------------------: ")
traceback.print_exc()
raise e
def run_sql_2(self,sql):
try:
return self.run_sql(sql)
except Exception as e:
logger.error("run_sql failed {0}".format(sql))
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
logger.info(f"generate_rewritten_question---------------{new_question}")
if last_question is None:
return new_question
print("last question {0}".format(last_question))
print("new question {0}".format(new_question))
prompt = [
self.system_message(
"你的目标是将一系列相关问题合并成一个单一的问题。"
"合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。"
"合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。"
"合并准则三、理论上合并后的问题应该能够通过单个SQL语句来回答"),
self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
]
return self.submit_prompt(prompt=prompt, **kwargs)
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(
self,
config_file={}
):
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
super().__init__(config_file)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
def _get_error_string(response: requests.Response) -> str:
try:
if response.content:
return response.json()["detail"]
except Exception:
pass
try:
response.raise_for_status()
except requests.HTTPError as e:
return str(e)
return "Unknown error"
request_body = {
"model": self.embedding_model_name,
"sentences": [data],
}
# request_body = {
# "model": self.embedding_model_name,
# "encoding_format": "float",
# "input": [data],
# }
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}",
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
embeddings = result['embeddings']
# embeddings = result['data'][0]['embedding']
return embeddings[0]
# return embeddings
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
class TTLCacheWrapper:
"为MemoryCache()添加带ttl的包装器防治内存泄漏"
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
self.cache = cache or MemoryCache()
self.ttl = ttl
self._expiry_times = {}
self._cleanup_thread = None
self._start_cleanup()
def _start_cleanup(self):
"""启动后台清理线程"""
def cleanup():
while True:
current_time = time.time()
expired_keys = []
# 找出所有过期的key
for key_info, expiry in self._expiry_times.items():
if expiry <= current_time:
expired_keys.append(key_info)
# 清理过期数据
for key_info in expired_keys:
id, field = key_info
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
time.sleep(180) # 每3分钟清理一次
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
self._cleanup_thread.start()
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
"""设置缓存值支持TTL"""
# 使用提供的TTL或默认TTL
actual_ttl = ttl if ttl is not None else self.ttl
# 调用原始cache的set方法
self.cache.set(id=id, field=field, value=value)
# 记录过期时间
key_info = (id, field)
self._expiry_times[key_info] = time.time() + actual_ttl
def get(self, id: str, field: str) -> Any:
"""获取缓存值,自动处理过期"""
key_info = (id, field)
# 检查是否过期
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期删除并返回None
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
return None
# 返回缓存值
return self.cache.get(id=id, field=field)
def delete(self, id: str, field: str):
"""删除缓存值"""
key_info = (id, field)
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
if key_info in self._expiry_times:
del self._expiry_times[key_info]
def exists(self, id: str, field: str) -> bool:
"""检查缓存是否存在且未过期"""
key_info = (id, field)
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期清理并返回False
self.delete(id=id, field=field)
return False
return self.get(id=id, field=field) is not None
def items(self):
"""遍历所有未过期的缓存键值对"""
current_time = time.time()
items = []
for (id, field), expiry in self._expiry_times.items():
# 检查是否过期
if current_time <= expiry:
value = self.get(id=id, field=field)
if value is not None:
items.append({
'id': id,
'field': field,
'value': value,
'expires_at': expiry,
'time_left': expiry - current_time
})
return items
def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None):
"""
获取指定ID下时间最近的缓存项
Args:
id: 要查询的ID
limit: 返回最近几条记录默认1条
field_filter: 可选的字段过滤,如只获取特定字段
Returns:
按时间倒序排列的缓存项列表
"""
current_time = time.time()
matched_items = []
# 找出该ID下所有未过期的缓存项
for (cache_id, field), expiry in self._expiry_times.items():
if cache_id == id and current_time <= expiry:
# 字段过滤
if field_filter and field != field_filter:
continue
value = self.get(id=id, field=field)
if value is not None:
matched_items.append({
'id': id,
'field': field,
'value': value,
'expires_at': expiry,
'created_time': expiry - self.ttl, # 估算创建时间
'time_left': expiry - current_time
})
# 按过期时间倒序排列(最近创建的排在前面)
matched_items.sort(key=lambda x: x['expires_at'], reverse=True)
return matched_items[:limit]
# 代理其他方法到原始cache
def __getattr__(self, name):
return getattr(self.cache, name)