feat:增加预制问题-保存用户问题为后期加入rag
This commit is contained in:
0
db_util/__init__.py
Normal file
0
db_util/__init__.py
Normal file
68
db_util/db_main.py
Normal file
68
db_util/db_main.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from sqlalchemy import Column, DateTime, String, create_engine, Boolean, Text
|
||||||
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
|
|
||||||
|
# 申明基类对象
|
||||||
|
Base = declarative_base()
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
DB_PATH = config('DB_PATH', default='E://pyptoject//sqlbot_agent//main.sqlite3')
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionFeedBack(Base):
|
||||||
|
# 定义表名
|
||||||
|
__tablename__ = 'db_question_feedback'
|
||||||
|
# 定义字段
|
||||||
|
id = Column(String(255), primary_key=True)
|
||||||
|
create_time = Column(DateTime, nullable=False, )
|
||||||
|
question = Column(String(255), nullable=False)
|
||||||
|
sql = Column(String(500), nullable=False)
|
||||||
|
user_id = Column(String(100), nullable=False, default='1')
|
||||||
|
# 用户意见反馈
|
||||||
|
user_comment = Column(Text, nullable=True)
|
||||||
|
# 用户点赞,点踩
|
||||||
|
user_praise = Column(Boolean, nullable=False, default=False)
|
||||||
|
# 该数据是否被认为梳理过
|
||||||
|
is_process = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionFeedBack(Base):
|
||||||
|
# 定义表名
|
||||||
|
__tablename__ = 'db_question_feedback'
|
||||||
|
__table_args__ = {'extend_existing': True}
|
||||||
|
# 定义字段
|
||||||
|
id = Column(String(255), primary_key=True)
|
||||||
|
create_time = Column(DateTime, nullable=False, )
|
||||||
|
question = Column(String(255), nullable=False)
|
||||||
|
sql = Column(String(500), nullable=False)
|
||||||
|
user_id = Column(String(100), nullable=False, default='1')
|
||||||
|
# 用户意见反馈
|
||||||
|
user_comment = Column(Text, nullable=True)
|
||||||
|
# 用户点赞,点踩
|
||||||
|
user_praise = Column(Boolean, nullable=False, default=False)
|
||||||
|
# 该数据是否被认为梳理过
|
||||||
|
is_process = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class PredefinedQuestion(Base):
|
||||||
|
# 定义表名,预制问题表
|
||||||
|
__tablename__ = 'db_predefined_question'
|
||||||
|
__table_args__ = {'extend_existing': True}
|
||||||
|
# 定义字段
|
||||||
|
id = Column(String(255), primary_key=True)
|
||||||
|
question = Column(String(255), nullable=False)
|
||||||
|
user_id = Column(String(100), nullable=False, default='1')
|
||||||
|
# 该数据是否被认为梳理过
|
||||||
|
enable = Column(Boolean, nullable=False, default=True)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSqlalchemy(object):
|
||||||
|
def __init__(self):
|
||||||
|
# 创建sqlite连接引擎
|
||||||
|
engine = create_engine(f'sqlite:///{DB_PATH}', echo=True)
|
||||||
|
# 创建表
|
||||||
|
Base.metadata.create_all(engine, checkfirst=True)
|
||||||
|
# 创建sqlite的session连接对象
|
||||||
|
self.session = sessionmaker(bind=engine)()
|
||||||
@@ -4,13 +4,15 @@ from functools import wraps
|
|||||||
import util.utils
|
import util.utils
|
||||||
from logging_config import LOGGING_CONFIG
|
from logging_config import LOGGING_CONFIG
|
||||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||||
|
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||||||
from decouple import config
|
from decouple import config
|
||||||
import flask
|
import flask
|
||||||
from util import load_ddl_doc
|
from util import load_ddl_doc
|
||||||
from flask import Flask, Response, jsonify, request
|
from flask import Flask, Response, jsonify, request
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def connect_database(vn):
|
def connect_database(vn):
|
||||||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||||||
if db_type == 'sqlite':
|
if db_type == 'sqlite':
|
||||||
@@ -47,8 +49,8 @@ def create_vana():
|
|||||||
"api_key": config('CHAT_MODEL_API_KEY', default=''),
|
"api_key": config('CHAT_MODEL_API_KEY', default=''),
|
||||||
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
|
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
|
||||||
"model": config('CHAT_MODEL_NAME', default=''),
|
"model": config('CHAT_MODEL_NAME', default=''),
|
||||||
'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
|
'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
|
||||||
'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=5000),
|
'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,11 +68,14 @@ def init_vn(vn):
|
|||||||
|
|
||||||
|
|
||||||
from vanna.flask import VannaFlaskApp
|
from vanna.flask import VannaFlaskApp
|
||||||
|
|
||||||
vn = create_vana()
|
vn = create_vana()
|
||||||
app = VannaFlaskApp(vn,chart=False)
|
app = VannaFlaskApp(vn, chart=False)
|
||||||
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60))
|
app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60))
|
||||||
init_vn(vn)
|
init_vn(vn)
|
||||||
cache = app.cache
|
cache = app.cache
|
||||||
|
|
||||||
|
|
||||||
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
|
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
|
||||||
def generate_sql_2():
|
def generate_sql_2():
|
||||||
"""
|
"""
|
||||||
@@ -108,36 +113,38 @@ def generate_sql_2():
|
|||||||
logger.info("Generate sql result is {0}".format(data))
|
logger.info("Generate sql result is {0}".format(data))
|
||||||
data['id'] = id
|
data['id'] = id
|
||||||
sql = data["resp"]["sql"]
|
sql = data["resp"]["sql"]
|
||||||
logger.info("generate sql is : "+ sql)
|
logger.info("generate sql is : " + sql)
|
||||||
cache.set(id=id, field="question", value=question)
|
cache.set(id=id, field="question", value=question)
|
||||||
cache.set(id=id, field="sql", value=sql)
|
cache.set(id=id, field="sql", value=sql)
|
||||||
data["type"]="success"
|
data["type"] = "success"
|
||||||
|
save_save_question_async(id, user_id, question, sql)
|
||||||
return jsonify(data)
|
return jsonify(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"generate sql failed:{e}")
|
logger.error(f"generate sql failed:{e}")
|
||||||
return jsonify({"type": "error", "error": str(e)})
|
return jsonify({"type": "error", "error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
def session_save(func):
|
def session_save(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
id=request.args.get("id")
|
id = request.args.get("id")
|
||||||
user_id = request.args.get("user_id")
|
user_id = request.args.get("user_id")
|
||||||
logger.info(f" id: {id},user_id: {user_id}")
|
logger.info(f" id: {id},user_id: {user_id}")
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
datas=[]
|
datas = []
|
||||||
session_len = int(config("SESSION_LENGTH", default=2))
|
session_len = int(config("SESSION_LENGTH", default=2))
|
||||||
if cache.exists(id=user_id, field="data"):
|
if cache.exists(id=user_id, field="data"):
|
||||||
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
|
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
|
||||||
data = {
|
data = {
|
||||||
"id": id,
|
"id": id,
|
||||||
"question":cache.get(id=id, field="question"),
|
"question": cache.get(id=id, field="question"),
|
||||||
"sql":cache.get(id=id, field="sql")
|
"sql": cache.get(id=id, field="sql")
|
||||||
}
|
}
|
||||||
datas.append(data)
|
datas.append(data)
|
||||||
logger.info("datas is {0}".format(datas))
|
logger.info("datas is {0}".format(datas))
|
||||||
if len(datas) > session_len and session_len > 0:
|
if len(datas) > session_len and session_len > 0:
|
||||||
datas=datas[-session_len:]
|
datas = datas[-session_len:]
|
||||||
# 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
|
# 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
|
||||||
cache.delete(id=id, field="question")
|
cache.delete(id=id, field="question")
|
||||||
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
|
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
|
||||||
@@ -147,7 +154,6 @@ def session_save(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
||||||
@session_save
|
@session_save
|
||||||
@app.requires_cache(["sql"])
|
@app.requires_cache(["sql"])
|
||||||
@@ -200,7 +206,7 @@ def run_sql_2(id: str, sql: str):
|
|||||||
# logger.info("Total count is {0}".format(total_count))
|
# logger.info("Total count is {0}".format(total_count))
|
||||||
df = vn.run_sql_2(sql=sql)
|
df = vn.run_sql_2(sql=sql)
|
||||||
result = df.to_dict(orient='records')
|
result = df.to_dict(orient='records')
|
||||||
logger.info("df ---------------{0} {1}".format(result,type(result)))
|
logger.info("df ---------------{0} {1}".format(result, type(result)))
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"type": "success",
|
"type": "success",
|
||||||
@@ -231,7 +237,15 @@ def verify_user():
|
|||||||
return jsonify({"type": "error", "error": str(e)})
|
return jsonify({"type": "error", "error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
|
||||||
|
def query_present_question():
|
||||||
|
try:
|
||||||
|
data=query_predefined_question_list()
|
||||||
|
return jsonify({"type": "success", "data": data})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询预制问题失败 failed:{e}")
|
||||||
|
return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
app.run(host='0.0.0.0', port=8084, debug=False)
|
app.run(host='0.0.0.0', port=8084, debug=False)
|
||||||
|
|||||||
50
service/question_feedback_service.py
Normal file
50
service/question_feedback_service.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from db_util.db_main import QuestionFeedBack, SqliteSqlalchemy, PredefinedQuestion
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
pool = ThreadPoolExecutor(max_workers=10)
|
||||||
|
|
||||||
|
def save_question(id, user_id, question, sql):
|
||||||
|
logger.info(f"开始保存用户用户==>:{question}")
|
||||||
|
user_id = user_id if user_id else '1'
|
||||||
|
qq = QuestionFeedBack(id=id, user_id=user_id, question=question, sql=sql
|
||||||
|
, create_time=datetime.now())
|
||||||
|
session = SqliteSqlalchemy().session
|
||||||
|
try:
|
||||||
|
session.add(qq)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存用户问题与sql 错误,{e}")
|
||||||
|
session.rollback()
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def save_save_question_async(id, user_id, question, sql):
|
||||||
|
pool.submit(save_question, id, user_id, question, sql)
|
||||||
|
|
||||||
|
|
||||||
|
def update_user_feedBack(id, user_comment, user_praise: bool):
|
||||||
|
session = SqliteSqlalchemy().session
|
||||||
|
try:
|
||||||
|
session.query(QuestionFeedBack).filter(QuestionFeedBack.id == id).update(
|
||||||
|
{QuestionFeedBack.user_comment: user_comment, QuestionFeedBack.user_praise: user_praise})
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新用户反馈错误,{e}")
|
||||||
|
session.rollback()
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
'''
|
||||||
|
查询所有的预制问题
|
||||||
|
'''
|
||||||
|
def query_predefined_question_list() -> list:
|
||||||
|
session = SqliteSqlalchemy().session
|
||||||
|
all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all()
|
||||||
|
all = [a.to_dict() for a in all] if all else []
|
||||||
|
session.close()
|
||||||
|
return all
|
||||||
Reference in New Issue
Block a user