import copy import json import logging import uuid import time from functools import wraps import util.utils from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper from service.question_feedback_service import save_save_question_async, query_predefined_question_list, \ update_user_feedBack, query_feedBack_question_list from service.conversation_service import (save_conversation, update_conversation, get_qa_by_id, get_latest_question,get_all_conversations_by_user) from decouple import config import flask from util import load_ddl_doc from flask import Flask, Response, jsonify, request from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent from graph_chat.gen_data_report_agent import result_report_agent, DateReportAgentState import traceback logger = logging.getLogger(__name__) logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG) def generate_timestamp_id(): """生成基于时间戳的ID""" # 获取当前时间戳(秒级) timestamp = int(time.time() * 1000) return f"Q{timestamp}" def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) elif db_type == 'mysql': vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), port=int(config('MYSQL_DATABASE_PORT', default=3306)), user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), database=config('MYSQL_DATABASE_DBNAME', default='')) elif db_type == 'dameng': # 待补充 vn.connect_to_dameng( host=config('DAMENG_DATABASE_HOST', default=''), port=config('DAMENG_DATABASE_PORT', default=3306), user=config('DAMENG_DATABASE_USER', default=''), password=config('DAMENG_DATABASE_PASSWORD', default=''), ) else: pass def load_train_data_ddl(vn: CustomVanna): vn.train() def create_vana(): logger.info("----------------create vana ---------") q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient( url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333)) vn = CustomVanna( vector_store_config={"client": q_client}, llm_config={ "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), 'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float), 'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000), }, ) return vn def init_vn(vn): logger.info("--------------init vana-----connect to datasouce db----") connect_database(vn) if config('IS_FIRST_LOAD', default=False, cast=bool): load_ddl_doc.add_ddl(vn) load_ddl_doc.add_documentation(vn) load_train_data_ddl(vn) return vn from vanna.flask import VannaFlaskApp vn = create_vana() app = VannaFlaskApp(vn, chart=False) app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60)) init_vn(vn) cache = app.cache @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) def generate_sql_2(): """ Generate SQL from a question --- parameters: - name: user_id in: query - name: question in: query - name: question_id responses: 200: schema: type: object properties: type: type: string default: sql id: type: string text: type: string """ logger.info("Start to generate sql in main") question = flask.request.args.get("question") if question is None: return jsonify({"type": "error", "error": "No question provided"}) user_id = request.args.get("user_id") cvs_id = request.args.get("cvs_id") need_context = bool(request.args.get("need_context")) if user_id is None or cvs_id is None: return jsonify({"type": "error", "error": "No user_id or cvs_id provided"}) id = generate_timestamp_id() logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}") save_conversation(id, user_id, cvs_id, question) try: logger.info(f"Generate sql for {question}") data = vn.generate_sql_2(user_id, cvs_id, question, id, need_context) logger.info("Generate sql result is {0}".format(data)) data['id'] = id sql = data["resp"]["sql"] logger.info("generate sql is : " + sql) update_conversation(id, sql) save_save_question_async(id, user_id, question, sql) data["type"] = "success" return jsonify(data) except Exception as e: logger.error(f"generate sql failed:{e}") return jsonify({"type": "error", "error": str(e)}) # def requires_cache_2(required_keys): # def decorator(f): # @wraps(f) # def decorated(*args, **kwargs): # id = request.args.get("id") # user_id = request.args.get("user_id") # if user_id is None: # user_id = request.json.get("user_id") # if user_id is None: # return jsonify({"type": "error", "error": "No user_id provided"}) # if id is None: # id = request.json.get("id") # if id is None: # return jsonify({"type": "error", "error": "No id provided"}) # all_v = cache.items() # logger.info(f"all values {all_v}") # logger.info(f"user {user_id} id {id}") # qa_list = cache.get(id=user_id, field="qa_list") # if qa_list is None: # return jsonify({"type": "error", "error": f"No qa_list found"}) # logger.info(f"qa_list {qa_list}") # q_a = list(filter(lambda x: x["id"] == id, qa_list)) # logger.info(f"q_a {q_a}") # for key in required_keys: # if q_a[0][key] is None: # return jsonify({"type": "error", "error": f"No {key} found for id:{id}"}) # values = {key:q_a[0][key] for key in required_keys} # values["id"] = id # logger.info("cache values {0}".format(values)) # # return f(*args, **values, **kwargs) # # return decorated # # return decorator # def session_save(func): # @wraps(func) # def wrapper(*args, **kwargs): # id = request.args.get("id") # user_id = request.args.get("user_id") # logger.info(f" id: {id},user_id: {user_id}") # result = func(*args, **kwargs) # # datas = [] # session_len = int(config("SESSION_LENGTH", default=2)) # if cache.exists(id=user_id, field="qa_list"): # datas = copy.deepcopy(cache.get(id=user_id, field="qa_list")) # logger.info("datas is {0}".format(datas)) # if len(datas) > session_len and session_len > 0: # logger.info(f"开始裁剪-------------------------------------") # datas=datas[-session_len:] # # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文 # # cache.delete(id=id, field="question") # print("datas---------------------{0}".format(datas)) # cache.set(id=user_id, field="qa_list", value=copy.deepcopy(datas)) # logger.info(f" user data {cache.get(user_id, field='qa_list')}") # return result # # return wrapper @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) # @session_save # @requires_cache_2(required_keys=["sql"]) def run_sql_2(): """ Run SQL --- parameters: - name: user_id in: query required: true - name: id in: query|body type: string required: true responses: 200: schema: type: object properties: type: type: string default: df id: type: string df: type: object should_generate_chart: type: boolean """ logger.info("Start to run sql in main") try: id = request.args.get("id") qa = get_qa_by_id(id) sql = qa["sql"] logger.info(f"sql is {sql}") if not vn.run_sql_is_set: return jsonify( { "type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", } ) df = vn.run_sql_2(sql=sql) result = df.to_dict(orient='records') logger.info("df ---------------{0} {1}".format(result, type(result))) return jsonify( { "type": "success", "id": id, "df": result, } ) except Exception as e: logger.error(f"run sql failed:{e}") return jsonify({"type": "sql_error", "error": str(e)}) @app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"]) def verify_user(): try: id = request.args.get("user_id") users = config('ALLOWED_USERS', default='') users = users.split(',') logger.info(f"allowed users {users}") for user in users: if user == id: return jsonify({"type": "success", "verify": True}) else: return jsonify({"type": "success", "verify": False}) except Exception as e: logger.error(f"verify user failed:{e}") return jsonify({"type": "error", "error": str(e)}) @app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"]) def query_present_question(): try: data = query_predefined_question_list() return jsonify({"type": "success", "data": data}) except Exception as e: logger.error(f"查询预制问题失败 failed:{e}") return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'}) @app.flask_app.route("/yj_sqlbot/api/v0/query_feedback_question", methods=["POST"]) def query_feedback_question(): id_list = request.json.get("id_list", []) try: data = query_feedBack_question_list(id_list) return jsonify({"type": "success", "data": data}) except Exception as e: logger.error(f"查询用户反馈问题失败 failed:{e}") return jsonify({"type": "error", "error": f'查询用户反馈问题失败:{str(e)}'}) @app.flask_app.route("/yj_sqlbot/api/v0/question_feed_back", methods=["PUT"]) def update_question_feed_back(): id = request.json.get("id") user_feedback = request.json.get("user_feedback") if not id or not user_feedback: return jsonify({"type": "error", "error": "id 或者用户反馈为空"}) try: update_user_feedBack(id, '', user_feedback) return jsonify({"type": "success"}) except Exception as e: logger.error(f"查询预制问题失败 failed:{e}") return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'}) @app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"]) def gen_graph_question(): try: user_id = request.args.get("user_id") logger.info(f"Into gen_graph_question => user {user_id}") cvs_id = request.args.get("cvs_id") config = {"configurable": {"thread_id": str(uuid.uuid4())}} question = flask.request.args.get("question") logger.info(f"Start to get question context") question_context = get_latest_question(cvs_id, user_id, limit_count=2) history = [] i = 0 if not question_context: question_context = [] for q in question_context: is_latest=False if i==0: is_latest=True history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest}) i +=1 logger.info(f"question context is {history}") initial_state: SqlAgentState = { "user_id": user_id, "user_question": question, "history": history, "sql_retry_count": 0, "chart_retry_count": 0 } id = str(uuid.uuid4()) logger.info(f"Start to save conversation info (id,user_id,cvs_id,question)") save_conversation(id, user_id, cvs_id, question) logger.info(f"Enter the graph node=>gen_sql, gen_chart") result = sql_chart_agent.invoke(initial_state, config=config) logger.info(f"End====>gen_sql, gen_chart") new_question = result.get('rewritten_user_question', question) # save_conversation(id, user_id, cvs_id, new_question) if new_question: logger.info(f"new_question is {new_question}") update_conversation(id=id, meta={'new_question': new_question}) logger.info("gen_sql_result => {0}".format(result.get("gen_sql_result", {}))) data = { 'id': id, 'sql': result.get("gen_sql_result", {}), 'chart': result.get("gen_chart_result", {}), 'gen_sql_error': result.get("gen_sql_error", None), 'gen_chart_error': result.get("gen_chart_error", None), } sql = data.get('sql', {}).get('sql', '') state = data.get('sql', {}).get('success', False) if not state: logger.info("SQL generation failed.save error info to table") error_msg = data.get('sql', {}).get('message', '') update_conversation(id=id, answer={'type_error':'error','error':error_msg}) chart_cfg = data.get('chart', {}) update_conversation(id=id, sql=sql, chart_cfg=chart_cfg) save_save_question_async(id, user_id, new_question, sql) return jsonify(data) except Exception as e: traceback.print_exc() logger.error(f"查询预制问题失败 failed:{e}") return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'}) @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_3", methods=["GET"]) def run_sql_3(): id = request.args.get("id") qa = get_qa_by_id(id) if not qa: return jsonify({"type": "error", "error": f'获取对话失败:{id}'}) logger.info(f"Start to run_sql_3 => {qa}") sql = qa["sql"] if not sql: error_info = qa["answer"] if error_info: return jsonify({"type": "error", "error": json.loads(error_info).get("error","")}) return jsonify({"type": "error", "error": f'sql 生成失败,请联系管理员'}) question = qa["question"] logger.info(f"in main sql {sql} question {question}") logger.info("Start to run sql in main") try: user_id = request.args.get("user_id") cvs_id = request.args.get("cvs_id") if not vn.run_sql_is_set: return jsonify( { "type": "error", "error": "Please connect to a database using " "vn.connect_to_... in order to run SQL queries.", } ) question_context = get_latest_question(cvs_id, user_id, limit_count=2) history = [] i = 0 for q in question_context: is_latest=False if i==0: is_latest=True history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest}) i +=1 initial_state: DateReportAgentState = { "id": id, "user_id": user_id, "sql": sql, "question": question, "retry_count": 0, "history": history, } config = {"configurable": {"thread_id": str(uuid.uuid4())}} rr = result_report_agent.invoke(initial_state, config) data = rr.get('data', {}) summary = rr.get('summary', '') run_sql_error = rr.get('run_sql_error', '') logger.debug(f"data type is {type(data)} data is {data} summary is {summary}") if data and not run_sql_error: update_conversation(id=id, answer={'data': data,'summary':summary}) elif run_sql_error: update_conversation(id=id, answer={'type_error':'error','error':"sql执行失败"}) logger.info(f"run_Sql finish run_sql_error => {run_sql_error}") return jsonify( { 'data': data, 'summary': summary, 'run_sql_error': "sql执行失败" if run_sql_error else None } ) except Exception as e: logger.error(f"run sql failed:{e}") return jsonify({"type": "sql_error", "error": str(e)}) @app.flask_app.route("/yj_sqlbot/api/v0/get_history", methods=["GET"]) def get_history(): logger.info(f"in main get history {request.args}") try: user_id = request.args.get("user_id") page_num = int(request.args.get("page_num")) page_size = int(request.args.get("page_size")) if page_num < 1: page_num = 1 if page_size < 1 or page_size > 100: page_size = 10 offset = (page_num-1)*page_size result = get_all_conversations_by_user(user_id, offset, page_size) data = result.get('data', {}) total_count = result.get('total_count', 0) total_pages = (total_count + page_size - 1) // page_size logger.info(f"offset {offset} pagesize {page_size} total_count {total_count} total_pages {total_pages}") info = { "total_pages": total_pages, "total_count": total_count, } history = [] ids = [d.id for d in data] feedback_questions = query_feedBack_question_list(ids) map_list = {q["id"]:q["user_praise"] for q in feedback_questions} logger.info(f"map_list {map_list}") for item in data: cvs = {} cvs['id'] = item.id cvs['question'] = item.question # answer_info = item.answer # info_type = answer_info.get('type','') # if info_type and info_type=="error": # cvs[info_type] = answer_info.get("error","信息获取失败") cvs['answer'] = json.loads(item.answer) if item.answer else None cvs['chart_cfg'] = json.loads(item.chart_cfg) if item.chart_cfg else None cvs['user_praise'] = map_list[item.id] history.append(cvs) info["history"] = history return jsonify(info) except Exception as e: logger.error(f"get history failed:{e}") return jsonify({"type": "get history error", "error": str(e)}) if __name__ == '__main__': app.run(host='0.0.0.0', port=8084, debug=False)