from email.policy import default from service.cus_vanna_srevice import CustomVanna, QdrantClient from decouple import config import flask from flask import Flask, Response, jsonify, request, send_from_directory def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) elif db_type == 'mysql': vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), port=config('MYSQL_DATABASE_PORT', default=3306), user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), database=config('MYSQL_DATABASE_DBNAME', default='')) elif db_type == 'postgresql': # 待补充 pass else: pass def load_train_data_ddl(vn: CustomVanna): vn.train(ddl=""" create table db_user ( id integer not null constraint db_user_pk primary key autoincrement, user_name TEXT not null, age integer not null, address TEXT, gender integer not null, email TEXT ) """) vn.train(documentation=''' gender 字段 0代表女性,1代表男性; 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; 语法为sqlite语法; ''') def create_vana(): print("----------------create---------") vn = CustomVanna( vector_store_config={"client": QdrantClient(":memory:")}, llm_config={ "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), }, ) return vn def init_vn(vn): print("--------------init vn-----connect----") connect_database(vn) if config('IS_FIRST_LOAD', default=False, cast=bool): load_train_data_ddl(vn) return vn from vanna.flask import VannaFlaskApp vn = create_vana() app = VannaFlaskApp(vn,chart=False) init_vn(vn) @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) def generate_sql_2(): """ Generate SQL from a question --- parameters: - name: user in: query - name: question in: query type: string required: true responses: 200: schema: type: object properties: type: type: string default: sql id: type: string text: type: string """ question = flask.request.args.get("question") if question is None: return jsonify({"type": "error", "error": "No question provided"}) #id = self.cache.generate_id(question=question) data = vn.generate_sql_2(question=question) return jsonify(data) if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False)