237 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			237 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy
 | ||
| import logging
 | ||
| from functools import wraps
 | ||
| import util.utils
 | ||
| from logging_config import LOGGING_CONFIG
 | ||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
 | ||
| from decouple import config
 | ||
| import flask
 | ||
| from util import load_ddl_doc
 | ||
| 
 | ||
| from flask import Flask, Response, jsonify, request, send_from_directory
 | ||
| 
 | ||
| 
 | ||
| logger = logging.getLogger(__name__)
 | ||
| def connect_database(vn):
 | ||
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite')
 | ||
|     if db_type == 'sqlite':
 | ||
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
 | ||
|     elif db_type == 'mysql':
 | ||
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
 | ||
|                             port=int(config('MYSQL_DATABASE_PORT', default=3306)),
 | ||
|                             user=config('MYSQL_DATABASE_USER', default=''),
 | ||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''),
 | ||
|                             database=config('MYSQL_DATABASE_DBNAME', default=''))
 | ||
|     elif db_type == 'dameng':
 | ||
|         # 待补充
 | ||
|         vn.connect_to_dameng(
 | ||
|             host=config('DAMENG_DATABASE_HOST', default=''),
 | ||
|             port=config('DAMENG_DATABASE_PORT', default=3306),
 | ||
|             user=config('DAMENG_DATABASE_USER', default=''),
 | ||
|             password=config('DAMENG_DATABASE_PASSWORD', default=''),
 | ||
|         )
 | ||
|     else:
 | ||
|         pass
 | ||
| 
 | ||
| 
 | ||
| def load_train_data_ddl(vn: CustomVanna):
 | ||
|     vn.train()
 | ||
| 
 | ||
| 
 | ||
| def create_vana():
 | ||
|     logger.info("----------------create vana ---------")
 | ||
|     q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
 | ||
|         url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
 | ||
|     vn = CustomVanna(
 | ||
|         vector_store_config={"client": q_client},
 | ||
|         llm_config={
 | ||
|             "api_key": config('CHAT_MODEL_API_KEY', default=''),
 | ||
|             "api_base": config('CHAT_MODEL_BASE_URL', default=''),
 | ||
|             "model": config('CHAT_MODEL_NAME', default=''),
 | ||
|         },
 | ||
|     )
 | ||
| 
 | ||
|     return vn
 | ||
| 
 | ||
| 
 | ||
| def init_vn(vn):
 | ||
|     logger.info("--------------init vana-----connect to datasouce db----")
 | ||
|     connect_database(vn)
 | ||
|     load_ddl_doc.add_ddl(vn)
 | ||
|     load_ddl_doc.add_documentation(vn)
 | ||
|     if config('IS_FIRST_LOAD', default=False, cast=bool):
 | ||
|         load_train_data_ddl(vn)
 | ||
|     return vn
 | ||
| 
 | ||
| 
 | ||
| from vanna.flask import VannaFlaskApp
 | ||
| vn = create_vana()
 | ||
| app = VannaFlaskApp(vn,chart=False)
 | ||
| app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60))
 | ||
| init_vn(vn)
 | ||
| cache = app.cache
 | ||
| @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
 | ||
| def generate_sql_2():
 | ||
|     """
 | ||
|     Generate SQL from a question
 | ||
|     ---
 | ||
|     parameters:
 | ||
|       - name: user
 | ||
|         in: query
 | ||
|       - name: question
 | ||
|         in: query
 | ||
|         type: string
 | ||
|         required: true
 | ||
|     responses:
 | ||
|       200:
 | ||
|         schema:
 | ||
|           type: object
 | ||
|           properties:
 | ||
|             type:
 | ||
|               type: string
 | ||
|               default: sql
 | ||
|             id:
 | ||
|               type: string
 | ||
|             text:
 | ||
|               type: string
 | ||
|     """
 | ||
|     logger.info("Start to generate sql in main")
 | ||
|     question = flask.request.args.get("question")
 | ||
| 
 | ||
|     if question is None:
 | ||
|         return jsonify({"type": "error", "error": "No question provided"})
 | ||
|     try:
 | ||
|         id = cache.generate_id(question=question)
 | ||
|         user_id = request.args.get("user_id")
 | ||
|         logger.info(f"Generate sql for {question}")
 | ||
|         data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id)
 | ||
|         logger.info("Generate sql result is {0}".format(data))
 | ||
|         data['id'] = id
 | ||
|         sql = data["resp"]["sql"]
 | ||
|         logger.info("generate sql is : "+ sql)
 | ||
|         cache.set(id=id, field="question", value=question)
 | ||
|         cache.set(id=id, field="sql", value=sql)
 | ||
|         data["type"]="success"
 | ||
|         return jsonify(data)
 | ||
|     except Exception as e:
 | ||
