Files
moss-ai/app/backend-python/services/conversation_service.py
雷雨 8635b84b2d init
2025-12-15 22:05:56 +08:00

372 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对话列表管理服务
负责管理用户的对话列表(会话列表)
"""
import logging
from typing import Optional, Dict, Any, List
from datetime import datetime
from database import query, insert, update, get_db_type
logger = logging.getLogger(__name__)
class ConversationService:
"""对话列表管理服务"""
@staticmethod
def generate_id() -> int:
"""生成唯一ID使用时间戳+随机数)"""
import time
import random
return int(time.time() * 1000000) + random.randint(1000, 9999)
@staticmethod
async def create_conversation(
system_user_id: int,
context_id: str,
title: str = "新对话",
description: Optional[str] = None
) -> Dict[str, Any]:
"""
创建新对话
Args:
system_user_id: 系统用户ID
context_id: 会话上下文ID
title: 对话标题
description: 对话描述(可选)
Returns:
创建的对话信息
"""
try:
conv_id = ConversationService.generate_id()
sql = """
INSERT INTO conversation_list
(id, context_id, system_user_id, title, description,
message_count, created_at, updated_at, is_active)
VALUES (%s, %s, %s, %s, %s, 0, NOW(), NOW(), TRUE)
"""
params = (conv_id, context_id, system_user_id, title, description)
insert(sql, params)
logger.info(f"✅ 创建对话成功: user={system_user_id}, context={context_id}, title={title}")
return {
"id": conv_id,
"context_id": context_id,
"system_user_id": system_user_id,
"title": title,
"description": description,
"message_count": 0,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"is_active": True
}
except Exception as e:
logger.error(f"❌ 创建对话失败: {e}", exc_info=True)
raise Exception(f"创建对话失败: {str(e)}")
@staticmethod
async def get_conversation_by_context(
context_id: str
) -> Optional[Dict[str, Any]]:
"""
根据context_id获取对话信息
Args:
context_id: 会话上下文ID
Returns:
对话信息不存在则返回None
"""
try:
sql = """
SELECT id, context_id, system_user_id, title, description,
message_count, last_message, created_at, updated_at, is_active
FROM conversation_list
WHERE context_id = %s
LIMIT 1
"""
results = query(sql, (context_id,))
if results:
return results[0]
return None
except Exception as e:
logger.error(f"❌ 获取对话信息失败: {e}", exc_info=True)
return None
@staticmethod
async def update_conversation(
context_id: str,
title: Optional[str] = None,
last_message: Optional[str] = None,
increment_message_count: bool = False
) -> bool:
"""
更新对话信息
Args:
context_id: 会话上下文ID
title: 新标题(可选)
last_message: 最后一条消息(可选)
increment_message_count: 是否增加消息计数
Returns:
是否成功更新
"""
try:
db_type = get_db_type()
# 对于StarRocks使用DELETE+INSERT策略
if db_type == 'starrocks':
# 先获取当前记录
conv = await ConversationService.get_conversation_by_context(context_id)
if not conv:
logger.warning(f"⚠️ 对话不存在: context={context_id}")
return False
# 更新字段
if title is not None:
conv['title'] = title
if last_message is not None:
preview = last_message[:200] if len(last_message) > 200 else last_message
conv['last_message'] = preview
if increment_message_count:
conv['message_count'] = conv.get('message_count', 0) + 1
# 删除旧记录
delete_sql = "DELETE FROM conversation_list WHERE context_id = %s"
update(delete_sql, (context_id,))
# 插入新记录
insert_sql = """
INSERT INTO conversation_list
(id, system_user_id, updated_at, context_id, title, description,
message_count, last_message, created_at, is_active)
VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, %s)
"""
insert(insert_sql, (
conv['id'],
conv['system_user_id'],
conv['context_id'],
conv['title'],
conv.get('description'),
conv['message_count'],
conv.get('last_message'),
conv['created_at'],
conv.get('is_active', True)
))
logger.info(f"✅ 更新对话成功(StarRocks): context={context_id}")
return True
# 对于MySQL使用正常的UPDATE
else:
# 构建动态更新SQL
update_fields = []
params = []
if title is not None:
update_fields.append("title = %s")
params.append(title)
if last_message is not None:
# 限制预览长度为200字符
preview = last_message[:200] if len(last_message) > 200 else last_message
update_fields.append("last_message = %s")
params.append(preview)
if increment_message_count:
update_fields.append("message_count = message_count + 1")
# 总是更新 updated_at
update_fields.append("updated_at = NOW()")
if not update_fields:
return True # 没有字段需要更新
params.append(context_id)
sql = f"""
UPDATE conversation_list
SET {', '.join(update_fields)}
WHERE context_id = %s
"""
affected = update(sql, tuple(params))
if affected > 0:
logger.info(f"✅ 更新对话成功: context={context_id}")
return True
else:
logger.warning(f"⚠️ 对话不存在: context={context_id}")
return False
except Exception as e:
logger.error(f"❌ 更新对话失败: {e}", exc_info=True)
return False
@staticmethod
async def get_user_conversations(
system_user_id: int,
limit: int = 50,
only_active: bool = True
) -> List[Dict[str, Any]]:
"""
获取用户的对话列表
Args:
system_user_id: 系统用户ID
limit: 返回数量限制
only_active: 是否只返回活跃对话
Returns:
对话列表(按更新时间倒序)
"""
try:
if only_active:
sql = """
SELECT id, context_id, system_user_id, title, description,
message_count, last_message, created_at, updated_at, is_active
FROM conversation_list
WHERE system_user_id = %s AND is_active = TRUE
ORDER BY updated_at DESC
LIMIT %s
"""
else:
sql = """
SELECT id, context_id, system_user_id, title, description,
message_count, last_message, created_at, updated_at, is_active
FROM conversation_list
WHERE system_user_id = %s
ORDER BY updated_at DESC
LIMIT %s
"""
results = query(sql, (system_user_id, limit))
return results
except Exception as e:
logger.error(f"❌ 获取对话列表失败: {e}", exc_info=True)
return []
@staticmethod
async def delete_conversation(
context_id: str,
system_user_id: int,
soft_delete: bool = True
) -> bool:
"""
删除对话
Args:
context_id: 会话上下文ID
system_user_id: 系统用户ID用于权限验证
soft_delete: 是否软删除(仅标记为不活跃)
Returns:
是否成功删除
"""
try:
db_type = get_db_type()
# StarRocks只支持硬删除DELETE不支持UPDATE
if db_type == 'starrocks' or not soft_delete:
# 硬删除:物理删除记录
sql = """
DELETE FROM conversation_list
WHERE context_id = %s AND system_user_id = %s
"""
affected = update(sql, (context_id, system_user_id))
if affected > 0:
logger.info(f"✅ 删除对话成功: context={context_id}")
return True
else:
logger.warning(f"⚠️ 对话不存在或无权限: context={context_id}")
return False
# MySQL支持软删除
else:
# 软删除:仅标记为不活跃(需要先获取记录再重新插入)
conv = await ConversationService.get_conversation_by_context(context_id)
if not conv or conv['system_user_id'] != system_user_id:
logger.warning(f"⚠️ 对话不存在或无权限: context={context_id}")
return False
# 删除旧记录
delete_sql = """
DELETE FROM conversation_list
WHERE context_id = %s AND system_user_id = %s
"""
update(delete_sql, (context_id, system_user_id))
# 插入标记为不活跃的记录
insert_sql = """
INSERT INTO conversation_list
(id, system_user_id, updated_at, context_id, title, description,
message_count, last_message, created_at, is_active)
VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, FALSE)
"""
insert(insert_sql, (
conv['id'],
conv['system_user_id'],
conv['context_id'],
conv['title'],
conv.get('description'),
conv['message_count'],
conv.get('last_message'),
conv['created_at']
))
logger.info(f"✅ 软删除对话成功: context={context_id}")
return True
except Exception as e:
logger.error(f"❌ 删除对话失败: {e}", exc_info=True)
return False
@staticmethod
async def ensure_conversation_exists(
system_user_id: int,
context_id: str,
title: str = "新对话"
) -> Dict[str, Any]:
"""
确保对话存在,不存在则创建
Args:
system_user_id: 系统用户ID
context_id: 会话上下文ID
title: 对话标题
Returns:
对话信息
"""
# 先尝试获取
conv = await ConversationService.get_conversation_by_context(context_id)
if conv:
return conv
# 不存在则创建
return await ConversationService.create_conversation(
system_user_id=system_user_id,
context_id=context_id,
title=title
)
# 创建全局服务实例
conversation_service = ConversationService()