批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
This commit is contained in:
		| @@ -3,6 +3,7 @@ from email.policy import default | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||
| from decouple import config | ||||
| import flask | ||||
| from util import load_ddl_doc | ||||
|  | ||||
| from flask import Flask, Response, jsonify, request, send_from_directory | ||||
|  | ||||
| @@ -12,10 +13,10 @@ def connect_database(vn): | ||||
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) | ||||
|     elif db_type == 'mysql': | ||||
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), | ||||
|                             port=config('MYSQL_DATABASE_PORT', default=3306), | ||||
|                             port=int(config('MYSQL_DATABASE_PORT', default=3306)), | ||||
|                             user=config('MYSQL_DATABASE_USER', default=''), | ||||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | ||||
|                             database=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|                             dbname=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|     elif db_type == 'postgresql': | ||||
|         # 待补充 | ||||
|         pass | ||||
| @@ -24,26 +25,7 @@ def connect_database(vn): | ||||
|  | ||||
|  | ||||
| def load_train_data_ddl(vn: CustomVanna): | ||||
|     vn.train(ddl=""" | ||||
|                  create table db_user | ||||
|                  ( | ||||
|                      id        integer not null | ||||
|                          constraint db_user_pk | ||||
|                              primary key autoincrement, | ||||
|                      user_name TEXT    not null, | ||||
|                      age       integer not null, | ||||
|                      address   TEXT, | ||||
|                      gender    integer not null, | ||||
|                      email     TEXT | ||||
|                  ) | ||||
|  | ||||
|  | ||||
|                  """) | ||||
|     vn.train(documentation=''' | ||||
|                 gender 字段 0代表女性,1代表男性; | ||||
|                 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
|                 语法为sqlite语法; | ||||
|         ''') | ||||
|     vn.train() | ||||
|  | ||||
|  | ||||
| def create_vana(): | ||||
| @@ -62,6 +44,8 @@ def create_vana(): | ||||
| def init_vn(vn): | ||||
|     print("--------------init vn-----connect----") | ||||
|     connect_database(vn) | ||||
|     load_ddl_doc.add_ddl(vn) | ||||
|     load_ddl_doc.add_documentation(vn) | ||||
|     if config('IS_FIRST_LOAD', default=False, cast=bool): | ||||
|         load_train_data_ddl(vn) | ||||
|     return vn | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128