feat:初始化
This commit is contained in:
		
							
								
								
									
										115
									
								
								main_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								main_service.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,115 @@ | ||||
| from email.policy import default | ||||
|  | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||
| from decouple import config | ||||
| import flask | ||||
|  | ||||
| from flask import Flask, Response, jsonify, request, send_from_directory | ||||
|  | ||||
| def connect_database(vn): | ||||
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite') | ||||
|     if db_type == 'sqlite': | ||||
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) | ||||
|     elif db_type == 'mysql': | ||||
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), | ||||
|                             port=config('MYSQL_DATABASE_PORT', default=3306), | ||||
|                             user=config('MYSQL_DATABASE_USER', default=''), | ||||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | ||||
|                             database=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|     elif db_type == 'postgresql': | ||||
|         # 待补充 | ||||
|         pass | ||||
|     else: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def load_train_data_ddl(vn: CustomVanna): | ||||
|     vn.train(ddl=""" | ||||
|                  create table db_user | ||||
|                  ( | ||||
|                      id        integer not null | ||||
|                          constraint db_user_pk | ||||
|                              primary key autoincrement, | ||||
|                      user_name TEXT    not null, | ||||
|                      age       integer not null, | ||||
|                      address   TEXT, | ||||
|                      gender    integer not null, | ||||
|                      email     TEXT | ||||
|                  ) | ||||
|  | ||||
|  | ||||
|                  """) | ||||
|     vn.train(documentation=''' | ||||
|                 gender 字段 0代表女性,1代表男性; | ||||
|                 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
|                 语法为sqlite语法; | ||||
|         ''') | ||||
|  | ||||
|  | ||||
| def create_vana(): | ||||
|     print("----------------create---------") | ||||
|     vn = CustomVanna( | ||||
|         vector_store_config={"client": QdrantClient(":memory:")}, | ||||
|         llm_config={ | ||||
|             "api_key": config('CHAT_MODEL_API_KEY', default=''), | ||||
|             "api_base": config('CHAT_MODEL_BASE_URL', default=''), | ||||
|             "model": config('CHAT_MODEL_NAME', default=''), | ||||
|         }, | ||||
|     ) | ||||
|     return vn | ||||
|  | ||||
|  | ||||
| def init_vn(vn): | ||||
|     print("--------------init vn-----connect----") | ||||
|     connect_database(vn) | ||||
|     if config('IS_FIRST_LOAD', default=False, cast=bool): | ||||
|         load_train_data_ddl(vn) | ||||
|     return vn | ||||
|  | ||||
|  | ||||
| from vanna.flask import VannaFlaskApp | ||||
| vn = create_vana() | ||||
| app = VannaFlaskApp(vn,chart=False) | ||||
| init_vn(vn) | ||||
|  | ||||
|  | ||||
| @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) | ||||
| def generate_sql_2(): | ||||
|     """ | ||||
|     Generate SQL from a question | ||||
|     --- | ||||
|     parameters: | ||||
|       - name: user | ||||
|         in: query | ||||
|       - name: question | ||||
|         in: query | ||||
|         type: string | ||||
|         required: true | ||||
|     responses: | ||||
|       200: | ||||
|         schema: | ||||
|           type: object | ||||
|           properties: | ||||
|             type: | ||||
|               type: string | ||||
|               default: sql | ||||
|             id: | ||||
|               type: string | ||||
|             text: | ||||
|               type: string | ||||
|     """ | ||||
|     question = flask.request.args.get("question") | ||||
|  | ||||
|     if question is None: | ||||
|         return jsonify({"type": "error", "error": "No question provided"}) | ||||
|  | ||||
|     #id = self.cache.generate_id(question=question) | ||||
|     data = vn.generate_sql_2(question=question) | ||||
|  | ||||
|  | ||||
|     return jsonify(data) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     app.run(host='0.0.0.0', port=8084, debug=False) | ||||
		Reference in New Issue
	
	Block a user
	 雷雨
					雷雨