120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
import json
|
|
from datetime import datetime
|
|
import pandas as pd
|
|
from db_util.db_main import Conversation, SqliteSqlalchemy
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def save_conversation(id, user_id, cvs_id, question):
|
|
logger.info(f'save_conversation => id {id} user_id {user_id} cvs_id {cvs_id} question {question}')
|
|
cvs = Conversation(id=id, user_id=user_id, cvs_id=cvs_id, question=question, create_time = datetime.now())
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
session.add(cvs)
|
|
session.commit()
|
|
except Exception as e:
|
|
logger.error(f"Failed to save conversation info (id,user_id,cvs_id,question) error {e}")
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
|
|
def get_conversation(cvs_id: str):
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
results = session.query(Conversation).filter(Conversation.id == cvs_id)
|
|
logger.info(f"conversation {cvs_id} results is {results}")
|
|
return results.all()
|
|
except Exception as e:
|
|
logger.info(f"get conversation with id {cvs_id} error {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def get_all_conversations_by_user(user_id, offset=0,page_size=100):
|
|
session = SqliteSqlalchemy().session
|
|
user_cvs = []
|
|
try:
|
|
query = session.query(Conversation).filter(Conversation.user_id == user_id)
|
|
total_count = query.count()
|
|
data = (session.query(Conversation).filter(Conversation.user_id == user_id).
|
|
order_by(Conversation.create_time.desc()).offset(offset).limit(page_size)).all()
|
|
results = {
|
|
"total_count": total_count,
|
|
"data": data
|
|
}
|
|
logger.info(f"conversation {user_id} total_count {total_count}")
|
|
return results
|
|
except Exception as e:
|
|
logger.info(f"get all conversation with user id {user_id} error {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def update_conversation(id: str, sql=None, question=None, chart_cfg=None, meta=None, answer=None):
|
|
"""更新sql到对应question"""
|
|
session = SqliteSqlalchemy().session
|
|
logger.info(f"update conversation with id {id} sql {sql} question {question} chart_cfg {chart_cfg} meta {meta} answer {answer}")
|
|
try:
|
|
update_data = {}
|
|
if sql is not None:
|
|
if isinstance(sql, dict):
|
|
sql = json.dumps(sql, ensure_ascii=False)
|
|
update_data['sql'] = sql
|
|
if question is not None:
|
|
update_data['question'] = question
|
|
if chart_cfg is not None:
|
|
if isinstance(chart_cfg, dict):
|
|
chart_cfg = json.dumps(chart_cfg, ensure_ascii=False)
|
|
update_data['chart_cfg'] = chart_cfg
|
|
if meta is not None:
|
|
if isinstance(meta, dict):
|
|
meta = json.dumps(meta, ensure_ascii=False)
|
|
update_data['meta'] = meta
|
|
if answer is not None:
|
|
if isinstance(answer, dict):
|
|
logger.info("answer dict")
|
|
answer_df = pd.DataFrame(answer)
|
|
answer = answer_df.to_json(orient='records', date_format='iso', force_ascii=False)
|
|
# answer = json.dumps(answer, ensure_ascii=False)
|
|
update_data['answer'] = answer
|
|
if update_data:
|
|
result = session.query(Conversation).filter(Conversation.id == id).update(update_data)
|
|
session.commit()
|
|
except Exception as e:
|
|
logger.info(f"update conversation with id {id} error {e}")
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def get_latest_question(cvs_id, user_id, limit_count):
|
|
"""获取指定会话的最新问题"""
|
|
logger.info("get latest question.........")
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
latest_conversation = session.query(Conversation).filter_by(
|
|
cvs_id=cvs_id,
|
|
user_id=user_id
|
|
).order_by(Conversation.create_time.desc()).limit(limit_count).all()
|
|
last_question = [cs.question for cs in latest_conversation]
|
|
return last_question
|
|
except Exception as e:
|
|
logger.error(f"get_latest_question error {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def get_qa_by_id(id: str):
|
|
session = SqliteSqlalchemy().session
|
|
try:
|
|
result = session.query(Conversation).filter_by(id=id).first()
|
|
if result:
|
|
return {"question":result.question, "sql":result.sql,"answer":result.answer}
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"get_sql_by_id error {e}")
|
|
finally:
|
|
session.close()
|
|
|