feat:增加预制问题-保存用户问题为后期加入rag
This commit is contained in:
0
db_util/__init__.py
Normal file
0
db_util/__init__.py
Normal file
68
db_util/db_main.py
Normal file
68
db_util/db_main.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
# 申明基类对象
|
||||
Base = declarative_base()
|
||||
from decouple import config
|
||||
|
||||
DB_PATH = config('DB_PATH', default='E://pyptoject//sqlbot_agent//main.sqlite3')
|
||||
|
||||
|
||||
class QuestionFeedBack(Base):
|
||||
# 定义表名
|
||||
__tablename__ = 'db_question_feedback'
|
||||
# 定义字段
|
||||
id = Column(String(255), primary_key=True)
|
||||
create_time = Column(DateTime, nullable=False, )
|
||||
question = Column(String(255), nullable=False)
|
||||
sql = Column(String(500), nullable=False)
|
||||
user_id = Column(String(100), nullable=False, default='1')
|
||||
# 用户意见反馈
|
||||
user_comment = Column(Text, nullable=True)
|
||||
# 用户点赞,点踩
|
||||
user_praise = Column(Boolean, nullable=False, default=False)
|
||||
# 该数据是否被认为梳理过
|
||||
is_process = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class QuestionFeedBack(Base):
|
||||
# 定义表名
|
||||
__tablename__ = 'db_question_feedback'
|
||||
__table_args__ = {'extend_existing': True}
|
||||
# 定义字段
|
||||
id = Column(String(255), primary_key=True)
|
||||
create_time = Column(DateTime, nullable=False, )
|
||||
question = Column(String(255), nullable=False)
|
||||
sql = Column(String(500), nullable=False)
|
||||
user_id = Column(String(100), nullable=False, default='1')
|
||||
# 用户意见反馈
|
||||
user_comment = Column(Text, nullable=True)
|
||||
# 用户点赞,点踩
|
||||
user_praise = Column(Boolean, nullable=False, default=False)
|
||||
# 该数据是否被认为梳理过
|
||||
is_process = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class PredefinedQuestion(Base):
|
||||
# 定义表名,预制问题表
|
||||
__tablename__ = 'db_predefined_question'
|
||||
__table_args__ = {'extend_existing': True}
|
||||
# 定义字段
|
||||
id = Column(String(255), primary_key=True)
|
||||
question = Column(String(255), nullable=False)
|
||||
user_id = Column(String(100), nullable=False, default='1')
|
||||
# 该数据是否被认为梳理过
|
||||
enable = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
def to_dict(self):
|
||||
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
||||
|
||||
|
||||
class SqliteSqlalchemy(object):
|
||||
def __init__(self):
|
||||
# 创建sqlite连接引擎
|
||||
engine = create_engine(f'sqlite:///{DB_PATH}', echo=True)
|
||||
# 创建表
|
||||
Base.metadata.create_all(engine, checkfirst=True)
|
||||
# 创建sqlite的session连接对象
|
||||
self.session = sessionmaker(bind=engine)()
|
||||
@@ -4,13 +4,15 @@ from functools import wraps
|
||||
import util.utils
|
||||
from logging_config import LOGGING_CONFIG
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||||
from decouple import config
|
||||
import flask
|
||||
from util import load_ddl_doc
|
||||
from flask import Flask, Response, jsonify, request
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def connect_database(vn):
|
||||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||||
if db_type == 'sqlite':
|
||||
@@ -47,8 +49,8 @@ def create_vana():
|
||||
"api_key": config('CHAT_MODEL_API_KEY', default=''),
|
||||
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
|
||||
"model": config('CHAT_MODEL_NAME', default=''),
|
||||
'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
|
||||
'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=5000),
|
||||
'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
|
||||
'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -66,11 +68,14 @@ def init_vn(vn):
|
||||
|
||||
|
||||
from vanna.flask import VannaFlaskApp
|
||||
|
||||
vn = create_vana()
|
||||
app = VannaFlaskApp(vn,chart=False)
|
||||
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60))
|
||||
app = VannaFlaskApp(vn, chart=False)
|
||||
app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60))
|
||||
init_vn(vn)
|
||||
cache = app.cache
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
|
||||
def generate_sql_2():
|
||||
"""
|
||||
@@ -108,36 +113,38 @@ def generate_sql_2():
|
||||
logger.info("Generate sql result is {0}".format(data))
|
||||
data['id'] = id
|
||||
sql = data["resp"]["sql"]
|
||||
logger.info("generate sql is : "+ sql)
|
||||
logger.info("generate sql is : " + sql)
|
||||
cache.set(id=id, field="question", value=question)
|
||||
cache.set(id=id, field="sql", value=sql)
|
||||
data["type"]="success"
|
||||
data["type"] = "success"
|
||||
save_save_question_async(id, user_id, question, sql)
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
logger.error(f"generate sql failed:{e}")
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
|
||||
def session_save(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
id=request.args.get("id")
|
||||
id = request.args.get("id")
|
||||
user_id = request.args.get("user_id")
|
||||
logger.info(f" id: {id},user_id: {user_id}")
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
datas=[]
|
||||
datas = []
|
||||
session_len = int(config("SESSION_LENGTH", default=2))
|
||||
if cache.exists(id=user_id, field="data"):
|
||||
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
|
||||
data = {
|
||||
"id": id,
|
||||
"question":cache.get(id=id, field="question"),
|
||||
"sql":cache.get(id=id, field="sql")
|
||||
"question": cache.get(id=id, field="question"),
|
||||
"sql": cache.get(id=id, field="sql")
|
||||
}
|
||||
datas.append(data)
|
||||
logger.info("datas is {0}".format(datas))
|
||||
if len(datas) > session_len and session_len > 0:
|
||||
datas=datas[-session_len:]
|
||||
datas = datas[-session_len:]
|
||||
# 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
|
||||
cache.delete(id=id, field="question")
|
||||
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
|
||||
@@ -147,7 +154,6 @@ def session_save(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
||||
@session_save
|
||||
@app.requires_cache(["sql"])
|
||||
@@ -200,7 +206,7 @@ def run_sql_2(id: str, sql: str):
|
||||
# logger.info("Total count is {0}".format(total_count))
|
||||
df = vn.run_sql_2(sql=sql)
|
||||
result = df.to_dict(orient='records')
|
||||
logger.info("df ---------------{0} {1}".format(result,type(result)))
|
||||
logger.info("df ---------------{0} {1}".format(result, type(result)))
|
||||
return jsonify(
|
||||
{
|
||||
"type": "success",
|
||||
@@ -231,7 +237,15 @@ def verify_user():
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
|
||||
def query_present_question():
|
||||
try:
|
||||
data=query_predefined_question_list()
|
||||
return jsonify({"type": "success", "data": data})
|
||||
except Exception as e:
|
||||
logger.error(f"查询预制问题失败 failed:{e}")
|
||||
return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
app.run(host='0.0.0.0', port=8084, debug=False)
|
||||
|
||||
50
service/question_feedback_service.py
Normal file
50
service/question_feedback_service.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from db_util.db_main import QuestionFeedBack, SqliteSqlalchemy, PredefinedQuestion
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
pool = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def save_question(id, user_id, question, sql):
|
||||
logger.info(f"开始保存用户用户==>:{question}")
|
||||
user_id = user_id if user_id else '1'
|
||||
qq = QuestionFeedBack(id=id, user_id=user_id, question=question, sql=sql
|
||||
, create_time=datetime.now())
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
session.add(qq)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存用户问题与sql 错误,{e}")
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def save_save_question_async(id, user_id, question, sql):
|
||||
pool.submit(save_question, id, user_id, question, sql)
|
||||
|
||||
|
||||
def update_user_feedBack(id, user_comment, user_praise: bool):
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update(
|
||||
{QuestionFeedBack.user_comment: user_comment, QuestionFeedBack.user_praise: user_praise})
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户反馈错误,{e}")
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
'''
|
||||
查询所有的预制问题
|
||||
'''
|
||||
def query_predefined_question_list() -> list:
|
||||
session = SqliteSqlalchemy().session
|
||||
all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all()
|
||||
all = [a.to_dict() for a in all] if all else []
|
||||
session.close()
|
||||
return all
|
||||
Reference in New Issue
Block a user