from email.policy import default from service.cus_vanna_srevice import CustomVanna, QdrantClient from decouple import config import flask from util import load_ddl_doc from flask import Flask, Response, jsonify, request, send_from_directory def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) elif db_type == 'mysql': vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), port=int(config('MYSQL_DATABASE_PORT', default=3306)), user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), dbname=config('MYSQL_DATABASE_DBNAME', default='')) elif db_type == 'postgresql': # 待补充 pass else: pass def load_train_data_ddl(vn: CustomVanna): vn.train() def create_vana(): print("----------------create---------") vn = CustomVanna( vector_store_config={"client": QdrantClient(":memory:")}, llm_config={ "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), }, ) return vn def init_vn(vn): print("--------------init vn-----connect----") connect_database(vn) load_ddl_doc.add_ddl(vn) load_ddl_doc.add_documentation(vn) if config('IS_FIRST_LOAD', default=False, cast=bool): load_train_data_ddl(vn) return vn from vanna.flask import VannaFlaskApp vn = create_vana() app = VannaFlaskApp(vn,chart=False) init_vn(vn) @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) def generate_sql_2(): """ Generate SQL from a question --- parameters: - name: user in: query - name: question in: query type: string required: true responses: 200: schema: type: object properties: type: type: string default: sql id: type: string text: type: string """ question = flask.request.args.get("question") if question is None: return jsonify({"type": "error", "error": "No question provided"}) #id = self.cache.generate_id(question=question) data = vn.generate_sql_2(question=question) return jsonify(data) if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False)