198 lines
6.1 KiB
Python
198 lines
6.1 KiB
Python
|
|
"""
|
|||
|
|
对话记录服务
|
|||
|
|
负责存储和查询用户与管家agent的对话记录
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import json
|
|||
|
|
from typing import Optional, Dict, Any, List
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
from database import query, insert
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 延迟导入以避免循环依赖
|
|||
|
|
_conversation_service = None
|
|||
|
|
def get_conversation_service():
|
|||
|
|
global _conversation_service
|
|||
|
|
if _conversation_service is None:
|
|||
|
|
from services.conversation_service import conversation_service
|
|||
|
|
_conversation_service = conversation_service
|
|||
|
|
return _conversation_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatHistoryService:
|
|||
|
|
"""对话记录服务"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def generate_id() -> int:
|
|||
|
|
"""生成唯一ID(使用时间戳+随机数)"""
|
|||
|
|
import time
|
|||
|
|
import random
|
|||
|
|
return int(time.time() * 1000000) + random.randint(1000, 9999)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def save_message(
|
|||
|
|
system_user_id: int,
|
|||
|
|
context_id: str,
|
|||
|
|
role: str,
|
|||
|
|
content: str,
|
|||
|
|
message_id: Optional[str] = None,
|
|||
|
|
task_id: Optional[str] = None,
|
|||
|
|
status: str = "success",
|
|||
|
|
error_message: Optional[str] = None,
|
|||
|
|
metadata: Optional[Dict[str, Any]] = None
|
|||
|
|
) -> bool:
|
|||
|
|
"""
|
|||
|
|
保存对话消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_user_id: 系统用户ID
|
|||
|
|
context_id: 会话上下文ID
|
|||
|
|
role: 角色 (user, agent, system)
|
|||
|
|
content: 消息内容
|
|||
|
|
message_id: 消息ID(可选)
|
|||
|
|
task_id: 任务ID(可选)
|
|||
|
|
status: 状态 (success, failed, error)
|
|||
|
|
error_message: 错误信息(如果失败)
|
|||
|
|
metadata: 元数据(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功保存
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
msg_id = ChatHistoryService.generate_id()
|
|||
|
|
|
|||
|
|
# 将元数据转为JSON字符串
|
|||
|
|
metadata_json = json.dumps(metadata, ensure_ascii=False) if metadata else None
|
|||
|
|
|
|||
|
|
sql = """
|
|||
|
|
INSERT INTO chat_history
|
|||
|
|
(id, system_user_id, context_id, message_id, task_id, role, content,
|
|||
|
|
status, error_message, metadata, created_at)
|
|||
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
params = (
|
|||
|
|
msg_id,
|
|||
|
|
system_user_id,
|
|||
|
|
context_id,
|
|||
|
|
message_id,
|
|||
|
|
task_id,
|
|||
|
|
role,
|
|||
|
|
content,
|
|||
|
|
status,
|
|||
|
|
error_message,
|
|||
|
|
metadata_json
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
insert(sql, params)
|
|||
|
|
logger.info(f"✅ 保存对话记录成功: user={system_user_id}, role={role}, context={context_id}")
|
|||
|
|
|
|||
|
|
# 更新对话列表
|
|||
|
|
try:
|
|||
|
|
conv_service = get_conversation_service()
|
|||
|
|
|
|||
|
|
# 确保对话列表存在
|
|||
|
|
await conv_service.ensure_conversation_exists(
|
|||
|
|
system_user_id=system_user_id,
|
|||
|
|
context_id=context_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 如果是用户或agent消息,更新最后一条消息和消息计数
|
|||
|
|
if role in ["user", "agent"]:
|
|||
|
|
await conv_service.update_conversation(
|
|||
|
|
context_id=context_id,
|
|||
|
|
last_message=content,
|
|||
|
|
increment_message_count=True
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"⚠️ 更新对话列表失败(不影响消息保存): {e}")
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"❌ 保存对话记录失败: {e}", exc_info=True)
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def get_conversation_history(
|
|||
|
|
system_user_id: int,
|
|||
|
|
context_id: str,
|
|||
|
|
limit: int = 50
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
获取会话历史记录
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_user_id: 系统用户ID
|
|||
|
|
context_id: 会话上下文ID
|
|||
|
|
limit: 返回记录数量限制
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
对话记录列表
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
sql = """
|
|||
|
|
SELECT id, system_user_id, context_id, message_id, task_id, role,
|
|||
|
|
content, status, error_message, metadata, created_at
|
|||
|
|
FROM chat_history
|
|||
|
|
WHERE system_user_id = %s AND context_id = %s
|
|||
|
|
ORDER BY created_at DESC
|
|||
|
|
LIMIT %s
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
results = query(sql, (system_user_id, context_id, limit))
|
|||
|
|
|
|||
|
|
# 解析metadata JSON字符串
|
|||
|
|
for row in results:
|
|||
|
|
if row.get('metadata'):
|
|||
|
|
try:
|
|||
|
|
row['metadata'] = json.loads(row['metadata'])
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"❌ 获取会话历史失败: {e}", exc_info=True)
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def get_user_recent_conversations(
|
|||
|
|
system_user_id: int,
|
|||
|
|
limit: int = 20
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
获取用户最近的对话记录(跨会话)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_user_id: 系统用户ID
|
|||
|
|
limit: 返回记录数量限制
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
对话记录列表
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
sql = """
|
|||
|
|
SELECT id, system_user_id, context_id, message_id, task_id, role,
|
|||
|
|
content, status, error_message, created_at
|
|||
|
|
FROM chat_history
|
|||
|
|
WHERE system_user_id = %s
|
|||
|
|
ORDER BY created_at DESC
|
|||
|
|
LIMIT %s
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
results = query(sql, (system_user_id, limit))
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"❌ 获取用户对话历史失败: {e}", exc_info=True)
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 创建全局服务实例
|
|||
|
|
chat_history_service = ChatHistoryService()
|
|||
|
|
|