批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
This commit is contained in:
@@ -3,6 +3,7 @@ from email.policy import default
|
|||||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
||||||
from decouple import config
|
from decouple import config
|
||||||
import flask
|
import flask
|
||||||
|
from util import load_ddl_doc
|
||||||
|
|
||||||
from flask import Flask, Response, jsonify, request, send_from_directory
|
from flask import Flask, Response, jsonify, request, send_from_directory
|
||||||
|
|
||||||
@@ -12,10 +13,10 @@ def connect_database(vn):
|
|||||||
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
|
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
|
||||||
elif db_type == 'mysql':
|
elif db_type == 'mysql':
|
||||||
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
|
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
|
||||||
port=config('MYSQL_DATABASE_PORT', default=3306),
|
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
|
||||||
user=config('MYSQL_DATABASE_USER', default=''),
|
user=config('MYSQL_DATABASE_USER', default=''),
|
||||||
password=config('MYSQL_DATABASE_PASSWORD', default=''),
|
password=config('MYSQL_DATABASE_PASSWORD', default=''),
|
||||||
database=config('MYSQL_DATABASE_DBNAME', default=''))
|
dbname=config('MYSQL_DATABASE_DBNAME', default=''))
|
||||||
elif db_type == 'postgresql':
|
elif db_type == 'postgresql':
|
||||||
# 待补充
|
# 待补充
|
||||||
pass
|
pass
|
||||||
@@ -24,26 +25,7 @@ def connect_database(vn):
|
|||||||
|
|
||||||
|
|
||||||
def load_train_data_ddl(vn: CustomVanna):
|
def load_train_data_ddl(vn: CustomVanna):
|
||||||
vn.train(ddl="""
|
vn.train()
|
||||||
create table db_user
|
|
||||||
(
|
|
||||||
id integer not null
|
|
||||||
constraint db_user_pk
|
|
||||||
primary key autoincrement,
|
|
||||||
user_name TEXT not null,
|
|
||||||
age integer not null,
|
|
||||||
address TEXT,
|
|
||||||
gender integer not null,
|
|
||||||
email TEXT
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
""")
|
|
||||||
vn.train(documentation='''
|
|
||||||
gender 字段 0代表女性,1代表男性;
|
|
||||||
查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
|
|
||||||
语法为sqlite语法;
|
|
||||||
''')
|
|
||||||
|
|
||||||
|
|
||||||
def create_vana():
|
def create_vana():
|
||||||
@@ -62,6 +44,8 @@ def create_vana():
|
|||||||
def init_vn(vn):
|
def init_vn(vn):
|
||||||
print("--------------init vn-----connect----")
|
print("--------------init vn-----connect----")
|
||||||
connect_database(vn)
|
connect_database(vn)
|
||||||
|
load_ddl_doc.add_ddl(vn)
|
||||||
|
load_ddl_doc.add_documentation(vn)
|
||||||
if config('IS_FIRST_LOAD', default=False, cast=bool):
|
if config('IS_FIRST_LOAD', default=False, cast=bool):
|
||||||
load_train_data_ddl(vn)
|
load_train_data_ddl(vn)
|
||||||
return vn
|
return vn
|
||||||
|
|||||||
@@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
template = get_base_template()
|
template = get_base_template()
|
||||||
sql_temp = template['template']['sql']
|
sql_temp = template['template']['sql']
|
||||||
char_temp = template['template']['chart']
|
char_temp = template['template']['chart']
|
||||||
# --------基于提示词,生成sql以及图标类型
|
# --------基于提示词,生成sql以及图表类型
|
||||||
sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
|
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
|
||||||
data_training=question_sql_list)
|
data_training=question_sql_list)
|
||||||
|
print("sys_temp", sys_temp)
|
||||||
user_temp = sql_temp['user'].format(question=question,
|
user_temp = sql_temp['user'].format(question=question,
|
||||||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||||
|
print("user_temp", user_temp)
|
||||||
llm_response = self.submit_prompt(
|
llm_response = self.submit_prompt(
|
||||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||||||
print(llm_response)
|
print(llm_response)
|
||||||
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
||||||
|
print("result", result)
|
||||||
sql = check_and_get_sql(llm_response)
|
sql = check_and_get_sql(llm_response)
|
||||||
# ---------------生成图表
|
# ---------------生成图表
|
||||||
char_type = get_chart_type_from_sql_answer(llm_response)
|
char_type = get_chart_type_from_sql_answer(llm_response)
|
||||||
if char_type:
|
if char_type:
|
||||||
sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
|
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
|
||||||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||||||
llm_response2 = self.submit_prompt(
|
llm_response2 = self.submit_prompt(
|
||||||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
|
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
|
||||||
@@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
||||||
|
print("new_question---------------", new_question)
|
||||||
|
if last_question is None:
|
||||||
|
return new_question
|
||||||
|
|
||||||
|
prompt = [
|
||||||
|
self.system_message(
|
||||||
|
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
|
||||||
|
self.user_message(new_question),
|
||||||
|
]
|
||||||
|
|
||||||
|
return self.submit_prompt(prompt=prompt, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ template:
|
|||||||
<rule>
|
<rule>
|
||||||
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
|
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
|
||||||
</rule>
|
</rule>
|
||||||
|
<rule>
|
||||||
|
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
|
||||||
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
请使用JSON格式返回你的回答:
|
请使用JSON格式返回你的回答:
|
||||||
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
|
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
|
||||||
@@ -498,3 +501,6 @@ template:
|
|||||||
|
|
||||||
### 子查询映射表:
|
### 子查询映射表:
|
||||||
{sub_query}
|
{sub_query}
|
||||||
|
|
||||||
|
|
||||||
|
Resources:
|
||||||
131
util/load_ddl_doc.py
Normal file
131
util/load_ddl_doc.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
from service.cus_vanna_srevice import CustomVanna
|
||||||
|
# table_ddls = [
|
||||||
|
# """
|
||||||
|
# create table db_user
|
||||||
|
# (
|
||||||
|
# id integer not null
|
||||||
|
# constraint db_user_pk
|
||||||
|
# primary key autoincrement,
|
||||||
|
# user_name TEXT not null,
|
||||||
|
# age integer not null,
|
||||||
|
# address TEXT,
|
||||||
|
# gender integer not null,
|
||||||
|
# email TEXT
|
||||||
|
# )
|
||||||
|
# """,
|
||||||
|
# ]
|
||||||
|
# list_documentions = [
|
||||||
|
# """
|
||||||
|
# gender 字段 0代表女性,1代表男性;
|
||||||
|
# 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
|
||||||
|
# 语法为sqlite语法;
|
||||||
|
# """,
|
||||||
|
# ]
|
||||||
|
table_ddls = [
|
||||||
|
"""
|
||||||
|
CREATE TABLE 人员库表 (
|
||||||
|
id VARCHAR(22) PRIMARY KEY COMMENT '主键',
|
||||||
|
name VARCHAR(600) DEFAULT NULL COMMENT '姓名',
|
||||||
|
gender VARCHAR(108) DEFAULT NULL COMMENT '性别',
|
||||||
|
id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型',
|
||||||
|
id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码',
|
||||||
|
birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期',
|
||||||
|
native_place TEXT DEFAULT NULL COMMENT '籍贯',
|
||||||
|
nation TEXT DEFAULT NULL COMMENT '民族',
|
||||||
|
country TEXT DEFAULT NULL COMMENT '国籍',
|
||||||
|
residence_address TEXT DEFAULT NULL COMMENT '户籍地址',
|
||||||
|
highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历',
|
||||||
|
highest_degree TEXT DEFAULT NULL COMMENT '最高学位',
|
||||||
|
graduate_school TEXT DEFAULT NULL COMMENT '毕业院校',
|
||||||
|
political_status TEXT DEFAULT NULL COMMENT '政治面貌',
|
||||||
|
phone_number TEXT DEFAULT NULL COMMENT '手机号',
|
||||||
|
email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱',
|
||||||
|
worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号',
|
||||||
|
post TEXT DEFAULT NULL COMMENT '职务',
|
||||||
|
engage_post TEXT DEFAULT NULL COMMENT '现从事岗位',
|
||||||
|
work_unit TEXT DEFAULT NULL COMMENT '工作单位全称',
|
||||||
|
work_content TEXT DEFAULT NULL COMMENT '工作内容',
|
||||||
|
engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号',
|
||||||
|
engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称',
|
||||||
|
is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商',
|
||||||
|
general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称',
|
||||||
|
office_city TEXT DEFAULT NULL COMMENT '办公城市',
|
||||||
|
office_address TEXT DEFAULT NULL COMMENT '办公地点',
|
||||||
|
person_type TEXT DEFAULT NULL COMMENT '人员类型',
|
||||||
|
person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态',
|
||||||
|
is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工',
|
||||||
|
internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位',
|
||||||
|
internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门',
|
||||||
|
external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位',
|
||||||
|
external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门',
|
||||||
|
to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室',
|
||||||
|
pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型',
|
||||||
|
entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期',
|
||||||
|
expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期',
|
||||||
|
expire_time DATETIME DEFAULT NULL COMMENT '失效时间',
|
||||||
|
verifystate INT DEFAULT NULL COMMENT '单据状态',
|
||||||
|
auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人',
|
||||||
|
auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人',
|
||||||
|
auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人',
|
||||||
|
procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID',
|
||||||
|
bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id',
|
||||||
|
bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称',
|
||||||
|
bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码',
|
||||||
|
bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id',
|
||||||
|
sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id',
|
||||||
|
first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id',
|
||||||
|
firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id',
|
||||||
|
firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象',
|
||||||
|
firstcode TEXT DEFAULT NULL COMMENT '来源单据号',
|
||||||
|
source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id',
|
||||||
|
sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id',
|
||||||
|
sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象',
|
||||||
|
sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号',
|
||||||
|
code TEXT DEFAULT NULL COMMENT '编码',
|
||||||
|
ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id',
|
||||||
|
photo TEXT DEFAULT NULL COMMENT '照片',
|
||||||
|
input_time DATETIME DEFAULT NULL COMMENT '录入时间',
|
||||||
|
create_time DATETIME DEFAULT NULL COMMENT '创建时间',
|
||||||
|
modify_time DATETIME DEFAULT NULL COMMENT '修改时间',
|
||||||
|
audit_time DATETIME DEFAULT NULL COMMENT '审批日期',
|
||||||
|
input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人',
|
||||||
|
input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门',
|
||||||
|
creator VARCHAR(60) DEFAULT NULL COMMENT '创建人',
|
||||||
|
modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人',
|
||||||
|
sort INT DEFAULT NULL COMMENT '排序',
|
||||||
|
dr INT DEFAULT 0 COMMENT '逻辑删除:0-未删除,1-已删除',
|
||||||
|
DHDATASTA INT DEFAULT NULL COMMENT '推送状态',
|
||||||
|
pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)'
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表';
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
list_documentions = [
|
||||||
|
"""
|
||||||
|
<人员库表注意事项>
|
||||||
|
<rule>
|
||||||
|
person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用;
|
||||||
|
gender 字段 1代表男,2代表女
|
||||||
|
is_internal 字段 0代表否,1代表是
|
||||||
|
pass_type 字段 1代表集团公司员工,2代表借调员工,3代表借用人员,4代表外部监管人员,5代表外协服务人员,6代表工勤人员,7代表来访人员
|
||||||
|
person_type 字段 YG代表正式员工,PQ代表劳务派遣人员,QT代表其他柔性引进人员,WHZ代表合作单位,WLS代表临时访客,WQT代表其他外部人员
|
||||||
|
dr 字段 0代表否,1代表是
|
||||||
|
id_card_type 字段 1代表居民身份证,2代表护照,3代表港澳通行证
|
||||||
|
highest_education 字段 1代表初中,2代表高中,3代表中专,4代表技校,5代表职高,6代表大专,7代表本科,8代表硕士,9代表博士
|
||||||
|
highest_degree 字段 1代表学士学位,2代表硕士学位,3代表博士学位,4代表无
|
||||||
|
is_subcontractor 字段 0代表否,1代表是
|
||||||
|
is_sign_confidentiality_agreement 字段 0代表否,1代表是
|
||||||
|
DHDATASTA 字段 0代表新增 1代表更新
|
||||||
|
查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%';
|
||||||
|
语法为mysql语法;
|
||||||
|
</rule>
|
||||||
|
</人员库表注意事项>
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
def add_ddl(vn: CustomVanna):
|
||||||
|
for ddl in table_ddls:
|
||||||
|
vn.add_ddl(ddl)
|
||||||
|
|
||||||
|
def add_documentation(vn: CustomVanna):
|
||||||
|
for doc in list_documentions:
|
||||||
|
vn.add_documentation(doc)
|
||||||
|
|
||||||
Reference in New Issue
Block a user