From ed676ee63373b9e8a538b0d9f4176211c8db81f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=B7=E9=9B=A8?= Date: Mon, 3 Nov 2025 17:32:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=A2=9E=E5=8A=A0=E9=A2=84=E5=88=B6?= =?UTF-8?q?=E9=97=AE=E9=A2=98-=E4=BF=9D=E5=AD=98=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E9=97=AE=E9=A2=98=E4=B8=BA=E5=90=8E=E6=9C=9F=E5=8A=A0=E5=85=A5?= =?UTF-8?q?rag?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- db_util/__init__.py | 0 db_util/db_main.py | 68 ++++++++++++++++++++++++++++ main_service.py | 44 ++++++++++++------ service/question_feedback_service.py | 50 ++++++++++++++++++++ 4 files changed, 147 insertions(+), 15 deletions(-) create mode 100644 db_util/__init__.py create mode 100644 db_util/db_main.py create mode 100644 service/question_feedback_service.py diff --git a/db_util/__init__.py b/db_util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/db_util/db_main.py b/db_util/db_main.py new file mode 100644 index 0000000..80d87f6 --- /dev/null +++ b/db_util/db_main.py @@ -0,0 +1,68 @@ +from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text +from sqlalchemy.orm import declarative_base, sessionmaker + +# 申明基类对象 +Base = declarative_base() +from decouple import config + +DB_PATH = config('DB_PATH', default='E://pyptoject//sqlbot_agent//main.sqlite3') + + +class QuestionFeedBack(Base): + # 定义表名 + __tablename__ = 'db_question_feedback' + # 定义字段 + id = Column(String(255), primary_key=True) + create_time = Column(DateTime, nullable=False, ) + question = Column(String(255), nullable=False) + sql = Column(String(500), nullable=False) + user_id = Column(String(100), nullable=False, default='1') + # 用户意见反馈 + user_comment = Column(Text, nullable=True) + # 用户点赞,点踩 + user_praise = Column(Boolean, nullable=False, default=False) + # 该数据是否被认为梳理过 + is_process = Column(Boolean, nullable=False, default=False) + + +class QuestionFeedBack(Base): + # 定义表名 + __tablename__ = 'db_question_feedback' + __table_args__ = {'extend_existing': True} + # 定义字段 + id = Column(String(255), primary_key=True) + create_time = Column(DateTime, nullable=False, ) + question = Column(String(255), nullable=False) + sql = Column(String(500), nullable=False) + user_id = Column(String(100), nullable=False, default='1') + # 用户意见反馈 + user_comment = Column(Text, nullable=True) + # 用户点赞,点踩 + user_praise = Column(Boolean, nullable=False, default=False) + # 该数据是否被认为梳理过 + is_process = Column(Boolean, nullable=False, default=False) + + +class PredefinedQuestion(Base): + # 定义表名,预制问题表 + __tablename__ = 'db_predefined_question' + __table_args__ = {'extend_existing': True} + # 定义字段 + id = Column(String(255), primary_key=True) + question = Column(String(255), nullable=False) + user_id = Column(String(100), nullable=False, default='1') + # 该数据是否被认为梳理过 + enable = Column(Boolean, nullable=False, default=True) + + def to_dict(self): + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + +class SqliteSqlalchemy(object): + def __init__(self): + # 创建sqlite连接引擎 + engine = create_engine(f'sqlite:///{DB_PATH}', echo=True) + # 创建表 + Base.metadata.create_all(engine, checkfirst=True) + # 创建sqlite的session连接对象 + self.session = sessionmaker(bind=engine)() diff --git a/main_service.py b/main_service.py index bc9fe82..24d7fc6 100644 --- a/main_service.py +++ b/main_service.py @@ -4,13 +4,15 @@ from functools import wraps import util.utils from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper +from service.question_feedback_service import save_save_question_async,query_predefined_question_list from decouple import config import flask from util import load_ddl_doc from flask import Flask, Response, jsonify, request - logger = logging.getLogger(__name__) + + def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': @@ -47,8 +49,8 @@ def create_vana(): "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), - 'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float), - 'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=5000), + 'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float), + 'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000), }, ) @@ -66,11 +68,14 @@ def init_vn(vn): from vanna.flask import VannaFlaskApp + vn = create_vana() -app = VannaFlaskApp(vn,chart=False) -app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60)) +app = VannaFlaskApp(vn, chart=False) +app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60)) init_vn(vn) cache = app.cache + + @app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"]) def generate_sql_2(): """ @@ -108,36 +113,38 @@ def generate_sql_2(): logger.info("Generate sql result is {0}".format(data)) data['id'] = id sql = data["resp"]["sql"] - logger.info("generate sql is : "+ sql) + logger.info("generate sql is : " + sql) cache.set(id=id, field="question", value=question) cache.set(id=id, field="sql", value=sql) - data["type"]="success" + data["type"] = "success" + save_save_question_async(id, user_id, question, sql) return jsonify(data) except Exception as e: logger.error(f"generate sql failed:{e}") return jsonify({"type": "error", "error": str(e)}) + def session_save(func): @wraps(func) def wrapper(*args, **kwargs): - id=request.args.get("id") + id = request.args.get("id") user_id = request.args.get("user_id") logger.info(f" id: {id},user_id: {user_id}") result = func(*args, **kwargs) - datas=[] + datas = [] session_len = int(config("SESSION_LENGTH", default=2)) if cache.exists(id=user_id, field="data"): datas = copy.deepcopy(cache.get(id=user_id, field="data")) data = { "id": id, - "question":cache.get(id=id, field="question"), - "sql":cache.get(id=id, field="sql") + "question": cache.get(id=id, field="question"), + "sql": cache.get(id=id, field="sql") } datas.append(data) logger.info("datas is {0}".format(datas)) if len(datas) > session_len and session_len > 0: - datas=datas[-session_len:] + datas = datas[-session_len:] # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文 cache.delete(id=id, field="question") cache.set(id=user_id, field="data", value=copy.deepcopy(datas)) @@ -147,7 +154,6 @@ def session_save(func): return wrapper - @app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"]) @session_save @app.requires_cache(["sql"]) @@ -200,7 +206,7 @@ def run_sql_2(id: str, sql: str): # logger.info("Total count is {0}".format(total_count)) df = vn.run_sql_2(sql=sql) result = df.to_dict(orient='records') - logger.info("df ---------------{0} {1}".format(result,type(result))) + logger.info("df ---------------{0} {1}".format(result, type(result))) return jsonify( { "type": "success", @@ -231,7 +237,15 @@ def verify_user(): return jsonify({"type": "error", "error": str(e)}) +@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"]) +def query_present_question(): + try: + data=query_predefined_question_list() + return jsonify({"type": "success", "data": data}) + except Exception as e: + logger.error(f"查询预制问题失败 failed:{e}") + return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'}) + if __name__ == '__main__': - app.run(host='0.0.0.0', port=8084, debug=False) diff --git a/service/question_feedback_service.py b/service/question_feedback_service.py new file mode 100644 index 0000000..45d8d1f --- /dev/null +++ b/service/question_feedback_service.py @@ -0,0 +1,50 @@ +from db_util.db_main import QuestionFeedBack, SqliteSqlalchemy, PredefinedQuestion +from datetime import datetime +import logging + +logger = logging.getLogger(__name__) +from concurrent.futures import ThreadPoolExecutor + +pool = ThreadPoolExecutor(max_workers=10) + +def save_question(id, user_id, question, sql): + logger.info(f"开始保存用户用户==>:{question}") + user_id = user_id if user_id else '1' + qq = QuestionFeedBack(id=id, user_id=user_id, question=question, sql=sql + , create_time=datetime.now()) + session = SqliteSqlalchemy().session + try: + session.add(qq) + session.commit() + except Exception as e: + logger.error(f"保存用户问题与sql 错误,{e}") + session.rollback() + finally: + session.close() + + +def save_save_question_async(id, user_id, question, sql): + pool.submit(save_question, id, user_id, question, sql) + + +def update_user_feedBack(id, user_comment, user_praise: bool): + session = SqliteSqlalchemy().session + try: + session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update( + {QuestionFeedBack.user_comment: user_comment, QuestionFeedBack.user_praise: user_praise}) + session.commit() + except Exception as e: + logger.error(f"更新用户反馈错误,{e}") + session.rollback() + finally: + session.close() + +''' +查询所有的预制问题 +''' +def query_predefined_question_list() -> list: + session = SqliteSqlalchemy().session + all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all() + all = [a.to_dict() for a in all] if all else [] + session.close() + return all