强化学习实战:用Python可视化不同策略下的状态访问分布(附代码)
在强化学习的研究和应用中,理解智能体如何探索环境是至关重要的。想象一下,你正在训练一个机器人穿越迷宫——有些策略会让它反复在同一个区域徘徊,而另一些策略则能引导它高效到达目标。这种差异不仅影响学习效率,更直接决定了最终策略的质量。本文将带你用Python构建一个可视化实验,直观展示不同策略如何塑造智能体的探索行为。
1. 环境搭建与基础概念
我们先从创建一个简单的网格世界开始。这个5x5的方格环境将作为我们的实验场地,智能体可以在其中移动,目标是到达右上角的终止状态。使用Gymnasium库可以快速实现这个环境:
import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap import gymnasium as gym from gymnasium import spaces class GridWorldEnv(gym.Env): def __init__(self, size=5): self.size = size self.observation_space = spaces.Discrete(size*size) self.action_space = spaces.Discrete(4) # 上:0, 右:1, 下:2, 左:3 self.goal = size*size - 1 # 右上角为终止状态 self.state = 0 def reset(self): self.state = 0 return self.state def step(self, action): x, y = self.state // self.size, self.state % self.size if action == 0: x = max(0, x-1) # 上 elif action == 1: y = min(self.size-1, y+1) # 右 elif action == 2: x = min(self.size-1, x+1) # 下 elif action == 3: y = max(0, y-1) # 左 self.state = x * self.size + y done = (self.state == self.goal) reward = 10 if done else -0.1 # 稀疏奖励设置 return self.state, reward, done, {}状态访问分布的核心公式为:
v^π(s) = (1-γ)∑_{t=1}^∞ γ^t P_t^π(s)其中γ是折扣因子,P_t^π(s)表示在策略π下时刻t处于状态s的概率。这个分布反映了智能体在长期交互中访问各个状态的频率。
2. 策略实现与访问统计
我们将对比三种典型策略:随机策略、目标导向策略和ε-贪婪策略。首先实现这些策略:
def random_policy(env): return env.action_space.sample() def goal_oriented_policy(env): x, y = env.state // env.size, env.state % env.size goal_x, goal_y = env.goal // env.size, env.goal % env.size # 优先减少x和y与目标的距离 if x < goal_x and y < goal_y: return np.random.choice([1, 2]) # 随机选择向右或向下 elif x < goal_x: return 2 # 向下 elif y < goal_y: return 1 # 向右 else: return env.action_space.sample() def epsilon_greedy_policy(env, Q, epsilon=0.1): if np.random.random() < epsilon: return env.action_space.sample() else: return np.argmax(Q[env.state])接下来是计算状态访问分布的函数。由于无限时间步不实际,我们通过足够多的episode来近似:
def compute_visitation_distribution(env, policy, episodes=1000, gamma=0.99): visitation = np.zeros(env.size * env.size) for _ in range(episodes): state = env.reset() t = 0 done = False while not done: action = policy(env) next_state, _, done, _ = env.step(action) # 更新访问分布 visitation[state] += (gamma ** t) state = next_state t += 1 # 归一化处理 visitation = (1 - gamma) * visitation / episodes return visitation3. 可视化对比分析
现在我们可以生成并可视化不同策略下的状态访问分布。使用Matplotlib的热力图能够直观展示差异:
def plot_visitation(visitation, size=5, title=""): plt.figure(figsize=(8, 6)) visitation_2d = visitation.reshape((size, size)) # 自定义颜色映射 colors = [(0, 'white'), (0.5, 'lightblue'), (1, 'darkblue')] cmap = LinearSegmentedColormap.from_list('custom', colors) plt.imshow(visitation_2d, cmap=cmap, interpolation='nearest') plt.colorbar(label='访问概率') plt.title(title) # 标注坐标 for i in range(size): for j in range(size): plt.text(j, i, f"{visitation_2d[i, j]:.3f}", ha="center", va="center", color="black") plt.xticks([]) plt.yticks([]) plt.show()运行三种策略的对比实验:
env = GridWorldEnv(size=5) # 随机策略 random_visitation = compute_visitation_distribution(env, random_policy) plot_visitation(random_visitation, title="随机策略状态访问分布") # 目标导向策略 goal_visitation = compute_visitation_distribution(env, goal_oriented_policy) plot_visitation(goal_visitation, title="目标导向策略状态访问分布") # ε-贪婪策略 (需要先训练Q表) Q = np.random.rand(env.size*env.size, env.action_space.n) for _ in range(1000): state = env.reset() done = False while not done: action = epsilon_greedy_policy(env, Q) next_state, reward, done, _ = env.step(action) Q[state][action] += 0.1 * (reward + 0.99 * np.max(Q[next_state]) - Q[state][action]) state = next_state epsilon_visitation = compute_visitation_distribution(env, lambda env: epsilon_greedy_policy(env, Q)) plot_visitation(epsilon_visitation, title="ε-贪婪策略状态访问分布")4. 结果解读与优化建议
从可视化结果中我们可以观察到几个关键现象:
随机策略:
- 访问分布相对均匀
- 中心区域访问频率略高(因为从起点出发更容易到达)
- 平均访问概率约0.04(1/25)
目标导向策略:
- 右下到左上的对角线区域访问频率显著提高
- 远离路径的状态几乎不被访问
- 最高访问概率可达0.15以上
ε-贪婪策略:
- 呈现折中特征
- 保留了一定的探索性
- 关键路径上的状态访问概率明显提升
优化探索的策略技巧:
对于稀疏奖励环境,可以尝试以下方法调整状态访问:
# 基于计数的探索奖励 def count_based_exploration(env, visitation_counts, beta=0.1): def reward_fn(state): return beta / np.sqrt(visitation_counts[state] + 1) return reward_fn调整折扣因子γ的影响:
当γ接近1时,智能体更关注长期访问分布 当γ较小时,近期访问状态权重更大
实际项目中,我们可以通过访问分布分析发现策略的探索不足问题。例如,如果某些关键状态从未被访问,可能需要:
- 增加探索率ε
- 引入内在好奇心机制
- 调整奖励函数鼓励探索
5. 高级应用:占用度量的可视化扩展
状态-动作对的占用度量ρ^π(s,a)提供了更细粒度的分析视角。我们可以扩展之前的代码来可视化这一度量:
def compute_occupancy_measure(env, policy, episodes=1000, gamma=0.99): occupancy = np.zeros((env.size * env.size, env.action_space.n)) for _ in range(episodes): state = env.reset() t = 0 done = False while not done: action = policy(env) next_state, _, done, _ = env.step(action) occupancy[state, action] += (gamma ** t) state = next_state t += 1 occupancy = (1 - gamma) * occupancy / episodes return occupancy def plot_occupancy(occupancy, size=5): fig, axes = plt.subplots(1, 4, figsize=(20, 5)) action_names = ['上', '右', '下', '左'] for a in range(4): occupancy_2d = occupancy[:, a].reshape((size, size)) axes[a].imshow(occupancy_2d, cmap='Blues', interpolation='nearest') axes[a].set_title(f"动作 {action_names[a]}") for i in range(size): for j in range(size): axes[a].text(j, i, f"{occupancy_2d[i, j]:.3f}", ha="center", va="center", color="black") plt.tight_layout() plt.show() # 示例:目标导向策略的占用度量 goal_occupancy = compute_occupancy_measure(env, goal_oriented_policy) plot_occupancy(goal_occupancy)这种可视化可以清晰展示策略在特定状态下偏好的动作,例如:
- 在目标路径上的状态会高频选择特定方向动作
- 随机策略则显示各动作分布均匀
- 训练良好的ε-贪婪策略会在关键状态表现出明显的动作偏好
6. 实际项目中的经验分享
在真实场景应用这些可视化技术时,有几个实用技巧值得注意:
大规模环境的采样优化:
# 使用稀疏矩阵存储大规模状态空间的访问计数 from scipy.sparse import dok_matrix class SparseVisitationCounter: def __init__(self, env): self.counts = dok_matrix((env.observation_space.n,), dtype=np.float32) def update(self, state, gamma, t): self.counts[state] += (gamma ** t)动态策略的实时监控:
- 在训练过程中定期保存访问分布快照
- 比较不同训练阶段的状态覆盖变化
- 检测策略是否过早收敛到局部最优
可视化优化技巧:
- 对访问概率使用对数尺度增强对比度
- 叠加环境的关键特征(如障碍物、奖励位置)
- 制作动态gif展示访问分布的演变过程
# 示例:带对数尺度的访问分布图 def plot_log_visitation(visitation, size=5): log_vis = np.log(visitation + 1e-10) # 避免log(0) plt.imshow(log_vis.reshape(size, size), cmap='viridis') plt.colorbar(label='log(访问概率)')通过这些方法,我们不仅能诊断策略问题,还能直观理解强化学习算法的探索特性。例如,在某个自动驾驶仿真项目中,访问分布可视化曾帮助我们发现策略总是避开某个特定路口,进而发现该区域的状态表征存在问题。