179 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			179 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from email.policy import default
 | |
| import logging
 | |
| 
 | |
| import util.utils
 | |
| from logging_config import LOGGING_CONFIG
 | |
| from service.cus_vanna_srevice import CustomVanna, QdrantClient
 | |
| from decouple import config
 | |
| import flask
 | |
| from util import load_ddl_doc
 | |
| 
 | |
| from flask import Flask, Response, jsonify, request, send_from_directory
 | |
| 
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| def connect_database(vn):
 | |
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite')
 | |
|     if db_type == 'sqlite':
 | |
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
 | |
|     elif db_type == 'mysql':
 | |
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
 | |
|                             port=int(config('MYSQL_DATABASE_PORT', default=3306)),
 | |
|                             user=config('MYSQL_DATABASE_USER', default=''),
 | |
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''),
 | |
|                             database=config('MYSQL_DATABASE_DBNAME', default=''))
 | |
|     elif db_type == 'dameng':
 | |
|         # 待补充
 | |
|         vn.connect_to_dameng(
 | |
|             host=config('DAMENG_DATABASE_HOST', default=''),
 | |
|             port=config('DAMENG_DATABASE_PORT', default=3306),
 | |
|             user=config('DAMENG_DATABASE_USER', default=''),
 | |
|             password=config('DAMENG_DATABASE_PASSWORD', default=''),
 | |
|         )
 | |
|     else:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| def load_train_data_ddl(vn: CustomVanna):
 | |
|     vn.train()
 | |
| 
 | |
| 
 | |
| def create_vana():
 | |
|     logger.info("----------------create vana ---------")
 | |
|     q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
 | |
|         url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
 | |
|     vn = CustomVanna(
 | |
|         vector_store_config={"client": q_client},
 | |
|         llm_config={
 | |
|             "api_key": config('CHAT_MODEL_API_KEY', default=''),
 | |
|             "api_base": config('CHAT_MODEL_BASE_URL', default=''),
 | |
|             "model": config('CHAT_MODEL_NAME', default=''),
 | |
|         },
 | |
|     )
 | |
|     return vn
 | |
| 
 | |
| 
 | |
| def init_vn(vn):
 | |
|     logger.info("--------------init vana-----connect to datasouce db----")
 | |
|     connect_database(vn)
 | |
|     load_ddl_doc.add_ddl(vn)
 | |
|     load_ddl_doc.add_documentation(vn)
 | |
|     if config('IS_FIRST_LOAD', default=False, cast=bool):
 | |
|         load_train_data_ddl(vn)
 | |
|     return vn
 | |
| 
 | |
| 
 | |
| from vanna.flask import VannaFlaskApp
 | |
| vn = create_vana()
 | |
| app = VannaFlaskApp(vn,chart=False)
 | |
| init_vn(vn)
 | |
| cache = app.cache
 | |
| @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"])
 | |
| def generate_sql_2():
 | |
|     """
 | |
|     Generate SQL from a question
 | |
|     ---
 | |
|     parameters:
 | |
|       - name: user
 | |
|         in: query
 | |
|       - name: question
 | |
|         in: query
 | |
|         type: string
 | |
|         required: true
 | |
|     responses:
 | |
|       200:
 | |
|         schema:
 | |
|           type: object
 | |
|           properties:
 | |
|             type:
 | |
|               type: string
 | |
|               default: sql
 | |
|             id:
 | |
|               type: string
 | |
|             text:
 | |
|               type: string
 | |
|     """
 | |
|     logger.info("Start to generate sql in main")
 | |
|     question = flask.request.args.get("question")
 | |
| 
 | |
|     if question is None:
 | |
|         return jsonify({"type": "error", "error": "No question provided"})
 | |
|     try:
 | |
|         id = cache.generate_id(question=question)
 | |
|         logger.info(f"Generate sql for {question}")
 | |
|         data = vn.generate_sql_2(question=question)
 | |
|         logger.info("Generate sql result is {0}".format(data))
 | |
|         data['id'] = id
 | |
|         sql = data["resp"]["sql"]
 | |
|         logger.info("generate sql is : "+ sql)
 | |
|         cache.set(id=id, field="question", value=question)
 | |
|         cache.set(id=id, field="sql", value=sql)
 | |
|         data["type"]="success"
 | |
|         return jsonify(data)
 | |
|     except Exception as e:
 | |
|         logger.error("generate sql failed:{e}")
 | |
|         return jsonify({"type": "error", "error": str(e)})
 | |
| 
 | |
| 
 | |
| 
 | |
| @app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
 | |
| @app.requires_cache(["sql"])
 | |
| def run_sql_2(id: str, sql: str):
 | |
|     """
 | |
|     Run SQL
 | |
|     ---
 | |
|     parameters:
 | |
|       - name: user
 | |
|         in: query
 | |
|       - name: id
 | |
|         in: query|body
 | |
|         type: string
 | |
|         required: true
 | |
|     responses:
 | |
|       200:
 | |
|         schema:
 | |
|           type: object
 | |
|           properties:
 | |
|             type:
 | |
|               type: string
 | |
|               default: df
 | |
|             id:
 | |
|               type: string
 | |
|             df:
 | |
|               type: object
 | |
|             should_generate_chart:
 | |
|               type: boolean
 | |
|     """
 | |
|     logger.info("Start to run sql in main")
 | |
|     try:
 | |
|         if not vn.run_sql_is_set:
 | |
|             return jsonify(
 | |
|                 {
 | |
|                     "type": "error",
 | |
|                     "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
 | |
|                 }
 | |
|             )
 | |
| 
 | |
|         df = vn.run_sql(sql=sql)
 | |
|         logger.info("")
 | |
|         app.cache.set(id=id, field="df", value=df)
 | |
|         result = df.to_dict(orient='records')
 | |
|         logger.info("df ---------------{0}   {1}".format(result,type(result)))
 | |
|         # result = util.utils.deal_result(data=result)
 | |
| 
 | |
|         return jsonify(
 | |
|             {
 | |
|                 "type": "success",
 | |
|                 "id": id,
 | |
|                 "df": result,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     except Exception as e:
 | |
|         logger.error("run sql failed:{e}")
 | |
|         return jsonify({"type": "sql_error", "error": str(e)})
 | |
| 
 | |
| if __name__ == '__main__':
 | |
| 
 | |
|     app.run(host='0.0.0.0', port=8084, debug=False)
 | 
