664 lines
31 KiB
Python
664 lines
31 KiB
Python
|
||
from typing import List, Union, Any, Optional
|
||
import time
|
||
import threading
|
||
from vanna.flask import Cache, MemoryCache
|
||
import dmPython
|
||
import orjson
|
||
import pandas as pd
|
||
from vanna.base import VannaBase
|
||
from vanna.flask import MemoryCache
|
||
from vanna.qdrant import Qdrant_VectorStore
|
||
from qdrant_client import QdrantClient
|
||
from openai import OpenAI
|
||
import requests
|
||
from decouple import config
|
||
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
|
||
import json
|
||
from template.template import get_base_template
|
||
from datetime import datetime
|
||
import logging
|
||
from util import train_ddl
|
||
logger = logging.getLogger(__name__)
|
||
import traceback
|
||
from service.conversation_service import get_latest_question,update_conversation
|
||
|
||
limit_count = 3
|
||
|
||
class OpenAICompatibleLLM(VannaBase):
|
||
def __init__(self, client=None, config_file=None):
|
||
VannaBase.__init__(self, config=config_file)
|
||
# default parameters - can be overrided using config
|
||
self.temperature = 0.6
|
||
self.max_tokens = 10000
|
||
|
||
if "temperature" in config_file:
|
||
self.temperature = config_file["temperature"]
|
||
|
||
if "max_tokens" in config_file:
|
||
self.max_tokens = config_file["max_tokens"]
|
||
|
||
if "api_type" in config_file:
|
||
raise Exception(
|
||
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
||
)
|
||
|
||
if "api_version" in config_file:
|
||
raise Exception(
|
||
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
||
)
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
return
|
||
|
||
if "api_base" not in config_file:
|
||
raise Exception("Please passing api_base")
|
||
|
||
if "api_key" not in config_file:
|
||
raise Exception("Please passing api_key")
|
||
|
||
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
|
||
|
||
def system_message(self, message: str) -> any:
|
||
return {"role": "system", "content": message}
|
||
|
||
def connect_to_dameng(
|
||
self,
|
||
host: str = None,
|
||
dbname: str = None,
|
||
user: str = None,
|
||
password: str = None,
|
||
port: int = None,
|
||
**kwargs
|
||
):
|
||
self.conn = None
|
||
try:
|
||
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
||
except Exception as e:
|
||
raise Exception(f"Failed to connect to dameng database: {e}")
|
||
|
||
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
|
||
logger.info(f"start to run_sql_damengsql")
|
||
try:
|
||
if not is_connection_alive(conn=self.conn):
|
||
logger.info("connection is not alive, reconnecting..........")
|
||
reconnect()
|
||
# conn.ping(reconnect=True)
|
||
cs = self.conn.cursor()
|
||
cs.execute(sql)
|
||
results = cs.fetchall()
|
||
|
||
# Create a pandas dataframe from the results
|
||
df = pd.DataFrame(
|
||
results, columns=[desc[0] for desc in cs.description]
|
||
)
|
||
|
||
return df
|
||
except Exception as e:
|
||
self.conn.rollback()
|
||
logger.error(f"Failed to execute sql query: {e}")
|
||
raise e
|
||
return None
|
||
|
||
def reconnect():
|
||
try:
|
||
self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
|
||
except Exception as e:
|
||
raise Exception(f"reconnect failed: {e}")
|
||
def is_connection_alive(conn) -> bool:
|
||
if conn is None:
|
||
return False
|
||
try:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT 1 FROM DUAL")
|
||
cursor.close()
|
||
return True
|
||
except Exception as e:
|
||
return False
|
||
|
||
self.run_sql_is_set = True
|
||
self.run_sql = run_sql_damengsql
|
||
|
||
|
||
def user_message(self, message: str) -> any:
|
||
return {"role": "user", "content": message}
|
||
|
||
def assistant_message(self, message: str) -> any:
|
||
return {"role": "assistant", "content": message}
|
||
|
||
def submit_prompt(self, prompt, **kwargs) -> str:
|
||
logger.info(f"submit prompt: {prompt}")
|
||
if prompt is None:
|
||
logger.info("test1")
|
||
raise Exception("Prompt is None")
|
||
|
||
if len(prompt) == 0:
|
||
logger.info("test2")
|
||
raise Exception("Prompt is empty")
|
||
print(prompt)
|
||
|
||
num_tokens = 0
|
||
for message in prompt:
|
||
num_tokens += len(message["content"]) / 4
|
||
logger.info("test3 {0}".format(num_tokens))
|
||
|
||
if kwargs.get("model", None) is not None:
|
||
logger.info("test4")
|
||
model = kwargs.get("model", None)
|
||
logger.info(
|
||
f"Using model {model} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
elif kwargs.get("engine", None) is not None:
|
||
logger.info("test5")
|
||
engine = kwargs.get("engine", None)
|
||
logger.info(
|
||
f"Using model {engine} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
engine=engine,
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
elif self.config is not None and "engine" in self.config:
|
||
logger.info("test6")
|
||
logger.info(
|
||
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
engine=self.config["engine"],
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
elif self.config is not None and "model" in self.config:
|
||
logger.info("test7")
|
||
logger.info(
|
||
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
||
)
|
||
logger.info("config is ",self.config)
|
||
response = self.client.chat.completions.create(
|
||
model=self.config["model"],
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
# data = {
|
||
# "model": self.model,
|
||
# "prompt": prompt,
|
||
# "max_tokens": self.max_tokens,
|
||
# "temperature": self.temperature,
|
||
# }
|
||
# response = requests.post(
|
||
# url=f"{self.api_base}/completions",
|
||
# headers=self.headers,
|
||
# json=data
|
||
# )
|
||
else:
|
||
logger.info("test8")
|
||
if num_tokens > 3500:
|
||
model = "kimi"
|
||
else:
|
||
model = "doubao"
|
||
|
||
logger.info(f"5.Using model {model} for {num_tokens} tokens (approx)")
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=prompt,
|
||
max_tokens=self.max_tokens,
|
||
stop=None,
|
||
temperature=self.temperature,
|
||
)
|
||
for choice in response.choices:
|
||
if "text" in choice:
|
||
return choice.text
|
||
|
||
return response.choices[0].message.content
|
||
|
||
def generate_sql_2(self, user_id: str, cvs_id: str, question: str,id: str, need_context: bool, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||
try:
|
||
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
|
||
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
||
if question_sql_list and len(question_sql_list)>2:
|
||
question_sql_list=question_sql_list[:2]
|
||
ddl_list = self.get_related_ddl(question, **kwargs)
|
||
#doc_list = self.get_related_documentation(question, **kwargs)
|
||
template = get_base_template()
|
||
sql_temp = template['template']['sql']
|
||
char_temp = template['template']['chart']
|
||
history = None
|
||
if need_context:
|
||
questions = get_latest_question(cvs_id, user_id,limit_count)
|
||
logger.info(f"latest_questions is {questions}")
|
||
if questions[0] != question:
|
||
raise Exception(f"上下文不匹配 {question} {questions[0]}")
|
||
new_question = self.generate_rewritten_question(questions,**kwargs)
|
||
logger.info(f"new_question is {new_question}")
|
||
question = new_question if new_question else question
|
||
update_conversation(id, meta=question)
|
||
|
||
# if user_id and cache:
|
||
# history = cache.get(id=user_id, field="data")
|
||
# --------基于提示词,生成sql以及图表类型
|
||
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
|
||
schema=ddl_list, documentation=[train_ddl.train_document],
|
||
history=history,retrieved_examples_data=question_sql_list,
|
||
data_training=question_sql_list,)
|
||
|
||
user_temp = sql_temp['user'].format(question=question,
|
||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||
logger.info(f"user_temp:{user_temp}")
|
||
logger.info(f"sys_temp:{sys_temp}")
|
||
llm_response = self.submit_prompt(
|
||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||
llm_response = str(llm_response.strip())
|
||
logger.info(f"llm_response:{llm_response}")
|
||
#优化中.......
|
||
result = extract_nested_json(llm_response)
|
||
logger.info(f"result:{result}")
|
||
result = {"resp": orjson.loads(result)}
|
||
logger.info(f"llm_response:{llm_response}")
|
||
sql = check_and_get_sql(llm_response)
|
||
logger.info(f"sql:{sql}")
|
||
# ---------------生成图表
|
||
char_type = get_chart_type_from_sql_answer(llm_response)
|
||
logger.info(f"chart type:{char_type}")
|
||
if char_type:
|
||
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
||
lang='中文', sql=sql, chart_type=char_type)
|
||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||
llm_response2 = self.submit_prompt(
|
||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
|
||
**kwargs)
|
||
print(llm_response2)
|
||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||
logger.info(f"chart_response:{result}")
|
||
|
||
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
|
||
return result
|
||
except Exception as e:
|
||
|
||
logger.info("cus_vanna_srevice failed-------------------: ")
|
||
traceback.print_exc()
|
||
raise e
|
||
|
||
def run_sql_2(self,sql):
|
||
try:
|
||
return self.run_sql(sql)
|
||
except Exception as e:
|
||
logger.error("run_sql failed {0}".format(sql))
|
||
raise e
|
||
|
||
def generate_rewritten_question(self, questions: str, **kwargs) -> str:
|
||
logger.info(f"generate_rewritten_question---------------{questions}")
|
||
new_question = questions[0]
|
||
context_question = questions[1:]
|
||
|
||
if not context_question:
|
||
return new_question
|
||
print("last question {0}".format(context_question))
|
||
print("new question {0}".format(new_question))
|
||
# sys_info = '''
|
||
# 你是一个问题补全助手,先判断问题1是否存在信息不完整的情况,如果不完整则根据上下文(问题2,问题3)来补全问题1
|
||
# (按时间顺序从新到旧:问题1、问题2、问题3),问题1是用户当前提出的问题
|
||
#
|
||
# 【准则一】独立性优先
|
||
# 如果问题1本身含义完整,不依赖其他问题的上下文也能被理解,则直接返回问题1,禁止强行合并。
|
||
# 【准则二】最新问题优先
|
||
# 问题1始终作为核心,只判断它是否需要利用前序问题补充自身信息,当它含义完整时,不再考虑合并。
|
||
# 合并时只能用较旧的问题(问题2、问题3)的信息来补全较新的问题(问题1),不能反向操作。
|
||
# 要以问题1的中心思想为准,禁止合并后该表问题1的中心思想
|
||
# 【准则三】单向合并限制
|
||
# 只有问题1能与其他问题合并,问题2和问题3之间不能单独合并。
|
||
# 【准则四】顺序依赖判断
|
||
# 只有当问题1明确依赖问题2或问题3的结果或上下文时,才考虑合并。依赖特征包括:
|
||
# - 问题1中包含"其中"、"这些"、"这个"、"那些"、"他"、"他们"等指代性词语
|
||
# - 问题1是在问题2或问题3基础上的细节追问
|
||
# - 问题1具有明确的顺序逻辑关系
|
||
# - 问题1缺少必要的主语或时间范围或具体查询信息等上下文信息
|
||
# 【准则五】合并范围选择
|
||
# - 如果问题1只依赖问题2,则合并问题2+问题1
|
||
# - 如果问题1只依赖问题3,则合并问题3+问题1
|
||
# - 如果问题1依赖问题2,可得出结果,依赖问题3也可得出结果,则就近原则,合并问题2+问题1
|
||
# - 如果问题1同时依赖问题2和问题3才能得出结果,则合并问题3+问题2+问题1
|
||
# - 如果问题1不依赖任何其他问题,则直接返回问题1
|
||
#
|
||
# 【准则六】合并执行原则
|
||
# 将选择的问题自然衔接成一个完整问题,不要添加任何解释性文字。
|
||
#
|
||
# 【准则七】SQL可行性验证
|
||
# 合并后的问题应该能够通过一条SQL查询语句来回答。
|
||
#
|
||
# 【准则八】兜底措施
|
||
# 当你无法判断问题一是否完整,也无法判断问题一是否依赖其他问题才能补全信息时,请直接向用户询问细节
|
||
#
|
||
#
|
||
# 示例:
|
||
# 输入:问题1="早退多少天",问题2="其中迟到多少天",问题3="张三九月工作了多少天"
|
||
# 输出:"张三九月早退多少天"
|
||
#
|
||
# 输入:问题1="这些天里张三是否有迟到,问题2="李四考勤","问题3="张三九月的考勤"
|
||
# 输出:"张三九月的迟到情况"
|
||
#
|
||
# 输入:问题1="最近一个月是否有迟到",问题2="李四考勤",问题3="张三九月的考勤"
|
||
# 输出:"张三最近一个月是否有迟到"
|
||
#
|
||
# 输入:问题1="迟到了多少天",问题2="他哪几天迟到了",问题3="张三九月在林芝是否早退"
|
||
# 输出:"张三九月在林芝迟到了多少天"
|
||
#
|
||
# 输入:问题1="张三九月休息了多少天" ,问题2="张三九月迟到了多少天",问题3="张三其中迟到多少天"
|
||
# 输出:"张三九月休息了多少天"
|
||
#
|
||
# 输入:问题1="张三九月考勤情况" ,问题2="张三九月迟到了多少天",问题3="李四迟到多少天"
|
||
# 输出:"张三九月考勤情况",
|
||
# '''
|
||
# sys_info2 = '''
|
||
# 你是一个问题补全助手,任务是判断用户当前提出的问题(问题1)是否信息完整。
|
||
#
|
||
# 若问题1语义完整且可独立理解,则直接返回原问题1;
|
||
#
|
||
# 若问题1信息缺失或存在指代依赖,则根据其前序上下文(问题2 和 问题3,按时间倒序排列:问题1 最新,问题2 次新,问题3 最早)进行最小必要补全,生成一个语义完整、忠实于原意、且能通过一条 SQL 查询回答的问题。
|
||
#
|
||
# 请严格遵循以下准则:
|
||
# <>
|
||
# <rule_title>独立性优先</rule_title>
|
||
# <rule>若问题1 本身语义完整(即不依赖问题2 或问题3 也能被准确理解),则直接返回问题1,禁止强行合并或改写。,也不再受下面的规则约束</rule>
|
||
#
|
||
# <rule_title>以最新问题为核心</rule_title>
|
||
#
|
||
# <rule>问题1 始终是查询意图的唯一来源</rule>
|
||
# <rule>补全时只能借用问题2 或问题3 中的信息(如主语、时间、地点等),补全问题1的缺失,不得改变问题1 的核心要素。</rule>
|
||
# <rule>合并后的问题必须完全保留问题1 的原始意图、时间范围、查询对象和动作。</rule>
|
||
# <rule>如果问题1已有明确的时间、地点、人物等信息,禁止用前序问题的不同信息进行覆盖替换。</rule>
|
||
#
|
||
# <rule_title>单向合并限制</rule_title>
|
||
# <rule>仅允许将问题1 与问题2 或问题3 合并。</rule>
|
||
# <rule>禁止问题2 与问题3 直接合并,也禁止忽略问题1 进行其他组合。</rule>
|
||
#
|
||
# <rule_title>依赖判断标准</rule_title>
|
||
# <rule>仅当问题1 明确依赖前序问题的上下文时,才触发合并。</rule>
|
||
# <rule>包含指代词:如“这个”“这些”“其中”“他”“他们”等;</rule>
|
||
# <rule>是对前一个问题的细节追问(如追问数量、时间、条件等);</rule>
|
||
# <rule>存在顺序逻辑(如“然后呢?”“接下来怎么样?”);<rule>
|
||
# <rule>缺失关键要素:如主语、时间范围、地点、对象等,需从前序问题中补全。</rule>
|
||
#
|
||
# <rule_title>合并范围选择规则</rule_title>
|
||
# 根据依赖关系,按以下优先级确定合并方式:
|
||
# <rule>仅依赖问题2 → 合并为:问题1 + 问题2</rule>
|
||
# <rule>仅依赖问题3 → 合并为:问题1 + 问题3</rule>
|
||
# <rule>问题2 和问题3 均可独立支撑问题1 → 采用就近原则,合并问题1 + 问题2</rule>
|
||
# <rule>必须同时依赖问题2 和问题3 才能完整理解问题1 → 合并为:问题1 + 问题2 + 问题3</rule>
|
||
# <rule>不依赖任何前序问题 → 直接返回问题1</rule>
|
||
# <rule_title>自然衔接,无额外内容</rule_title>
|
||
# <rule>合并后的问题必须是一个语法通顺、语义连贯的完整问句,不得添加解释、连接词或说明性文字(如“根据前面的问题”“结合上下文”等)。</rule>
|
||
#
|
||
# <rule_title>SQL 可执行性</rule_title>
|
||
# <rule>合并后的问题必须能通过一条 SQL 查询直接回答。</rule>
|
||
# <rule>若合并后的问题模糊、多义、或无法映射到具体数据库字段,则不应合并</rule>
|
||
#
|
||
# <rule_title>兜底策略</rule_title>
|
||
# <rule>若无法明确判断问题1 是否完整,或无法确定其是否依赖前序问题,请不要猜测,而是主动向用户请求澄清或补充细节。</rule>
|
||
#
|
||
# <example>
|
||
# 输入:问题1="早退多少天",问题2="其中迟到多少天",问题3="张三九月工作了多少天"
|
||
# 输出:"张三九月早退多少天"
|
||
#
|
||
# 输入:问题1="这些天里张三是否有迟到,问题2="李四考勤","问题3="张三九月的考勤"
|
||
# 正确输出:"张三九月的迟到情况"
|
||
# 错误输出:
|
||
#
|
||
# 输入:问题1="最近一个月是否有迟到",问题2="李四考勤",问题3="张三九月的考勤"
|
||
# 输出:"张三最近一个月是否有迟到"
|
||
#
|
||
# 输入:问题1="迟到了多少天",问题2="他哪几天迟到了",问题3="张三九月在林芝是否早退"
|
||
# 输出:"张三九月在林芝迟到了多少天"
|
||
#
|
||
# 输入:问题1="张三九月休息了多少天" ,问题2="张三九月迟到了多少天",问题3="张三其中迟到多少天"
|
||
# 输出:"张三九月休息了多少天"
|
||
#
|
||
# 输入:问题1="张三九月考勤情况" ,问题2="张三九月迟到了多少天",问题3="李四迟到多少天"
|
||
# 输出:"张三九月考勤情况",
|
||
#
|
||
# 输入:问题1="张三9月在林芝上班多少天" ,问题2="9月29的考勤",问题3="9月29的考勤"
|
||
# 输出:"张三9月在林芝上班多少天",
|
||
# </example>
|
||
# '''
|
||
# sys_info3 = '''
|
||
|
||
# '''
|
||
sys_info = '''
|
||
你是一个问题补全助手,任务是判断用户当前提出的问题(问题1)是否信息完整。
|
||
|
||
处理流程:
|
||
先判断问题1是否语义完整:
|
||
如果问题1 自身含义清晰、包含必要要素(如主语、时间范围、具体查询目标等),指代明确,不依赖任何上下文也能被准确理解,则直接返回问题1原文,禁止任何形式的改写或合并。
|
||
仅当问题1信息缺失时,才使用上下文补全:
|
||
上下文包括前两个历史问题:问题2(较近)、问题3(较远)。
|
||
补全时遵循就近优先原则:优先使用问题2 的信息;仅当问题2 无法提供所需信息,且问题3 可补全时,才使用问题3。
|
||
若需同时依赖问题2 和问题3 才能补全,则按 问题3 + 问题2 + 问题1 的顺序融合。
|
||
若问题1 中包含“他”“她”“它”等代词,无需无需区分性别或语义类别,(他不一定代表男,她也不一定代表女),采用就近原则从上下文中找出最近的具有人名或明确实体的主语进行替换。
|
||
主语选择必须严格遵循时间顺序:仅当问题2 中无有效主语时,才考虑问题3。只要问题2 包含明确人名,就必须使用问题2 的主语
|
||
补全要求:
|
||
合并后的问题必须是一个语法通顺、语义完整的自然问句。
|
||
不得添加任何解释性、连接性或说明性文字(如“根据前面的问题”“结合上下文”等)。
|
||
补全后的问题必须能通过一条 SQL 查询语句直接回答(即具备明确的查询对象、条件和指标)。
|
||
兜底策略:
|
||
如果你无法确定问题1 是否完整,或无法判断是否依赖上下文,或即使合并上下文仍无法形成完整、可执行的问题,请不要猜测或强行输出,而是直接向用户请求补充细节。
|
||
'''
|
||
prompt = [
|
||
self.system_message(sys_info),
|
||
# self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
|
||
self.user_message("问题1: " + new_question + "\n上下文: " +str(context_question))
|
||
]
|
||
|
||
return self.submit_prompt(prompt=prompt, **kwargs)
|
||
|
||
|
||
class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||
def __init__(
|
||
self,
|
||
config_file={}
|
||
):
|
||
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
|
||
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
|
||
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
|
||
super().__init__(config_file)
|
||
|
||
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
||
def _get_error_string(response: requests.Response) -> str:
|
||
try:
|
||
if response.content:
|
||
return response.json()["detail"]
|
||
except Exception:
|
||
pass
|
||
try:
|
||
response.raise_for_status()
|
||
except requests.HTTPError as e:
|
||
return str(e)
|
||
return "Unknown error"
|
||
|
||
request_body = {
|
||
"model": self.embedding_model_name,
|
||
"encoding_format": "float",
|
||
"input": [data],
|
||
}
|
||
# request_body = {
|
||
# "model": self.embedding_model_name,
|
||
# "encoding_format": "float",
|
||
# "input": [{"type":"text","text":data}],
|
||
# }
|
||
request_body.update(kwargs)
|
||
|
||
response = requests.post(
|
||
url=f"{self.embedding_api_base}/v1/embeddings",
|
||
json=request_body,
|
||
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
|
||
)
|
||
if response.status_code != 200:
|
||
raise RuntimeError(
|
||
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
|
||
)
|
||
result = response.json()
|
||
# embeddings = result['embeddings']
|
||
embeddings = result['data'][0]['embedding']
|
||
# return embeddings[0]
|
||
return embeddings
|
||
|
||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
|
||
def __init__(self, llm_config=None, vector_store_config=None):
|
||
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
|
||
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
|
||
|
||
class TTLCacheWrapper:
|
||
"为MemoryCache()添加带ttl的包装器,防治内存泄漏"
|
||
def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
|
||
self.cache = cache or MemoryCache()
|
||
self.ttl = ttl
|
||
self._expiry_times = {}
|
||
self._cleanup_thread = None
|
||
self._start_cleanup()
|
||
|
||
def _start_cleanup(self):
|
||
"""启动后台清理线程"""
|
||
|
||
def cleanup():
|
||
while True:
|
||
current_time = time.time()
|
||
expired_keys = []
|
||
|
||
# 找出所有过期的key
|
||
for key_info, expiry in self._expiry_times.items():
|
||
if expiry <= current_time:
|
||
expired_keys.append(key_info)
|
||
|
||
# 清理过期数据
|
||
for key_info in expired_keys:
|
||
id, field = key_info
|
||
if hasattr(self.cache, 'delete'):
|
||
self.cache.delete(id=id)
|
||
del self._expiry_times[key_info]
|
||
|
||
time.sleep(180) # 每3分钟清理一次
|
||
|
||
self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
|
||
self._cleanup_thread.start()
|
||
|
||
def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
|
||
"""设置缓存值,支持TTL"""
|
||
# 使用提供的TTL或默认TTL
|
||
actual_ttl = ttl if ttl is not None else self.ttl
|
||
|
||
# 调用原始cache的set方法
|
||
self.cache.set(id=id, field=field, value=value)
|
||
|
||
# 记录过期时间
|
||
key_info = (id, field)
|
||
self._expiry_times[key_info] = time.time() + actual_ttl
|
||
|
||
def get(self, id: str, field: str) -> Any:
|
||
"""获取缓存值,自动处理过期"""
|
||
key_info = (id, field)
|
||
|
||
# 检查是否过期
|
||
if key_info in self._expiry_times:
|
||
if time.time() > self._expiry_times[key_info]:
|
||
# 已过期,删除并返回None
|
||
if hasattr(self.cache, 'delete'):
|
||
self.cache.delete(id=id)
|
||
del self._expiry_times[key_info]
|
||
return None
|
||
|
||
# 返回缓存值
|
||
return self.cache.get(id=id, field=field)
|
||
|
||
def delete(self, id: str, field: str):
|
||
"""删除缓存值"""
|
||
key_info = (id, field)
|
||
if hasattr(self.cache, 'delete'):
|
||
self.cache.delete(id=id)
|
||
if key_info in self._expiry_times:
|
||
del self._expiry_times[key_info]
|
||
|
||
def exists(self, id: str, field: str) -> bool:
|
||
"""检查缓存是否存在且未过期"""
|
||
key_info = (id, field)
|
||
if key_info in self._expiry_times:
|
||
if time.time() > self._expiry_times[key_info]:
|
||
# 已过期,清理并返回False
|
||
self.delete(id=id, field=field)
|
||
return False
|
||
return self.get(id=id, field=field) is not None
|
||
|
||
def items(self):
|
||
"""遍历所有未过期的缓存键值对"""
|
||
current_time = time.time()
|
||
items = []
|
||
|
||
for (id, field), expiry in self._expiry_times.items():
|
||
# 检查是否过期
|
||
if current_time <= expiry:
|
||
value = self.get(id=id, field=field)
|
||
if value is not None:
|
||
items.append({
|
||
'id': id,
|
||
'field': field,
|
||
'value': value,
|
||
'expires_at': expiry,
|
||
'time_left': expiry - current_time
|
||
})
|
||
|
||
return items
|
||
|
||
def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None):
|
||
"""
|
||
获取指定ID下时间最近的缓存项
|
||
|
||
Args:
|
||
id: 要查询的ID
|
||
limit: 返回最近几条记录,默认1条
|
||
field_filter: 可选的字段过滤,如只获取特定字段
|
||
|
||
Returns:
|
||
按时间倒序排列的缓存项列表
|
||
"""
|
||
current_time = time.time()
|
||
matched_items = []
|
||
|
||
# 找出该ID下所有未过期的缓存项
|
||
for (cache_id, field), expiry in self._expiry_times.items():
|
||
if cache_id == id and current_time <= expiry:
|
||
# 字段过滤
|
||
if field_filter and field != field_filter:
|
||
continue
|
||
|
||
value = self.get(id=id, field=field)
|
||
if value is not None:
|
||
matched_items.append({
|
||
'id': id,
|
||
'field': field,
|
||
'value': value,
|
||
'expires_at': expiry,
|
||
'created_time': expiry - self.ttl, # 估算创建时间
|
||
'time_left': expiry - current_time
|
||
})
|
||
|
||
# 按过期时间倒序排列(最近创建的排在前面)
|
||
matched_items.sort(key=lambda x: x['expires_at'], reverse=True)
|
||
|
||
return matched_items[:limit]
|
||
|
||
# 代理其他方法到原始cache
|
||
def __getattr__(self, name):
|
||
return getattr(self.cache, name) |