86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
from datetime import datetime
|
|
|
|
from db_util.db_main import Conversation, SqliteSqlalchemy
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def save_conversation(id, user_id, cvs_id, question):
|
|
cvs = Conversation(id=id, user_id=user_id, cvs_id=cvs_id, question=question, create_time = datetime.now())
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
session.add(cvs)
|
|
session.commit()
|
|
except:
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
|
|
def get_conversation(cvs_id: str):
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
results = session.query(Conversation).filter(Conversation.id == cvs_id)
|
|
logger.info(f"conversation {cvs_id} results is {results}")
|
|
return results.all()
|
|
except Exception as e:
|
|
logger.info(f"get conversation with id {cvs_id} error {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# def get_all_conversations_by_user(user_id):
|
|
# session = SqliteSqlalchemy().session
|
|
# user_cvs = []
|
|
# try:
|
|
# results = session.query(Conversation).filter(Conversation.user_id == user_id).all()
|
|
# logger.info(f"conversation {user_id} results is {results}")
|
|
# cvs = {}
|
|
# for rs in results:
|
|
# cvs[rs.cvs_id] = {
|
|
#
|
|
# }
|
|
|
|
|
|
def update_conversation(cvs_id: str, id: str, sql=None, meta=None):
|
|
"""更新sql到对应question"""
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
if sql:
|
|
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.sql: sql})
|
|
if meta:
|
|
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.meta: meta})
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def get_latest_question(cvs_id, user_id, limit_count):
|
|
"""获取指定会话的最新问题"""
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
latest_conversation = session.query(Conversation).filter_by(
|
|
cvs_id=cvs_id,
|
|
user_id=user_id
|
|
).order_by(Conversation.create_time.desc()).limit(limit_count).all()
|
|
last_question = [cs.question for cs in latest_conversation]
|
|
return last_question
|
|
except Exception as e:
|
|
logger.error(f"get_latest_question error {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def get_sql_by_id(id: str):
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
result = session.query(Conversation).filter_by(id=id).first()
|
|
if result:
|
|
return result.sql
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"get_sql_by_id error {e}")
|
|
finally:
|
|
session.close()
|