Files
yj_room_agent/yj_room_agent/LLM/modelclient.py
2025-06-03 17:20:18 +08:00

142 lines
4.8 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: yelinfeng
@contact: 1312954526@qq.com
@version: 1.0
@created: 2025/5/22 18:32
"""
import requests
import json
import logging
import os
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelClient:
"""
A generic model client to interact with various large language models.
Configured via environment variables.
"""
def __init__(
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
model_name: Optional[str] = None
):
"""
Initialize the model client using environment variables or provided parameters.
Environment Variables:
MODEL_BASE_URL: Base URL of the model API (e.g., http://10.254.23.128:31205/v1)
MODEL_API_KEY: API key (optional)
MODEL_NAME: Model name (e.g., deepseek-chat)
MODEL_HEADERS: JSON string of custom headers (e.g., '{"User-Agent": "Apifox/1.0.0"}')
Args:
base_url: Base URL (overrides env var)
api_key: API key (overrides env var)
headers: Custom headers (overrides env var)
model_name: Model name (overrides env var)
"""
self.base_url = base_url or os.getenv('MODEL_BASE_URL', 'http://10.254.23.128:31205/v1').rstrip('/')
self.api_key = api_key or os.getenv('MODEL_API_KEY', 'NotRequiredSinceWeAreLocal')
self.model_name = model_name or os.getenv('MODEL_NAME', 'deepseek-chat')
# 加载 MODEL_HEADERS
headers_env = os.getenv('MODEL_HEADERS', '{}')
try:
self.headers = json.loads(headers_env)
except json.JSONDecodeError:
logger.error(f"Invalid MODEL_HEADERS format: {headers_env}. Using empty headers.")
self.headers = {}
# 添加默认 headers
self.headers.setdefault('Content-Type', 'application/json')
# if headers:
# self.headers.update(headers)
# 处理 api_key
if self.api_key and self.api_key != 'NotRequiredSinceWeAreLocal' and 'Authorization' not in self.headers:
self.headers['Authorization'] = f'Bearer {self.api_key}'
# # 添加默认 User-Agent
self.headers.setdefault('User-Agent', '')
# 验证配置
if not self.base_url:
raise ValueError("MODEL_BASE_URL is required")
if not self.model_name:
raise ValueError("MODEL_NAME is required")
logger.info(
f"Initialized ModelClient: base_url={self.base_url}, model_name={self.model_name}, headers={self.headers}")
def create(self, messages: List[Dict[str, str]], stream=True):
"""
Send a chat completion request to the model.
Args:
messages: List of messages [{'role': 'system', 'content': '...'}, ...]
**kwargs: Additional parameters (e.g., temperature, max_tokens)
Returns:
Dict: Model response
Raises:
Exception: If the request fails
"""
url = f"{self.base_url}/v1/chat/completions"
payload = {
'model': self.model_name,
'messages': messages,
'stream': stream
}
try:
with requests.post(
url,
headers=self.headers,
json=payload,
stream=True,
timeout=30
) as response:
response.raise_for_status()
for chunk in response.iter_lines():
if chunk:
decoded = chunk.decode('utf-8')
if decoded.startswith("data: "):
json_str = decoded[6:]
try:
data = json.loads(json_str)
if "choices" in data:
content = data["choices"][0].get("delta", {}).get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except requests.exceptions.RequestException as e:
raise ConnectionError(f"API请求失败: {str(e)}") from e
def test_connection(self) -> bool:
"""
Test connectivity to the model API.
Returns:
bool: True if connection is successful
"""
try:
response = requests.get(self.base_url, headers=self.headers, timeout=10)
response.raise_for_status()
return True
except requests.RequestException as e:
logger.error(f"Model connection test failed: {str(e)}")
return False