批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
This commit is contained in:
		| @@ -3,6 +3,7 @@ from email.policy import default | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||
| from decouple import config | ||||
| import flask | ||||
| from util import load_ddl_doc | ||||
|  | ||||
| from flask import Flask, Response, jsonify, request, send_from_directory | ||||
|  | ||||
| @@ -12,10 +13,10 @@ def connect_database(vn): | ||||
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) | ||||
|     elif db_type == 'mysql': | ||||
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), | ||||
|                             port=config('MYSQL_DATABASE_PORT', default=3306), | ||||
|                             port=int(config('MYSQL_DATABASE_PORT', default=3306)), | ||||
|                             user=config('MYSQL_DATABASE_USER', default=''), | ||||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | ||||
|                             database=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|                             dbname=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|     elif db_type == 'postgresql': | ||||
|         # 待补充 | ||||
|         pass | ||||
| @@ -24,26 +25,7 @@ def connect_database(vn): | ||||
|  | ||||
|  | ||||
| def load_train_data_ddl(vn: CustomVanna): | ||||
|     vn.train(ddl=""" | ||||
|                  create table db_user | ||||
|                  ( | ||||
|                      id        integer not null | ||||
|                          constraint db_user_pk | ||||
|                              primary key autoincrement, | ||||
|                      user_name TEXT    not null, | ||||
|                      age       integer not null, | ||||
|                      address   TEXT, | ||||
|                      gender    integer not null, | ||||
|                      email     TEXT | ||||
|                  ) | ||||
|  | ||||
|  | ||||
|                  """) | ||||
|     vn.train(documentation=''' | ||||
|                 gender 字段 0代表女性,1代表男性; | ||||
|                 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
|                 语法为sqlite语法; | ||||
|         ''') | ||||
|     vn.train() | ||||
|  | ||||
|  | ||||
| def create_vana(): | ||||
| @@ -62,6 +44,8 @@ def create_vana(): | ||||
| def init_vn(vn): | ||||
|     print("--------------init vn-----connect----") | ||||
|     connect_database(vn) | ||||
|     load_ddl_doc.add_ddl(vn) | ||||
|     load_ddl_doc.add_documentation(vn) | ||||
|     if config('IS_FIRST_LOAD', default=False, cast=bool): | ||||
|         load_train_data_ddl(vn) | ||||
|     return vn | ||||
|   | ||||
| @@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         template = get_base_template() | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图标类型 | ||||
|         sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|         # --------基于提示词,生成sql以及图表类型 | ||||
|         sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         print("sys_temp", sys_temp) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|         print("user_temp", user_temp) | ||||
|         llm_response = self.submit_prompt( | ||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|         print(llm_response) | ||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|  | ||||
|         print("result", result) | ||||
|         sql = check_and_get_sql(llm_response) | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) | ||||
|             sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
| @@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         return result | ||||
|  | ||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||
|         print("new_question---------------", new_question) | ||||
|         if last_question is None: | ||||
|             return new_question | ||||
|  | ||||
|         prompt = [ | ||||
|             self.system_message( | ||||
|                 "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."), | ||||
|             self.user_message(new_question), | ||||
|         ] | ||||
|  | ||||
|         return self.submit_prompt(prompt=prompt, **kwargs) | ||||
|  | ||||
|  | ||||
| class CustomQdrant_VectorStore(Qdrant_VectorStore): | ||||
|     def __init__( | ||||
|   | ||||
| @@ -36,6 +36,9 @@ template: | ||||
|         <rule> | ||||
|           若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           请使用JSON格式返回你的回答: | ||||
|           若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} | ||||
| @@ -240,11 +243,11 @@ template: | ||||
|         {current_time} | ||||
|         </current-time> | ||||
|       <background-infos> | ||||
|       | ||||
|        | ||||
|       <user-question> | ||||
|       {question} | ||||
|       </user-question> | ||||
|        | ||||
|    | ||||
|   chart: | ||||
|     system: | | ||||
|       <Instruction> | ||||
| @@ -312,7 +315,7 @@ template: | ||||
|           如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}} | ||||
|           可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。" | ||||
|         </rule> | ||||
|          | ||||
|        | ||||
|       <Rules> | ||||
|        | ||||
|       ### 以下<example>帮助你理解问题及返回格式的例子,不要将<example>内的表结构用来回答用户的问题 | ||||
| @@ -498,3 +501,6 @@ template: | ||||
|        | ||||
|       ### 子查询映射表: | ||||
|       {sub_query} | ||||
|  | ||||
|  | ||||
| Resources: | ||||
							
								
								
									
										131
									
								
								util/load_ddl_doc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								util/load_ddl_doc.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | ||||
