feat:增加预制问题-保存用户问题为后期加入rag

This commit is contained in:
雷雨
2025-11-03 17:32:56 +08:00
parent 24c2a1b6da
commit ed676ee633
4 changed files with 147 additions and 15 deletions

0
db_util/__init__.py Normal file
View File

68
db_util/db_main.py Normal file
View File

@@ -0,0 +1,68 @@
from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text
from sqlalchemy.orm import declarative_base, sessionmaker
# 申明基类对象
Base = declarative_base()
from decouple import config
DB_PATH = config('DB_PATH', default='E://pyptoject//sqlbot_agent//main.sqlite3')
class QuestionFeedBack(Base):
# 定义表名
__tablename__ = 'db_question_feedback'
# 定义字段
id = Column(String(255), primary_key=True)
create_time = Column(DateTime, nullable=False, )
question = Column(String(255), nullable=False)
sql = Column(String(500), nullable=False)
user_id = Column(String(100), nullable=False, default='1')
# 用户意见反馈
user_comment = Column(Text, nullable=True)
# 用户点赞,点踩
user_praise = Column(Boolean, nullable=False, default=False)
# 该数据是否被认为梳理过
is_process = Column(Boolean, nullable=False, default=False)
class QuestionFeedBack(Base):
# 定义表名
__tablename__ = 'db_question_feedback'
__table_args__ = {'extend_existing': True}
# 定义字段
id = Column(String(255), primary_key=True)
create_time = Column(DateTime, nullable=False, )
question = Column(String(255), nullable=False)
sql = Column(String(500), nullable=False)
user_id = Column(String(100), nullable=False, default='1')
# 用户意见反馈
user_comment = Column(Text, nullable=True)
# 用户点赞,点踩
user_praise = Column(Boolean, nullable=False, default=False)
# 该数据是否被认为梳理过
is_process = Column(Boolean, nullable=False, default=False)
class PredefinedQuestion(Base):
# 定义表名,预制问题表
__tablename__ = 'db_predefined_question'
__table_args__ = {'extend_existing': True}
# 定义字段
id = Column(String(255), primary_key=True)
question = Column(String(255), nullable=False)
user_id = Column(String(100), nullable=False, default='1')
# 该数据是否被认为梳理过
enable = Column(Boolean, nullable=False, default=True)
def to_dict(self):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
class SqliteSqlalchemy(object):
def __init__(self):
# 创建sqlite连接引擎
engine = create_engine(f'sqlite:///{DB_PATH}', echo=True)
# 创建表
Base.metadata.create_all(engine, checkfirst=True)
# 创建sqlite的session连接对象
self.session = sessionmaker(bind=engine)()

View File

@@ -4,13 +4,15 @@ from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request
logger = logging.getLogger(__name__)
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
@@ -47,8 +49,8 @@ def create_vana():
"api_key": config('CHAT_MODEL_API_KEY', default=''),
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
"model": config('CHAT_MODEL_NAME', default=''),
'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=5000),
'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000),
},
)
@@ -66,11 +68,14 @@ def init_vn(vn):
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn,chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60))
app = VannaFlaskApp(vn, chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60))
init_vn(vn)
cache = app.cache
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
def generate_sql_2():
"""
@@ -108,36 +113,38 @@ def generate_sql_2():
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : "+ sql)
logger.info("generate sql is : " + sql)
cache.set(id=id, field="question", value=question)
cache.set(id=id, field="sql", value=sql)
data["type"]="success"
data["type"] = "success"
save_save_question_async(id, user_id, question, sql)
return jsonify(data)
except Exception as e:
logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)})
def session_save(func):
@wraps(func)
def wrapper(*args, **kwargs):
id=request.args.get("id")
id = request.args.get("id")
user_id = request.args.get("user_id")
logger.info(f" id: {id},user_id: {user_id}")
result = func(*args, **kwargs)
datas=[]
datas = []
session_len = int(config("SESSION_LENGTH", default=2))
if cache.exists(id=user_id, field="data"):
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
data = {
"id": id,
"question":cache.get(id=id, field="question"),
"sql":cache.get(id=id, field="sql")
"question": cache.get(id=id, field="question"),
"sql": cache.get(id=id, field="sql")
}
datas.append(data)
logger.info("datas is {0}".format(datas))
if len(datas) > session_len and session_len > 0:
datas=datas[-session_len:]
datas = datas[-session_len:]
# 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
cache.delete(id=id, field="question")
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
@@ -147,7 +154,6 @@ def session_save(func):
return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
@session_save
@app.requires_cache(["sql"])
@@ -200,7 +206,7 @@ def run_sql_2(id: str, sql: str):
# logger.info("Total count is {0}".format(total_count))
df = vn.run_sql_2(sql=sql)
result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result,type(result)))
logger.info("df ---------------{0} {1}".format(result, type(result)))
return jsonify(
{
"type": "success",
@@ -231,7 +237,15 @@ def verify_user():
return jsonify({"type": "error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
def query_present_question():
try:
data=query_predefined_question_list()
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)

View File

@@ -0,0 +1,50 @@
from db_util.db_main import QuestionFeedBack, SqliteSqlalchemy, PredefinedQuestion
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
from concurrent.futures import ThreadPoolExecutor
pool = ThreadPoolExecutor(max_workers=10)
def save_question(id, user_id, question, sql):
logger.info(f"开始保存用户用户==>:{question}")
user_id = user_id if user_id else '1'
qq = QuestionFeedBack(id=id, user_id=user_id, question=question, sql=sql
, create_time=datetime.now())
session = SqliteSqlalchemy().session
try:
session.add(qq)
session.commit()
except Exception as e:
logger.error(f"保存用户问题与sql 错误,{e}")
session.rollback()
finally:
session.close()
def save_save_question_async(id, user_id, question, sql):
pool.submit(save_question, id, user_id, question, sql)
def update_user_feedBack(id, user_comment, user_praise: bool):
session = SqliteSqlalchemy().session
try:
session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update(
{QuestionFeedBack.user_comment: user_comment, QuestionFeedBack.user_praise: user_praise})
session.commit()
except Exception as e:
logger.error(f"更新用户反馈错误,{e}")
session.rollback()
finally:
session.close()
'''
查询所有的预制问题
'''
def query_predefined_question_list() -> list:
session = SqliteSqlalchemy().session
all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all()
all = [a.to_dict() for a in all] if all else []
session.close()
return all