Files
sqlbot_agent/service/cus_vanna_srevice.py
2025-11-07 15:30:27 +08:00

664 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Union, Any, Optional
import time
import threading
from vanna.flask import Cache, MemoryCache
import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
from vanna.flask import MemoryCache
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
import json
from template.template import get_base_template
from datetime import datetime
import logging
from util import train_ddl
logger = logging.getLogger(__name__)
import traceback
from service.conversation_service import get_latest_question,update_conversation
limit_count = 3
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
# default parameters - can be overrided using config
self.temperature = 0.6
self.max_tokens = 10000
if "temperature" in config_file:
self.temperature = config_file["temperature"]
if "max_tokens" in config_file:
self.max_tokens = config_file["max_tokens"]
if "api_type" in config_file:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)
if "api_version" in config_file:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)
if client is not None:
self.client = client
return
if "api_base" not in config_file:
raise Exception("Please passing api_base")
if "api_key" not in config_file:
raise Exception("Please passing api_key")
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def connect_to_dameng(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):
self.conn = None
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"Failed to connect to dameng database: {e}")
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
logger.info(f"start to run_sql_damengsql")
try:
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
# conn.ping(reconnect=True)
cs = self.conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except Exception as e:
self.conn.rollback()
logger.error(f"Failed to execute sql query: {e}")
raise e
return None
def reconnect():
try:
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"reconnect failed: {e}")
def is_connection_alive(conn) -> bool:
if conn is None:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM DUAL")
cursor.close()
return True
except Exception as e:
return False
self.run_sql_is_set = True
self.run_sql = run_sql_damengsql
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
logger.info(f"submit prompt: {prompt}")
if prompt is None:
logger.info("test1")
raise Exception("Prompt is None")
if len(prompt) == 0:
logger.info("test2")
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
logger.info("test3 {0}".format(num_tokens))
if kwargs.get("model", None) is not None:
logger.info("test4")
model = kwargs.get("model", None)
logger.info(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
logger.info("test5")
engine = kwargs.get("engine", None)
logger.info(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
logger.info("test6")
logger.info(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
logger.info("test7")
logger.info(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
logger.info("config is ",self.config)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
# data = {
# "model": self.model,
# "prompt": prompt,
# "max_tokens": self.max_tokens,
# "temperature": self.temperature,
# }
# response = requests.post(
# url=f"{self.api_base}/completions",
# headers=self.headers,
# json=data
# )
else:
logger.info("test8")
if num_tokens > 3500:
model = "kimi"
else:
model = "doubao"
logger.info(f"5.Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
for choice in response.choices:
if "text" in choice:
return choice.text
return response.choices[0].message.content
def generate_sql_2(self, user_id: str, cvs_id: str, question: str,id: str, need_context: bool, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
if question_sql_list and len(question_sql_list)>2:
question_sql_list=question_sql_list[:2]
ddl_list = self.get_related_ddl(question, **kwargs)
#doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
history = None
if need_context:
questions = get_latest_question(cvs_id, user_id,limit_count)
logger.info(f"latest_questions is {questions}")
if questions[0] != question:
raise Exception(f"上下文不匹配 {question} {questions[0]}")
new_question = self.generate_rewritten_question(questions,**kwargs)
logger.info(f"new_question is {new_question}")
question = new_question if new_question else question
update_conversation(id, meta=question)
# if user_id and cache:
# history = cache.get(id=user_id, field="data")
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
history=history,retrieved_examples_data=question_sql_list,
data_training=question_sql_list,)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
logger.info(f"user_temp:{user_temp}")
logger.info(f"sys_temp:{sys_temp}")
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
llm_response = str(llm_response.strip())
logger.info(f"llm_response:{llm_response}")
#优化中.......
result = extract_nested_json(llm_response)
logger.info(f"result:{result}")
result = {"resp": orjson.loads(result)}
logger.info(f"llm_response:{llm_response}")
sql = check_and_get_sql(llm_response)
logger.info(f"sql:{sql}")
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
logger.info(f"chart type:{char_type}")
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
logger.info(f"chart_response:{result}")
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
return result
except Exception as e:
logger.info("cus_vanna_srevice failed-------------------: ")
traceback.print_exc()
raise e
def run_sql_2(self,sql):
try:
return self.run_sql(sql)
except Exception as e:
logger.error("run_sql failed {0}".format(sql))
raise e
def generate_rewritten_question(self, questions: str, **kwargs) -> str:
logger.info(f"generate_rewritten_question---------------{questions}")
new_question = questions[0]
context_question = questions[1:]
if not context_question:
return new_question
print("last question {0}".format(context_question))
print("new question {0}".format(new_question))
# sys_info = '''
# 你是一个问题补全助手先判断问题1是否存在信息不完整的情况如果不完整则根据上下文问题2问题3来补全问题1
# 按时间顺序从新到旧问题1、问题2、问题3问题1是用户当前提出的问题
#
# 【准则一】独立性优先
# 如果问题1本身含义完整不依赖其他问题的上下文也能被理解则直接返回问题1,禁止强行合并。
# 【准则二】最新问题优先
# 问题1始终作为核心只判断它是否需要利用前序问题补充自身信息当它含义完整时不再考虑合并。
# 合并时只能用较旧的问题问题2、问题3的信息来补全较新的问题问题1不能反向操作。
# 要以问题1的中心思想为准禁止合并后该表问题1的中心思想
# 【准则三】单向合并限制
# 只有问题1能与其他问题合并问题2和问题3之间不能单独合并。
# 【准则四】顺序依赖判断
# 只有当问题1明确依赖问题2或问题3的结果或上下文时才考虑合并。依赖特征包括
# - 问题1中包含"其中"、"这些"、"这个"、"那些"、"他"、"他们"等指代性词语
# - 问题1是在问题2或问题3基础上的细节追问
# - 问题1具有明确的顺序逻辑关系
# - 问题1缺少必要的主语或时间范围或具体查询信息等上下文信息
# 【准则五】合并范围选择
# - 如果问题1只依赖问题2则合并问题2+问题1
# - 如果问题1只依赖问题3则合并问题3+问题1
# - 如果问题1依赖问题2可得出结果依赖问题3也可得出结果则就近原则合并问题2+问题1
# - 如果问题1同时依赖问题2和问题3才能得出结果则合并问题3+问题2+问题1
# - 如果问题1不依赖任何其他问题则直接返回问题1
#
# 【准则六】合并执行原则
# 将选择的问题自然衔接成一个完整问题,不要添加任何解释性文字。
#
# 【准则七】SQL可行性验证
# 合并后的问题应该能够通过一条SQL查询语句来回答。
#
# 【准则八】兜底措施
# 当你无法判断问题一是否完整,也无法判断问题一是否依赖其他问题才能补全信息时,请直接向用户询问细节
#
#
# 示例:
# 输入问题1="早退多少天"问题2="其中迟到多少天"问题3="张三九月工作了多少天"
# 输出:"张三九月早退多少天"
#
# 输入问题1="这些天里张三是否有迟到问题2="李四考勤""问题3="张三九月的考勤"
# 输出:"张三九月的迟到情况"
#
# 输入问题1="最近一个月是否有迟到"问题2="李四考勤"问题3="张三九月的考勤"
# 输出:"张三最近一个月是否有迟到"
#
# 输入问题1="迟到了多少天"问题2="他哪几天迟到了"问题3="张三九月在林芝是否早退"
# 输出:"张三九月在林芝迟到了多少天"
#
# 输入问题1="张三九月休息了多少天" 问题2="张三九月迟到了多少天"问题3="张三其中迟到多少天"
# 输出:"张三九月休息了多少天"
#
# 输入问题1="张三九月考勤情况" 问题2="张三九月迟到了多少天"问题3="李四迟到多少天"
# 输出:"张三九月考勤情况"
# '''
# sys_info2 = '''
# 你是一个问题补全助手任务是判断用户当前提出的问题问题1是否信息完整。
#
# 若问题1语义完整且可独立理解则直接返回原问题1
#
# 若问题1信息缺失或存在指代依赖则根据其前序上下文问题2 和 问题3按时间倒序排列问题1 最新问题2 次新问题3 最早)进行最小必要补全,生成一个语义完整、忠实于原意、且能通过一条 SQL 查询回答的问题。
#
# 请严格遵循以下准则:
# <>
# <rule_title>独立性优先</rule_title>
# <rule>若问题1 本身语义完整即不依赖问题2 或问题3 也能被准确理解则直接返回问题1禁止强行合并或改写。也不再受下面的规则约束</rule>
#
# <rule_title>以最新问题为核心</rule_title>
#
# <rule>问题1 始终是查询意图的唯一来源</rule>
# <rule>补全时只能借用问题2 或问题3 中的信息如主语、时间、地点等补全问题1的缺失不得改变问题1 的核心要素。</rule>
# <rule>合并后的问题必须完全保留问题1 的原始意图、时间范围、查询对象和动作。</rule>
# <rule>如果问题1已有明确的时间、地点、人物等信息禁止用前序问题的不同信息进行覆盖替换。</rule>
#
# <rule_title>单向合并限制</rule_title>
# <rule>仅允许将问题1 与问题2 或问题3 合并。</rule>
# <rule>禁止问题2 与问题3 直接合并也禁止忽略问题1 进行其他组合。</rule>
#
# <rule_title>依赖判断标准</rule_title>
# <rule>仅当问题1 明确依赖前序问题的上下文时,才触发合并。</rule>
# <rule>包含指代词:如“这个”“这些”“其中”“他”“他们”等;</rule>
# <rule>是对前一个问题的细节追问(如追问数量、时间、条件等);</rule>
# <rule>存在顺序逻辑(如“然后呢?”“接下来怎么样?”);<rule>
# <rule>缺失关键要素:如主语、时间范围、地点、对象等,需从前序问题中补全。</rule>
#
# <rule_title>合并范围选择规则</rule_title>
# 根据依赖关系,按以下优先级确定合并方式:
# <rule>仅依赖问题2 → 合并为问题1 + 问题2</rule>
# <rule>仅依赖问题3 → 合并为问题1 + 问题3</rule>
# <rule>问题2 和问题3 均可独立支撑问题1 → 采用就近原则合并问题1 + 问题2</rule>
# <rule>必须同时依赖问题2 和问题3 才能完整理解问题1 → 合并为问题1 + 问题2 + 问题3</rule>
# <rule>不依赖任何前序问题 → 直接返回问题1</rule>
# <rule_title>自然衔接,无额外内容</rule_title>
# <rule>合并后的问题必须是一个语法通顺、语义连贯的完整问句,不得添加解释、连接词或说明性文字(如“根据前面的问题”“结合上下文”等)。</rule>
#
# <rule_title>SQL 可执行性</rule_title>
# <rule>合并后的问题必须能通过一条 SQL 查询直接回答。</rule>
# <rule>若合并后的问题模糊、多义、或无法映射到具体数据库字段,则不应合并</rule>
#
# <rule_title>兜底策略</rule_title>
# <rule>若无法明确判断问题1 是否完整,或无法确定其是否依赖前序问题,请不要猜测,而是主动向用户请求澄清或补充细节。</rule>
#
# <example>
# 输入问题1="早退多少天"问题2="其中迟到多少天"问题3="张三九月工作了多少天"
# 输出:"张三九月早退多少天"
#
# 输入问题1="这些天里张三是否有迟到问题2="李四考勤""问题3="张三九月的考勤"
# 正确输出:"张三九月的迟到情况"
# 错误输出:
#
# 输入问题1="最近一个月是否有迟到"问题2="李四考勤"问题3="张三九月的考勤"
# 输出:"张三最近一个月是否有迟到"
#
# 输入问题1="迟到了多少天"问题2="他哪几天迟到了"问题3="张三九月在林芝是否早退"
# 输出:"张三九月在林芝迟到了多少天"
#
# 输入问题1="张三九月休息了多少天" 问题2="张三九月迟到了多少天"问题3="张三其中迟到多少天"
# 输出:"张三九月休息了多少天"
#
# 输入问题1="张三九月考勤情况" 问题2="张三九月迟到了多少天"问题3="李四迟到多少天"
# 输出:"张三九月考勤情况"
#
# 输入问题1="张三9月在林芝上班多少天" 问题2="9月29的考勤"问题3="9月29的考勤"
# 输出:"张三9月在林芝上班多少天"
# </example>
# '''
# sys_info3 = '''
# '''
sys_info = '''
你是一个问题补全助手任务是判断用户当前提出的问题问题1是否信息完整。
处理流程:
先判断问题1是否语义完整
如果问题1 自身含义清晰、包含必要要素如主语、时间范围、具体查询目标等指代明确不依赖任何上下文也能被准确理解则直接返回问题1原文禁止任何形式的改写或合并。
仅当问题1信息缺失时才使用上下文补全
上下文包括前两个历史问题问题2较近、问题3较远
补全时遵循就近优先原则优先使用问题2 的信息仅当问题2 无法提供所需信息且问题3 可补全时才使用问题3。
若需同时依赖问题2 和问题3 才能补全,则按 问题3 + 问题2 + 问题1 的顺序融合。
若问题1 中包含“他”“她”“它”等代词,无需无需区分性别或语义类别,(他不一定代表男,她也不一定代表女),采用就近原则从上下文中找出最近的具有人名或明确实体的主语进行替换。
主语选择必须严格遵循时间顺序仅当问题2 中无有效主语时才考虑问题3。只要问题2 包含明确人名就必须使用问题2 的主语
补全要求:
合并后的问题必须是一个语法通顺、语义完整的自然问句。
不得添加任何解释性、连接性或说明性文字(如“根据前面的问题”“结合上下文”等)。
补全后的问题必须能通过一条 SQL 查询语句直接回答(即具备明确的查询对象、条件和指标)。
兜底策略:
如果你无法确定问题1 是否完整,或无法判断是否依赖上下文,或即使合并上下文仍无法形成完整、可执行的问题,请不要猜测或强行输出,而是直接向用户请求补充细节。
'''
prompt = [
self.system_message(sys_info),
# self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
self.user_message("问题1: " + new_question + "\n上下文: " +str(context_question))
]
return self.submit_prompt(prompt=prompt, **kwargs)
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(
self,
config_file={}
):
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
super().__init__(config_file)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
def _get_error_string(response: requests.Response) -> str:
try:
if response.content:
return response.json()["detail"]
except Exception:
pass
try:
response.raise_for_status()
except requests.HTTPError as e:
return str(e)
return "Unknown error"
request_body = {
"model": self.embedding_model_name,
"encoding_format": "float",
"input": [data],
}
# request_body = {
# "model": self.embedding_model_name,
# "encoding_format": "float",
# "input": [{"type":"text","text":data}],
# }
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}/v1/embeddings",
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
# embeddings = result['embeddings']
embeddings = result['data'][0]['embedding']
# return embeddings[0]
return embeddings
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
class TTLCacheWrapper:
"为MemoryCache()添加带ttl的包装器防治内存泄漏"
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
self.cache = cache or MemoryCache()
self.ttl = ttl
self._expiry_times = {}
self._cleanup_thread = None
self._start_cleanup()
def _start_cleanup(self):
"""启动后台清理线程"""
def cleanup():
while True:
current_time = time.time()
expired_keys = []
# 找出所有过期的key
for key_info, expiry in self._expiry_times.items():
if expiry <= current_time:
expired_keys.append(key_info)
# 清理过期数据
for key_info in expired_keys:
id, field = key_info
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
time.sleep(180) # 每3分钟清理一次
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
self._cleanup_thread.start()
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
"""设置缓存值支持TTL"""
# 使用提供的TTL或默认TTL
actual_ttl = ttl if ttl is not None else self.ttl
# 调用原始cache的set方法
self.cache.set(id=id, field=field, value=value)
# 记录过期时间
key_info = (id, field)
self._expiry_times[key_info] = time.time() + actual_ttl
def get(self, id: str, field: str) -> Any:
"""获取缓存值,自动处理过期"""
key_info = (id, field)
# 检查是否过期
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期删除并返回None
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
del self._expiry_times[key_info]
return None
# 返回缓存值
return self.cache.get(id=id, field=field)
def delete(self, id: str, field: str):
"""删除缓存值"""
key_info = (id, field)
if hasattr(self.cache, 'delete'):
self.cache.delete(id=id)
if key_info in self._expiry_times:
del self._expiry_times[key_info]
def exists(self, id: str, field: str) -> bool:
"""检查缓存是否存在且未过期"""
key_info = (id, field)
if key_info in self._expiry_times:
if time.time() > self._expiry_times[key_info]:
# 已过期清理并返回False
self.delete(id=id, field=field)
return False
return self.get(id=id, field=field) is not None
def items(self):
"""遍历所有未过期的缓存键值对"""
current_time = time.time()
items = []
for (id, field), expiry in self._expiry_times.items():
# 检查是否过期
if current_time <= expiry:
value = self.get(id=id, field=field)
if value is not None:
items.append({
'id': id,
'field': field,
'value': value,
'expires_at': expiry,
'time_left': expiry - current_time
})
return items
def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None):
"""
获取指定ID下时间最近的缓存项
Args:
id: 要查询的ID
limit: 返回最近几条记录默认1条
field_filter: 可选的字段过滤,如只获取特定字段
Returns:
按时间倒序排列的缓存项列表
"""
current_time = time.time()
matched_items = []
# 找出该ID下所有未过期的缓存项
for (cache_id, field), expiry in self._expiry_times.items():
if cache_id == id and current_time <= expiry:
# 字段过滤
if field_filter and field != field_filter:
continue
value = self.get(id=id, field=field)
if value is not None:
matched_items.append({
'id': id,
'field': field,
'value': value,
'expires_at': expiry,
'created_time': expiry - self.ttl, # 估算创建时间
'time_left': expiry - current_time
})
# 按过期时间倒序排列(最近创建的排在前面)
matched_items.sort(key=lambda x: x['expires_at'], reverse=True)
return matched_items[:limit]
# 代理其他方法到原始cache
def __getattr__(self, name):
return getattr(self.cache, name)