276 lines
8.4 KiB
Python
276 lines
8.4 KiB
Python
"""
|
||
用户认证 API
|
||
提供用户注册、登录、登出等功能
|
||
"""
|
||
|
||
import logging
|
||
import hashlib
|
||
import time
|
||
import random
|
||
from typing import Optional
|
||
from fastapi import APIRouter, HTTPException, status
|
||
from pydantic import BaseModel, Field
|
||
|
||
from database import query, insert, update
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class UserRegisterRequest(BaseModel):
|
||
"""用户注册请求"""
|
||
username: str = Field(..., min_length=3, max_length=50, description="用户名")
|
||
password: str = Field(..., min_length=6, description="密码")
|
||
email: Optional[str] = Field(None, description="邮箱")
|
||
phone: Optional[str] = Field(None, description="手机号")
|
||
nickname: Optional[str] = Field(None, description="昵称")
|
||
|
||
|
||
class UserLoginRequest(BaseModel):
|
||
"""用户登录请求"""
|
||
username: str = Field(..., description="用户名")
|
||
password: str = Field(..., description="密码")
|
||
|
||
|
||
class UserResponse(BaseModel):
|
||
"""用户信息响应"""
|
||
id: int
|
||
username: str
|
||
nickname: Optional[str] = None
|
||
email: Optional[str] = None
|
||
phone: Optional[str] = None
|
||
avatar: Optional[str] = None
|
||
created_at: str
|
||
|
||
|
||
class LoginResponse(BaseModel):
|
||
"""登录响应"""
|
||
success: bool
|
||
message: str
|
||
token: Optional[str] = None
|
||
user: Optional[UserResponse] = None
|
||
xiaomi_bound: bool = False # 是否已绑定小米账号
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
def hash_password(password: str) -> str:
|
||
"""密码加密"""
|
||
return hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
|
||
def generate_token(user_id: int) -> str:
|
||
"""生成简单的token(生产环境应使用JWT)"""
|
||
import uuid
|
||
return f"{user_id}_{uuid.uuid4().hex}"
|
||
|
||
|
||
def generate_user_id() -> int:
|
||
"""生成用户ID"""
|
||
return int(time.time() * 1000) + random.randint(1000, 9999)
|
||
|
||
|
||
# ==================== API 端点 ====================
|
||
|
||
@router.post("/register", response_model=LoginResponse)
|
||
async def register(data: UserRegisterRequest):
|
||
"""
|
||
用户注册
|
||
"""
|
||
try:
|
||
# 检查用户名是否已存在
|
||
check_sql = "SELECT id FROM users WHERE username = %s"
|
||
existing_user = query(check_sql, (data.username,))
|
||
|
||
if existing_user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="用户名已存在"
|
||
)
|
||
|
||
# 检查邮箱是否已存在
|
||
if data.email:
|
||
check_email_sql = "SELECT id FROM users WHERE email = %s"
|
||
existing_email = query(check_email_sql, (data.email,))
|
||
if existing_email:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="邮箱已被注册"
|
||
)
|
||
|
||
# 生成用户ID
|
||
user_id = generate_user_id()
|
||
|
||
# 加密密码
|
||
hashed_password = hash_password(data.password)
|
||
|
||
# 插入用户记录
|
||
insert_sql = """
|
||
INSERT INTO users
|
||
(id, username, password, email, phone, nickname, status, created_at, updated_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s, 1, NOW(), NOW())
|
||
"""
|
||
insert(insert_sql, (
|
||
user_id,
|
||
data.username,
|
||
hashed_password,
|
||
data.email,
|
||
data.phone,
|
||
data.nickname or data.username
|
||
))
|
||
|
||
# 生成token
|
||
token = generate_token(user_id)
|
||
|
||
# 返回用户信息
|
||
user_info = UserResponse(
|
||
id=user_id,
|
||
username=data.username,
|
||
nickname=data.nickname or data.username,
|
||
email=data.email,
|
||
phone=data.phone,
|
||
avatar=None,
|
||
created_at=str(time.strftime("%Y-%m-%d %H:%M:%S"))
|
||
)
|
||
|
||
logger.info(f"用户注册成功: {data.username}")
|
||
|
||
return LoginResponse(
|
||
success=True,
|
||
message="注册成功",
|
||
token=token,
|
||
user=user_info
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"用户注册失败: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"注册失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.post("/login", response_model=LoginResponse)
|
||
async def login(data: UserLoginRequest):
|
||
"""
|
||
用户登录
|
||
"""
|
||
try:
|
||
# 查询用户
|
||
sql = """
|
||
SELECT id, username, password, nickname, email, phone, avatar,
|
||
status, created_at
|
||
FROM users
|
||
WHERE username = %s
|
||
"""
|
||
users = query(sql, (data.username,))
|
||
|
||
if not users:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户名或密码错误"
|
||
)
|
||
|
||
user = users[0]
|
||
|
||
# 检查状态
|
||
if user.get("status") != 1:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="账号已被禁用"
|
||
)
|
||
|
||
# 验证密码
|
||
hashed_password = hash_password(data.password)
|
||
if user.get("password") != hashed_password:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户名或密码错误"
|
||
)
|
||
|
||
# 注意:StarRocks DUPLICATE KEY 表不支持 UPDATE,跳过更新最后登录时间
|
||
# TODO: 如果需要记录登录时间,可以改用 PRIMARY KEY 表或单独的登录日志表
|
||
|
||
# 生成token
|
||
token = generate_token(user["id"])
|
||
|
||
# 检查是否绑定小米账号(先试带 is_active=1 的条件)
|
||
xiaomi_check_sql = "SELECT id, xiaomi_username, is_active FROM xiaomi_account WHERE system_user_id = %s AND is_active = 1"
|
||
xiaomi_result = query(xiaomi_check_sql, (user["id"],))
|
||
xiaomi_bound = len(xiaomi_result) > 0
|
||
logger.info(f"🔍 用户 {user['id']} 绑定检查 (is_active=1): {xiaomi_result}, xiaomi_bound={xiaomi_bound}")
|
||
|
||
# 如果查不到,尝试不带 is_active 条件
|
||
if not xiaomi_bound:
|
||
xiaomi_check_sql_no_active = "SELECT id, xiaomi_username, is_active FROM xiaomi_account WHERE system_user_id = %s ORDER BY updated_at DESC LIMIT 1"
|
||
xiaomi_result_no_active = query(xiaomi_check_sql_no_active, (user["id"],))
|
||
xiaomi_bound = len(xiaomi_result_no_active) > 0
|
||
logger.info(f"🔍 用户 {user['id']} 绑定检查 (不带is_active): {xiaomi_result_no_active}, xiaomi_bound={xiaomi_bound}")
|
||
|
||
# 返回用户信息
|
||
user_info = UserResponse(
|
||
id=user["id"],
|
||
username=user["username"],
|
||
nickname=user.get("nickname"),
|
||
email=user.get("email"),
|
||
phone=user.get("phone"),
|
||
avatar=user.get("avatar"),
|
||
created_at=str(user.get("created_at", ""))
|
||
)
|
||
|
||
logger.info(f"用户登录成功: {data.username}, 小米绑定状态: {xiaomi_bound}")
|
||
|
||
return LoginResponse(
|
||
success=True,
|
||
message="登录成功",
|
||
token=token,
|
||
user=user_info,
|
||
xiaomi_bound=xiaomi_bound
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"用户登录失败: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"登录失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.post("/logout")
|
||
async def logout():
|
||
"""
|
||
用户登出
|
||
"""
|
||
return {
|
||
"success": True,
|
||
"message": "登出成功"
|
||
}
|
||
|
||
|
||
@router.get("/check-username/{username}")
|
||
async def check_username(username: str):
|
||
"""
|
||
检查用户名是否可用
|
||
"""
|
||
try:
|
||
sql = "SELECT id FROM users WHERE username = %s"
|
||
result = query(sql, (username,))
|
||
|
||
return {
|
||
"available": len(result) == 0,
|
||
"message": "用户名可用" if len(result) == 0 else "用户名已存在"
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"检查用户名失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=str(e)
|
||
)
|
||
|