feat:初始化
This commit is contained in:
		
							
								
								
									
										0
									
								
								service/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								service/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										217
									
								
								service/cus_vanna_srevice.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								service/cus_vanna_srevice.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,217 @@ | ||||
| from email.policy import default | ||||
| from typing import List | ||||
|  | ||||
| import orjson | ||||
| from vanna.base import VannaBase | ||||
| from vanna.qdrant import Qdrant_VectorStore | ||||
| from qdrant_client import QdrantClient | ||||
| from openai import OpenAI | ||||
| import requests | ||||
| from decouple import config | ||||
| from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer | ||||
| import json | ||||
| from template.template import get_base_template | ||||
| from datetime import datetime | ||||
|  | ||||
|  | ||||
| class OpenAICompatibleLLM(VannaBase): | ||||
|     def __init__(self, client=None, config_file=None): | ||||
|         VannaBase.__init__(self, config=config_file) | ||||
|  | ||||
|         # default parameters - can be overrided using config | ||||
|         self.temperature = 0.5 | ||||
|         self.max_tokens = 5000 | ||||
|  | ||||
|         if "temperature" in config_file: | ||||
|             self.temperature = config_file["temperature"] | ||||
|  | ||||
|         if "max_tokens" in config_file: | ||||
|             self.max_tokens = config_file["max_tokens"] | ||||
|  | ||||
|         if "api_type" in config_file: | ||||
|             raise Exception( | ||||
|                 "Passing api_type is now deprecated. Please pass an OpenAI client instead." | ||||
|             ) | ||||
|  | ||||
|         if "api_version" in config_file: | ||||
|             raise Exception( | ||||
|                 "Passing api_version is now deprecated. Please pass an OpenAI client instead." | ||||
|             ) | ||||
|  | ||||
|         if client is not None: | ||||
|             self.client = client | ||||
|             return | ||||
|  | ||||
|         if "api_base" not in config_file: | ||||
|             raise Exception("Please passing api_base") | ||||
|  | ||||
|         if "api_key" not in config_file: | ||||
|             raise Exception("Please passing api_key") | ||||
|  | ||||
|         self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"]) | ||||
|  | ||||
|     def system_message(self, message: str) -> any: | ||||
|         return {"role": "system", "content": message} | ||||
|  | ||||
|     def user_message(self, message: str) -> any: | ||||
|         return {"role": "user", "content": message} | ||||
|  | ||||
|     def assistant_message(self, message: str) -> any: | ||||
|         return {"role": "assistant", "content": message} | ||||
|  | ||||
|     def submit_prompt(self, prompt, **kwargs) -> str: | ||||
|         if prompt is None: | ||||
|             raise Exception("Prompt is None") | ||||
|  | ||||
|         if len(prompt) == 0: | ||||
|             raise Exception("Prompt is empty") | ||||
|         print(prompt) | ||||
|  | ||||
|         num_tokens = 0 | ||||
|         for message in prompt: | ||||
|             num_tokens += len(message["content"]) / 4 | ||||
|  | ||||
|         if kwargs.get("model", None) is not None: | ||||
|             model = kwargs.get("model", None) | ||||
|             print( | ||||
|                 f"Using model {model} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=model, | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         elif kwargs.get("engine", None) is not None: | ||||
|             engine = kwargs.get("engine", None) | ||||
|             print( | ||||
|                 f"Using model {engine} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 engine=engine, | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         elif self.config is not None and "engine" in self.config: | ||||
|             print( | ||||
|                 f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 engine=self.config["engine"], | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         elif self.config is not None and "model" in self.config: | ||||
|             print( | ||||
|                 f"Using model {self.config['model']} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=self.config["model"], | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         else: | ||||
|             if num_tokens > 3500: | ||||
|                 model = "kimi" | ||||
|             else: | ||||
|                 model = "doubao" | ||||
|  | ||||
|             print(f"Using model {model} for {num_tokens} tokens (approx)") | ||||
|  | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=model, | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|  | ||||
|         for choice in response.choices: | ||||
|             if "text" in choice: | ||||
|                 return choice.text | ||||
|  | ||||
|         return response.choices[0].message.content | ||||
|  | ||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|         question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|         ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|         doc_list = self.get_related_documentation(question, **kwargs) | ||||
|         template = get_base_template() | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图标类型 | ||||
|         sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|         llm_response = self.submit_prompt( | ||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|         print(llm_response) | ||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|  | ||||
|         sql = check_and_get_sql(llm_response) | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
|             print(llm_response2) | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         return result | ||||
|  | ||||
|  | ||||
| class CustomQdrant_VectorStore(Qdrant_VectorStore): | ||||
|     def __init__( | ||||
|             self, | ||||
|             config_file={} | ||||
|     ): | ||||
|         self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='') | ||||
|         self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='') | ||||
|         self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='') | ||||
|         super().__init__(config_file) | ||||
|  | ||||
|     def generate_embedding(self, data: str, **kwargs) -> List[float]: | ||||
|         def _get_error_string(response: requests.Response) -> str: | ||||
|             try: | ||||
|                 if response.content: | ||||
|                     return response.json()["detail"] | ||||
|             except Exception: | ||||
|                 pass | ||||
|             try: | ||||
|                 response.raise_for_status() | ||||
|             except requests.HTTPError as e: | ||||
|                 return str(e) | ||||
|             return "Unknown error" | ||||
|  | ||||
|         request_body = { | ||||
|             "model": self.embedding_model_name, | ||||
|             "input": data, | ||||
|         } | ||||
|         request_body.update(kwargs) | ||||
|  | ||||
|         response = requests.post( | ||||
|             url=f"{self.embedding_api_base}/v1/embeddings", | ||||
|             json=request_body, | ||||
|             headers={"Authorization": f"Bearer {self.embedding_api_key}"}, | ||||
|         ) | ||||
|         if response.status_code != 200: | ||||
|             raise RuntimeError( | ||||
|                 f"Failed to create the embeddings, detail: {_get_error_string(response)}" | ||||
|             ) | ||||
|         result = response.json() | ||||
|         embeddings = [d["embedding"] for d in result["data"]] | ||||
|         return embeddings[0] | ||||
|  | ||||
| class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): | ||||
|     def __init__(self, llm_config=None, vector_store_config=None): | ||||
|         CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) | ||||
|         OpenAICompatibleLLM.__init__(self, config_file=llm_config) | ||||
		Reference in New Issue
	
	Block a user
	 雷雨
					雷雨