from email.policy import default from typing import List import orjson from vanna.base import VannaBase from vanna.qdrant import Qdrant_VectorStore from qdrant_client import QdrantClient from openai import OpenAI import requests from decouple import config from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer import json from template.template import get_base_template from datetime import datetime class OpenAICompatibleLLM(VannaBase): def __init__(self, client=None, config_file=None): VannaBase.__init__(self, config=config_file) # default parameters - can be overrided using config self.temperature = 0.5 self.max_tokens = 5000 if "temperature" in config_file: self.temperature = config_file["temperature"] if "max_tokens" in config_file: self.max_tokens = config_file["max_tokens"] if "api_type" in config_file: raise Exception( "Passing api_type is now deprecated. Please pass an OpenAI client instead." ) if "api_version" in config_file: raise Exception( "Passing api_version is now deprecated. Please pass an OpenAI client instead." ) if client is not None: self.client = client return if "api_base" not in config_file: raise Exception("Please passing api_base") if "api_key" not in config_file: raise Exception("Please passing api_key") self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"]) def system_message(self, message: str) -> any: return {"role": "system", "content": message} def user_message(self, message: str) -> any: return {"role": "user", "content": message} def assistant_message(self, message: str) -> any: return {"role": "assistant", "content": message} def submit_prompt(self, prompt, **kwargs) -> str: if prompt is None: raise Exception("Prompt is None") if len(prompt) == 0: raise Exception("Prompt is empty") print(prompt) num_tokens = 0 for message in prompt: num_tokens += len(message["content"]) / 4 if kwargs.get("model", None) is not None: model = kwargs.get("model", None) print( f"Using model {model} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif kwargs.get("engine", None) is not None: engine = kwargs.get("engine", None) print( f"Using model {engine} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=engine, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "engine" in self.config: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=self.config["engine"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "model" in self.config: print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( model=self.config["model"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) else: if num_tokens > 3500: model = "kimi" else: model = "doubao" print(f"Using model {model} for {num_tokens} tokens (approx)") response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) for choice in response.choices: if "text" in choice: return choice.text return response.choices[0].message.content def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: question_sql_list = self.get_similar_question_sql(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs) doc_list = self.get_related_documentation(question, **kwargs) template = get_base_template() sql_temp = template['template']['sql'] char_temp = template['template']['chart'] # --------基于提示词,生成sql以及图标类型 sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, data_training=question_sql_list) user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) llm_response = self.submit_prompt( [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) print(llm_response) result = {"resp": orjson.loads(extract_nested_json(llm_response))} sql = check_and_get_sql(llm_response) # ---------------生成图表 char_type = get_chart_type_from_sql_answer(llm_response) if char_type: sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) llm_response2 = self.submit_prompt( [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) print(llm_response2) result['chart'] = orjson.loads(extract_nested_json(llm_response2)) return result class CustomQdrant_VectorStore(Qdrant_VectorStore): def __init__( self, config_file={} ): self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='') self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='') self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='') super().__init__(config_file) def generate_embedding(self, data: str, **kwargs) -> List[float]: def _get_error_string(response: requests.Response) -> str: try: if response.content: return response.json()["detail"] except Exception: pass try: response.raise_for_status() except requests.HTTPError as e: return str(e) return "Unknown error" request_body = { "model": self.embedding_model_name, "input": data, } request_body.update(kwargs) response = requests.post( url=f"{self.embedding_api_base}/v1/embeddings", json=request_body, headers={"Authorization": f"Bearer {self.embedding_api_key}"}, ) if response.status_code != 200: raise RuntimeError( f"Failed to create the embeddings, detail: {_get_error_string(response)}" ) result = response.json() embeddings = [d["embedding"] for d in result["data"]] return embeddings[0] class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): def __init__(self, llm_config=None, vector_store_config=None): CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) OpenAICompatibleLLM.__init__(self, config_file=llm_config)