diff --git a/main_service.py b/main_service.py index 42eb408..a9aae26 100644 --- a/main_service.py +++ b/main_service.py @@ -3,6 +3,7 @@ from email.policy import default from service.cus_vanna_srevice import CustomVanna, QdrantClient from decouple import config import flask +from util import load_ddl_doc from flask import Flask, Response, jsonify, request, send_from_directory @@ -12,10 +13,10 @@ def connect_database(vn): vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) elif db_type == 'mysql': vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), - port=config('MYSQL_DATABASE_PORT', default=3306), + port=int(config('MYSQL_DATABASE_PORT', default=3306)), user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), - database=config('MYSQL_DATABASE_DBNAME', default='')) + dbname=config('MYSQL_DATABASE_DBNAME', default='')) elif db_type == 'postgresql': # 待补充 pass @@ -24,26 +25,7 @@ def connect_database(vn): def load_train_data_ddl(vn: CustomVanna): - vn.train(ddl=""" - create table db_user - ( - id integer not null - constraint db_user_pk - primary key autoincrement, - user_name TEXT not null, - age integer not null, - address TEXT, - gender integer not null, - email TEXT - ) - - - """) - vn.train(documentation=''' - gender 字段 0代表女性,1代表男性; - 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; - 语法为sqlite语法; - ''') + vn.train() def create_vana(): @@ -62,6 +44,8 @@ def create_vana(): def init_vn(vn): print("--------------init vn-----connect----") connect_database(vn) + load_ddl_doc.add_ddl(vn) + load_ddl_doc.add_documentation(vn) if config('IS_FIRST_LOAD', default=False, cast=bool): load_train_data_ddl(vn) return vn diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py index e104585..d0c68af 100644 --- a/service/cus_vanna_srevice.py +++ b/service/cus_vanna_srevice.py @@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase): template = get_base_template() sql_temp = template['template']['sql'] char_temp = template['template']['chart'] - # --------基于提示词,生成sql以及图标类型 - sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, + # --------基于提示词,生成sql以及图表类型 + sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, data_training=question_sql_list) + print("sys_temp", sys_temp) user_temp = sql_temp['user'].format(question=question, current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + print("user_temp", user_temp) llm_response = self.submit_prompt( [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) print(llm_response) result = {"resp": orjson.loads(extract_nested_json(llm_response))} - + print("result", result) sql = check_and_get_sql(llm_response) # ---------------生成图表 char_type = get_chart_type_from_sql_answer(llm_response) if char_type: - sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) + sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) llm_response2 = self.submit_prompt( [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) @@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase): result['chart'] = orjson.loads(extract_nested_json(llm_response2)) return result + def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: + print("new_question---------------", new_question) + if last_question is None: + return new_question + + prompt = [ + self.system_message( + "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."), + self.user_message(new_question), + ] + + return self.submit_prompt(prompt=prompt, **kwargs) + class CustomQdrant_VectorStore(Qdrant_VectorStore): def __init__( diff --git a/template.yaml b/template.yaml index 0e00b95..f6f59ac 100644 --- a/template.yaml +++ b/template.yaml @@ -36,6 +36,9 @@ template: 若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句 + + 请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 + 请使用JSON格式返回你的回答: 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} @@ -240,11 +243,11 @@ template: {current_time} - + {question} - + chart: system: | @@ -312,7 +315,7 @@ template: 如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}} 可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。" - + ### 以下帮助你理解问题及返回格式的例子,不要将内的表结构用来回答用户的问题 @@ -498,3 +501,6 @@ template: ### 子查询映射表: {sub_query} + + +Resources: \ No newline at end of file diff --git a/util/load_ddl_doc.py b/util/load_ddl_doc.py new file mode 100644 index 0000000..a1e2d5a --- /dev/null +++ b/util/load_ddl_doc.py @@ -0,0 +1,131 @@ +from service.cus_vanna_srevice import CustomVanna +# table_ddls = [ +# """ +# create table db_user +# ( +# id integer not null +# constraint db_user_pk +# primary key autoincrement, +# user_name TEXT not null, +# age integer not null, +# address TEXT, +# gender integer not null, +# email TEXT +# ) +# """, +# ] +# list_documentions = [ +# """ +# gender 字段 0代表女性,1代表男性; +# 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; +# 语法为sqlite语法; +# """, +# ] +table_ddls = [ + """ + CREATE TABLE 人员库表 ( + id VARCHAR(22) PRIMARY KEY COMMENT '主键', + name VARCHAR(600) DEFAULT NULL COMMENT '姓名', + gender VARCHAR(108) DEFAULT NULL COMMENT '性别', + id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型', + id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码', + birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期', + native_place TEXT DEFAULT NULL COMMENT '籍贯', + nation TEXT DEFAULT NULL COMMENT '民族', + country TEXT DEFAULT NULL COMMENT '国籍', + residence_address TEXT DEFAULT NULL COMMENT '户籍地址', + highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历', + highest_degree TEXT DEFAULT NULL COMMENT '最高学位', + graduate_school TEXT DEFAULT NULL COMMENT '毕业院校', + political_status TEXT DEFAULT NULL COMMENT '政治面貌', + phone_number TEXT DEFAULT NULL COMMENT '手机号', + email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱', + worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号', + post TEXT DEFAULT NULL COMMENT '职务', + engage_post TEXT DEFAULT NULL COMMENT '现从事岗位', + work_unit TEXT DEFAULT NULL COMMENT '工作单位全称', + work_content TEXT DEFAULT NULL COMMENT '工作内容', + engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号', + engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称', + is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商', + general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称', + office_city TEXT DEFAULT NULL COMMENT '办公城市', + office_address TEXT DEFAULT NULL COMMENT '办公地点', + person_type TEXT DEFAULT NULL COMMENT '人员类型', + person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态', + is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工', + internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位', + internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门', + external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位', + external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门', + to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室', + pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型', + entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期', + expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期', + expire_time DATETIME DEFAULT NULL COMMENT '失效时间', + verifystate INT DEFAULT NULL COMMENT '单据状态', + auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人', + auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人', + auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人', + procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID', + bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id', + bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称', + bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码', + bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id', + sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id', + first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id', + firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id', + firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象', + firstcode TEXT DEFAULT NULL COMMENT '来源单据号', + source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id', + sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id', + sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象', + sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号', + code TEXT DEFAULT NULL COMMENT '编码', + ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id', + photo TEXT DEFAULT NULL COMMENT '照片', + input_time DATETIME DEFAULT NULL COMMENT '录入时间', + create_time DATETIME DEFAULT NULL COMMENT '创建时间', + modify_time DATETIME DEFAULT NULL COMMENT '修改时间', + audit_time DATETIME DEFAULT NULL COMMENT '审批日期', + input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人', + input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门', + creator VARCHAR(60) DEFAULT NULL COMMENT '创建人', + modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人', + sort INT DEFAULT NULL COMMENT '排序', + dr INT DEFAULT 0 COMMENT '逻辑删除:0-未删除,1-已删除', + DHDATASTA INT DEFAULT NULL COMMENT '推送状态', + pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)' + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表'; + """, +] +list_documentions = [ + """ + <人员库表注意事项> + + person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用; + gender 字段 1代表男,2代表女 + is_internal 字段 0代表否,1代表是 + pass_type 字段 1代表集团公司员工,2代表借调员工,3代表借用人员,4代表外部监管人员,5代表外协服务人员,6代表工勤人员,7代表来访人员 + person_type 字段 YG代表正式员工,PQ代表劳务派遣人员,QT代表其他柔性引进人员,WHZ代表合作单位,WLS代表临时访客,WQT代表其他外部人员 + dr 字段 0代表否,1代表是 + id_card_type 字段 1代表居民身份证,2代表护照,3代表港澳通行证 + highest_education 字段 1代表初中,2代表高中,3代表中专,4代表技校,5代表职高,6代表大专,7代表本科,8代表硕士,9代表博士 + highest_degree 字段 1代表学士学位,2代表硕士学位,3代表博士学位,4代表无 + is_subcontractor 字段 0代表否,1代表是 + is_sign_confidentiality_agreement 字段 0代表否,1代表是 + DHDATASTA 字段 0代表新增 1代表更新 + 查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; + 语法为mysql语法; + + + """, +] +def add_ddl(vn: CustomVanna): + for ddl in table_ddls: + vn.add_ddl(ddl) + +def add_documentation(vn: CustomVanna): + for doc in list_documentions: + vn.add_documentation(doc) +