fix:历史消息纳入redis

This commit is contained in:
lei_y601
2025-07-10 16:18:52 +08:00
parent f9f57af405
commit 16d513be6f
5 changed files with 102 additions and 151 deletions

2
.env
View File

@@ -37,3 +37,5 @@ EDIT_MEETING = '/yonbip/uspace/external/access/edit'
DEFAULT_PAGESIZE = 30
DEFAULT_QUERY_SIZE = 30
REDIS_HOST=127.0.0.1
REDIS_PORT=6379

View File

@@ -3,4 +3,5 @@ requests==2.32.3
python-decouple==3.8
APScheduler==3.11.0
pydantic-ai==0.3.6
Hypercorn==0.17.3
Hypercorn==0.17.3
redis==6.2.0

View File

@@ -10,6 +10,7 @@ import threading, re
from pydantic_ai.settings import ModelSettings
from ..tools import getinfo, params_filter
from ..tools import redis_message_manage
import logging
MODEL_NAME = config('MODEL_NAME', default="")
@@ -396,7 +397,7 @@ def map_meetingname_to_id(params: dict):
'''
def build_prompt(params):
def build_prompt(params,map_meetingname_to_id:list):
"""构建增强提示词"""
qry_room_info_for_mart_str = '''
@@ -471,7 +472,7 @@ def build_prompt(params):
##ROLE##:
你是一个专业的OA会议预订助手请根据以下信息提供服务
现在时间是 :{time_now}
会议室名称与会议室ID的映射为(与会议ID无关): {map_meetingname_to_id(params=params)}
会议室名称与会议室ID的映射为(与会议ID无关): {map_meetingname_to_id_list}
##TASK##:
请按以下步骤处理:
1. 预订会议室
@@ -706,7 +707,7 @@ def process_query_book_room(**kwargs) -> tuple:
book_promot = f'''
系统调用API查询当前租户下已经预订的会议室的结果如下(如果有多个已预定会议,返回时请赋予编号):
{result}
严格按照TASK步骤6中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
严格按照TASK步骤7中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
'''
return False, book_promot
@@ -719,7 +720,7 @@ def process_user_query_book_room(**kwargs) -> tuple:
book_promot = f'''
系统调用API查询当前租户下已经预订的会议室的结果如下(如果有多个已预定会议,返回时请赋予编号):
{result}
严格按照TASK步骤6中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
严格按照TASK步骤7中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
'''
return False, book_promot
@@ -818,7 +819,8 @@ from pydantic_ai.providers.openai import OpenAIProvider
provider = OpenAIProvider(api_key=config('MODEL_API_KEY'), base_url=BASE_URL)
model = OpenAIModel(MODEL_NAME, provider=provider)
agent = Agent(model, system_prompt=build_prompt({'tenantId':config('TEMP_TENANT_ID')}))
map_meetingname_to_id_list=map_meetingname_to_id({'tenantId':config('TEMP_TENANT_ID')})
agent = Agent(model)
#seed 作用:控制面对相同问题,输出尽量保持一致,越大一致性越强
model_setting=ModelSettings(
temperature=config('MODEL_TEMPERATURE',cast=float,default=0.2),
@@ -826,14 +828,18 @@ model_setting=ModelSettings(
)
@agent.system_prompt(dynamic=True)
def gen_system_prompt()->str:
return build_prompt({},map_meetingname_to_id_list)
def process_chat(covers_id:str,user_id: str, user_input: str, params: dict):
history = []
query_history = dialog_manager.get_history(covers_id)
query_history = redis_message_manage.get_history(covers_id)
history.extend(query_history)
req = agent.run_sync(user_prompt=user_input,model_settings=model_setting, message_history=history if len(history) > 0 else None)
dialog_manager.set_msg(covers_id, req.all_messages())
redis_message_manage.set_history(covers_id, req.all_messages())
content = req.output
logger.info(f"process chat content is : {content}")
if "func_name" in content:
@@ -873,7 +879,7 @@ def process_chat(covers_id:str,user_id: str, user_input: str, params: dict):
content=resp.output
logger.info("final content => {0}".format(content))
new_content = check_and_process_think(content)
dialog_manager.add_message2(covers_id, resp.new_messages()[-1])
redis_message_manage.append_message(covers_id, resp.new_messages()[-1])
return {'response': new_content}
return {'response': content}

View File

@@ -1,142 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: yelinfeng
@contact: 1312954526@qq.com
@version: 1.0
@created: 2025/5/22 18:32
"""
import requests
import json
import logging
import os
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelClient:
"""
A generic model client to interact with various large language models.
Configured via environment variables.
"""
def __init__(
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
model_name: Optional[str] = None
):
"""
Initialize the model client using environment variables or provided parameters.
Environment Variables:
MODEL_BASE_URL: Base URL of the model API (e.g., http://10.254.23.128:31205/v1)
MODEL_API_KEY: API key (optional)
MODEL_NAME: Model name (e.g., deepseek-chat)
MODEL_HEADERS: JSON string of custom headers (e.g., '{"User-Agent": "Apifox/1.0.0"}')
Args:
base_url: Base URL (overrides env var)
api_key: API key (overrides env var)
headers: Custom headers (overrides env var)
model_name: Model name (overrides env var)
"""
self.base_url = base_url or os.getenv('MODEL_BASE_URL', 'http://10.254.23.128:31205/v1').rstrip('/')
self.api_key = api_key or os.getenv('MODEL_API_KEY', 'NotRequiredSinceWeAreLocal')
self.model_name = model_name or os.getenv('MODEL_NAME', 'deepseek-chat')
# 加载 MODEL_HEADERS
headers_env = os.getenv('MODEL_HEADERS', '{}')
try:
self.headers = json.loads(headers_env)
except json.JSONDecodeError:
logger.error(f"Invalid MODEL_HEADERS format: {headers_env}. Using empty headers.")
self.headers = {}
# 添加默认 headers
self.headers.setdefault('Content-Type', 'application/json')
# if headers:
# self.headers.update(headers)
# 处理 api_key
if self.api_key and self.api_key != 'NotRequiredSinceWeAreLocal' and 'Authorization' not in self.headers:
self.headers['Authorization'] = f'Bearer {self.api_key}'
# # 添加默认 User-Agent
self.headers.setdefault('User-Agent', '')
# 验证配置
if not self.base_url:
raise ValueError("MODEL_BASE_URL is required")
if not self.model_name:
raise ValueError("MODEL_NAME is required")
logger.info(
f"Initialized ModelClient: base_url={self.base_url}, model_name={self.model_name}, headers={self.headers}")
def create(self, messages: List[Dict[str, str]], stream=True):
"""
Send a chat completion request to the model.
Args:
messages: List of messages [{'role': 'system', 'content': '...'}, ...]
**kwargs: Additional parameters (e.g., temperature, max_tokens)
Returns:
Dict: Model response
Raises:
Exception: If the request fails
"""
url = f"{self.base_url}/v1/chat/completions"
payload = {
'model': self.model_name,
'messages': messages,
'stream': stream
}
try:
with requests.post(
url,
headers=self.headers,
json=payload,
stream=True,
timeout=30
) as response:
response.raise_for_status()
for chunk in response.iter_lines():
if chunk:
decoded = chunk.decode('utf-8')
if decoded.startswith("data: "):
json_str = decoded[6:]
try:
data = json.loads(json_str)
if "choices" in data:
content = data["choices"][0].get("delta", {}).get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except requests.exceptions.RequestException as e:
raise ConnectionError(f"API请求失败: {str(e)}") from e
def test_connection(self) -> bool:
"""
Test connectivity to the model API.
Returns:
bool: True if connection is successful
"""
try:
response = requests.get(self.base_url, headers=self.headers, timeout=10)
response.raise_for_status()
return True
except requests.RequestException as e:
logger.error(f"Model connection test failed: {str(e)}")
return False

View File

@@ -0,0 +1,84 @@
import json, logging
import redis
from decouple import config
from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelRequest, ModelResponse
from pydantic_core import to_jsonable_python
logger = logging.getLogger('django')
pool = redis.ConnectionPool(
host=config('REDIS_HOST', cast=str, default='localhost'),
port=config('REDIS_PORT', cast=int, default=6379),
max_connections=20, # 最大连接数
decode_responses=True,
encoding='utf-8',
)
default_redis_covers_id_prefix = 'yj_room:agent:'
def print_his(key,r):
es = r.lrange(key, 0, -1)
data=[json.loads(i) for i in es]
logger.info(f"----------{len(data)}-----------")
for i in data:
logger.info(f"data is {i}")
logger.info(f"----------{len(data)}-----------")
'''
裁剪历史消息 大于20条默认取后第一条和后五条
'''
def cut_history(covers_id: str, r):
key = default_redis_covers_id_prefix + covers_id
if r.llen(key) > 20:
logger.info("消息超过设定阈值,开始裁剪消息。。。")
script='''
local msg = redis.call('lindex', KEYS[1],0)
redis.call('ltrim', KEYS[1],-5,-1)
redis.call('lpush', KEYS[1],msg)
'''
r.eval(script,1,key)
'''
根据会话ID 获取历史消息列表
'''
def get_history(covers_id: str) -> list:
key = default_redis_covers_id_prefix + covers_id
r = redis.Redis(connection_pool=pool)
res = r.lrange(key, 0, -1)
if res and len(res) > 0:
return ModelMessagesTypeAdapter.validate_python([json.loads(i) for i in res])
else:
return []
'''
设置历史消息
'''
def set_history(covers_id: str, data):
key = default_redis_covers_id_prefix + covers_id
r = redis.Redis(connection_pool=pool)
script='''
redis.call('del', KEYS[1])
redis.call('RPUSH', KEYS[1], unpack(ARGV))
redis.call('expire', KEYS[1], 172800)
'''
r.eval(script,1,key,*[json.dumps(i, ensure_ascii=False) for i in to_jsonable_python(data)])
cut_history(covers_id, r)
'''
添加历史消息
'''
def append_message(covers_id: str, data):
key = default_redis_covers_id_prefix + covers_id
r = redis.Redis(connection_pool=pool)
r.rpush(key, json.dumps(to_jsonable_python(data), ensure_ascii=False))
cut_history(covers_id=covers_id, r=r)