AI Agent 记忆系统设计与实现:让 AI 记住一切
前言
记忆系统是 AI Agent 能否长期有效工作的关键。一个没有记忆的 Agent 每次交互都像是与陌生人对话,而有完善记忆系统的 Agent 则可以像老朋友一样理解你的偏好、记住你的请求历史、提供连贯的服务。
我之前设计过一个客服 Agent,最初没有完善的记忆系统,导致每次对话都是独立的,用户需要反复说明背景信息。加入记忆系统后,用户体验有了质的飞跃。今天分享一些记忆系统的设计经验和实现方法。
记忆系统的分层架构
三层记忆模型
┌─────────────────────────────────────────────────────┐ │ Semantic Memory (语义记忆) │ │ 存储长期知识:事实、概念、通用规则 │ │ 特点:持久存储,很少变化 │ ├─────────────────────────────────────────────────────┤ │ Episodic Memory (情景记忆) │ │ 存储过去事件:对话摘要、已完成任务 │ │ 特点:定期更新,可遗忘旧内容 │ ├─────────────────────────────────────────────────────┤ │ Working Memory (工作记忆) │ │ 当前任务上下文:当前对话、最近状态 │ │ 特点:临时存储,快速访问 │ └─────────────────────────────────────────────────────┘记忆类型对比
| 类型 | 容量 | 访问频率 | 更新频率 | 持久性 |
|---|---|---|---|---|
| 工作记忆 | 几 KB | 每轮 | 每轮 | 会话级 |
| 情景记忆 | 几 MB | 每次新会话 | 每天 | 长期 |
| 语义记忆 | 无限制 | 偶尔 | 很少 | 永久 |
核心组件实现
1. 工作记忆
from dataclasses import dataclass, field from typing import List, Dict, Optional from datetime import datetime import json @dataclass class Message: """消息记录""" role: str # "user" | "assistant" | "system" content: str timestamp: datetime = field(default_factory=datetime.now) metadata: Dict = field(default_factory=dict) class WorkingMemory: """工作记忆 - 当前会话的上下文""" def __init__(self, max_messages: int = 50): self.messages: List[Message] = [] self.max_messages = max_messages self.session_id: Optional[str] = None self.metadata: Dict = {} def add_message(self, role: str, content: str, metadata: Dict = None): """添加消息""" message = Message( role=role, content=content, metadata=metadata or {} ) self.messages.append(message) # 清理超出限制的消息 if len(self.messages) > self.max_messages: self.messages = self.messages[-self.max_messages:] def get_context(self, max_tokens: int = 3000) -> str: """获取当前上下文""" context_parts = [] current_tokens = 0 # 从最近的消息开始 for msg in reversed(self.messages): msg_text = f"{msg.role}: {msg.content}" msg_tokens = self._count_tokens(msg_text) if current_tokens + msg_tokens > max_tokens: break context_parts.insert(0, msg_text) current_tokens += msg_tokens return "\n\n".join(context_parts) def get_recent(self, n: int = 5) -> List[Message]: """获取最近 n 条消息""" return self.messages[-n:] def clear(self): """清空工作记忆""" self.messages = [] def _count_tokens(self, text: str) -> int: """简单 token 计数""" return len(text) // 4 # 粗略估计2. 情景记忆
import sqlite3 from typing import List, Optional from datetime import datetime, timedelta @dataclass class Episode: """情景记忆条目""" id: str session_id: str summary: str key_points: List[str] entities: List[str] # 提到的实体 timestamp: datetime importance: float # 重要性评分 access_count: int = 0 last_accessed: Optional[datetime] = None class EpisodicMemory: """情景记忆 - 跨会话的记忆""" def __init__(self, db_path: str = "./episodic_memory.db"): self.db_path = db_path self._init_db() def _init_db(self): """初始化数据库""" conn = sqlite3.connect(self.db_path) conn.execute(""" CREATE TABLE IF NOT EXISTS episodes ( id TEXT PRIMARY KEY, session_id TEXT, summary TEXT, key_points TEXT, entities TEXT, timestamp DATETIME, importance REAL, access_count INTEGER, last_accessed DATETIME ) """) conn.commit() conn.close() def add_episode( self, session_id: str, messages: List[Message], summary: str = None ): """添加情景记忆""" if summary is None: summary = self._generate_summary(messages) key_points = self._extract_key_points(messages) entities = self._extract_entities(messages) episode = Episode( id=f"{session_id}_{datetime.now().timestamp()}", session_id=session_id, summary=summary, key_points=key_points, entities=entities, timestamp=datetime.now(), importance=1.0 ) self._save_episode(episode) return episode def _generate_summary(self, messages: List[Message]) -> str: """生成摘要""" # 简化实现,实际可用 LLM contents = [m.content[:100] for m in messages if len(m.content) > 20] return " | ".join(contents[:3]) def _extract_key_points(self, messages: List[Message]) -> List[str]: """提取关键点""" # 简化实现 return [] def _extract_entities(self, messages: List[Message]) -> List[str]: """提取实体""" # 简化实现 return [] def search( self, query: str, max_results: int = 5, recency_days: int = 30 ) -> List[Episode]: """搜索相关记忆""" conn = sqlite3.connect(self.db_path) cursor = conn.execute(""" SELECT * FROM episodes WHERE timestamp > datetime('now', '-' || ? || ' days') ORDER BY importance DESC, timestamp DESC LIMIT ? """, (recency_days, max_results)) rows = cursor.fetchall() conn.close() episodes = [] for row in rows: episodes.append(Episode( id=row[0], session_id=row[1], summary=row[2], key_points=json.loads(row[3]), entities=json.loads(row[4]), timestamp=datetime.fromisoformat(row[5]), importance=row[6], access_count=row[7], last_accessed=datetime.fromisoformat(row[8]) if row[8] else None )) return episodes def _save_episode(self, episode: Episode): """保存到数据库""" conn = sqlite3.connect(self.db_path) conn.execute(""" INSERT INTO episodes (id, session_id, summary, key_points, entities, timestamp, importance, access_count, last_accessed) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( episode.id, episode.session_id, episode.summary, json.dumps(episode.key_points), json.dumps(episode.entities), episode.timestamp.isoformat(), episode.importance, episode.access_count, episode.last_accessed.isoformat() if episode.last_accessed else None )) conn.commit() conn.close()3. 语义记忆
from typing import Dict, List, Optional import hashlib @dataclass class Knowledge: """知识条目""" id: str content: str source: str confidence: float tags: List[str] created_at: datetime updated_at: datetime class SemanticMemory: """语义记忆 - 长期知识库""" def __init__(self, vector_store, embedding_model): self.vector_store = vector_store self.embedding_model = embedding_model self.knowledge_graph: Dict[str, dict] = {} def add_knowledge( self, content: str, source: str, tags: List[str] = None, embedding: Optional[List[float]] = None ): """添加知识""" knowledge_id = hashlib.md5(content.encode()).hexdigest() if embedding is None: embedding = self.embedding_model.encode([content])[0] knowledge = Knowledge( id=knowledge_id, content=content, source=source, confidence=1.0, tags=tags or [], created_at=datetime.now(), updated_at=datetime.now() ) # 存储到向量数据库 self.vector_store.add( id=knowledge_id, vector=embedding, payload={"content": content, "source": source} ) # 更新知识图谱 self.knowledge_graph[knowledge_id] = { "content": content, "tags": tags, "related": [] } return knowledge def retrieve(self, query: str, top_k: int = 5) -> List[Knowledge]: """检索相关知识""" query_embedding = self.embedding_model.encode([query])[0] results = self.vector_store.search( vector=query_embedding, top_k=top_k ) return [ Knowledge( id=r["id"], content=r["payload"]["content"], source=r["payload"]["source"], confidence=r["score"], tags=r["payload"].get("tags", []), created_at=datetime.now(), updated_at=datetime.now() ) for r in results ] def update_knowledge(self, knowledge_id: str, content: str): """更新知识""" if knowledge_id in self.knowledge_graph: self.knowledge_graph[knowledge_id]["content"] = content self.knowledge_graph[knowledge_id]["updated_at"] = datetime.now()完整记忆系统
class UnifiedMemory: """统一记忆系统""" def __init__( self, vector_store=None, embedding_model=None, llm=None ): self.working_memory = WorkingMemory() self.episodic_memory = EpisodicMemory() self.semantic_memory = SemanticMemory(vector_store, embedding_model) self.llm = llm def add_user_message(self, content: str): """添加用户消息""" self.working_memory.add_message("user", content) def add_assistant_message(self, content: str): """添加助手消息""" self.working_memory.add_message("assistant", content) def get_full_context(self, include_semantic: bool = True) -> str: """获取完整上下文""" parts = [] # 1. 语义记忆(长期知识) if include_semantic and self.llm: semantic_context = self._get_semantic_context() if semantic_context: parts.append(f"【背景知识】\n{semantic_context}") # 2. 情景记忆(历史经验) episodic_context = self._get_episodic_context() if episodic_context: parts.append(f"【相关历史】\n{episodic_context}") # 3. 工作记忆(当前对话) working_context = self.working_memory.get_context() if working_context: parts.append(f"【当前对话】\n{working_context}") return "\n\n".join(parts) def _get_semantic_context(self) -> str: """获取语义记忆上下文""" if not self.llm: return "" # 从当前对话提取关键信息 recent = self.working_memory.get_recent(3) if not recent: return "" query = " ".join([m.content for m in recent]) # 检索相关知识 relevant_knowledge = self.semantic_memory.retrieve(query, top_k=3) if not relevant_knowledge: return "" return "\n".join([ f"- {k.content} (来源: {k.source})" for k in relevant_knowledge ]) def _get_episodic_context(self) -> str: """获取情景记忆上下文""" recent = self.working_memory.get_recent(3) if not recent: return "" # 搜索相关情景 query = " ".join([m.content for m in recent]) episodes = self.episodic_memory.search(query, max_results=2) if not episodes: return "" return "\n".join([ f"- {e.summary}" for e in episodes ]) def save_session(self, session_id: str): """保存当前会话到情景记忆""" messages = self.working_memory.messages if not messages: return self.episodic_memory.add_episode(session_id, messages) self.working_memory.clear() def clear_session(self): """清空当前会话""" self.working_memory.clear()记忆检索优化
基于重要性的衰减
class MemoryWithDecay(UnifiedMemory): """带重要性衰减的记忆系统""" def __init__(self, decay_rate: float = 0.95): super().__init__() self.decay_rate = decay_rate def decay_importance(self, days_passed: int) -> float: """计算衰减后的重要性""" return self.decay_rate ** days_passed def prune_old_memories(self, threshold: float = 0.1): """清理不重要的记忆""" # 删除重要性低于阈值的历史记录 pass主动记忆增强
class ProactiveMemory(UnifiedMemory): """主动记忆增强""" def __init__(self, *args, summary_threshold: int = 20, **kwargs): super().__init__(*args, **kwargs) self.summary_threshold = summary_threshold def should_summarize(self) -> bool: """判断是否需要总结""" return len(self.working_memory.messages) >= self.summary_threshold def proactive_summarize(self, session_id: str): """主动生成总结""" if not self.should_summarize(): return messages = self.working_memory.messages # 使用 LLM 生成总结 summary_prompt = f"""请总结以下对话的关键信息: {chr(10).join([f"{m.role}: {m.content}" for m in messages])} 请提取: 1. 对话主题 2. 关键决策或结论 3. 用户偏好或需求 4. 待处理事项 总结:""" summary = self.llm.generate(summary_prompt) # 保存到情景记忆 self.episodic_memory.add_episode(session_id, messages, summary) # 清空并保留摘要 self.working_memory.clear() self.working_memory.add_message( "system", f"[对话摘要] {summary}" )实际应用
class MemoryfulAgent: """带有记忆的 Agent""" def __init__(self): # 初始化各组件 self.memory = UnifiedMemory() self.llm = OpenAILLM() # 加载长期偏好 self._load_preferences() def _load_preferences(self): """加载用户偏好""" prefs = self.semantic_memory.retrieve("user preferences", top_k=5) for p in prefs: self.memory.working_memory.metadata["preferences"] = json.loads(p.content) def chat(self, user_input: str) -> str: """聊天""" # 1. 记录用户消息 self.memory.add_user_message(user_input) # 2. 获取完整上下文 context = self.memory.get_full_context() # 3. 构建 prompt prompt = f"""{context} 用户:{user_input} 请基于以上上下文回答用户问题。 """ # 4. 生成回答 response = self.llm.chat(prompt) # 5. 记录回答 self.memory.add_assistant_message(response) return response def end_session(self, session_id: str): """结束会话""" # 保存情景记忆 self.memory.save_session(session_id) # 提取并保存用户偏好 self._extract_and_save_preferences()总结
记忆系统是 AI Agent 的重要组成部分:
- 三层记忆:工作记忆、情景记忆、语义记忆各有分工
- 工作记忆:处理当前对话上下文
- 情景记忆:跨会话存储历史经验
- 语义记忆:长期知识和偏好
- 主动优化:定期总结和清理
关键要点:
- 合理的分层设计提高效率
- 定期总结防止上下文溢出
- 重要性衰减保证记忆时效性
- 与 LLM 结合实现智能检索