提示词优化,日志添加

This commit is contained in:
yujj128
2025-09-24 14:39:42 +08:00
parent 23916fc328
commit 2533126e36
5 changed files with 200 additions and 44 deletions

62
logging_config.py Normal file
View File

@@ -0,0 +1,62 @@
# logging_config.py
import logging
import logging.config
from pathlib import Path
# 确保 logs 目录存在
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "default",
"stream": "ext://sys.stdout"
},
"file": {
"class": "logging.handlers.RotatingFileHandler", # 自动轮转
"level": "INFO",
"formatter": "detailed",
"filename": "logs/sqlbot.log",
"maxBytes": 10485760, # 10MB
"backupCount": 5, # 保留5个备份
"encoding": "utf8"
},
},
"root": {
"level": "INFO",
"handlers": ["console", "file"]
},
"loggers": {
"uvicorn": {
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False
},
"uvicorn.error": {
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False
},
"uvicorn.access": {
"level": "WARNING", # 只记录警告以上,避免刷屏
"handlers": ["file"], # 只写入文件
"propagate": False
}
}
}
# 应用配置
logging.config.dictConfig(LOGGING_CONFIG)

View File

@@ -1,5 +1,8 @@
from email.policy import default from email.policy import default
import dmPython
import logging
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient from service.cus_vanna_srevice import CustomVanna, QdrantClient
from decouple import config from decouple import config
import flask import flask
@@ -7,6 +10,8 @@ from util import load_ddl_doc
from flask import Flask, Response, jsonify, request, send_from_directory from flask import Flask, Response, jsonify, request, send_from_directory
logger = logging.getLogger(__name__)
def connect_database(vn): def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite') db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite': if db_type == 'sqlite':
@@ -17,6 +22,8 @@ def connect_database(vn):
user=config('MYSQL_DATABASE_USER', default=''), user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''),
dbname=config('MYSQL_DATABASE_DBNAME', default='')) dbname=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'dameng':
vn.connect_to_dameng( )
elif db_type == 'postgresql': elif db_type == 'postgresql':
# 待补充 # 待补充
pass pass
@@ -81,22 +88,78 @@ def generate_sql_2():
text: text:
type: string type: string
""" """
logger.info("Start to generate sql in main")
question = flask.request.args.get("question") question = flask.request.args.get("question")
if question is None: if question is None:
return jsonify({"type": "error", "error": "No question provided"}) return jsonify({"type": "error", "error": "No question provided"})
id = cache.generate_id(question=question) try:
data = vn.generate_sql_2(question=question) id = cache.generate_id(question=question)
data['id'] =id data = vn.generate_sql_2(question=question)
sql = data["resp"]["sql"] data['id'] = id
print("sql:",sql) sql = data["resp"]["sql"]
cache.set(id=id, field="question", value=question) print("sql:", sql)
cache.set(id=id, field="sql", value=sql) cache.set(id=id, field="question", value=question)
print("data---------------------------",data) cache.set(id=id, field="sql", value=sql)
print("data---------------------------", data)
return jsonify(data)
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
return jsonify(data)
@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
"""
Run SQL
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
df:
type: object
should_generate_chart:
type: boolean
"""
logger.info("Start to run sql in main")
try:
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
df = vn.run_sql(sql=sql)
logger.info("")
app.cache.set(id=id, field="df", value=df)
x = df.head(10).to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(x,type(x)))
return jsonify(
{
"type": "df",
"id": id,
"df": df.head(10).to_dict(orient='records'),
}
)
except Exception as e:
return jsonify({"type": "sql_error", "error": str(e)})
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,6 +1,6 @@
from email.policy import default from email.policy import default
from typing import List from typing import List
import dmPython
import orjson import orjson
import pandas as pd import pandas as pd
from vanna.base import VannaBase from vanna.base import VannaBase
@@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase):
# default parameters - can be overrided using config # default parameters - can be overrided using config
self.temperature = 0.5 self.temperature = 0.5
self.max_tokens = 5000 self.max_tokens = 5000
self.conn = None
if "temperature" in config_file: if "temperature" in config_file:
self.temperature = config_file["temperature"] self.temperature = config_file["temperature"]
@@ -140,36 +140,57 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content return response.choices[0].message.content
# def connect_to_dameng(self, host, port, username, password, database):
# try:
# self.conn = dmPython.connect(
# user=username,
# password=password,
# server=host,
# port=port, # 达梦默认端口5236
# autoCommit=True
# )
# print("达梦数据库连接成功")
# return True
# except Exception as e:
# print(f"连接失败: {e}")
# return False
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
question_sql_list = self.get_similar_question_sql(question, **kwargs) try:
ddl_list = self.get_related_ddl(question, **kwargs) question_sql_list = self.get_similar_question_sql(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs)
template = get_base_template() doc_list = self.get_related_documentation(question, **kwargs)
sql_temp = template['template']['sql'] template = get_base_template()
char_temp = template['template']['chart'] sql_temp = template['template']['sql']
# --------基于提示词生成sql以及图表类型 char_temp = template['template']['chart']
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, # --------基于提示词生成sql以及图表类型
data_training=question_sql_list) sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文',
print("sys_temp", sys_temp) schema=ddl_list, documentation=doc_list,
user_temp = sql_temp['user'].format(question=question, data_training=question_sql_list)
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) print("sys_temp", sys_temp)
print("user_temp", user_temp) user_temp = sql_temp['user'].format(question=question,
llm_response = self.submit_prompt( current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) print("user_temp", user_temp)
print(llm_response) llm_response = self.submit_prompt(
result = {"resp": orjson.loads(extract_nested_json(llm_response))} [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print("result", result) print(llm_response)
sql = check_and_get_sql(llm_response) result = {"resp": orjson.loads(extract_nested_json(llm_response))}
# ---------------生成图表 print("result", result)
char_type = get_chart_type_from_sql_answer(llm_response) sql = check_and_get_sql(llm_response)
if char_type: # ---------------生成图表
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) char_type = get_chart_type_from_sql_answer(llm_response)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) if char_type:
llm_response2 = self.submit_prompt( sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'),
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) lang='中文', sql=sql, chart_type=char_type)
print(llm_response2) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
result['chart'] = orjson.loads(extract_nested_json(llm_response2)) llm_response2 = self.submit_prompt(
return result [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
except Exception as e:
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
print("new_question---------------", new_question) print("new_question---------------", new_question)

View File

@@ -27,6 +27,9 @@ template:
<rule> <rule>
你只能生成查询用的SQL语句不得生成增删改相关或操作数据库以及操作数据库数据的SQL 你只能生成查询用的SQL语句不得生成增删改相关或操作数据库以及操作数据库数据的SQL
</rule> </rule>
<rule>
如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可
</rule>
<rule> <rule>
不要编造<m-schema>内没有提供给你的表结构 不要编造<m-schema>内没有提供给你的表结构
</rule> </rule>
@@ -39,6 +42,9 @@ template:
<rule> <rule>
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
</rule> </rule>
<rule>
如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算
</rule>
<rule> <rule>
请使用JSON格式返回你的回答: 请使用JSON格式返回你的回答:
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}

View File

@@ -103,6 +103,12 @@ list_documentions = [
""" """
<人员库表注意事项> <人员库表注意事项>
<rule> <rule>
查询address时,尽量使用like查询如:select * from 人员库 where address like '%张三%';
语法为mysql语法;
如果涉及下面<info>中的字段需要展示给用户看时请替换成相关代表
birthday 字段涉及计算时,请转化为合理格式计算
</rule>
<info>
person_status 字段 1代表草稿2代表审批中3代表制卡中4代表已入库5代表停用; person_status 字段 1代表草稿2代表审批中3代表制卡中4代表已入库5代表停用;
gender 字段 1代表男2代表女 gender 字段 1代表男2代表女
is_internal 字段 0代表否1代表是 is_internal 字段 0代表否1代表是
@@ -115,9 +121,7 @@ list_documentions = [
is_subcontractor 字段 0代表否1代表是 is_subcontractor 字段 0代表否1代表是
is_sign_confidentiality_agreement 字段 0代表否1代表是 is_sign_confidentiality_agreement 字段 0代表否1代表是
DHDATASTA 字段 0代表新增 1代表更新 DHDATASTA 字段 0代表新增 1代表更新
查询address时,尽量使用like查询如:select * from 人员库 where address like '%张三%'; </info>
语法为mysql语法;
</rule>
</人员库表注意事项> </人员库表注意事项>
""", """,
] ]