提示词优化,日志添加
This commit is contained in:
		
							
								
								
									
										62
									
								
								logging_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								logging_config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | |||||||
|  | # logging_config.py | ||||||
|  | import logging | ||||||
|  | import logging.config | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | # 确保 logs 目录存在 | ||||||
|  | log_dir = Path("logs") | ||||||
|  | log_dir.mkdir(exist_ok=True) | ||||||
|  |  | ||||||
|  | LOGGING_CONFIG = { | ||||||
|  |     "version": 1, | ||||||
|  |     "disable_existing_loggers": False, | ||||||
|  |     "formatters": { | ||||||
|  |         "default": { | ||||||
|  |             "format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", | ||||||
|  |         }, | ||||||
|  |         "detailed": { | ||||||
|  |             "format": "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     "handlers": { | ||||||
|  |         "console": { | ||||||
|  |             "class": "logging.StreamHandler", | ||||||
|  |             "level": "INFO", | ||||||
|  |             "formatter": "default", | ||||||
|  |             "stream": "ext://sys.stdout" | ||||||
|  |         }, | ||||||
|  |         "file": { | ||||||
|  |             "class": "logging.handlers.RotatingFileHandler",  # 自动轮转 | ||||||
|  |             "level": "INFO", | ||||||
|  |             "formatter": "detailed", | ||||||
|  |             "filename": "logs/sqlbot.log", | ||||||
|  |             "maxBytes": 10485760,  # 10MB | ||||||
|  |             "backupCount": 5,      # 保留5个备份 | ||||||
|  |             "encoding": "utf8" | ||||||
|  |         }, | ||||||
|  |     }, | ||||||
|  |     "root": { | ||||||
|  |         "level": "INFO", | ||||||
|  |         "handlers": ["console", "file"] | ||||||
|  |     }, | ||||||
|  |     "loggers": { | ||||||
|  |         "uvicorn": { | ||||||
|  |             "level": "INFO", | ||||||
|  |             "handlers": ["console", "file"], | ||||||
|  |             "propagate": False | ||||||
|  |         }, | ||||||
|  |         "uvicorn.error": { | ||||||
|  |             "level": "INFO", | ||||||
|  |             "handlers": ["console", "file"], | ||||||
|  |             "propagate": False | ||||||
|  |         }, | ||||||
|  |         "uvicorn.access": { | ||||||
|  |             "level": "WARNING",  # 只记录警告以上,避免刷屏 | ||||||
|  |             "handlers": ["file"],  # 只写入文件 | ||||||
|  |             "propagate": False | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # 应用配置 | ||||||
|  | logging.config.dictConfig(LOGGING_CONFIG) | ||||||
| @@ -1,5 +1,8 @@ | |||||||
| from email.policy import default | from email.policy import default | ||||||
|  | import dmPython | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  | from logging_config import LOGGING_CONFIG | ||||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||||
| from decouple import config | from decouple import config | ||||||
| import flask | import flask | ||||||
| @@ -7,6 +10,8 @@ from util import load_ddl_doc | |||||||
|  |  | ||||||
| from flask import Flask, Response, jsonify, request, send_from_directory | from flask import Flask, Response, jsonify, request, send_from_directory | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
| def connect_database(vn): | def connect_database(vn): | ||||||
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite') |     db_type = config('DATA_SOURCE_TYPE', default='sqlite') | ||||||
|     if db_type == 'sqlite': |     if db_type == 'sqlite': | ||||||
| @@ -17,6 +22,8 @@ def connect_database(vn): | |||||||
|                             user=config('MYSQL_DATABASE_USER', default=''), |                             user=config('MYSQL_DATABASE_USER', default=''), | ||||||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''), |                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | ||||||
|                             dbname=config('MYSQL_DATABASE_DBNAME', default='')) |                             dbname=config('MYSQL_DATABASE_DBNAME', default='')) | ||||||
|  |     elif db_type == 'dameng': | ||||||
|  |         vn.connect_to_dameng( ) | ||||||
|     elif db_type == 'postgresql': |     elif db_type == 'postgresql': | ||||||
|         # 待补充 |         # 待补充 | ||||||
|         pass |         pass | ||||||
| @@ -81,22 +88,78 @@ def generate_sql_2(): | |||||||
|             text: |             text: | ||||||
|               type: string |               type: string | ||||||
|     """ |     """ | ||||||
|  |     logger.info("Start to generate sql in main") | ||||||
|     question = flask.request.args.get("question") |     question = flask.request.args.get("question") | ||||||
|  |  | ||||||
|     if question is None: |     if question is None: | ||||||
|         return jsonify({"type": "error", "error": "No question provided"}) |         return jsonify({"type": "error", "error": "No question provided"}) | ||||||
|     id = cache.generate_id(question=question) |     try: | ||||||
|     data = vn.generate_sql_2(question=question) |         id = cache.generate_id(question=question) | ||||||
|     data['id'] =id |         data = vn.generate_sql_2(question=question) | ||||||
|     sql = data["resp"]["sql"] |         data['id'] = id | ||||||
|     print("sql:",sql) |         sql = data["resp"]["sql"] | ||||||
|     cache.set(id=id, field="question", value=question) |         print("sql:", sql) | ||||||
|     cache.set(id=id, field="sql", value=sql) |         cache.set(id=id, field="question", value=question) | ||||||
|     print("data---------------------------",data) |         cache.set(id=id, field="sql", value=sql) | ||||||
|  |         print("data---------------------------", data) | ||||||
|  |         return jsonify(data) | ||||||
|  |     except Exception as e: | ||||||
|  |         return jsonify({"type": "error", "error": str(e)}) | ||||||
|  |  | ||||||
|     return jsonify(data) |  | ||||||
|  |  | ||||||
|  | @app.flask_app.route("/api/v0/run_sql_2", methods=["GET"]) | ||||||
|  | @app.requires_cache(["sql"]) | ||||||
|  | def run_sql_2(id: str, sql: str): | ||||||
|  |     """ | ||||||
|  |     Run SQL | ||||||
|  |     --- | ||||||
|  |     parameters: | ||||||
|  |       - name: user | ||||||
|  |         in: query | ||||||
|  |       - name: id | ||||||
|  |         in: query|body | ||||||
|  |         type: string | ||||||
|  |         required: true | ||||||
|  |     responses: | ||||||
|  |       200: | ||||||
|  |         schema: | ||||||
|  |           type: object | ||||||
|  |           properties: | ||||||
|  |             type: | ||||||
|  |               type: string | ||||||
|  |               default: df | ||||||
|  |             id: | ||||||
|  |               type: string | ||||||
|  |             df: | ||||||
|  |               type: object | ||||||
|  |             should_generate_chart: | ||||||
|  |               type: boolean | ||||||
|  |     """ | ||||||
|  |     logger.info("Start to run sql in main") | ||||||
|  |     try: | ||||||
|  |         if not vn.run_sql_is_set: | ||||||
|  |             return jsonify( | ||||||
|  |                 { | ||||||
|  |                     "type": "error", | ||||||
|  |                     "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         df = vn.run_sql(sql=sql) | ||||||
|  |         logger.info("") | ||||||
|  |         app.cache.set(id=id, field="df", value=df) | ||||||
|  |         x = df.head(10).to_dict(orient='records') | ||||||
|  |         logger.info("df ---------------{0}   {1}".format(x,type(x))) | ||||||
|  |         return jsonify( | ||||||
|  |             { | ||||||
|  |                 "type": "df", | ||||||
|  |                 "id": id, | ||||||
|  |                 "df": df.head(10).to_dict(orient='records'), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         return jsonify({"type": "sql_error", "error": str(e)}) | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| from email.policy import default | from email.policy import default | ||||||
| from typing import List | from typing import List | ||||||
|  | import dmPython | ||||||
| import orjson | import orjson | ||||||
| import pandas as pd | import pandas as pd | ||||||
| from vanna.base import VannaBase | from vanna.base import VannaBase | ||||||
| @@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase): | |||||||
|         # default parameters - can be overrided using config |         # default parameters - can be overrided using config | ||||||
|         self.temperature = 0.5 |         self.temperature = 0.5 | ||||||
|         self.max_tokens = 5000 |         self.max_tokens = 5000 | ||||||
|  |         self.conn = None | ||||||
|         if "temperature" in config_file: |         if "temperature" in config_file: | ||||||
|             self.temperature = config_file["temperature"] |             self.temperature = config_file["temperature"] | ||||||
|  |  | ||||||
| @@ -140,36 +140,57 @@ class OpenAICompatibleLLM(VannaBase): | |||||||
|  |  | ||||||
|         return response.choices[0].message.content |         return response.choices[0].message.content | ||||||
|  |  | ||||||
|  |     # def connect_to_dameng(self, host, port, username, password, database): | ||||||
|  |     #     try: | ||||||
|  |     #         self.conn = dmPython.connect( | ||||||
|  |     #             user=username, | ||||||
|  |     #             password=password, | ||||||
|  |     #             server=host, | ||||||
|  |     #             port=port,  # 达梦默认端口5236 | ||||||
|  |     #             autoCommit=True | ||||||
|  |     #         ) | ||||||
|  |     #         print("达梦数据库连接成功") | ||||||
|  |     #         return True | ||||||
|  |     #     except Exception as e: | ||||||
|  |     #         print(f"连接失败: {e}") | ||||||
|  |     #         return False | ||||||
|  |  | ||||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: |     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||||
|         question_sql_list = self.get_similar_question_sql(question, **kwargs) |         try: | ||||||
|         ddl_list = self.get_related_ddl(question, **kwargs) |             question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||||
|         doc_list = self.get_related_documentation(question, **kwargs) |             ddl_list = self.get_related_ddl(question, **kwargs) | ||||||
|         template = get_base_template() |             doc_list = self.get_related_documentation(question, **kwargs) | ||||||
|         sql_temp = template['template']['sql'] |             template = get_base_template() | ||||||
|         char_temp = template['template']['chart'] |             sql_temp = template['template']['sql'] | ||||||
|         # --------基于提示词,生成sql以及图表类型 |             char_temp = template['template']['chart'] | ||||||
|         sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, |             # --------基于提示词,生成sql以及图表类型 | ||||||
|                                              data_training=question_sql_list) |             sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文', | ||||||
|         print("sys_temp", sys_temp) |                                                  schema=ddl_list, documentation=doc_list, | ||||||
|         user_temp = sql_temp['user'].format(question=question, |                                                  data_training=question_sql_list) | ||||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) |             print("sys_temp", sys_temp) | ||||||
|         print("user_temp", user_temp) |             user_temp = sql_temp['user'].format(question=question, | ||||||
|         llm_response = self.submit_prompt( |                                                 current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) |             print("user_temp", user_temp) | ||||||
|         print(llm_response) |             llm_response = self.submit_prompt( | ||||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} |                 [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||||
|         print("result", result) |             print(llm_response) | ||||||
|         sql = check_and_get_sql(llm_response) |             result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||||
|         # ---------------生成图表 |             print("result", result) | ||||||
|         char_type = get_chart_type_from_sql_answer(llm_response) |             sql = check_and_get_sql(llm_response) | ||||||
|         if char_type: |             # ---------------生成图表 | ||||||
|             sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) |             char_type = get_chart_type_from_sql_answer(llm_response) | ||||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) |             if char_type: | ||||||
|             llm_response2 = self.submit_prompt( |                 sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), | ||||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) |                                                            lang='中文', sql=sql, chart_type=char_type) | ||||||
|             print(llm_response2) |                 user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) |                 llm_response2 = self.submit_prompt( | ||||||
|         return result |                     [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], | ||||||
|  |                     **kwargs) | ||||||
|  |                 print(llm_response2) | ||||||
|  |                 result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||||
|  |             return result | ||||||
|  |         except Exception as e: | ||||||
|  |             raise e | ||||||
|  |  | ||||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: |     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||||
|         print("new_question---------------", new_question) |         print("new_question---------------", new_question) | ||||||
|   | |||||||
| @@ -27,6 +27,9 @@ template: | |||||||
|         <rule> |         <rule> | ||||||
|           你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL |           你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL | ||||||
|         </rule> |         </rule> | ||||||
|  |         <rule> | ||||||
|  |           如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可 | ||||||
|  |         </rule> | ||||||
|         <rule> |         <rule> | ||||||
|           不要编造<m-schema>内没有提供给你的表结构 |           不要编造<m-schema>内没有提供给你的表结构 | ||||||
|         </rule> |         </rule> | ||||||
| @@ -39,6 +42,9 @@ template: | |||||||
|         <rule> |         <rule> | ||||||
|           请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 |           请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 | ||||||
|         </rule> |         </rule> | ||||||
|  |         <rule> | ||||||
|  |           如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算 | ||||||
|  |         </rule> | ||||||
|         <rule> |         <rule> | ||||||
|           请使用JSON格式返回你的回答: |           请使用JSON格式返回你的回答: | ||||||
|           若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} |           若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} | ||||||
|   | |||||||
| @@ -103,6 +103,12 @@ list_documentions = [ | |||||||
|     """ |     """ | ||||||
|     <人员库表注意事项> |     <人员库表注意事项> | ||||||
|         <rule> |         <rule> | ||||||
|  |             查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; | ||||||
|  |             语法为mysql语法; | ||||||
|  |             如果涉及下面<info>中的字段需要展示给用户看时请替换成相关代表 | ||||||
|  |             birthday 字段涉及计算时,请转化为合理格式计算 | ||||||
|  |         </rule> | ||||||
|  |         <info> | ||||||
|             person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用; |             person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用; | ||||||
|             gender 字段 1代表男,2代表女 |             gender 字段 1代表男,2代表女 | ||||||
|             is_internal 字段 0代表否,1代表是 |             is_internal 字段 0代表否,1代表是 | ||||||
| @@ -115,9 +121,7 @@ list_documentions = [ | |||||||
|             is_subcontractor 字段 0代表否,1代表是 |             is_subcontractor 字段 0代表否,1代表是 | ||||||
|             is_sign_confidentiality_agreement 字段 0代表否,1代表是 |             is_sign_confidentiality_agreement 字段 0代表否,1代表是 | ||||||
|             DHDATASTA 字段 0代表新增 1代表更新 |             DHDATASTA 字段 0代表新增 1代表更新 | ||||||
|             查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; |         </info> | ||||||
|             语法为mysql语法; |  | ||||||
|         </rule> |  | ||||||
|     </人员库表注意事项> |     </人员库表注意事项> | ||||||
|     """, |     """, | ||||||
| ] | ] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128