批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
This commit is contained in:
		
							
								
								
									
										12
									
								
								.env
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								.env
									
									
									
									
									
								
							| @@ -9,14 +9,14 @@ EMBEDDING_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy | ||||
| EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-8B | ||||
|  | ||||
| #mysql ,sqlite,pg等 | ||||
| DATA_SOURCE_TYPE=sqlite | ||||
| DATA_SOURCE_TYPE=mysql | ||||
|  | ||||
| #sqlite 连接信息 | ||||
| SQLITE_DATABASE_URL=E://db/db_flights.sqlite | ||||
|  | ||||
| #mysql 连接信息 | ||||
| MYSQL_DATABASE_HOST= | ||||
| MYSQL_DATABASE_PORT= | ||||
| MYSQL_DATABASE_PASSWORD= | ||||
| MYSQL_DATABASE_USER= | ||||
| MYSQL_DATABASE_DBNAME= | ||||
| MYSQL_DATABASE_HOST=192.168.31.115 | ||||
| MYSQL_DATABASE_PORT=3306 | ||||
| MYSQL_DATABASE_PASSWORD=Admin1234! | ||||
| MYSQL_DATABASE_USER=yu | ||||
| MYSQL_DATABASE_DBNAME=test | ||||
|   | ||||
| @@ -82,6 +82,7 @@ def generate_sql_2(): | ||||
|             text: | ||||
|               type: string | ||||
|     """ | ||||
|  | ||||
|     question = flask.request.args.get("question") | ||||
|  | ||||
|     if question is None: | ||||
|   | ||||
| @@ -3,6 +3,7 @@ from typing import List | ||||
|  | ||||
| import orjson | ||||
| from vanna.base import VannaBase | ||||
| from vanna.flask import MemoryCache | ||||
| from vanna.qdrant import Qdrant_VectorStore | ||||
| from qdrant_client import QdrantClient | ||||
| from openai import OpenAI | ||||
| @@ -17,6 +18,7 @@ from datetime import datetime | ||||
| class OpenAICompatibleLLM(VannaBase): | ||||
|     def __init__(self, client=None, config_file=None): | ||||
|         VannaBase.__init__(self, config=config_file) | ||||
|         self.cache = MemoryCache() | ||||
|  | ||||
|         # default parameters - can be overrided using config | ||||
|         self.temperature = 0.5 | ||||
| @@ -168,6 +170,8 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
|             print(llm_response2) | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         result['id'] = self.cache.generate_id(question=question) | ||||
|         print("result----------------------{0}".format(result)) | ||||
|         return result | ||||
|  | ||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||
|   | ||||
| @@ -14,13 +14,13 @@ from service.cus_vanna_srevice import CustomVanna | ||||
| #                  ) | ||||
| #     """, | ||||
| # ] | ||||
| # list_documentions = [ | ||||
| #     """ | ||||
| #     gender 字段 0代表女性,1代表男性; | ||||
| #     查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
| #     语法为sqlite语法; | ||||
| #     """, | ||||
| # ] | ||||
| list_documentions = [ | ||||
|     """ | ||||
|     gender 字段 0代表女性,1代表男性; | ||||
|     查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
|     语法为sqlite语法; | ||||
|     """, | ||||
| ] | ||||
| table_ddls = [ | ||||
|     """ | ||||
|     CREATE TABLE 人员库表 ( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128