批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
添加id,sql,question到上下文
This commit is contained in:
		| @@ -55,8 +55,7 @@ from vanna.flask import VannaFlaskApp | |||||||
| vn = create_vana() | vn = create_vana() | ||||||
| app = VannaFlaskApp(vn,chart=False) | app = VannaFlaskApp(vn,chart=False) | ||||||
| init_vn(vn) | init_vn(vn) | ||||||
|  | cache = app.cache | ||||||
|  |  | ||||||
| @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) | @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) | ||||||
| def generate_sql_2(): | def generate_sql_2(): | ||||||
|     """ |     """ | ||||||
| @@ -87,10 +86,14 @@ def generate_sql_2(): | |||||||
|  |  | ||||||
|     if question is None: |     if question is None: | ||||||
|         return jsonify({"type": "error", "error": "No question provided"}) |         return jsonify({"type": "error", "error": "No question provided"}) | ||||||
|  |     id = cache.generate_id(question=question) | ||||||
|     #id = self.cache.generate_id(question=question) |  | ||||||
|     data = vn.generate_sql_2(question=question) |     data = vn.generate_sql_2(question=question) | ||||||
|  |     data['id'] =id | ||||||
|  |     sql = data["resp"]["sql"] | ||||||
|  |     print("sql:",sql) | ||||||
|  |     cache.set(id=id, field="question", value=question) | ||||||
|  |     cache.set(id=id, field="sql", value=sql) | ||||||
|  |     print("data---------------------------",data) | ||||||
|  |  | ||||||
|     return jsonify(data) |     return jsonify(data) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ from email.policy import default | |||||||
| from typing import List | from typing import List | ||||||
|  |  | ||||||
| import orjson | import orjson | ||||||
|  | import pandas as pd | ||||||
| from vanna.base import VannaBase | from vanna.base import VannaBase | ||||||
| from vanna.flask import MemoryCache | from vanna.flask import MemoryCache | ||||||
| from vanna.qdrant import Qdrant_VectorStore | from vanna.qdrant import Qdrant_VectorStore | ||||||
| @@ -18,8 +19,6 @@ from datetime import datetime | |||||||
| class OpenAICompatibleLLM(VannaBase): | class OpenAICompatibleLLM(VannaBase): | ||||||
|     def __init__(self, client=None, config_file=None): |     def __init__(self, client=None, config_file=None): | ||||||
|         VannaBase.__init__(self, config=config_file) |         VannaBase.__init__(self, config=config_file) | ||||||
|         self.cache = MemoryCache() |  | ||||||
|  |  | ||||||
|         # default parameters - can be overrided using config |         # default parameters - can be overrided using config | ||||||
|         self.temperature = 0.5 |         self.temperature = 0.5 | ||||||
|         self.max_tokens = 5000 |         self.max_tokens = 5000 | ||||||
| @@ -170,8 +169,6 @@ class OpenAICompatibleLLM(VannaBase): | |||||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) |                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||||
|             print(llm_response2) |             print(llm_response2) | ||||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) |             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||||
|         result['id'] = self.cache.generate_id(question=question) |  | ||||||
|         print("result----------------------{0}".format(result)) |  | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: |     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||||
|   | |||||||
| @@ -14,13 +14,13 @@ from service.cus_vanna_srevice import CustomVanna | |||||||
| #                  ) | #                  ) | ||||||
| #     """, | #     """, | ||||||
| # ] | # ] | ||||||
| list_documentions = [ | # list_documentions = [ | ||||||
|     """ | #     """ | ||||||
|     gender 字段 0代表女性,1代表男性; | #     gender 字段 0代表女性,1代表男性; | ||||||
|     查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | #     查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||||
|     语法为sqlite语法; | #     语法为sqlite语法; | ||||||
|     """, | #     """, | ||||||
| ] | # ] | ||||||
| table_ddls = [ | table_ddls = [ | ||||||
|     """ |     """ | ||||||
|     CREATE TABLE 人员库表 ( |     CREATE TABLE 人员库表 ( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128