Files
sqlbot_agent/service/cus_vanna_srevice.py

417 lines
16 KiB
Python
Raw Normal View History

2025-10-16 15:00:38 +08:00
from typing import List, Union, Any, Optional
import time
import threading
from vanna.flask import Cache, MemoryCache
import dmPython
2025-09-23 14:49:00 +08:00
import orjson
import pandas as pd
2025-09-23 14:49:00 +08:00
from vanna.base import VannaBase
from vanna.flask import MemoryCache
2025-09-23 14:49:00 +08:00
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
import json
from template.template import get_base_template
from datetime import datetime
2025-09-25 08:35:49 +08:00
import logging
from util import train_ddl
2025-09-25 08:35:49 +08:00
logger = logging.getLogger(__name__)
2025-09-28 16:44:58 +08:00
import traceback
2025-09-23 14:49:00 +08:00
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
# default parameters - can be overrided using config
2025-10-15 14:42:48 +08:00
self.temperature = 0.6
2025-09-23 14:49:00 +08:00
self.max_tokens = 5000
if "temperature" in config_file:
self.temperature = config_file["temperature"]
if "max_tokens" in config_file:
self.max_tokens = config_file["max_tokens"]
if "api_type" in config_file:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)
if "api_version" in config_file:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)
if client is not None:
self.client = client
return
if "api_base" not in config_file:
raise Exception("Please passing api_base")
if "api_key" not in config_file:
raise Exception("Please passing api_key")
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def connect_to_dameng(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):
self.conn = None
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"Failed to connect to dameng database: {e}")
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
logger.info(f"start to run_sql_damengsql")
try:
2025-10-13 18:18:58 +08:00
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
# conn.ping(reconnect=True)
cs = self.conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except Exception as e:
self.conn.rollback()
logger.error(f"Failed to execute sql query: {e}")
raise e
return None
def reconnect():
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"reconnect failed: {e}")
def is_connection_alive(conn) -> bool:
if conn is None:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM DUAL")
cursor.close()
return True
except Exception as e:
return False
self.run_sql_is_set = True
self.run_sql = run_sql_damengsql
2025-09-23 14:49:00 +08:00
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
2025-10-15 14:42:48 +08:00
print(self.config)
2025-09-23 14:49:00 +08:00
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
2025-10-15 16:52:35 +08:00
# data = {
# "model": self.model,
# "prompt": prompt,
# "max_tokens": self.max_tokens,
# "temperature": self.temperature,
# }
# response = requests.post(
# url=f"{self.api_base}/completions",
# headers=self.headers,
# json=data
# )
2025-09-23 14:49:00 +08:00
else:
if num_tokens > 3500:
model = "kimi"
else:
model = "doubao"
2025-10-15 16:52:35 +08:00
print(f"5.Using model {model} for {num_tokens} tokens (approx)")
2025-09-23 14:49:00 +08:00
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
for choice in response.choices:
if "text" in choice:
return choice.text
return response.choices[0].message.content
2025-10-13 18:18:58 +08:00
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
2025-09-24 14:39:42 +08:00
try:
2025-09-25 16:49:25 +08:00
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
2025-09-24 14:39:42 +08:00
question_sql_list = self.get_similar_question_sql(question, **kwargs)
2025-09-28 16:44:58 +08:00
if question_sql_list and len(question_sql_list)>2:
2025-10-14 11:02:44 +08:00
question_sql_list=question_sql_list[:2]
2025-09-28 16:44:58 +08:00
2025-09-24 14:39:42 +08:00
ddl_list = self.get_related_ddl(question, **kwargs)
#doc_list = self.get_related_documentation(question, **kwargs)
2025-09-24 14:39:42 +08:00
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
2025-10-13 18:18:58 +08:00
history = None
if user_id and cache:
history = cache.get(id=user_id, field="data")
2025-09-24 14:39:42 +08:00
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
2025-10-13 18:18:58 +08:00
history=history,retrieved_examples_data=question_sql_list,
2025-09-28 16:44:58 +08:00
data_training=question_sql_list,)
2025-10-13 18:18:58 +08:00
2025-09-24 14:39:42 +08:00
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
2025-09-25 08:35:49 +08:00
logger.info(f"user_temp:{user_temp}")
2025-10-13 18:18:58 +08:00
logger.info(f"sys_temp:{sys_temp}")
2025-09-24 14:39:42 +08:00
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
2025-10-15 16:52:35 +08:00
llm_response = str(llm_response.strip())
2025-09-25 08:35:49 +08:00
logger.info(f"llm_response:{llm_response}")
#优化中.......
result = extract_nested_json(llm_response)
logger.info(f"result:{result}")
result = {"resp": orjson.loads(result)}
2025-09-25 16:49:25 +08:00
logger.info(f"llm_response:{llm_response}")
2025-09-24 14:39:42 +08:00
sql = check_and_get_sql(llm_response)
2025-09-25 16:49:25 +08:00
logger.info(f"sql:{sql}")
2025-09-24 14:39:42 +08:00
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
2025-09-25 16:49:25 +08:00
logger.info(f"chart type:{char_type}")
2025-09-24 14:39:42 +08:00
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
2025-09-24 14:39:42 +08:00
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
2025-09-25 08:35:49 +08:00
logger.info(f"chart_response:{result}")
2025-09-25 16:49:25 +08:00
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
2025-09-24 14:39:42 +08:00
return result
except Exception as e:
2025-09-28 16:44:58 +08:00
logger.info("cus_vanna_srevice failed-------------------: ")
traceback.print_exc()
2025-09-24 14:39:42 +08:00
raise e
2025-09-23 14:49:00 +08:00
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
2025-09-25 08:35:49 +08:00
logger.info(f"generate_rewritten_question---------------{new_question}")
if last_question is None:
return new_question
prompt = [
self.system_message(
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
self.user_message(new_question),
]
return self.submit_prompt(prompt=prompt, **kwargs)
2025-09-23 14:49:00 +08:00
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(
self,
config_file={}
):
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
super().__init__(config_file)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
def _get_error_string(response: requests.Response) -> str:
try:
if response.content:
return response.json()["detail"]
except Exception:
pass
try:
response.raise_for_status()
except requests.HTTPError as e:
return str(e)
return "Unknown error"
request_body = {
"model": self.embedding_model_name,
"sentences": [data],
2025-09-23 14:49:00 +08:00
}
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}",
2025-09-23 14:49:00 +08:00
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
2025-09-23 14:49:00 +08:00
)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
embeddings = result['embeddings']
2025-09-23 14:49:00 +08:00
return embeddings[0]
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
class TTLCacheWrapper:
"为MemoryCache()添加带ttl的包装器防治内存泄漏"
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
self.cache = cache or MemoryCache()
self.ttl = ttl
self._expiry_times = {}
self._cleanup_thread = None
self._start_cleanup()
def _start_cleanup(self):
"""启动后台清理线程"""
def cleanup():
while True:
current_time = time.time()
expired_keys = []
# 找出所有过期的key
for key_info, expiry in self._expiry_times.items():
if expiry <= current_time:
expired_keys.append(key_info)
# 清理过期数据
for key_info in expired_keys:
id, field = key_info
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
time.sleep(180) # 每3分钟清理一次
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
self._cleanup_thread.start()
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
"""设置缓存值支持TTL"""
# 使用提供的TTL或默认TTL
actual_ttl = ttl if ttl is not None else self.ttl
# 调用原始cache的set方法
self.cache.set(id=id, field=field, value=value)
# 记录过期时间
key_info = (id, field)
self._expiry_times[key_info] = time.time() + actual_ttl
def get(self, id: str, field: str) -> Any:
"""获取缓存值,自动处理过期"""
key_info = (id, field)
# 检查是否过期
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期删除并返回None
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
return None
# 返回缓存值
return self.cache.get(id=id, field=field)
def delete(self, id: str, field: str):
"""删除缓存值"""
key_info = (id, field)
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
if key_info in self._expiry_times:
del self._expiry_times[key_info]
def exists(self, id: str, field: str) -> bool:
"""检查缓存是否存在且未过期"""
key_info = (id, field)
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期清理并返回False
self.delete(id=id, field=field)
return False
return self.get(id=id, field=field) is not None
# 代理其他方法到原始cache
def __getattr__(self, name):
return getattr(self.cache, name)