Files
sqlbot_agent/service/cus_vanna_srevice.py
2025-09-24 14:39:42 +08:00

254 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from email.policy import default
from typing import List
import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
from vanna.flask import MemoryCache
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
import json
from template.template import get_base_template
from datetime import datetime
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
# default parameters - can be overrided using config
self.temperature = 0.5
self.max_tokens = 5000
self.conn = None
if "temperature" in config_file:
self.temperature = config_file["temperature"]
if "max_tokens" in config_file:
self.max_tokens = config_file["max_tokens"]
if "api_type" in config_file:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)
if "api_version" in config_file:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)
if client is not None:
self.client = client
return
if "api_base" not in config_file:
raise Exception("Please passing api_base")
if "api_key" not in config_file:
raise Exception("Please passing api_key")
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
else:
if num_tokens > 3500:
model = "kimi"
else:
model = "doubao"
print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
for choice in response.choices:
if "text" in choice:
return choice.text
return response.choices[0].message.content
# def connect_to_dameng(self, host, port, username, password, database):
# try:
# self.conn = dmPython.connect(
# user=username,
# password=password,
# server=host,
# port=port, # 达梦默认端口5236
# autoCommit=True
# )
# print("达梦数据库连接成功")
# return True
# except Exception as e:
# print(f"连接失败: {e}")
# return False
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文',
schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("user_temp", user_temp)
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print(llm_response)
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
print("result", result)
sql = check_and_get_sql(llm_response)
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
except Exception as e:
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
print("new_question---------------", new_question)
if last_question is None:
return new_question
prompt = [
self.system_message(
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
self.user_message(new_question),
]
return self.submit_prompt(prompt=prompt, **kwargs)
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(
self,
config_file={}
):
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
super().__init__(config_file)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
def _get_error_string(response: requests.Response) -> str:
try:
if response.content:
return response.json()["detail"]
except Exception:
pass
try:
response.raise_for_status()
except requests.HTTPError as e:
return str(e)
return "Unknown error"
request_body = {
"model": self.embedding_model_name,
"input": data,
}
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}/v1/embeddings",
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}"},
)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
embeddings = [d["embedding"] for d in result["data"]]
return embeddings[0]
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)