from email.policy import default import logging import util.utils from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient from decouple import config import flask from util import load_ddl_doc from flask import Flask, Response, jsonify, request, send_from_directory logger = logging.getLogger(__name__) def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) elif db_type == 'mysql': vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), port=int(config('MYSQL_DATABASE_PORT', default=3306)), user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), database=config('MYSQL_DATABASE_DBNAME', default='')) elif db_type == 'dameng': # 待补充 vn.connect_to_dameng( host=config('DAMENG_DATABASE_HOST', default=''), port=config('DAMENG_DATABASE_PORT', default=3306), user=config('DAMENG_DATABASE_USER', default=''), password=config('DAMENG_DATABASE_PASSWORD', default=''), ) else: pass def load_train_data_ddl(vn: CustomVanna): vn.train() def create_vana(): logger.info("----------------create vana ---------") q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient( url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333)) vn = CustomVanna( vector_store_config={"client": q_client}, llm_config={ "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), }, ) return vn def init_vn(vn): logger.info("--------------init vana-----connect to datasouce db----") connect_database(vn) load_ddl_doc.add_ddl(vn) load_ddl_doc.add_documentation(vn) if config('IS_FIRST_LOAD', default=False, cast=bool): load_train_data_ddl(vn) return vn from vanna.flask import VannaFlaskApp vn = create_vana() app = VannaFlaskApp(vn,chart=False) init_vn(vn) cache = app.cache @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) def generate_sql_2(): """ Generate SQL from a question --- parameters: - name: user in: query - name: question in: query type: string required: true responses: 200: schema: type: object properties: type: type: string default: sql id: type: string text: type: string """ logger.info("Start to generate sql in main") question = flask.request.args.get("question") if question is None: return jsonify({"type": "error", "error": "No question provided"}) try: id = cache.generate_id(question=question) logger.info(f"Generate sql for {question}") data = vn.generate_sql_2(question=question) logger.info("Generate sql result is {0}".format(data)) data['id'] = id sql = data["resp"]["sql"] logger.info("generate sql is : "+ sql) cache.set(id=id, field="question", value=question) cache.set(id=id, field="sql", value=sql) data["type"]="success" return jsonify(data) except Exception as e: logger.error(f"generate sql failed:{e}") return jsonify({"type": "error", "error": str(e)}) @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) @app.requires_cache(["sql"]) def run_sql_2(id: str, sql: str): """ Run SQL --- parameters: - name: user in: query - name: id in: query|body type: string required: true responses: 200: schema: type: object properties: type: type: string default: df id: type: string df: type: object should_generate_chart: type: boolean """ logger.info("Start to run sql in main") try: if not vn.run_sql_is_set: return jsonify( { "type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", } ) df = vn.run_sql(sql=sql) logger.info("") app.cache.set(id=id, field="df", value=df) result = df.to_dict(orient='records') logger.info("df ---------------{0} {1}".format(result,type(result))) # result = util.utils.deal_result(data=result) return jsonify( { "type": "success", "id": id, "df": result, } ) except Exception as e: logger.error(f"run sql failed:{e}") return jsonify({"type": "sql_error", "error": str(e)}) if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False)