diff --git a/main_service.py b/main_service.py
index 42eb408..a9aae26 100644
--- a/main_service.py
+++ b/main_service.py
@@ -3,6 +3,7 @@ from email.policy import default
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from decouple import config
import flask
+from util import load_ddl_doc
from flask import Flask, Response, jsonify, request, send_from_directory
@@ -12,10 +13,10 @@ def connect_database(vn):
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
- port=config('MYSQL_DATABASE_PORT', default=3306),
+ port=int(config('MYSQL_DATABASE_PORT', default=3306)),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
- database=config('MYSQL_DATABASE_DBNAME', default=''))
+ dbname=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'postgresql':
# 待补充
pass
@@ -24,26 +25,7 @@ def connect_database(vn):
def load_train_data_ddl(vn: CustomVanna):
- vn.train(ddl="""
- create table db_user
- (
- id integer not null
- constraint db_user_pk
- primary key autoincrement,
- user_name TEXT not null,
- age integer not null,
- address TEXT,
- gender integer not null,
- email TEXT
- )
-
-
- """)
- vn.train(documentation='''
- gender 字段 0代表女性,1代表男性;
- 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
- 语法为sqlite语法;
- ''')
+ vn.train()
def create_vana():
@@ -62,6 +44,8 @@ def create_vana():
def init_vn(vn):
print("--------------init vn-----connect----")
connect_database(vn)
+ load_ddl_doc.add_ddl(vn)
+ load_ddl_doc.add_documentation(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_train_data_ddl(vn)
return vn
diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py
index e104585..d0c68af 100644
--- a/service/cus_vanna_srevice.py
+++ b/service/cus_vanna_srevice.py
@@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase):
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
- # --------基于提示词,生成sql以及图标类型
- sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
+ # --------基于提示词,生成sql以及图表类型
+ sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
+ print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
+ print("user_temp", user_temp)
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print(llm_response)
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
-
+ print("result", result)
sql = check_and_get_sql(llm_response)
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
- sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
+ sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
@@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase):
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
+ def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
+ print("new_question---------------", new_question)
+ if last_question is None:
+ return new_question
+
+ prompt = [
+ self.system_message(
+ "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
+ self.user_message(new_question),
+ ]
+
+ return self.submit_prompt(prompt=prompt, **kwargs)
+
class CustomQdrant_VectorStore(Qdrant_VectorStore):
def __init__(
diff --git a/template.yaml b/template.yaml
index 0e00b95..f6f59ac 100644
--- a/template.yaml
+++ b/template.yaml
@@ -36,6 +36,9 @@ template:
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
+
+ 请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
+
请使用JSON格式返回你的回答:
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
@@ -240,11 +243,11 @@ template:
{current_time}
-
+
{question}
-
+
chart:
system: |
@@ -312,7 +315,7 @@ template:
如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}}
可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。"
-
+
### 以下帮助你理解问题及返回格式的例子,不要将内的表结构用来回答用户的问题
@@ -498,3 +501,6 @@ template:
### 子查询映射表:
{sub_query}
+
+
+Resources:
\ No newline at end of file
diff --git a/util/load_ddl_doc.py b/util/load_ddl_doc.py
new file mode 100644
index 0000000..a1e2d5a
--- /dev/null
+++ b/util/load_ddl_doc.py
@@ -0,0 +1,131 @@
+from service.cus_vanna_srevice import CustomVanna
+# table_ddls = [
+# """
+# create table db_user
+# (
+# id integer not null
+# constraint db_user_pk
+# primary key autoincrement,
+# user_name TEXT not null,
+# age integer not null,
+# address TEXT,
+# gender integer not null,
+# email TEXT
+# )
+# """,
+# ]
+# list_documentions = [
+# """
+# gender 字段 0代表女性,1代表男性;
+# 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
+# 语法为sqlite语法;
+# """,
+# ]
+table_ddls = [
+ """
+ CREATE TABLE 人员库表 (
+ id VARCHAR(22) PRIMARY KEY COMMENT '主键',
+ name VARCHAR(600) DEFAULT NULL COMMENT '姓名',
+ gender VARCHAR(108) DEFAULT NULL COMMENT '性别',
+ id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型',
+ id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码',
+ birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期',
+ native_place TEXT DEFAULT NULL COMMENT '籍贯',
+ nation TEXT DEFAULT NULL COMMENT '民族',
+ country TEXT DEFAULT NULL COMMENT '国籍',
+ residence_address TEXT DEFAULT NULL COMMENT '户籍地址',
+ highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历',
+ highest_degree TEXT DEFAULT NULL COMMENT '最高学位',
+ graduate_school TEXT DEFAULT NULL COMMENT '毕业院校',
+ political_status TEXT DEFAULT NULL COMMENT '政治面貌',
+ phone_number TEXT DEFAULT NULL COMMENT '手机号',
+ email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱',
+ worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号',
+ post TEXT DEFAULT NULL COMMENT '职务',
+ engage_post TEXT DEFAULT NULL COMMENT '现从事岗位',
+ work_unit TEXT DEFAULT NULL COMMENT '工作单位全称',
+ work_content TEXT DEFAULT NULL COMMENT '工作内容',
+ engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号',
+ engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称',
+ is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商',
+ general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称',
+ office_city TEXT DEFAULT NULL COMMENT '办公城市',
+ office_address TEXT DEFAULT NULL COMMENT '办公地点',
+ person_type TEXT DEFAULT NULL COMMENT '人员类型',
+ person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态',
+ is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工',
+ internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位',
+ internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门',
+ external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位',
+ external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门',
+ to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室',
+ pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型',
+ entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期',
+ expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期',
+ expire_time DATETIME DEFAULT NULL COMMENT '失效时间',
+ verifystate INT DEFAULT NULL COMMENT '单据状态',
+ auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人',
+ auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人',
+ auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人',
+ procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID',
+ bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id',
+ bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称',
+ bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码',
+ bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id',
+ sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id',
+ first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id',
+ firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id',
+ firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象',
+ firstcode TEXT DEFAULT NULL COMMENT '来源单据号',
+ source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id',
+ sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id',
+ sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象',
+ sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号',
+ code TEXT DEFAULT NULL COMMENT '编码',
+ ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id',
+ photo TEXT DEFAULT NULL COMMENT '照片',
+ input_time DATETIME DEFAULT NULL COMMENT '录入时间',
+ create_time DATETIME DEFAULT NULL COMMENT '创建时间',
+ modify_time DATETIME DEFAULT NULL COMMENT '修改时间',
+ audit_time DATETIME DEFAULT NULL COMMENT '审批日期',
+ input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人',
+ input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门',
+ creator VARCHAR(60) DEFAULT NULL COMMENT '创建人',
+ modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人',
+ sort INT DEFAULT NULL COMMENT '排序',
+ dr INT DEFAULT 0 COMMENT '逻辑删除:0-未删除,1-已删除',
+ DHDATASTA INT DEFAULT NULL COMMENT '推送状态',
+ pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)'
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表';
+ """,
+]
+list_documentions = [
+ """
+ <人员库表注意事项>
+
+ person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用;
+ gender 字段 1代表男,2代表女
+ is_internal 字段 0代表否,1代表是
+ pass_type 字段 1代表集团公司员工,2代表借调员工,3代表借用人员,4代表外部监管人员,5代表外协服务人员,6代表工勤人员,7代表来访人员
+ person_type 字段 YG代表正式员工,PQ代表劳务派遣人员,QT代表其他柔性引进人员,WHZ代表合作单位,WLS代表临时访客,WQT代表其他外部人员
+ dr 字段 0代表否,1代表是
+ id_card_type 字段 1代表居民身份证,2代表护照,3代表港澳通行证
+ highest_education 字段 1代表初中,2代表高中,3代表中专,4代表技校,5代表职高,6代表大专,7代表本科,8代表硕士,9代表博士
+ highest_degree 字段 1代表学士学位,2代表硕士学位,3代表博士学位,4代表无
+ is_subcontractor 字段 0代表否,1代表是
+ is_sign_confidentiality_agreement 字段 0代表否,1代表是
+ DHDATASTA 字段 0代表新增 1代表更新
+ 查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%';
+ 语法为mysql语法;
+
+ 人员库表注意事项>
+ """,
+]
+def add_ddl(vn: CustomVanna):
+ for ddl in table_ddls:
+ vn.add_ddl(ddl)
+
+def add_documentation(vn: CustomVanna):
+ for doc in list_documentions:
+ vn.add_documentation(doc)
+