142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@author: yelinfeng
|
|
@contact: 1312954526@qq.com
|
|
@version: 1.0
|
|
@created: 2025/5/22 18:32
|
|
"""
|
|
|
|
import requests
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelClient:
|
|
"""
|
|
A generic model client to interact with various large language models.
|
|
Configured via environment variables.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
model_name: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the model client using environment variables or provided parameters.
|
|
|
|
Environment Variables:
|
|
MODEL_BASE_URL: Base URL of the model API (e.g., http://10.254.23.128:31205/v1)
|
|
MODEL_API_KEY: API key (optional)
|
|
MODEL_NAME: Model name (e.g., deepseek-chat)
|
|
MODEL_HEADERS: JSON string of custom headers (e.g., '{"User-Agent": "Apifox/1.0.0"}')
|
|
|
|
Args:
|
|
base_url: Base URL (overrides env var)
|
|
api_key: API key (overrides env var)
|
|
headers: Custom headers (overrides env var)
|
|
model_name: Model name (overrides env var)
|
|
"""
|
|
self.base_url = base_url or os.getenv('MODEL_BASE_URL', 'http://10.254.23.128:31205/v1').rstrip('/')
|
|
self.api_key = api_key or os.getenv('MODEL_API_KEY', 'NotRequiredSinceWeAreLocal')
|
|
self.model_name = model_name or os.getenv('MODEL_NAME', 'deepseek-chat')
|
|
|
|
# 加载 MODEL_HEADERS
|
|
headers_env = os.getenv('MODEL_HEADERS', '{}')
|
|
try:
|
|
self.headers = json.loads(headers_env)
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Invalid MODEL_HEADERS format: {headers_env}. Using empty headers.")
|
|
self.headers = {}
|
|
|
|
# 添加默认 headers
|
|
self.headers.setdefault('Content-Type', 'application/json')
|
|
# if headers:
|
|
# self.headers.update(headers)
|
|
|
|
# 处理 api_key
|
|
if self.api_key and self.api_key != 'NotRequiredSinceWeAreLocal' and 'Authorization' not in self.headers:
|
|
self.headers['Authorization'] = f'Bearer {self.api_key}'
|
|
|
|
# # 添加默认 User-Agent
|
|
self.headers.setdefault('User-Agent', '')
|
|
|
|
# 验证配置
|
|
if not self.base_url:
|
|
raise ValueError("MODEL_BASE_URL is required")
|
|
if not self.model_name:
|
|
raise ValueError("MODEL_NAME is required")
|
|
|
|
logger.info(
|
|
f"Initialized ModelClient: base_url={self.base_url}, model_name={self.model_name}, headers={self.headers}")
|
|
|
|
def create(self, messages: List[Dict[str, str]], stream=True):
|
|
"""
|
|
Send a chat completion request to the model.
|
|
|
|
Args:
|
|
messages: List of messages [{'role': 'system', 'content': '...'}, ...]
|
|
**kwargs: Additional parameters (e.g., temperature, max_tokens)
|
|
|
|
Returns:
|
|
Dict: Model response
|
|
|
|
Raises:
|
|
Exception: If the request fails
|
|
"""
|
|
url = f"{self.base_url}/v1/chat/completions"
|
|
payload = {
|
|
'model': self.model_name,
|
|
'messages': messages,
|
|
'stream': stream
|
|
}
|
|
|
|
try:
|
|
with requests.post(
|
|
url,
|
|
headers=self.headers,
|
|
json=payload,
|
|
stream=True,
|
|
timeout=30
|
|
) as response:
|
|
response.raise_for_status()
|
|
|
|
for chunk in response.iter_lines():
|
|
if chunk:
|
|
decoded = chunk.decode('utf-8')
|
|
if decoded.startswith("data: "):
|
|
json_str = decoded[6:]
|
|
try:
|
|
data = json.loads(json_str)
|
|
if "choices" in data:
|
|
content = data["choices"][0].get("delta", {}).get("content", "")
|
|
if content:
|
|
yield content
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
raise ConnectionError(f"API请求失败: {str(e)}") from e
|
|
|
|
def test_connection(self) -> bool:
|
|
"""
|
|
Test connectivity to the model API.
|
|
|
|
Returns:
|
|
bool: True if connection is successful
|
|
"""
|
|
try:
|
|
response = requests.get(self.base_url, headers=self.headers, timeout=10)
|
|
response.raise_for_status()
|
|
return True
|
|
except requests.RequestException as e:
|
|
logger.error(f"Model connection test failed: {str(e)}")
|
|
return False |