167 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from email.policy import default
 | |
| import dmPython
 | |
| import logging
 | |
| 
 | |
| from logging_config import LOGGING_CONFIG
 | |
| from service.cus_vanna_srevice import CustomVanna, QdrantClient
 | |
| from decouple import config
 | |
| import flask
 | |
| from util import load_ddl_doc
 | |
| 
 | |
| from flask import Flask, Response, jsonify, request, send_from_directory
 | |
| 
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| def connect_database(vn):
 | |
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite')
 | |
|     if db_type == 'sqlite':
 | |
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
 | |
|     elif db_type == 'mysql':
 | |
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
 | |
|                             port=int(config('MYSQL_DATABASE_PORT', default=3306)),
 | |
|                             user=config('MYSQL_DATABASE_USER', default=''),
 | |
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''),
 | |
|                             dbname=config('MYSQL_DATABASE_DBNAME', default=''))
 | |
|     elif db_type == 'dameng':
 | |
|         vn.connect_to_dameng( )
 | |
|     elif db_type == 'postgresql':
 | |
|         # 待补充
 | |
|         pass
 | |
|     else:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| def load_train_data_ddl(vn: CustomVanna):
 | |
|     vn.train()
 | |
| 
 | |
| 
 | |
| def create_vana():
 | |
|     print("----------------create---------")
 | |
|     vn = CustomVanna(
 | |
|         vector_store_config={"client": QdrantClient(":memory:")},
 | |
|         llm_config={
 | |
|             "api_key": config('CHAT_MODEL_API_KEY', default=''),
 | |
|             "api_base": config('CHAT_MODEL_BASE_URL', default=''),
 | |
|             "model": config('CHAT_MODEL_NAME', default=''),
 | |
|         },
 | |
|     )
 | |
|     return vn
 | |
| 
 | |
| 
 | |
| def init_vn(vn):
 | |
|     print("--------------init vn-----connect----")
 | |
|     connect_database(vn)
 | |
|     load_ddl_doc.add_ddl(vn)
 | |
|     load_ddl_doc.add_documentation(vn)
 | |
|     if config('IS_FIRST_LOAD', default=False, cast=bool):
 | |
|         load_train_data_ddl(vn)
 | |
|     return vn
 | |
| 
 | |
| 
 | |
| from vanna.flask import VannaFlaskApp
 | |
| vn = create_vana()
 | |
| app = VannaFlaskApp(vn,chart=False)
 | |
| init_vn(vn)
 | |
| cache = app.cache
 | |
| @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"])
 | |
| def generate_sql_2():
 | |
|     """
 | |
|     Generate SQL from a question
 | |
|     ---
 | |
|     parameters:
 | |
|       - name: user
 | |
|         in: query
 | |
|       - name: question
 | |
|         in: query
 | |
|         type: string
 | |
|         required: true
 | |
|     responses:
 | |
|       200:
 | |
|         schema:
 | |
|           type: object
 | |
|           properties:
 | |
|             type:
 | |
|               type: string
 | |
|               default: sql
 | |
|             id:
 | |
|               type: string
 | |
|             text:
 | |
|               type: string
 | |
|     """
 | |
|     logger.info("Start to generate sql in main")
 | |
|     question = flask.request.args.get("question")
 | |
| 
 | |
|     if question is None:
 | |
|         return jsonify({"type": "error", "error": "No question provided"})
 | |
|     try:
 | |
|         id = cache.generate_id(question=question)
 | |
|         data = vn.generate_sql_2(question=question)
 | |
|         data['id'] = id
 | |
|         sql = data["resp"]["sql"]
 | |
|         print("sql:", sql)
 | |
|         cache.set(id=id, field="question", value=question)
 | |
|         cache.set(id=id, field="sql", value=sql)
 | |
|         print("data---------------------------", data)
 | |
|         return jsonify(data)
 | |
|     except Exception as e:
 | |
|         return jsonify({"type": "error", "error": str(e)})
 | |
| 
 | |
| 
 | |
| @app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
 | |
| @app.requires_cache(["sql"])
 | |
| def run_sql_2(id: str, sql: str):
 | |
|     """
 | |
|     Run SQL
 | |
|     ---
 | |
|     parameters:
 | |
|       - name: user
 | |
|         in: query
 | |
|       - name: id
 | |
|         in: query|body
 | |
|         type: string
 | |
|         required: true
 | |
|     responses:
 | |
|       200:
 | |
|         schema:
 | |
|           type: object
 | |
|           properties:
 | |
|             type:
 | |
|               type: string
 | |
|               default: df
 | |
|             id:
 | |
|               type: string
 | |
|             df:
 | |
|               type: object
 | |
|             should_generate_chart:
 | |
|               type: boolean
 | |
|     """
 | |
|     logger.info("Start to run sql in main")
 | |
|     try:
 | |
|         if not vn.run_sql_is_set:
 | |
|             return jsonify(
 | |
|                 {
 | |
|                     "type": "error",
 | |
|                     "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
 | |
|                 }
 | |
|             )
 | |
| 
 | |
|         df = vn.run_sql(sql=sql)
 | |
|         logger.info("")
 | |
|         app.cache.set(id=id, field="df", value=df)
 | |
|         x = df.head(10).to_dict(orient='records')
 | |
|         logger.info("df ---------------{0}   {1}".format(x,type(x)))
 | |
|         return jsonify(
 | |
|             {
 | |
|                 "type": "df",
 | |
|                 "id": id,
 | |
|                 "df": df.head(10).to_dict(orient='records'),
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     except Exception as e:
 | |
|         return jsonify({"type": "sql_error", "error": str(e)})
 | |
| 
 | |
| if __name__ == '__main__':
 | |
| 
 | |
|     app.run(host='0.0.0.0', port=8084, debug=False)
 | 
