批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question

This commit is contained in:
yujj128
2025-09-23 16:29:56 +08:00
parent 3ace3e5348
commit a1e92edf0a
4 changed files with 165 additions and 29 deletions

View File

@@ -3,6 +3,7 @@ from email.policy import default
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request, send_from_directory
@@ -12,10 +13,10 @@ def connect_database(vn):
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
port=config('MYSQL_DATABASE_PORT', default=3306),
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
database=config('MYSQL_DATABASE_DBNAME', default=''))
dbname=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'postgresql':
# 待补充
pass
@@ -24,26 +25,7 @@ def connect_database(vn):
def load_train_data_ddl(vn: CustomVanna):
vn.train(ddl="""
create table db_user
(
id integer not null
constraint db_user_pk
primary key autoincrement,
user_name TEXT not null,
age integer not null,
address TEXT,
gender integer not null,
email TEXT
)
""")
vn.train(documentation='''
gender 字段 0代表女性1代表男性;
查询address时,尽量使用like查询如:select * from db_user where address like '%北京%';
语法为sqlite语法;
''')
vn.train()
def create_vana():
@@ -62,6 +44,8 @@ def create_vana():
def init_vn(vn):
print("--------------init vn-----connect----")
connect_database(vn)
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_train_data_ddl(vn)
return vn

View File

@@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase):
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图类型
sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
# --------基于提示词生成sql以及图类型
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("user_temp", user_temp)
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print(llm_response)
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
print("result", result)
sql = check_and_get_sql(llm_response)
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
@@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase):
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
print("new_question---------------", new_question)
if last_question is None:
return new_question
prompt = [
self.system_message(
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
self.user_message(new_question),
]
return self.submit_prompt(prompt=prompt, **kwargs)
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(

View File

@@ -36,6 +36,9 @@ template:
<rule>
若用户提问中提供了参考SQL你需要判断该SQL是否是查询语句
</rule>
<rule>
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
</rule>
<rule>
请使用JSON格式返回你的回答:
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
@@ -240,11 +243,11 @@ template:
{current_time}
</current-time>
<background-infos>
<user-question>
{question}
</user-question>
chart:
system: |
<Instruction>
@@ -312,7 +315,7 @@ template:
如果你无法根据提供的内容生成合适的JSON配置则返回{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}}
可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置提供的SQL查询结果中没有找到适合作为分类(series)的字段。"
</rule>
<Rules>
### 以下<example>帮助你理解问题及返回格式的例子,不要将<example>内的表结构用来回答用户的问题
@@ -498,3 +501,6 @@ template:
### 子查询映射表:
{sub_query}
Resources:

131
util/load_ddl_doc.py Normal file
View File

@@ -0,0 +1,131 @@
from service.cus_vanna_srevice import CustomVanna
# table_ddls = [
# """
# create table db_user
# (
# id integer not null
# constraint db_user_pk
# primary key autoincrement,
# user_name TEXT not null,
# age integer not null,
# address TEXT,
# gender integer not null,
# email TEXT
# )
# """,
# ]
# list_documentions = [
# """
# gender 字段 0代表女性1代表男性;
# 查询address时,尽量使用like查询如:select * from db_user where address like '%北京%';
# 语法为sqlite语法;
# """,
# ]
table_ddls = [
"""
CREATE TABLE 人员库表 (
id VARCHAR(22) PRIMARY KEY COMMENT '主键',
name VARCHAR(600) DEFAULT NULL COMMENT '姓名',
gender VARCHAR(108) DEFAULT NULL COMMENT '性别',
id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型',
id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码',
birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期',
native_place TEXT DEFAULT NULL COMMENT '籍贯',
nation TEXT DEFAULT NULL COMMENT '民族',
country TEXT DEFAULT NULL COMMENT '国籍',
residence_address TEXT DEFAULT NULL COMMENT '户籍地址',
highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历',
highest_degree TEXT DEFAULT NULL COMMENT '最高学位',
graduate_school TEXT DEFAULT NULL COMMENT '毕业院校',
political_status TEXT DEFAULT NULL COMMENT '政治面貌',
phone_number TEXT DEFAULT NULL COMMENT '手机号',
email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱',
worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号',
post TEXT DEFAULT NULL COMMENT '职务',
engage_post TEXT DEFAULT NULL COMMENT '现从事岗位',
work_unit TEXT DEFAULT NULL COMMENT '工作单位全称',
work_content TEXT DEFAULT NULL COMMENT '工作内容',
engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号',
engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称',
is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商',
general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称',
office_city TEXT DEFAULT NULL COMMENT '办公城市',
office_address TEXT DEFAULT NULL COMMENT '办公地点',
person_type TEXT DEFAULT NULL COMMENT '人员类型',
person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态',
is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工',
internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位',
internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门',
external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位',
external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门',
to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室',
pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型',
entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期',
expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期',
expire_time DATETIME DEFAULT NULL COMMENT '失效时间',
verifystate INT DEFAULT NULL COMMENT '单据状态',
auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人',
auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人',
auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人',
procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID',
bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id',
bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称',
bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码',
bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id',
sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id',
first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id',
firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id',
firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象',
firstcode TEXT DEFAULT NULL COMMENT '来源单据号',
source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id',
sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id',
sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象',
sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号',
code TEXT DEFAULT NULL COMMENT '编码',
ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id',
photo TEXT DEFAULT NULL COMMENT '照片',
input_time DATETIME DEFAULT NULL COMMENT '录入时间',
create_time DATETIME DEFAULT NULL COMMENT '创建时间',
modify_time DATETIME DEFAULT NULL COMMENT '修改时间',
audit_time DATETIME DEFAULT NULL COMMENT '审批日期',
input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人',
input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门',
creator VARCHAR(60) DEFAULT NULL COMMENT '创建人',
modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人',
sort INT DEFAULT NULL COMMENT '排序',
dr INT DEFAULT 0 COMMENT '逻辑删除0-未删除1-已删除',
DHDATASTA INT DEFAULT NULL COMMENT '推送状态',
pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表';
""",
]
list_documentions = [
"""
<人员库表注意事项>
<rule>
person_status 字段 1代表草稿2代表审批中3代表制卡中4代表已入库5代表停用;
gender 字段 1代表男2代表女
is_internal 字段 0代表否1代表是
pass_type 字段 1代表集团公司员工2代表借调员工3代表借用人员4代表外部监管人员5代表外协服务人员6代表工勤人员7代表来访人员
person_type 字段 YG代表正式员工PQ代表劳务派遣人员QT代表其他柔性引进人员WHZ代表合作单位WLS代表临时访客WQT代表其他外部人员
dr 字段 0代表否1代表是
id_card_type 字段 1代表居民身份证2代表护照3代表港澳通行证
highest_education 字段 1代表初中2代表高中3代表中专4代表技校5代表职高6代表大专7代表本科8代表硕士9代表博士
highest_degree 字段 1代表学士学位2代表硕士学位3代表博士学位4代表无
is_subcontractor 字段 0代表否1代表是
is_sign_confidentiality_agreement 字段 0代表否1代表是
DHDATASTA 字段 0代表新增 1代表更新
查询address时,尽量使用like查询如:select * from 人员库 where address like '%张三%';
语法为mysql语法;
</rule>
</人员库表注意事项>
""",
]
def add_ddl(vn: CustomVanna):
for ddl in table_ddls:
vn.add_ddl(ddl)
def add_documentation(vn: CustomVanna):
for doc in list_documentions:
vn.add_documentation(doc)