提示词优化,日志添加

This commit is contained in:
yujj128
2025-09-24 14:39:42 +08:00
parent 23916fc328
commit 2533126e36
5 changed files with 200 additions and 44 deletions

62
logging_config.py Normal file
View File

@@ -0,0 +1,62 @@
# logging_config.py
import logging
import logging.config
from pathlib import Path
# 确保 logs 目录存在
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "default",
"stream": "ext://sys.stdout"
},
"file": {
"class": "logging.handlers.RotatingFileHandler", # 自动轮转
"level": "INFO",
"formatter": "detailed",
"filename": "logs/sqlbot.log",
"maxBytes": 10485760, # 10MB
"backupCount": 5, # 保留5个备份
"encoding": "utf8"
},
},
"root": {
"level": "INFO",
"handlers": ["console", "file"]
},
"loggers": {
"uvicorn": {
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False
},
"uvicorn.error": {
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False
},
"uvicorn.access": {
"level": "WARNING", # 只记录警告以上,避免刷屏
"handlers": ["file"], # 只写入文件
"propagate": False
}
}
}
# 应用配置
logging.config.dictConfig(LOGGING_CONFIG)

View File

@@ -1,5 +1,8 @@
from email.policy import default
import dmPython
import logging
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from decouple import config
import flask
@@ -7,6 +10,8 @@ from util import load_ddl_doc
from flask import Flask, Response, jsonify, request, send_from_directory
logger = logging.getLogger(__name__)
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
@@ -17,6 +22,8 @@ def connect_database(vn):
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
dbname=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'dameng':
vn.connect_to_dameng( )
elif db_type == 'postgresql':
# 待补充
pass
@@ -81,11 +88,12 @@ def generate_sql_2():
text:
type: string
"""
logger.info("Start to generate sql in main")
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
try:
id = cache.generate_id(question=question)
data = vn.generate_sql_2(question=question)
data['id'] = id
@@ -94,10 +102,65 @@ def generate_sql_2():
cache.set(id=id, field="question", value=question)
cache.set(id=id, field="sql", value=sql)
print("data---------------------------", data)
return jsonify(data)
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
"""
Run SQL
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
df:
type: object
should_generate_chart:
type: boolean
"""
logger.info("Start to run sql in main")
try:
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
df = vn.run_sql(sql=sql)
logger.info("")
app.cache.set(id=id, field="df", value=df)
x = df.head(10).to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(x,type(x)))
return jsonify(
{
"type": "df",
"id": id,
"df": df.head(10).to_dict(orient='records'),
}
)
except Exception as e:
return jsonify({"type": "sql_error", "error": str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)

View File

@@ -1,6 +1,6 @@
from email.policy import default
from typing import List
import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
@@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase):
# default parameters - can be overrided using config
self.temperature = 0.5
self.max_tokens = 5000
self.conn = None
if "temperature" in config_file:
self.temperature = config_file["temperature"]
@@ -140,7 +140,23 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content
# def connect_to_dameng(self, host, port, username, password, database):
# try:
# self.conn = dmPython.connect(
# user=username,
# password=password,
# server=host,
# port=port, # 达梦默认端口5236
# autoCommit=True
# )
# print("达梦数据库连接成功")
# return True
# except Exception as e:
# print(f"连接失败: {e}")
# return False
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
@@ -148,7 +164,8 @@ class OpenAICompatibleLLM(VannaBase):
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文',
schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
@@ -163,13 +180,17 @@ class OpenAICompatibleLLM(VannaBase):
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
except Exception as e:
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
print("new_question---------------", new_question)

View File

@@ -27,6 +27,9 @@ template:
<rule>
你只能生成查询用的SQL语句不得生成增删改相关或操作数据库以及操作数据库数据的SQL
</rule>
<rule>
如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可
</rule>
<rule>
不要编造<m-schema>内没有提供给你的表结构
</rule>
@@ -39,6 +42,9 @@ template:
<rule>
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
</rule>
<rule>
如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算
</rule>
<rule>
请使用JSON格式返回你的回答:
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}

View File

@@ -103,6 +103,12 @@ list_documentions = [
"""
<人员库表注意事项>
<rule>
查询address时,尽量使用like查询如:select * from 人员库 where address like '%张三%';
语法为mysql语法;
如果涉及下面<info>中的字段需要展示给用户看时请替换成相关代表
birthday 字段涉及计算时,请转化为合理格式计算
</rule>
<info>
person_status 字段 1代表草稿2代表审批中3代表制卡中4代表已入库5代表停用;
gender 字段 1代表男2代表女
is_internal 字段 0代表否1代表是
@@ -115,9 +121,7 @@ list_documentions = [
is_subcontractor 字段 0代表否1代表是
is_sign_confidentiality_agreement 字段 0代表否1代表是
DHDATASTA 字段 0代表新增 1代表更新
查询address时,尽量使用like查询如:select * from 人员库 where address like '%张三%';
语法为mysql语法;
</rule>
</info>
</人员库表注意事项>
""",
]