import copy import logging from functools import wraps import util.utils from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper from decouple import config import flask from util import load_ddl_doc from flask import Flask, Response, jsonify, request logger = logging.getLogger(__name__) def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) elif db_type == 'mysql': vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), port=int(config('MYSQL_DATABASE_PORT', default=3306)), user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), database=config('MYSQL_DATABASE_DBNAME', default='')) elif db_type == 'dameng': # 待补充 vn.connect_to_dameng( host=config('DAMENG_DATABASE_HOST', default=''), port=config('DAMENG_DATABASE_PORT', default=3306), user=config('DAMENG_DATABASE_USER', default=''), password=config('DAMENG_DATABASE_PASSWORD', default=''), ) else: pass def load_train_data_ddl(vn: CustomVanna): vn.train() def create_vana(): logger.info("----------------create vana ---------") q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient( url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333)) vn = CustomVanna( vector_store_config={"client": q_client}, llm_config={ "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), 'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float), 'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=5000), }, ) return vn def init_vn(vn): logger.info("--------------init vana-----connect to datasouce db----") connect_database(vn) load_ddl_doc.add_ddl(vn) load_ddl_doc.add_documentation(vn) if config('IS_FIRST_LOAD', default=False, cast=bool): load_train_data_ddl(vn) return vn from vanna.flask import VannaFlaskApp vn = create_vana() app = VannaFlaskApp(vn,chart=False) app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60)) init_vn(vn) cache = app.cache @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) def generate_sql_2(): """ Generate SQL from a question --- parameters: - name: user in: query - name: question in: query type: string required: true responses: 200: schema: type: object properties: type: type: string default: sql id: type: string text: type: string """ logger.info("Start to generate sql in main") question = flask.request.args.get("question") if question is None: return jsonify({"type": "error", "error": "No question provided"}) try: id = cache.generate_id(question=question) user_id = request.args.get("user_id") logger.info(f"Generate sql for {question}") data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id) logger.info("Generate sql result is {0}".format(data)) data['id'] = id sql = data["resp"]["sql"] logger.info("generate sql is : "+ sql) cache.set(id=id, field="question", value=question) cache.set(id=id, field="sql", value=sql) data["type"]="success" return jsonify(data) except Exception as e: logger.error(f"generate sql failed:{e}") return jsonify({"type": "error", "error": str(e)}) def session_save(func): @wraps(func) def wrapper(*args, **kwargs): id=request.args.get("id") user_id = request.args.get("user_id") logger.info(f" id: {id},user_id: {user_id}") result = func(*args, **kwargs) datas=[] session_len = int(config("SESSION_LENGTH", default=2)) if cache.exists(id=user_id, field="data"): datas = copy.deepcopy(cache.get(id=user_id, field="data")) data = { "id": id, "question":cache.get(id=id, field="question"), "sql":cache.get(id=id, field="sql") } datas.append(data) logger.info("datas is {0}".format(datas)) if len(datas) > session_len and session_len > 0: datas=datas[-session_len:] # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文 cache.delete(id=id, field="question") cache.set(id=user_id, field="data", value=copy.deepcopy(datas)) logger.info(f" user data {cache.get(user_id, field='data')}") return result return wrapper @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) @session_save @app.requires_cache(["sql"]) def run_sql_2(id: str, sql: str): """ Run SQL --- parameters: - name: user_id in: query required: true - name: id in: query|body type: string required: true - name: page_size in: query -name: page_num in: query responses: 200: schema: type: object properties: type: type: string default: df id: type: string df: type: object should_generate_chart: type: boolean """ logger.info("Start to run sql in main") try: if not vn.run_sql_is_set: return jsonify( { "type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", } ) # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery" # df_count = vn.run_sql(count_sql) # print(df_count,"is type",type(df_count)) # total_count = df_count.to_dict(orient="records")[0]["total_count"] # logger.info("Total count is {0}".format(total_count)) df = vn.run_sql(sql=sql) result = df.to_dict(orient='records') logger.info("df ---------------{0} {1}".format(result,type(result))) return jsonify( { "type": "success", "id": id, "df": result, } ) except Exception as e: logger.error(f"run sql failed:{e}") return jsonify({"type": "sql_error", "error": str(e)}) @app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"]) def verify_user(): try: id = request.args.get("user_id") users = config('ALLOWED_USERS', default='') users = users.split(',') logger.info(f"allowed users {users}") for user in users: if user == id: return jsonify({"type": "success", "verify": True}) else: return jsonify({"type": "success", "verify": False}) except Exception as e: logger.error(f"verify user failed:{e}") return jsonify({"type": "error", "error": str(e)}) if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False)