372 lines
13 KiB
Python
372 lines
13 KiB
Python
"""
|
||
对话列表管理服务
|
||
负责管理用户的对话列表(会话列表)
|
||
"""
|
||
|
||
import logging
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime
|
||
|
||
from database import query, insert, update, get_db_type
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ConversationService:
|
||
"""对话列表管理服务"""
|
||
|
||
@staticmethod
|
||
def generate_id() -> int:
|
||
"""生成唯一ID(使用时间戳+随机数)"""
|
||
import time
|
||
import random
|
||
return int(time.time() * 1000000) + random.randint(1000, 9999)
|
||
|
||
@staticmethod
|
||
async def create_conversation(
|
||
system_user_id: int,
|
||
context_id: str,
|
||
title: str = "新对话",
|
||
description: Optional[str] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
创建新对话
|
||
|
||
Args:
|
||
system_user_id: 系统用户ID
|
||
context_id: 会话上下文ID
|
||
title: 对话标题
|
||
description: 对话描述(可选)
|
||
|
||
Returns:
|
||
创建的对话信息
|
||
"""
|
||
try:
|
||
conv_id = ConversationService.generate_id()
|
||
|
||
sql = """
|
||
INSERT INTO conversation_list
|
||
(id, context_id, system_user_id, title, description,
|
||
message_count, created_at, updated_at, is_active)
|
||
VALUES (%s, %s, %s, %s, %s, 0, NOW(), NOW(), TRUE)
|
||
"""
|
||
|
||
params = (conv_id, context_id, system_user_id, title, description)
|
||
|
||
insert(sql, params)
|
||
logger.info(f"✅ 创建对话成功: user={system_user_id}, context={context_id}, title={title}")
|
||
|
||
return {
|
||
"id": conv_id,
|
||
"context_id": context_id,
|
||
"system_user_id": system_user_id,
|
||
"title": title,
|
||
"description": description,
|
||
"message_count": 0,
|
||
"created_at": datetime.now().isoformat(),
|
||
"updated_at": datetime.now().isoformat(),
|
||
"is_active": True
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 创建对话失败: {e}", exc_info=True)
|
||
raise Exception(f"创建对话失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_conversation_by_context(
|
||
context_id: str
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据context_id获取对话信息
|
||
|
||
Args:
|
||
context_id: 会话上下文ID
|
||
|
||
Returns:
|
||
对话信息,不存在则返回None
|
||
"""
|
||
try:
|
||
sql = """
|
||
SELECT id, context_id, system_user_id, title, description,
|
||
message_count, last_message, created_at, updated_at, is_active
|
||
FROM conversation_list
|
||
WHERE context_id = %s
|
||
LIMIT 1
|
||
"""
|
||
|
||
results = query(sql, (context_id,))
|
||
|
||
if results:
|
||
return results[0]
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取对话信息失败: {e}", exc_info=True)
|
||
return None
|
||
|
||
@staticmethod
|
||
async def update_conversation(
|
||
context_id: str,
|
||
title: Optional[str] = None,
|
||
last_message: Optional[str] = None,
|
||
increment_message_count: bool = False
|
||
) -> bool:
|
||
"""
|
||
更新对话信息
|
||
|
||
Args:
|
||
context_id: 会话上下文ID
|
||
title: 新标题(可选)
|
||
last_message: 最后一条消息(可选)
|
||
increment_message_count: 是否增加消息计数
|
||
|
||
Returns:
|
||
是否成功更新
|
||
"""
|
||
try:
|
||
db_type = get_db_type()
|
||
|
||
# 对于StarRocks,使用DELETE+INSERT策略
|
||
if db_type == 'starrocks':
|
||
# 先获取当前记录
|
||
conv = await ConversationService.get_conversation_by_context(context_id)
|
||
if not conv:
|
||
logger.warning(f"⚠️ 对话不存在: context={context_id}")
|
||
return False
|
||
|
||
# 更新字段
|
||
if title is not None:
|
||
conv['title'] = title
|
||
if last_message is not None:
|
||
preview = last_message[:200] if len(last_message) > 200 else last_message
|
||
conv['last_message'] = preview
|
||
if increment_message_count:
|
||
conv['message_count'] = conv.get('message_count', 0) + 1
|
||
|
||
# 删除旧记录
|
||
delete_sql = "DELETE FROM conversation_list WHERE context_id = %s"
|
||
update(delete_sql, (context_id,))
|
||
|
||
# 插入新记录
|
||
insert_sql = """
|
||
INSERT INTO conversation_list
|
||
(id, system_user_id, updated_at, context_id, title, description,
|
||
message_count, last_message, created_at, is_active)
|
||
VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
insert(insert_sql, (
|
||
conv['id'],
|
||
conv['system_user_id'],
|
||
conv['context_id'],
|
||
conv['title'],
|
||
conv.get('description'),
|
||
conv['message_count'],
|
||
conv.get('last_message'),
|
||
conv['created_at'],
|
||
conv.get('is_active', True)
|
||
))
|
||
|
||
logger.info(f"✅ 更新对话成功(StarRocks): context={context_id}")
|
||
return True
|
||
|
||
# 对于MySQL,使用正常的UPDATE
|
||
else:
|
||
# 构建动态更新SQL
|
||
update_fields = []
|
||
params = []
|
||
|
||
if title is not None:
|
||
update_fields.append("title = %s")
|
||
params.append(title)
|
||
|
||
if last_message is not None:
|
||
# 限制预览长度为200字符
|
||
preview = last_message[:200] if len(last_message) > 200 else last_message
|
||
update_fields.append("last_message = %s")
|
||
params.append(preview)
|
||
|
||
if increment_message_count:
|
||
update_fields.append("message_count = message_count + 1")
|
||
|
||
# 总是更新 updated_at
|
||
update_fields.append("updated_at = NOW()")
|
||
|
||
if not update_fields:
|
||
return True # 没有字段需要更新
|
||
|
||
params.append(context_id)
|
||
|
||
sql = f"""
|
||
UPDATE conversation_list
|
||
SET {', '.join(update_fields)}
|
||
WHERE context_id = %s
|
||
"""
|
||
|
||
affected = update(sql, tuple(params))
|
||
|
||
if affected > 0:
|
||
logger.info(f"✅ 更新对话成功: context={context_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"⚠️ 对话不存在: context={context_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 更新对话失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
@staticmethod
|
||
async def get_user_conversations(
|
||
system_user_id: int,
|
||
limit: int = 50,
|
||
only_active: bool = True
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户的对话列表
|
||
|
||
Args:
|
||
system_user_id: 系统用户ID
|
||
limit: 返回数量限制
|
||
only_active: 是否只返回活跃对话
|
||
|
||
Returns:
|
||
对话列表(按更新时间倒序)
|
||
"""
|
||
try:
|
||
if only_active:
|
||
sql = """
|
||
SELECT id, context_id, system_user_id, title, description,
|
||
message_count, last_message, created_at, updated_at, is_active
|
||
FROM conversation_list
|
||
WHERE system_user_id = %s AND is_active = TRUE
|
||
ORDER BY updated_at DESC
|
||
LIMIT %s
|
||
"""
|
||
else:
|
||
sql = """
|
||
SELECT id, context_id, system_user_id, title, description,
|
||
message_count, last_message, created_at, updated_at, is_active
|
||
FROM conversation_list
|
||
WHERE system_user_id = %s
|
||
ORDER BY updated_at DESC
|
||
LIMIT %s
|
||
"""
|
||
|
||
results = query(sql, (system_user_id, limit))
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取对话列表失败: {e}", exc_info=True)
|
||
return []
|
||
|
||
@staticmethod
|
||
async def delete_conversation(
|
||
context_id: str,
|
||
system_user_id: int,
|
||
soft_delete: bool = True
|
||
) -> bool:
|
||
"""
|
||
删除对话
|
||
|
||
Args:
|
||
context_id: 会话上下文ID
|
||
system_user_id: 系统用户ID(用于权限验证)
|
||
soft_delete: 是否软删除(仅标记为不活跃)
|
||
|
||
Returns:
|
||
是否成功删除
|
||
"""
|
||
try:
|
||
db_type = get_db_type()
|
||
|
||
# StarRocks只支持硬删除(DELETE),不支持UPDATE
|
||
if db_type == 'starrocks' or not soft_delete:
|
||
# 硬删除:物理删除记录
|
||
sql = """
|
||
DELETE FROM conversation_list
|
||
WHERE context_id = %s AND system_user_id = %s
|
||
"""
|
||
affected = update(sql, (context_id, system_user_id))
|
||
|
||
if affected > 0:
|
||
logger.info(f"✅ 删除对话成功: context={context_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"⚠️ 对话不存在或无权限: context={context_id}")
|
||
return False
|
||
|
||
# MySQL支持软删除
|
||
else:
|
||
# 软删除:仅标记为不活跃(需要先获取记录再重新插入)
|
||
conv = await ConversationService.get_conversation_by_context(context_id)
|
||
if not conv or conv['system_user_id'] != system_user_id:
|
||
logger.warning(f"⚠️ 对话不存在或无权限: context={context_id}")
|
||
return False
|
||
|
||
# 删除旧记录
|
||
delete_sql = """
|
||
DELETE FROM conversation_list
|
||
WHERE context_id = %s AND system_user_id = %s
|
||
"""
|
||
update(delete_sql, (context_id, system_user_id))
|
||
|
||
# 插入标记为不活跃的记录
|
||
insert_sql = """
|
||
INSERT INTO conversation_list
|
||
(id, system_user_id, updated_at, context_id, title, description,
|
||
message_count, last_message, created_at, is_active)
|
||
VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, FALSE)
|
||
"""
|
||
insert(insert_sql, (
|
||
conv['id'],
|
||
conv['system_user_id'],
|
||
conv['context_id'],
|
||
conv['title'],
|
||
conv.get('description'),
|
||
conv['message_count'],
|
||
conv.get('last_message'),
|
||
conv['created_at']
|
||
))
|
||
|
||
logger.info(f"✅ 软删除对话成功: context={context_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 删除对话失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
@staticmethod
|
||
async def ensure_conversation_exists(
|
||
system_user_id: int,
|
||
context_id: str,
|
||
title: str = "新对话"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
确保对话存在,不存在则创建
|
||
|
||
Args:
|
||
system_user_id: 系统用户ID
|
||
context_id: 会话上下文ID
|
||
title: 对话标题
|
||
|
||
Returns:
|
||
对话信息
|
||
"""
|
||
# 先尝试获取
|
||
conv = await ConversationService.get_conversation_by_context(context_id)
|
||
|
||
if conv:
|
||
return conv
|
||
|
||
# 不存在则创建
|
||
return await ConversationService.create_conversation(
|
||
system_user_id=system_user_id,
|
||
context_id=context_id,
|
||
title=title
|
||
)
|
||
|
||
|
||
# 创建全局服务实例
|
||
conversation_service = ConversationService()
|
||
|