116 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			116 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from email.policy import default
 | ||
| 
 | ||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient
 | ||
| from decouple import config
 | ||
| import flask
 | ||
| 
 | ||
| from flask import Flask, Response, jsonify, request, send_from_directory
 | ||
| 
 | ||
| def connect_database(vn):
 | ||
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite')
 | ||
|     if db_type == 'sqlite':
 | ||
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
 | ||
|     elif db_type == 'mysql':
 | ||
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
 | ||
|                             port=config('MYSQL_DATABASE_PORT', default=3306),
 | ||
|                             user=config('MYSQL_DATABASE_USER', default=''),
 | ||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''),
 | ||
|                             database=config('MYSQL_DATABASE_DBNAME', default=''))
 | ||
|     elif db_type == 'postgresql':
 | ||
|         # 待补充
 | ||
|         pass
 | ||
|     else:
 | ||
|         pass
 | ||
| 
 | ||
| 
 | ||
| def load_train_data_ddl(vn: CustomVanna):
 | ||
|     vn.train(ddl="""
 | ||
|                  create table db_user
 | ||
|                  (
 | ||
|                      id        integer not null
 | ||
|                          constraint db_user_pk
 | ||
|                              primary key autoincrement,
 | ||
|                      user_name TEXT    not null,
 | ||
|                      age       integer not null,
 | ||
|                      address   TEXT,
 | ||
|                      gender    integer not null,
 | ||
|                      email     TEXT
 | ||
|                  )
 | ||
| 
 | ||
| 
 | ||
|                  """)
 | ||
|     vn.train(documentation='''
 | ||
|                 gender 字段 0代表女性,1代表男性;
 | ||
|                 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
 | ||
|                 语法为sqlite语法;
 | ||
|         ''')
 | ||
| 
 | ||
| 
 | ||
| def create_vana():
 | ||
|     print("----------------create---------")
 | ||
|     vn = CustomVanna(
 | ||
|         vector_store_config={"client": QdrantClient(":memory:")},
 | ||
|         llm_config={
 | ||
|             "api_key": config('CHAT_MODEL_API_KEY', default=''),
 | ||
|             "api_base": config('CHAT_MODEL_BASE_URL', default=''),
 | ||
|             "model": config('CHAT_MODEL_NAME', default=''),
 | ||
|         },
 | ||
|     )
 | ||
|     return vn
 | ||
| 
 | ||
| 
 | ||
| def init_vn(vn):
 | ||
|     print("--------------init vn-----connect----")
 | ||
|     connect_database(vn)
 | ||
|     if config('IS_FIRST_LOAD', default=False, cast=bool):
 | ||
|         load_train_data_ddl(vn)
 | ||
|     return vn
 | ||
| 
 | ||
| 
 | ||
| from vanna.flask import VannaFlaskApp
 | ||
| vn = create_vana()
 | ||
| app = VannaFlaskApp(vn,chart=False)
 | ||
| init_vn(vn)
 | ||
| 
 | ||
| 
 | ||
| @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"])
 | ||
| def generate_sql_2():
 | ||
|     """
 | ||
|     Generate SQL from a question
 | ||
|     ---
 | ||
|     parameters:
 | ||
|       - name: user
 | ||
|         in: query
 | ||
|       - name: question
 | ||
|         in: query
 | ||
|         type: string
 | ||
|         required: true
 | ||
|     responses:
 | ||
|       200:
 | ||
|         schema:
 | ||
|           type: object
 | ||
|           properties:
 | ||
|             type:
 | ||
|               type: string
 | ||
|               default: sql
 | ||
|             id:
 | ||
|               type: string
 | ||
|             text:
 | ||
|               type: string
 | ||
|     """
 | ||
|     question = flask.request.args.get("question")
 | ||
| 
 | ||
|     if question is None:
 | ||
|         return jsonify({"type": "error", "error": "No question provided"})
 | ||
| 
 | ||
|     #id = self.cache.generate_id(question=question)
 | ||
|     data = vn.generate_sql_2(question=question)
 | ||
| 
 | ||
| 
 | ||
|     return jsonify(data)
 | ||
| 
 | ||
| 
 | ||
| if __name__ == '__main__':
 | ||
| 
 | ||
|     app.run(host='0.0.0.0', port=8084, debug=False)
 | 
