提示词优化,日志添加
This commit is contained in:
		
							
								
								
									
										62
									
								
								logging_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								logging_config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| # logging_config.py | ||||
| import logging | ||||
| import logging.config | ||||
| from pathlib import Path | ||||
|  | ||||
| # 确保 logs 目录存在 | ||||
| log_dir = Path("logs") | ||||
| log_dir.mkdir(exist_ok=True) | ||||
|  | ||||
| LOGGING_CONFIG = { | ||||
|     "version": 1, | ||||
|     "disable_existing_loggers": False, | ||||
|     "formatters": { | ||||
|         "default": { | ||||
|             "format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", | ||||
|         }, | ||||
|         "detailed": { | ||||
|             "format": "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", | ||||
|         } | ||||
|     }, | ||||
|     "handlers": { | ||||
|         "console": { | ||||
|             "class": "logging.StreamHandler", | ||||
|             "level": "INFO", | ||||
|             "formatter": "default", | ||||
|             "stream": "ext://sys.stdout" | ||||
|         }, | ||||
|         "file": { | ||||
|             "class": "logging.handlers.RotatingFileHandler",  # 自动轮转 | ||||
|             "level": "INFO", | ||||
|             "formatter": "detailed", | ||||
|             "filename": "logs/sqlbot.log", | ||||
|             "maxBytes": 10485760,  # 10MB | ||||
|             "backupCount": 5,      # 保留5个备份 | ||||
|             "encoding": "utf8" | ||||
|         }, | ||||
|     }, | ||||
|     "root": { | ||||
|         "level": "INFO", | ||||
|         "handlers": ["console", "file"] | ||||
|     }, | ||||
|     "loggers": { | ||||
|         "uvicorn": { | ||||
|             "level": "INFO", | ||||
|             "handlers": ["console", "file"], | ||||
|             "propagate": False | ||||
|         }, | ||||
|         "uvicorn.error": { | ||||
|             "level": "INFO", | ||||
|             "handlers": ["console", "file"], | ||||
|             "propagate": False | ||||
|         }, | ||||
|         "uvicorn.access": { | ||||
|             "level": "WARNING",  # 只记录警告以上,避免刷屏 | ||||
|             "handlers": ["file"],  # 只写入文件 | ||||
|             "propagate": False | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| # 应用配置 | ||||
| logging.config.dictConfig(LOGGING_CONFIG) | ||||
| @@ -1,5 +1,8 @@ | ||||
| from email.policy import default | ||||
| import dmPython | ||||
| import logging | ||||
|  | ||||
| from logging_config import LOGGING_CONFIG | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||
| from decouple import config | ||||
| import flask | ||||
| @@ -7,6 +10,8 @@ from util import load_ddl_doc | ||||
|  | ||||
| from flask import Flask, Response, jsonify, request, send_from_directory | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| def connect_database(vn): | ||||
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite') | ||||
|     if db_type == 'sqlite': | ||||
| @@ -17,6 +22,8 @@ def connect_database(vn): | ||||
|                             user=config('MYSQL_DATABASE_USER', default=''), | ||||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | ||||
|                             dbname=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|     elif db_type == 'dameng': | ||||
|         vn.connect_to_dameng( ) | ||||
|     elif db_type == 'postgresql': | ||||
|         # 待补充 | ||||
|         pass | ||||
| @@ -81,22 +88,78 @@ def generate_sql_2(): | ||||
|             text: | ||||
|               type: string | ||||
|     """ | ||||
|  | ||||
|     logger.info("Start to generate sql in main") | ||||
|     question = flask.request.args.get("question") | ||||
|  | ||||
|     if question is None: | ||||
|         return jsonify({"type": "error", "error": "No question provided"}) | ||||
|     id = cache.generate_id(question=question) | ||||
|     data = vn.generate_sql_2(question=question) | ||||
|     data['id'] =id | ||||
|     sql = data["resp"]["sql"] | ||||
|     print("sql:",sql) | ||||
|     cache.set(id=id, field="question", value=question) | ||||
|     cache.set(id=id, field="sql", value=sql) | ||||
|     print("data---------------------------",data) | ||||
|     try: | ||||
|         id = cache.generate_id(question=question) | ||||
|         data = vn.generate_sql_2(question=question) | ||||
|         data['id'] = id | ||||
|         sql = data["resp"]["sql"] | ||||
|         print("sql:", sql) | ||||
|         cache.set(id=id, field="question", value=question) | ||||
|         cache.set(id=id, field="sql", value=sql) | ||||
|         print("data---------------------------", data) | ||||
|         return jsonify(data) | ||||
|     except Exception as e: | ||||
|         return jsonify({"type": "error", "error": str(e)}) | ||||
|  | ||||
|     return jsonify(data) | ||||
|  | ||||
| @app.flask_app.route("/api/v0/run_sql_2", methods=["GET"]) | ||||
| @app.requires_cache(["sql"]) | ||||
| def run_sql_2(id: str, sql: str): | ||||
|     """ | ||||
|     Run SQL | ||||
|     --- | ||||
|     parameters: | ||||
|       - name: user | ||||
|         in: query | ||||
|       - name: id | ||||
|         in: query|body | ||||
|         type: string | ||||
|         required: true | ||||
|     responses: | ||||
|       200: | ||||
|         schema: | ||||
|           type: object | ||||
|           properties: | ||||
|             type: | ||||
|               type: string | ||||
|               default: df | ||||
|             id: | ||||
|               type: string | ||||
|             df: | ||||
|               type: object | ||||
|             should_generate_chart: | ||||
|               type: boolean | ||||
|     """ | ||||
|     logger.info("Start to run sql in main") | ||||
|     try: | ||||
|         if not vn.run_sql_is_set: | ||||
|             return jsonify( | ||||
|                 { | ||||
|                     "type": "error", | ||||
|                     "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         df = vn.run_sql(sql=sql) | ||||
|         logger.info("") | ||||
|         app.cache.set(id=id, field="df", value=df) | ||||
|         x = df.head(10).to_dict(orient='records') | ||||
|         logger.info("df ---------------{0}   {1}".format(x,type(x))) | ||||
|         return jsonify( | ||||
|             { | ||||
|                 "type": "df", | ||||
|                 "id": id, | ||||
|                 "df": df.head(10).to_dict(orient='records'), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     except Exception as e: | ||||
|         return jsonify({"type": "sql_error", "error": str(e)}) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| from email.policy import default | ||||
| from typing import List | ||||
|  | ||||
| import dmPython | ||||
| import orjson | ||||
| import pandas as pd | ||||
| from vanna.base import VannaBase | ||||
| @@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         # default parameters - can be overrided using config | ||||
|         self.temperature = 0.5 | ||||
|         self.max_tokens = 5000 | ||||
|  | ||||
|         self.conn = None | ||||
|         if "temperature" in config_file: | ||||
|             self.temperature = config_file["temperature"] | ||||
|  | ||||
| @@ -140,36 +140,57 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|  | ||||
|         return response.choices[0].message.content | ||||
|  | ||||
|     # def connect_to_dameng(self, host, port, username, password, database): | ||||
|     #     try: | ||||
|     #         self.conn = dmPython.connect( | ||||
|     #             user=username, | ||||
|     #             password=password, | ||||
|     #             server=host, | ||||
|     #             port=port,  # 达梦默认端口5236 | ||||
|     #             autoCommit=True | ||||
|     #         ) | ||||
|     #         print("达梦数据库连接成功") | ||||
|     #         return True | ||||
|     #     except Exception as e: | ||||
|     #         print(f"连接失败: {e}") | ||||
|     #         return False | ||||
|  | ||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|         question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|         ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|         doc_list = self.get_related_documentation(question, **kwargs) | ||||
|         template = get_base_template() | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图表类型 | ||||
|         sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         print("sys_temp", sys_temp) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|         print("user_temp", user_temp) | ||||
|         llm_response = self.submit_prompt( | ||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|         print(llm_response) | ||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|         print("result", result) | ||||
|         sql = check_and_get_sql(llm_response) | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
|             print(llm_response2) | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         return result | ||||
|         try: | ||||
|             question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|             ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|             doc_list = self.get_related_documentation(question, **kwargs) | ||||
|             template = get_base_template() | ||||
|             sql_temp = template['template']['sql'] | ||||
|             char_temp = template['template']['chart'] | ||||
|             # --------基于提示词,生成sql以及图表类型 | ||||
|             sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文', | ||||
|                                                  schema=ddl_list, documentation=doc_list, | ||||
|                                                  data_training=question_sql_list) | ||||
|             print("sys_temp", sys_temp) | ||||
|             user_temp = sql_temp['user'].format(question=question, | ||||
|                                                 current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|             print("user_temp", user_temp) | ||||
|             llm_response = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|             print(llm_response) | ||||
|             result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|             print("result", result) | ||||
|             sql = check_and_get_sql(llm_response) | ||||
|             # ---------------生成图表 | ||||
|             char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|             if char_type: | ||||
|                 sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), | ||||
|                                                            lang='中文', sql=sql, chart_type=char_type) | ||||
|                 user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|                 llm_response2 = self.submit_prompt( | ||||
|                     [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], | ||||
|                     **kwargs) | ||||
|                 print(llm_response2) | ||||
|                 result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|             return result | ||||
|         except Exception as e: | ||||
|             raise e | ||||
|  | ||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||
|         print("new_question---------------", new_question) | ||||
|   | ||||
| @@ -27,6 +27,9 @@ template: | ||||
|         <rule> | ||||
|           你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           不要编造<m-schema>内没有提供给你的表结构 | ||||
|         </rule> | ||||
| @@ -39,6 +42,9 @@ template: | ||||
|         <rule> | ||||
|           请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           请使用JSON格式返回你的回答: | ||||
|           若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} | ||||
|   | ||||
| @@ -103,6 +103,12 @@ list_documentions = [ | ||||
|     """ | ||||
|     <人员库表注意事项> | ||||
|         <rule> | ||||
|             查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; | ||||
|             语法为mysql语法; | ||||
|             如果涉及下面<info>中的字段需要展示给用户看时请替换成相关代表 | ||||
|             birthday 字段涉及计算时,请转化为合理格式计算 | ||||
|         </rule> | ||||
|         <info> | ||||
|             person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用; | ||||
|             gender 字段 1代表男,2代表女 | ||||
|             is_internal 字段 0代表否,1代表是 | ||||
| @@ -115,9 +121,7 @@ list_documentions = [ | ||||
|             is_subcontractor 字段 0代表否,1代表是 | ||||
|             is_sign_confidentiality_agreement 字段 0代表否,1代表是 | ||||
|             DHDATASTA 字段 0代表新增 1代表更新 | ||||
|             查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; | ||||
|             语法为mysql语法; | ||||
|         </rule> | ||||
|         </info> | ||||
|     </人员库表注意事项> | ||||
|     """, | ||||
| ] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128