feat:初始化

This commit is contained in:
雷雨
2025-09-23 14:49:00 +08:00
parent e70a28be7d
commit 3ace3e5348
10 changed files with 944 additions and 0 deletions

0
service/__init__.py Normal file
View File

View File

@@ -0,0 +1,217 @@
from email.policy import default
from typing import List
import orjson
from vanna.base import VannaBase
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
import json
from template.template import get_base_template
from datetime import datetime
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
# default parameters - can be overrided using config
self.temperature = 0.5
self.max_tokens = 5000
if "temperature" in config_file:
self.temperature = config_file["temperature"]
if "max_tokens" in config_file:
self.max_tokens = config_file["max_tokens"]
if "api_type" in config_file:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)
if "api_version" in config_file:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)
if client is not None:
self.client = client
return
if "api_base" not in config_file:
raise Exception("Please passing api_base")
if "api_key" not in config_file:
raise Exception("Please passing api_key")
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
else:
if num_tokens > 3500:
model = "kimi"
else:
model = "doubao"
print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
for choice in response.choices:
if "text" in choice:
return choice.text
return response.choices[0].message.content
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图标类型
sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print(llm_response)
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
sql = check_and_get_sql(llm_response)
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(
self,
config_file={}
):
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
super().__init__(config_file)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
def _get_error_string(response: requests.Response) -> str:
try:
if response.content:
return response.json()["detail"]
except Exception:
pass
try:
response.raise_for_status()
except requests.HTTPError as e:
return str(e)
return "Unknown error"
request_body = {
"model": self.embedding_model_name,
"input": data,
}
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}/v1/embeddings",
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}"},
)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
embeddings = [d["embedding"] for d in result["data"]]
return embeddings[0]
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
OpenAICompatibleLLM.__init__(self, config_file=llm_config)