|         logger.error(f"generate sql failed:{e}")
 | ||
|         return jsonify({"type": "error", "error": str(e)})
 | ||
| 
 | ||
| def session_save(func):
 | ||
|     @wraps(func)
 | ||
|     def wrapper(*args, **kwargs):
 | ||
|         id=request.args.get("id")
 | ||
|         user_id = request.args.get("user_id")
 | ||
|         logger.info(f"   id: {id},user_id: {user_id}")
 | ||
|         result = func(*args, **kwargs)
 | ||
| 
 | ||
|         datas=[]
 | ||
|         session_len = int(config("SESSION_LENGTH", default=2))
 | ||
|         if cache.exists(id=user_id, field="data"):
 | ||
|             datas = copy.deepcopy(cache.get(id=user_id, field="data"))
 | ||
|         data = {
 | ||
|             "id": id,
 | ||
|             "question":cache.get(id=id, field="question"),
 | ||
|             "sql":cache.get(id=id, field="sql")
 | ||
|         }
 | ||
|         datas.append(data)
 | ||
|         logger.info("datas is {0}".format(datas))
 | ||
|         if len(datas) > session_len and session_len > 0:
 | ||
|             datas=datas[-session_len:]
 | ||
|         # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
 | ||
|         cache.delete(id=id, field="question")
 | ||
|         cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
 | ||
|         logger.info(f" user data {cache.get(user_id, field='data')}")
 | ||
|         return result
 | ||
| 
 | ||
|     return wrapper
 | ||
| 
 | ||
| 
 | ||
| 
 | ||
| @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
 | ||
| @session_save
 | ||
| @app.requires_cache(["sql"])
 | ||
| def run_sql_2(id: str, sql: str):
 | ||
|     """
 | ||
|     Run SQL
 | ||
|     ---
 | ||
|     parameters:
 | ||
|       - name: user_id
 | ||
|         in: query
 | ||
|         required: true
 | ||
|       - name: id
 | ||
|         in: query|body
 | ||
|         type: string
 | ||
|         required: true
 | ||
|       - name: page_size
 | ||
|         in: query
 | ||
|       -name: page_num
 | ||
|         in: query
 | ||
|     responses:
 | ||
|       200:
 | ||
|         schema:
 | ||
|           type: object
 | ||
|           properties:
 | ||
|             type:
 | ||
|               type: string
 | ||
|               default: df
 | ||
|             id:
 | ||
|               type: string
 | ||
|             df:
 | ||
|               type: object
 | ||
|             should_generate_chart:
 | ||
|               type: boolean
 | ||
|     """
 | ||
|     logger.info("Start to run sql in main")
 | ||
|     try:
 | ||
|         if not vn.run_sql_is_set:
 | ||
|             return jsonify(
 | ||
|                 {
 | ||
|                     "type": "error",
 | ||
|                     "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
 | ||
|                 }
 | ||
|             )
 | ||
| 
 | ||
|         # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
 | ||
|         # df_count = vn.run_sql(count_sql)
 | ||
|         # print(df_count,"is type",type(df_count))
 | ||
|         # total_count = df_count.to_dict(orient="records")[0]["total_count"]
 | ||
|         # logger.info("Total count is {0}".format(total_count))
 | ||
|         df = vn.run_sql(sql=sql)
 | ||
|         result = df.to_dict(orient='records')
 | ||
|         logger.info("df ---------------{0}   {1}".format(result,type(result)))
 | ||
|         return jsonify(
 | ||
|             {
 | ||
|                 "type": "success",
 | ||
|                 "id": id,
 | ||
|                 "df": result,
 | ||
|             }
 | ||
|         )
 | ||
| 
 | ||
|     except Exception as e:
 | ||
|         logger.error(f"run sql failed:{e}")
 | ||
|         return jsonify({"type": "sql_error", "error": str(e)})
 | ||
| 
 | ||
| 
 | ||
| @app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"])
 | ||
| def verify_user():
 | ||
|     try:
 | ||
|         id = request.args.get("user_id")
 | ||
|         users = config('ALLOWED_USERS', default='')
 | ||
|         users = users.split(',')
 | ||
|         logger.info(f"allowed users {users}")
 | ||
|         for user in users:
 | ||
|             if user == id:
 | ||
|                 return jsonify({"type": "success", "verify": True})
 | ||
|             else:
 | ||
|                 return jsonify({"type": "success", "verify": False})
 | ||
|     except Exception as e:
 | ||
|         logger.error(f"verify user failed:{e}")
 | ||
|         return jsonify({"type": "error", "error": str(e)})
 | ||
| 
 | ||
| 
 | ||
| 
 | ||
| if __name__ == '__main__':
 | ||
| 
 | ||
|     app.run(host='0.0.0.0', port=8084, debug=False)
 | 