| from service.cus_vanna_srevice import CustomVanna | ||||
| # table_ddls = [ | ||||
| #     """ | ||||
| #     create table db_user | ||||
| #         ( | ||||
| #             id        integer not null | ||||
| #                 constraint db_user_pk | ||||
| #                     primary key autoincrement, | ||||
| #             user_name TEXT    not null, | ||||
| #             age       integer not null, | ||||
| #             address   TEXT, | ||||
| #             gender    integer not null, | ||||
| #             email     TEXT | ||||
| #                  ) | ||||
| #     """, | ||||
| # ] | ||||
| # list_documentions = [ | ||||
| #     """ | ||||
| #     gender 字段 0代表女性,1代表男性; | ||||
| #     查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
| #     语法为sqlite语法; | ||||
| #     """, | ||||
| # ] | ||||
| table_ddls = [ | ||||
|     """ | ||||
|     CREATE TABLE 人员库表 ( | ||||
|     id VARCHAR(22) PRIMARY KEY COMMENT '主键', | ||||
|     name VARCHAR(600) DEFAULT NULL COMMENT '姓名', | ||||
|     gender VARCHAR(108) DEFAULT NULL COMMENT '性别', | ||||
|     id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型', | ||||
|     id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码', | ||||
|     birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期', | ||||
|     native_place TEXT DEFAULT NULL COMMENT '籍贯', | ||||
|     nation TEXT DEFAULT NULL COMMENT '民族', | ||||
|     country TEXT DEFAULT NULL COMMENT '国籍', | ||||
|     residence_address TEXT DEFAULT NULL COMMENT '户籍地址', | ||||
|     highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历', | ||||
|     highest_degree TEXT DEFAULT NULL COMMENT '最高学位', | ||||
|     graduate_school TEXT DEFAULT NULL COMMENT '毕业院校', | ||||
|     political_status TEXT DEFAULT NULL COMMENT '政治面貌', | ||||
|     phone_number TEXT DEFAULT NULL COMMENT '手机号', | ||||
|     email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱', | ||||
|     worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号', | ||||
|     post TEXT DEFAULT NULL COMMENT '职务', | ||||
|     engage_post TEXT DEFAULT NULL COMMENT '现从事岗位', | ||||
|     work_unit TEXT DEFAULT NULL COMMENT '工作单位全称', | ||||
|     work_content TEXT DEFAULT NULL COMMENT '工作内容', | ||||
|     engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号', | ||||
|     engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称', | ||||
|     is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商', | ||||
|     general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称', | ||||
|     office_city TEXT DEFAULT NULL COMMENT '办公城市', | ||||
|     office_address TEXT DEFAULT NULL COMMENT '办公地点', | ||||
|     person_type TEXT DEFAULT NULL COMMENT '人员类型', | ||||
|     person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态', | ||||
|     is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工', | ||||
|     internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位', | ||||
|     internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门', | ||||
|     external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位', | ||||
|     external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门', | ||||
|     to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室', | ||||
|     pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型', | ||||
|     entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期', | ||||
|     expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期', | ||||
|     expire_time DATETIME DEFAULT NULL COMMENT '失效时间', | ||||
|     verifystate INT DEFAULT NULL COMMENT '单据状态', | ||||
|     auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人', | ||||
|     auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人', | ||||
|     auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人', | ||||
|     procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID', | ||||
|     bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id', | ||||
|     bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称', | ||||
|     bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码', | ||||
|     bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id', | ||||
|     sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id', | ||||
|     first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id', | ||||
|     firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id', | ||||
|     firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象', | ||||
|     firstcode TEXT DEFAULT NULL COMMENT '来源单据号', | ||||
|     source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id', | ||||
|     sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id', | ||||
|     sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象', | ||||
|     sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号', | ||||
|     code TEXT DEFAULT NULL COMMENT '编码', | ||||
|     ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id', | ||||
|     photo TEXT DEFAULT NULL COMMENT '照片', | ||||
|     input_time DATETIME DEFAULT NULL COMMENT '录入时间', | ||||
|     create_time DATETIME DEFAULT NULL COMMENT '创建时间', | ||||
|     modify_time DATETIME DEFAULT NULL COMMENT '修改时间', | ||||
|     audit_time DATETIME DEFAULT NULL COMMENT '审批日期', | ||||
|     input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人', | ||||
|     input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门', | ||||
|     creator VARCHAR(60) DEFAULT NULL COMMENT '创建人', | ||||
|     modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人', | ||||
|     sort INT DEFAULT NULL COMMENT '排序', | ||||
|     dr INT DEFAULT 0 COMMENT '逻辑删除:0-未删除,1-已删除', | ||||
|     DHDATASTA INT DEFAULT NULL COMMENT '推送状态', | ||||
|     pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)' | ||||
|     ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表'; | ||||
|     """, | ||||
| ] | ||||
| list_documentions = [ | ||||
|     """ | ||||
|     <人员库表注意事项> | ||||
|         <rule> | ||||
|             person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用; | ||||
|             gender 字段 1代表男,2代表女 | ||||
|             is_internal 字段 0代表否,1代表是 | ||||
|             pass_type 字段 1代表集团公司员工,2代表借调员工,3代表借用人员,4代表外部监管人员,5代表外协服务人员,6代表工勤人员,7代表来访人员 | ||||
|             person_type 字段 YG代表正式员工,PQ代表劳务派遣人员,QT代表其他柔性引进人员,WHZ代表合作单位,WLS代表临时访客,WQT代表其他外部人员 | ||||
|             dr 字段 0代表否,1代表是 | ||||
|             id_card_type 字段 1代表居民身份证,2代表护照,3代表港澳通行证 | ||||
|             highest_education 字段 1代表初中,2代表高中,3代表中专,4代表技校,5代表职高,6代表大专,7代表本科,8代表硕士,9代表博士 | ||||
|             highest_degree 字段 1代表学士学位,2代表硕士学位,3代表博士学位,4代表无 | ||||
|             is_subcontractor 字段 0代表否,1代表是 | ||||
|             is_sign_confidentiality_agreement 字段 0代表否,1代表是 | ||||
|             DHDATASTA 字段 0代表新增 1代表更新 | ||||
|             查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; | ||||
|             语法为mysql语法; | ||||
|         </rule> | ||||
|     </人员库表注意事项> | ||||
|     """, | ||||
| ] | ||||
| def add_ddl(vn: CustomVanna): | ||||
|     for ddl in table_ddls: | ||||
|         vn.add_ddl(ddl) | ||||
|  | ||||
| def add_documentation(vn: CustomVanna): | ||||
|     for doc in list_documentions: | ||||
|         vn.add_documentation(doc) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128