fix:历史消息纳入redis
This commit is contained in:
2
.env
2
.env
@@ -37,3 +37,5 @@ EDIT_MEETING = '/yonbip/uspace/external/access/edit'
|
||||
DEFAULT_PAGESIZE = 30
|
||||
DEFAULT_QUERY_SIZE = 30
|
||||
|
||||
REDIS_HOST=127.0.0.1
|
||||
REDIS_PORT=6379
|
||||
@@ -3,4 +3,5 @@ requests==2.32.3
|
||||
python-decouple==3.8
|
||||
APScheduler==3.11.0
|
||||
pydantic-ai==0.3.6
|
||||
Hypercorn==0.17.3
|
||||
Hypercorn==0.17.3
|
||||
redis==6.2.0
|
||||
@@ -10,6 +10,7 @@ import threading, re
|
||||
from pydantic_ai.settings import ModelSettings
|
||||
|
||||
from ..tools import getinfo, params_filter
|
||||
from ..tools import redis_message_manage
|
||||
import logging
|
||||
|
||||
MODEL_NAME = config('MODEL_NAME', default="")
|
||||
@@ -396,7 +397,7 @@ def map_meetingname_to_id(params: dict):
|
||||
'''
|
||||
|
||||
|
||||
def build_prompt(params):
|
||||
def build_prompt(params,map_meetingname_to_id:list):
|
||||
"""构建增强提示词"""
|
||||
|
||||
qry_room_info_for_mart_str = '''
|
||||
@@ -471,7 +472,7 @@ def build_prompt(params):
|
||||
##ROLE##:
|
||||
你是一个专业的OA会议预订助手,请根据以下信息提供服务:
|
||||
现在时间是 :{time_now}
|
||||
会议室名称与会议室ID的映射为(与会议ID无关): {map_meetingname_to_id(params=params)}
|
||||
会议室名称与会议室ID的映射为(与会议ID无关): {map_meetingname_to_id_list}
|
||||
##TASK##:
|
||||
请按以下步骤处理:
|
||||
1. 预订会议室
|
||||
@@ -706,7 +707,7 @@ def process_query_book_room(**kwargs) -> tuple:
|
||||
book_promot = f'''
|
||||
系统调用API查询当前租户下已经预订的会议室的结果如下(如果有多个已预定会议,返回时请赋予编号):
|
||||
{result}
|
||||
严格按照TASK步骤6中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
|
||||
严格按照TASK步骤7中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
|
||||
'''
|
||||
return False, book_promot
|
||||
|
||||
@@ -719,7 +720,7 @@ def process_user_query_book_room(**kwargs) -> tuple:
|
||||
book_promot = f'''
|
||||
系统调用API查询当前租户下已经预订的会议室的结果如下(如果有多个已预定会议,返回时请赋予编号):
|
||||
{result}
|
||||
严格按照TASK步骤6中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
|
||||
严格按照TASK步骤7中的要求,请帮用户解析已有的预订会议室的结果,列出所有的已预订的会议(用户可能用于进一步查询或者取消已预订的会议),禁止使用历史消息推理,专注解析结果,不要省略且结果中需要返回会议ID,不要触发其他操作,并根据结果给予用户相应自然语言反馈
|
||||
'''
|
||||
return False, book_promot
|
||||
|
||||
@@ -818,7 +819,8 @@ from pydantic_ai.providers.openai import OpenAIProvider
|
||||
|
||||
provider = OpenAIProvider(api_key=config('MODEL_API_KEY'), base_url=BASE_URL)
|
||||
model = OpenAIModel(MODEL_NAME, provider=provider)
|
||||
agent = Agent(model, system_prompt=build_prompt({'tenantId':config('TEMP_TENANT_ID')}))
|
||||
map_meetingname_to_id_list=map_meetingname_to_id({'tenantId':config('TEMP_TENANT_ID')})
|
||||
agent = Agent(model)
|
||||
#seed 作用:控制面对相同问题,输出尽量保持一致,越大一致性越强
|
||||
model_setting=ModelSettings(
|
||||
temperature=config('MODEL_TEMPERATURE',cast=float,default=0.2),
|
||||
@@ -826,14 +828,18 @@ model_setting=ModelSettings(
|
||||
|
||||
)
|
||||
|
||||
@agent.system_prompt(dynamic=True)
|
||||
def gen_system_prompt()->str:
|
||||
return build_prompt({},map_meetingname_to_id_list)
|
||||
|
||||
|
||||
def process_chat(covers_id:str,user_id: str, user_input: str, params: dict):
|
||||
|
||||
history = []
|
||||
query_history = dialog_manager.get_history(covers_id)
|
||||
query_history = redis_message_manage.get_history(covers_id)
|
||||
history.extend(query_history)
|
||||
req = agent.run_sync(user_prompt=user_input,model_settings=model_setting, message_history=history if len(history) > 0 else None)
|
||||
dialog_manager.set_msg(covers_id, req.all_messages())
|
||||
redis_message_manage.set_history(covers_id, req.all_messages())
|
||||
content = req.output
|
||||
logger.info(f"process chat content is : {content}")
|
||||
if "func_name" in content:
|
||||
@@ -873,7 +879,7 @@ def process_chat(covers_id:str,user_id: str, user_input: str, params: dict):
|
||||
content=resp.output
|
||||
logger.info("final content => {0}".format(content))
|
||||
new_content = check_and_process_think(content)
|
||||
dialog_manager.add_message2(covers_id, resp.new_messages()[-1])
|
||||
redis_message_manage.append_message(covers_id, resp.new_messages()[-1])
|
||||
return {'response': new_content}
|
||||
return {'response': content}
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
@author: yelinfeng
|
||||
@contact: 1312954526@qq.com
|
||||
@version: 1.0
|
||||
@created: 2025/5/22 18:32
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelClient:
|
||||
"""
|
||||
A generic model client to interact with various large language models.
|
||||
Configured via environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the model client using environment variables or provided parameters.
|
||||
|
||||
Environment Variables:
|
||||
MODEL_BASE_URL: Base URL of the model API (e.g., http://10.254.23.128:31205/v1)
|
||||
MODEL_API_KEY: API key (optional)
|
||||
MODEL_NAME: Model name (e.g., deepseek-chat)
|
||||
MODEL_HEADERS: JSON string of custom headers (e.g., '{"User-Agent": "Apifox/1.0.0"}')
|
||||
|
||||
Args:
|
||||
base_url: Base URL (overrides env var)
|
||||
api_key: API key (overrides env var)
|
||||
headers: Custom headers (overrides env var)
|
||||
model_name: Model name (overrides env var)
|
||||
"""
|
||||
self.base_url = base_url or os.getenv('MODEL_BASE_URL', 'http://10.254.23.128:31205/v1').rstrip('/')
|
||||
self.api_key = api_key or os.getenv('MODEL_API_KEY', 'NotRequiredSinceWeAreLocal')
|
||||
self.model_name = model_name or os.getenv('MODEL_NAME', 'deepseek-chat')
|
||||
|
||||
# 加载 MODEL_HEADERS
|
||||
headers_env = os.getenv('MODEL_HEADERS', '{}')
|
||||
try:
|
||||
self.headers = json.loads(headers_env)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid MODEL_HEADERS format: {headers_env}. Using empty headers.")
|
||||
self.headers = {}
|
||||
|
||||
# 添加默认 headers
|
||||
self.headers.setdefault('Content-Type', 'application/json')
|
||||
# if headers:
|
||||
# self.headers.update(headers)
|
||||
|
||||
# 处理 api_key
|
||||
if self.api_key and self.api_key != 'NotRequiredSinceWeAreLocal' and 'Authorization' not in self.headers:
|
||||
self.headers['Authorization'] = f'Bearer {self.api_key}'
|
||||
|
||||
# # 添加默认 User-Agent
|
||||
self.headers.setdefault('User-Agent', '')
|
||||
|
||||
# 验证配置
|
||||
if not self.base_url:
|
||||
raise ValueError("MODEL_BASE_URL is required")
|
||||
if not self.model_name:
|
||||
raise ValueError("MODEL_NAME is required")
|
||||
|
||||
logger.info(
|
||||
f"Initialized ModelClient: base_url={self.base_url}, model_name={self.model_name}, headers={self.headers}")
|
||||
|
||||
def create(self, messages: List[Dict[str, str]], stream=True):
|
||||
"""
|
||||
Send a chat completion request to the model.
|
||||
|
||||
Args:
|
||||
messages: List of messages [{'role': 'system', 'content': '...'}, ...]
|
||||
**kwargs: Additional parameters (e.g., temperature, max_tokens)
|
||||
|
||||
Returns:
|
||||
Dict: Model response
|
||||
|
||||
Raises:
|
||||
Exception: If the request fails
|
||||
"""
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
payload = {
|
||||
'model': self.model_name,
|
||||
'messages': messages,
|
||||
'stream': stream
|
||||
}
|
||||
|
||||
try:
|
||||
with requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=30
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
for chunk in response.iter_lines():
|
||||
if chunk:
|
||||
decoded = chunk.decode('utf-8')
|
||||
if decoded.startswith("data: "):
|
||||
json_str = decoded[6:]
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if "choices" in data:
|
||||
content = data["choices"][0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ConnectionError(f"API请求失败: {str(e)}") from e
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
Test connectivity to the model API.
|
||||
|
||||
Returns:
|
||||
bool: True if connection is successful
|
||||
"""
|
||||
try:
|
||||
response = requests.get(self.base_url, headers=self.headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Model connection test failed: {str(e)}")
|
||||
return False
|
||||
84
yj_room_agent/tools/redis_message_manage.py
Normal file
84
yj_room_agent/tools/redis_message_manage.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json, logging
|
||||
|
||||
import redis
|
||||
from decouple import config
|
||||
from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelRequest, ModelResponse
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
logger = logging.getLogger('django')
|
||||
pool = redis.ConnectionPool(
|
||||
host=config('REDIS_HOST', cast=str, default='localhost'),
|
||||
port=config('REDIS_PORT', cast=int, default=6379),
|
||||
max_connections=20, # 最大连接数
|
||||
decode_responses=True,
|
||||
encoding='utf-8',
|
||||
|
||||
)
|
||||
default_redis_covers_id_prefix = 'yj_room:agent:'
|
||||
|
||||
def print_his(key,r):
|
||||
es = r.lrange(key, 0, -1)
|
||||
data=[json.loads(i) for i in es]
|
||||
logger.info(f"----------{len(data)}-----------")
|
||||
for i in data:
|
||||
logger.info(f"data is {i}")
|
||||
logger.info(f"----------{len(data)}-----------")
|
||||
|
||||
'''
|
||||
裁剪历史消息 大于20条,默认取后第一条和后五条
|
||||
'''
|
||||
|
||||
|
||||
def cut_history(covers_id: str, r):
|
||||
key = default_redis_covers_id_prefix + covers_id
|
||||
if r.llen(key) > 20:
|
||||
logger.info("消息超过设定阈值,开始裁剪消息。。。")
|
||||
script='''
|
||||
local msg = redis.call('lindex', KEYS[1],0)
|
||||
redis.call('ltrim', KEYS[1],-5,-1)
|
||||
redis.call('lpush', KEYS[1],msg)
|
||||
'''
|
||||
r.eval(script,1,key)
|
||||
|
||||
|
||||
'''
|
||||
根据会话ID 获取历史消息列表
|
||||
'''
|
||||
|
||||
|
||||
def get_history(covers_id: str) -> list:
|
||||
key = default_redis_covers_id_prefix + covers_id
|
||||
r = redis.Redis(connection_pool=pool)
|
||||
res = r.lrange(key, 0, -1)
|
||||
if res and len(res) > 0:
|
||||
return ModelMessagesTypeAdapter.validate_python([json.loads(i) for i in res])
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
'''
|
||||
设置历史消息
|
||||
'''
|
||||
|
||||
|
||||
def set_history(covers_id: str, data):
|
||||
key = default_redis_covers_id_prefix + covers_id
|
||||
r = redis.Redis(connection_pool=pool)
|
||||
script='''
|
||||
redis.call('del', KEYS[1])
|
||||
redis.call('RPUSH', KEYS[1], unpack(ARGV))
|
||||
redis.call('expire', KEYS[1], 172800)
|
||||
'''
|
||||
r.eval(script,1,key,*[json.dumps(i, ensure_ascii=False) for i in to_jsonable_python(data)])
|
||||
cut_history(covers_id, r)
|
||||
'''
|
||||
添加历史消息
|
||||
'''
|
||||
|
||||
|
||||
def append_message(covers_id: str, data):
|
||||
key = default_redis_covers_id_prefix + covers_id
|
||||
r = redis.Redis(connection_pool=pool)
|
||||
r.rpush(key, json.dumps(to_jsonable_python(data), ensure_ascii=False))
|
||||
cut_history(covers_id=covers_id, r=r)
|
||||
|
||||
Reference in New Issue
Block a user