| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  | import copy | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  | import logging | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  | from functools import wraps | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  | import util.utils | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  | from logging_config import LOGGING_CONFIG | 
					
						
							| 
									
										
										
										
											2025-10-10 16:39:59 +08:00
										 |  |  |  | from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | from decouple import config | 
					
						
							|  |  |  |  | import flask | 
					
						
							| 
									
										
										
										
											2025-09-23 16:29:56 +08:00
										 |  |  |  | from util import load_ddl_doc | 
					
						
							| 
									
										
										
										
											2025-10-15 10:28:34 +08:00
										 |  |  |  | from flask import Flask, Response, jsonify, request | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | logger = logging.getLogger(__name__) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | def connect_database(vn): | 
					
						
							|  |  |  |  |     db_type = config('DATA_SOURCE_TYPE', default='sqlite') | 
					
						
							|  |  |  |  |     if db_type == 'sqlite': | 
					
						
							|  |  |  |  |         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) | 
					
						
							|  |  |  |  |     elif db_type == 'mysql': | 
					
						
							|  |  |  |  |         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), | 
					
						
							| 
									
										
										
										
											2025-09-23 16:29:56 +08:00
										 |  |  |  |                             port=int(config('MYSQL_DATABASE_PORT', default=3306)), | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |                             user=config('MYSQL_DATABASE_USER', default=''), | 
					
						
							|  |  |  |  |                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | 
					
						
							| 
									
										
										
										
											2025-09-24 14:17:07 +08:00
										 |  |  |  |                             database=config('MYSQL_DATABASE_DBNAME', default='')) | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |     elif db_type == 'dameng': | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |         # 待补充 | 
					
						
							| 
									
										
										
										
											2025-09-24 14:17:07 +08:00
										 |  |  |  |         vn.connect_to_dameng( | 
					
						
							|  |  |  |  |             host=config('DAMENG_DATABASE_HOST', default=''), | 
					
						
							|  |  |  |  |             port=config('DAMENG_DATABASE_PORT', default=3306), | 
					
						
							|  |  |  |  |             user=config('DAMENG_DATABASE_USER', default=''), | 
					
						
							|  |  |  |  |             password=config('DAMENG_DATABASE_PASSWORD', default=''), | 
					
						
							|  |  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |     else: | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def load_train_data_ddl(vn: CustomVanna): | 
					
						
							| 
									
										
										
										
											2025-09-23 16:29:56 +08:00
										 |  |  |  |     vn.train() | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def create_vana(): | 
					
						
							| 
									
										
										
										
											2025-09-25 10:32:01 +08:00
										 |  |  |  |     logger.info("----------------create vana ---------") | 
					
						
							|  |  |  |  |     q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient( | 
					
						
							|  |  |  |  |         url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333)) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |     vn = CustomVanna( | 
					
						
							| 
									
										
										
										
											2025-09-25 10:32:01 +08:00
										 |  |  |  |         vector_store_config={"client": q_client}, | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |         llm_config={ | 
					
						
							|  |  |  |  |             "api_key": config('CHAT_MODEL_API_KEY', default=''), | 
					
						
							|  |  |  |  |             "api_base": config('CHAT_MODEL_BASE_URL', default=''), | 
					
						
							|  |  |  |  |             "model": config('CHAT_MODEL_NAME', default=''), | 
					
						
							| 
									
										
										
										
											2025-10-15 14:42:48 +08:00
										 |  |  |  |             'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float), | 
					
						
							| 
									
										
										
										
											2025-10-23 16:23:19 +08:00
										 |  |  |  |             'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=5000), | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |         }, | 
					
						
							|  |  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-09-28 16:44:58 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |     return vn | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def init_vn(vn): | 
					
						
							| 
									
										
										
										
											2025-09-25 10:32:01 +08:00
										 |  |  |  |     logger.info("--------------init vana-----connect to datasouce db----") | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |     connect_database(vn) | 
					
						
							| 
									
										
										
										
											2025-10-24 17:32:38 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |     if config('IS_FIRST_LOAD', default=False, cast=bool): | 
					
						
							| 
									
										
										
										
											2025-10-24 17:32:38 +08:00
										 |  |  |  |         load_ddl_doc.add_ddl(vn) | 
					
						
							|  |  |  |  |         load_ddl_doc.add_documentation(vn) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |         load_train_data_ddl(vn) | 
					
						
							|  |  |  |  |     return vn | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from vanna.flask import VannaFlaskApp | 
					
						
							|  |  |  |  | vn = create_vana() | 
					
						
							|  |  |  |  | app = VannaFlaskApp(vn,chart=False) | 
					
						
							| 
									
										
										
										
											2025-10-14 10:31:00 +08:00
										 |  |  |  | app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60)) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | init_vn(vn) | 
					
						
							| 
									
										
										
										
											2025-09-23 20:06:32 +08:00
										 |  |  |  | cache = app.cache | 
					
						
							| 
									
										
										
										
											2025-09-29 11:22:56 +08:00
										 |  |  |  | @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | def generate_sql_2(): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     Generate SQL from a question | 
					
						
							|  |  |  |  |     --- | 
					
						
							|  |  |  |  |     parameters: | 
					
						
							|  |  |  |  |       - name: user | 
					
						
							|  |  |  |  |         in: query | 
					
						
							|  |  |  |  |       - name: question | 
					
						
							|  |  |  |  |         in: query | 
					
						
							|  |  |  |  |         type: string | 
					
						
							|  |  |  |  |         required: true | 
					
						
							|  |  |  |  |     responses: | 
					
						
							|  |  |  |  |       200: | 
					
						
							|  |  |  |  |         schema: | 
					
						
							|  |  |  |  |           type: object | 
					
						
							|  |  |  |  |           properties: | 
					
						
							|  |  |  |  |             type: | 
					
						
							|  |  |  |  |               type: string | 
					
						
							|  |  |  |  |               default: sql | 
					
						
							|  |  |  |  |             id: | 
					
						
							|  |  |  |  |               type: string | 
					
						
							|  |  |  |  |             text: | 
					
						
							|  |  |  |  |               type: string | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |     logger.info("Start to generate sql in main") | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  |     question = flask.request.args.get("question") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if question is None: | 
					
						
							|  |  |  |  |         return jsonify({"type": "error", "error": "No question provided"}) | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |     try: | 
					
						
							|  |  |  |  |         id = cache.generate_id(question=question) | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  |         user_id = request.args.get("user_id") | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  |         logger.info(f"Generate sql for {question}") | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  |         data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id) | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  |         logger.info("Generate sql result is {0}".format(data)) | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         data['id'] = id | 
					
						
							|  |  |  |  |         sql = data["resp"]["sql"] | 
					
						
							| 
									
										
										
										
											2025-09-25 10:32:01 +08:00
										 |  |  |  |         logger.info("generate sql is : "+ sql) | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         cache.set(id=id, field="question", value=question) | 
					
						
							|  |  |  |  |         cache.set(id=id, field="sql", value=sql) | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  |         data["type"]="success" | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         return jsonify(data) | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-09-28 16:44:58 +08:00
										 |  |  |  |         logger.error(f"generate sql failed:{e}") | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         return jsonify({"type": "error", "error": str(e)}) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  | def session_save(func): | 
					
						
							|  |  |  |  |     @wraps(func) | 
					
						
							|  |  |  |  |     def wrapper(*args, **kwargs): | 
					
						
							|  |  |  |  |         id=request.args.get("id") | 
					
						
							|  |  |  |  |         user_id = request.args.get("user_id") | 
					
						
							|  |  |  |  |         logger.info(f"   id: {id},user_id: {user_id}") | 
					
						
							|  |  |  |  |         result = func(*args, **kwargs) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         datas=[] | 
					
						
							|  |  |  |  |         session_len = int(config("SESSION_LENGTH", default=2)) | 
					
						
							|  |  |  |  |         if cache.exists(id=user_id, field="data"): | 
					
						
							|  |  |  |  |             datas = copy.deepcopy(cache.get(id=user_id, field="data")) | 
					
						
							|  |  |  |  |         data = { | 
					
						
							|  |  |  |  |             "id": id, | 
					
						
							|  |  |  |  |             "question":cache.get(id=id, field="question"), | 
					
						
							|  |  |  |  |             "sql":cache.get(id=id, field="sql") | 
					
						
							|  |  |  |  |         } | 
					
						
							|  |  |  |  |         datas.append(data) | 
					
						
							|  |  |  |  |         logger.info("datas is {0}".format(datas)) | 
					
						
							|  |  |  |  |         if len(datas) > session_len and session_len > 0: | 
					
						
							|  |  |  |  |             datas=datas[-session_len:] | 
					
						
							|  |  |  |  |         # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文 | 
					
						
							|  |  |  |  |         cache.delete(id=id, field="question") | 
					
						
							|  |  |  |  |         cache.set(id=user_id, field="data", value=copy.deepcopy(datas)) | 
					
						
							|  |  |  |  |         logger.info(f" user data {cache.get(user_id, field='data')}") | 
					
						
							|  |  |  |  |         return result | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     return wrapper | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-29 11:22:56 +08:00
										 |  |  |  | @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  | @session_save | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  | @app.requires_cache(["sql"]) | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  | def run_sql_2(id: str, sql: str): | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     Run SQL | 
					
						
							|  |  |  |  |     --- | 
					
						
							|  |  |  |  |     parameters: | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  |       - name: user_id | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         in: query | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  |         required: true | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |       - name: id | 
					
						
							|  |  |  |  |         in: query|body | 
					
						
							|  |  |  |  |         type: string | 
					
						
							|  |  |  |  |         required: true | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  |       - name: page_size | 
					
						
							|  |  |  |  |         in: query | 
					
						
							|  |  |  |  |       -name: page_num | 
					
						
							|  |  |  |  |         in: query | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |     responses: | 
					
						
							|  |  |  |  |       200: | 
					
						
							|  |  |  |  |         schema: | 
					
						
							|  |  |  |  |           type: object | 
					
						
							|  |  |  |  |           properties: | 
					
						
							|  |  |  |  |             type: | 
					
						
							|  |  |  |  |               type: string | 
					
						
							|  |  |  |  |               default: df | 
					
						
							|  |  |  |  |             id: | 
					
						
							|  |  |  |  |               type: string | 
					
						
							|  |  |  |  |             df: | 
					
						
							|  |  |  |  |               type: object | 
					
						
							|  |  |  |  |             should_generate_chart: | 
					
						
							|  |  |  |  |               type: boolean | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     logger.info("Start to run sql in main") | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         if not vn.run_sql_is_set: | 
					
						
							|  |  |  |  |             return jsonify( | 
					
						
							|  |  |  |  |                 { | 
					
						
							|  |  |  |  |                     "type": "error", | 
					
						
							|  |  |  |  |                     "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", | 
					
						
							|  |  |  |  |                 } | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 16:39:59 +08:00
										 |  |  |  |         # count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery" | 
					
						
							|  |  |  |  |         # df_count = vn.run_sql(count_sql) | 
					
						
							| 
									
										
										
										
											2025-10-13 18:18:58 +08:00
										 |  |  |  |         # print(df_count,"is type",type(df_count)) | 
					
						
							|  |  |  |  |         # total_count = df_count.to_dict(orient="records")[0]["total_count"] | 
					
						
							|  |  |  |  |         # logger.info("Total count is {0}".format(total_count)) | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         df = vn.run_sql(sql=sql) | 
					
						
							| 
									
										
										
										
											2025-09-26 17:35:23 +08:00
										 |  |  |  |         result = df.to_dict(orient='records') | 
					
						
							|  |  |  |  |         logger.info("df ---------------{0}   {1}".format(result,type(result))) | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         return jsonify( | 
					
						
							|  |  |  |  |             { | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  |                 "type": "success", | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |                 "id": id, | 
					
						
							| 
									
										
										
										
											2025-09-25 16:49:25 +08:00
										 |  |  |  |                 "df": result, | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |             } | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-09-28 16:44:58 +08:00
										 |  |  |  |         logger.error(f"run sql failed:{e}") | 
					
						
							| 
									
										
										
										
											2025-09-24 14:39:42 +08:00
										 |  |  |  |         return jsonify({"type": "sql_error", "error": str(e)}) | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 16:39:59 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-14 10:30:17 +08:00
										 |  |  |  | @app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"]) | 
					
						
							|  |  |  |  | def verify_user(): | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         id = request.args.get("user_id") | 
					
						
							| 
									
										
										
										
											2025-10-14 15:47:11 +08:00
										 |  |  |  |         users = config('ALLOWED_USERS', default='') | 
					
						
							|  |  |  |  |         users = users.split(',') | 
					
						
							|  |  |  |  |         logger.info(f"allowed users {users}") | 
					
						
							| 
									
										
										
										
											2025-10-14 10:30:17 +08:00
										 |  |  |  |         for user in users: | 
					
						
							|  |  |  |  |             if user == id: | 
					
						
							|  |  |  |  |                 return jsonify({"type": "success", "verify": True}) | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 return jsonify({"type": "success", "verify": False}) | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         logger.error(f"verify user failed:{e}") | 
					
						
							|  |  |  |  |         return jsonify({"type": "error", "error": str(e)}) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-23 14:49:00 +08:00
										 |  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     app.run(host='0.0.0.0', port=8084, debug=False) |