217 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			217 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from email.policy import default
 | ||
| from typing import List
 | ||
| 
 | ||
| import orjson
 | ||
| from vanna.base import VannaBase
 | ||
| from vanna.qdrant import Qdrant_VectorStore
 | ||
| from qdrant_client import QdrantClient
 | ||
| from openai import OpenAI
 | ||
| import requests
 | ||
| from decouple import config
 | ||
| from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
 | ||
| import json
 | ||
| from template.template import get_base_template
 | ||
| from datetime import datetime
 | ||
| 
 | ||
| 
 | ||
| class OpenAICompatibleLLM(VannaBase):
 | ||
|     def __init__(self, client=None, config_file=None):
 | ||
|         VannaBase.__init__(self, config=config_file)
 | ||
| 
 | ||
|         # default parameters - can be overrided using config
 | ||
|         self.temperature = 0.5
 | ||
|         self.max_tokens = 5000
 | ||
| 
 | ||
|         if "temperature" in config_file:
 | ||
|             self.temperature = config_file["temperature"]
 | ||
| 
 | ||
|         if "max_tokens" in config_file:
 | ||
|             self.max_tokens = config_file["max_tokens"]
 | ||
| 
 | ||
|         if "api_type" in config_file:
 | ||
|             raise Exception(
 | ||
|                 "Passing api_type is now deprecated. Please pass an OpenAI client instead."
 | ||
|             )
 | ||
| 
 | ||
|         if "api_version" in config_file:
 | ||
|             raise Exception(
 | ||
|                 "Passing api_version is now deprecated. Please pass an OpenAI client instead."
 | ||
|             )
 | ||
| 
 | ||
|         if client is not None:
 | ||
|             self.client = client
 | ||
|             return
 | ||
| 
 | ||
|         if "api_base" not in config_file:
 | ||
|             raise Exception("Please passing api_base")
 | ||
| 
 | ||
|         if "api_key" not in config_file:
 | ||
|             raise Exception("Please passing api_key")
 | ||
| 
 | ||
|         self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
 | ||
| 
 | ||
|     def system_message(self, message: str) -> any:
 | ||
|         return {"role": "system", "content": message}
 | ||
| 
 | ||
|     def user_message(self, message: str) -> any:
 | ||
|         return {"role": "user", "content": message}
 | ||
| 
 | ||
|     def assistant_message(self, message: str) -> any:
 | ||
|         return {"role": "assistant", "content": message}
 | ||
| 
 | ||
|     def submit_prompt(self, prompt, **kwargs) -> str:
 | ||
|         if prompt is None:
 | ||
|             raise Exception("Prompt is None")
 | ||
| 
 | ||
|         if len(prompt) == 0:
 | ||
|             raise Exception("Prompt is empty")
 | ||
|         print(prompt)
 | ||
| 
 | ||
|         num_tokens = 0
 | ||
|         for message in prompt:
 | ||
|             num_tokens += len(message["content"]) / 4
 | ||
| 
 | ||
|         if kwargs.get("model", None) is not None:
 | ||
|             model = kwargs.get("model", None)
 | ||
|             print(
 | ||
|                 f"Using model {model} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 model=model,
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         elif kwargs.get("engine", None) is not None:
 | ||
|             engine = kwargs.get("engine", None)
 | ||
|             print(
 | ||
|                 f"Using model {engine} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 engine=engine,
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         elif self.config is not None and "engine" in self.config:
 | ||
|             print(
 | ||
|                 f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 engine=self.config["engine"],
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         elif self.config is not None and "model" in self.config:
 | ||
|             print(
 | ||
|                 f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
 | ||
|             )
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 model=self.config["model"],
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
|         else:
 | ||
|             if num_tokens > 3500:
 | ||
|                 model = "kimi"
 | ||
|             else:
 | ||
|                 model = "doubao"
 | ||
| 
 | ||
|             print(f"Using model {model} for {num_tokens} tokens (approx)")
 | ||
| 
 | ||
|             response = self.client.chat.completions.create(
 | ||
|                 model=model,
 | ||
|                 messages=prompt,
 | ||
|                 max_tokens=self.max_tokens,
 | ||
|                 stop=None,
 | ||
|                 temperature=self.temperature,
 | ||
|             )
 | ||
| 
 | ||
|         for choice in response.choices:
 | ||
|             if "text" in choice:
 | ||
|                 return choice.text
 | ||
| 
 | ||
|         return response.choices[0].message.content
 | ||
| 
 | ||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
 | ||
|         question_sql_list = self.get_similar_question_sql(question, **kwargs)
 | ||
|         ddl_list = self.get_related_ddl(question, **kwargs)
 | ||
|         doc_list = self.get_related_documentation(question, **kwargs)
 | ||
|         template = get_base_template()
 | ||
|         sql_temp = template['template']['sql']
 | ||
|         char_temp = template['template']['chart']
 | ||
|         # --------基于提示词,生成sql以及图标类型
 | ||
|         sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
 | ||
|                                              data_training=question_sql_list)
 | ||
|         user_temp = sql_temp['user'].format(question=question,
 | ||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
 | ||
|         llm_response = self.submit_prompt(
 | ||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
 | ||
|         print(llm_response)
 | ||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))}
 | ||
| 
 | ||
|         sql = check_and_get_sql(llm_response)
 | ||
|         # ---------------生成图表
 | ||
|         char_type = get_chart_type_from_sql_answer(llm_response)
 | ||
|         if char_type:
 | ||
|             sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
 | ||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
 | ||
|             llm_response2 = self.submit_prompt(
 | ||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
 | ||
|             print(llm_response2)
 | ||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2))
 | ||
|         return result
 | ||
| 
 | ||
| 
 | ||
| class CustomQdrant_VectorStore(Qdrant_VectorStore):
 | ||
|     def __init__(
 | ||
|             self,
 | ||
|             config_file={}
 | ||
|     ):
 | ||
|         self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
 | ||
|         self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
 | ||
|         self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
 | ||
|         super().__init__(config_file)
 | ||
| 
 | ||
|     def generate_embedding(self, data: str, **kwargs) -> List[float]:
 | ||
|         def _get_error_string(response: requests.Response) -> str:
 | ||
|             try:
 | ||
|                 if response.content:
 | ||
|                     return response.json()["detail"]
 | ||
|             except Exception:
 | ||
|                 pass
 | ||
|             try:
 | ||
|                 response.raise_for_status()
 | ||
|             except requests.HTTPError as e:
 | ||
|                 return str(e)
 | ||
|             return "Unknown error"
 | ||
| 
 | ||
|         request_body = {
 | ||
|             "model": self.embedding_model_name,
 | ||
|             "input": data,
 | ||
|         }
 | ||
|         request_body.update(kwargs)
 | ||
| 
 | ||
|         response = requests.post(
 | ||
|             url=f"{self.embedding_api_base}/v1/embeddings",
 | ||
|             json=request_body,
 | ||
|             headers={"Authorization": f"Bearer {self.embedding_api_key}"},
 | ||
|         )
 | ||
|         if response.status_code != 200:
 | ||
|             raise RuntimeError(
 | ||
|                 f"Failed to create the embeddings, detail: {_get_error_string(response)}"
 | ||
|             )
 | ||
|         result = response.json()
 | ||
|         embeddings = [d["embedding"] for d in result["data"]]
 | ||
|         return embeddings[0]
 | ||
| 
 | ||
| class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
 | ||
|     def __init__(self, llm_config=None, vector_store_config=None):
 | ||
|         CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
 | ||
|         OpenAICompatibleLLM.__init__(self, config_file=llm_config) | 
