批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
This commit is contained in:
@@ -3,6 +3,7 @@ from email.policy import default
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
||||
from decouple import config
|
||||
import flask
|
||||
from util import load_ddl_doc
|
||||
|
||||
from flask import Flask, Response, jsonify, request, send_from_directory
|
||||
|
||||
@@ -12,10 +13,10 @@ def connect_database(vn):
|
||||
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
|
||||
elif db_type == 'mysql':
|
||||
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
|
||||
port=config('MYSQL_DATABASE_PORT', default=3306),
|
||||
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
|
||||
user=config('MYSQL_DATABASE_USER', default=''),
|
||||
password=config('MYSQL_DATABASE_PASSWORD', default=''),
|
||||
database=config('MYSQL_DATABASE_DBNAME', default=''))
|
||||
dbname=config('MYSQL_DATABASE_DBNAME', default=''))
|
||||
elif db_type == 'postgresql':
|
||||
# 待补充
|
||||
pass
|
||||
@@ -24,26 +25,7 @@ def connect_database(vn):
|
||||
|
||||
|
||||
def load_train_data_ddl(vn: CustomVanna):
|
||||
vn.train(ddl="""
|
||||
create table db_user
|
||||
(
|
||||
id integer not null
|
||||
constraint db_user_pk
|
||||
primary key autoincrement,
|
||||
user_name TEXT not null,
|
||||
age integer not null,
|
||||
address TEXT,
|
||||
gender integer not null,
|
||||
email TEXT
|
||||
)
|
||||
|
||||
|
||||
""")
|
||||
vn.train(documentation='''
|
||||
gender 字段 0代表女性,1代表男性;
|
||||
查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
|
||||
语法为sqlite语法;
|
||||
''')
|
||||
vn.train()
|
||||
|
||||
|
||||
def create_vana():
|
||||
@@ -62,6 +44,8 @@ def create_vana():
|
||||
def init_vn(vn):
|
||||
print("--------------init vn-----connect----")
|
||||
connect_database(vn)
|
||||
load_ddl_doc.add_ddl(vn)
|
||||
load_ddl_doc.add_documentation(vn)
|
||||
if config('IS_FIRST_LOAD', default=False, cast=bool):
|
||||
load_train_data_ddl(vn)
|
||||
return vn
|
||||
|
||||
@@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
template = get_base_template()
|
||||
sql_temp = template['template']['sql']
|
||||
char_temp = template['template']['chart']
|
||||
# --------基于提示词,生成sql以及图标类型
|
||||
sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
|
||||
# --------基于提示词,生成sql以及图表类型
|
||||
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
|
||||
data_training=question_sql_list)
|
||||
print("sys_temp", sys_temp)
|
||||
user_temp = sql_temp['user'].format(question=question,
|
||||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
print("user_temp", user_temp)
|
||||
llm_response = self.submit_prompt(
|
||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||||
print(llm_response)
|
||||
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
||||
|
||||
print("result", result)
|
||||
sql = check_and_get_sql(llm_response)
|
||||
# ---------------生成图表
|
||||
char_type = get_chart_type_from_sql_answer(llm_response)
|
||||
if char_type:
|
||||
sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
|
||||
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
|
||||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||||
llm_response2 = self.submit_prompt(
|
||||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
|
||||
@@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||||
return result
|
||||
|
||||
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
||||
print("new_question---------------", new_question)
|
||||
if last_question is None:
|
||||
return new_question
|
||||
|
||||
prompt = [
|
||||
self.system_message(
|
||||
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
|
||||
self.user_message(new_question),
|
||||
]
|
||||
|
||||
return self.submit_prompt(prompt=prompt, **kwargs)
|
||||
|
||||
|
||||
class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||
def __init__(
|
||||
|
||||
@@ -36,6 +36,9 @@ template:
|
||||
<rule>
|
||||
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
|
||||
</rule>
|
||||
<rule>
|
||||
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
|
||||
</rule>
|
||||
<rule>
|
||||
请使用JSON格式返回你的回答:
|
||||
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
|
||||
@@ -498,3 +501,6 @@ template:
|
||||
|
||||
### 子查询映射表:
|
||||
{sub_query}
|
||||
|
||||
|
||||
Resources:
|
||||
131
util/load_ddl_doc.py
Normal file
131
util/load_ddl_doc.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from service.cus_vanna_srevice import CustomVanna
|
||||
# table_ddls = [
|
||||
# """
|
||||
# create table db_user
|
||||
# (
|
||||
# id integer not null
|
||||
# constraint db_user_pk
|
||||
# primary key autoincrement,
|
||||
# user_name TEXT not null,
|
||||
# age integer not null,
|
||||
# address TEXT,
|
||||
# gender integer not null,
|
||||
# email TEXT
|
||||
# )
|
||||
# """,
|
||||
# ]
|
||||
# list_documentions = [
|
||||
# """
|
||||
# gender 字段 0代表女性,1代表男性;
|
||||
# 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
|
||||
# 语法为sqlite语法;
|
||||
# """,
|
||||
# ]
|
||||
table_ddls = [
|
||||
"""
|
||||
CREATE TABLE 人员库表 (
|
||||
id VARCHAR(22) PRIMARY KEY COMMENT '主键',
|
||||
name VARCHAR(600) DEFAULT NULL COMMENT '姓名',
|
||||
gender VARCHAR(108) DEFAULT NULL COMMENT '性别',
|
||||
id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型',
|
||||
id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码',
|
||||
birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期',
|
||||
native_place TEXT DEFAULT NULL COMMENT '籍贯',
|
||||
nation TEXT DEFAULT NULL COMMENT '民族',
|
||||
country TEXT DEFAULT NULL COMMENT '国籍',
|
||||
residence_address TEXT DEFAULT NULL COMMENT '户籍地址',
|
||||
highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历',
|
||||
highest_degree TEXT DEFAULT NULL COMMENT '最高学位',
|
||||
graduate_school TEXT DEFAULT NULL COMMENT '毕业院校',
|
||||
political_status TEXT DEFAULT NULL COMMENT '政治面貌',
|
||||
phone_number TEXT DEFAULT NULL COMMENT '手机号',
|
||||
email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱',
|
||||
worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号',
|
||||
post TEXT DEFAULT NULL COMMENT '职务',
|
||||
engage_post TEXT DEFAULT NULL COMMENT '现从事岗位',
|
||||
work_unit TEXT DEFAULT NULL COMMENT '工作单位全称',
|
||||
work_content TEXT DEFAULT NULL COMMENT '工作内容',
|
||||
engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号',
|
||||
engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称',
|
||||
is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商',
|
||||
general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称',
|
||||
office_city TEXT DEFAULT NULL COMMENT '办公城市',
|
||||
office_address TEXT DEFAULT NULL COMMENT '办公地点',
|
||||
person_type TEXT DEFAULT NULL COMMENT '人员类型',
|
||||
person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态',
|
||||
is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工',
|
||||
internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位',
|
||||
internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门',
|
||||
external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位',
|
||||
external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门',
|
||||
to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室',
|
||||
pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型',
|
||||
entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期',
|
||||
expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期',
|
||||
expire_time DATETIME DEFAULT NULL COMMENT '失效时间',
|
||||
verifystate INT DEFAULT NULL COMMENT '单据状态',
|
||||
auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人',
|
||||
auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人',
|
||||
auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人',
|
||||
procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID',
|
||||
bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id',
|
||||
bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称',
|
||||
bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码',
|
||||
bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id',
|
||||
sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id',
|
||||
first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id',
|
||||
firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id',
|
||||
firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象',
|
||||
firstcode TEXT DEFAULT NULL COMMENT '来源单据号',
|
||||
source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id',
|
||||
sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id',
|
||||
sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象',
|
||||
sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号',
|
||||
code TEXT DEFAULT NULL COMMENT '编码',
|
||||
ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id',
|
||||
photo TEXT DEFAULT NULL COMMENT '照片',
|
||||
input_time DATETIME DEFAULT NULL COMMENT '录入时间',
|
||||
create_time DATETIME DEFAULT NULL COMMENT '创建时间',
|
||||
modify_time DATETIME DEFAULT NULL COMMENT '修改时间',
|
||||
audit_time DATETIME DEFAULT NULL COMMENT '审批日期',
|
||||
input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人',
|
||||
input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门',
|
||||
creator VARCHAR(60) DEFAULT NULL COMMENT '创建人',
|
||||
modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人',
|
||||
sort INT DEFAULT NULL COMMENT '排序',
|
||||
dr INT DEFAULT 0 COMMENT '逻辑删除:0-未删除,1-已删除',
|
||||
DHDATASTA INT DEFAULT NULL COMMENT '推送状态',
|
||||
pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)'
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表';
|
||||
""",
|
||||
]
|
||||
list_documentions = [
|
||||
"""
|
||||
<人员库表注意事项>
|
||||
<rule>
|
||||
person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用;
|
||||
gender 字段 1代表男,2代表女
|
||||
is_internal 字段 0代表否,1代表是
|
||||
pass_type 字段 1代表集团公司员工,2代表借调员工,3代表借用人员,4代表外部监管人员,5代表外协服务人员,6代表工勤人员,7代表来访人员
|
||||
person_type 字段 YG代表正式员工,PQ代表劳务派遣人员,QT代表其他柔性引进人员,WHZ代表合作单位,WLS代表临时访客,WQT代表其他外部人员
|
||||
dr 字段 0代表否,1代表是
|
||||
id_card_type 字段 1代表居民身份证,2代表护照,3代表港澳通行证
|
||||
highest_education 字段 1代表初中,2代表高中,3代表中专,4代表技校,5代表职高,6代表大专,7代表本科,8代表硕士,9代表博士
|
||||
highest_degree 字段 1代表学士学位,2代表硕士学位,3代表博士学位,4代表无
|
||||
is_subcontractor 字段 0代表否,1代表是
|
||||
is_sign_confidentiality_agreement 字段 0代表否,1代表是
|
||||
DHDATASTA 字段 0代表新增 1代表更新
|
||||
查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%';
|
||||
语法为mysql语法;
|
||||
</rule>
|
||||
</人员库表注意事项>
|
||||
""",
|
||||
]
|
||||
def add_ddl(vn: CustomVanna):
|
||||
for ddl in table_ddls:
|
||||
vn.add_ddl(ddl)
|
||||
|
||||
def add_documentation(vn: CustomVanna):
|
||||
for doc in list_documentions:
|
||||
vn.add_documentation(doc)
|
||||
|
||||
Reference in New Issue
Block a user