news 2026/6/20 15:22:06

关于transformer的注意力权重可视化

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
关于transformer的注意力权重可视化

可视化

import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns from typing import Optional, List import os def visualize_attention_distribution( attentions, input_ids, processor, gt_start_frame, gt_end_frame, query_text, video_id: str, save_dir: str = "/home/share/svmd5vm0/home/scut_czy1/attn_map", show_all_layers: bool = True, figsize: tuple = (20, 12), ): """ 可视化query对各帧的注意力分布 Args: attentions: 模型输出的注意力 tuple of (batch, num_heads, seq_len, seq_len) input_ids: 输入token ids processor: tokenizer processor gt_start_frame: 真实起始帧 gt_end_frame: 真实结束帧 query_text: 查询文本 video_id: 视频ID,用于保存文件名 save_dir: 保存目录 show_all_layers: 是否显示所有层的注意力 figsize: 图表大小 """ os.makedirs(save_dir, exist_ok=True) # 1. 获取特殊token的ID vision_start_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_start|>') vision_end_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_end|>') # 2. 定位query token的位置 input_ids_list = input_ids[0].tolist() query = query_text.strip() if query.endswith('.'): query = query[:-1] query_ids = processor.tokenizer(query, add_special_tokens=False)["input_ids"] query_start_idx = None query_end_idx = None for i in range(len(input_ids_list) - len(query_ids) + 1): if input_ids_list[i:i + len(query_ids)] == query_ids: query_start_idx = i query_end_idx = i + len(query_ids) - 1 break if query_start_idx is None: print(f"Warning: Query tokens not found for video {video_id}") return # 3. 定位每一帧的vision token位置 vision_start_indices = [i for i, x in enumerate(input_ids_list) if x == vision_start_token_id] vision_end_indices = [i for i, x in enumerate(input_ids_list) if x == vision_end_token_id] num_frames = len(vision_start_indices) num_layers = len(attentions) if num_frames == 0: print(f"Warning: No vision tokens found for video {video_id}") return gt_end_frame = min(gt_end_frame, num_frames - 1) # 4. 提取每一层、每一帧的注意力分数 # layer_frame_attention: [num_layers, num_frames] layer_frame_attention = [] for layer_idx in range(num_layers): frame_scores = [] layer_attn = attentions[layer_idx][0] # [num_heads, seq_len, seq_len] for frame_idx in range(num_frames): v_start = vision_start_indices[frame_idx] v_end = vision_end_indices[frame_idx] # 提取 query tokens -> 该帧vision tokens 的注意力 query_to_frame_attn = layer_attn[:, query_start_idx:query_end_idx+1, v_start+1:v_end] # 对所有头、query tokens、vision patches取平均 frame_score = query_to_frame_attn.mean().item() frame_scores.append(frame_score) layer_frame_attention.append(frame_scores) layer_frame_attention = np.array(layer_frame_attention) # [num_layers, num_frames] # 5. 计算平均注意力(所有层平均) avg_attention = layer_frame_attention.mean(axis=0) # [num_frames] # 6. 创建可视化 if show_all_layers and num_layers > 1: fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3) # ========== 图1: 所有层的注意力热力图 ========== ax1 = fig.add_subplot(gs[0, :]) im = ax1.imshow(layer_frame_attention, aspect='auto', cmap='YlOrRd', interpolation='nearest') ax1.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax1.set_ylabel('Layer Index', fontsize=12, fontweight='bold') ax1.set_title(f'Attention Heatmap: Query → Frames (All Layers)\nQuery: "{query}"', fontsize=14, fontweight='bold', pad=20) # 标记目标区域 ax1.axvline(x=gt_start_frame-0.5, color='blue', linestyle='--', linewidth=2, label='GT Start') ax1.axvline(x=gt_end_frame+0.5, color='blue', linestyle='--', linewidth=2, label='GT End') # 添加颜色条 cbar = plt.colorbar(im, ax=ax1) cbar.set_label('Attention Score', fontsize=10, fontweight='bold') ax1.legend(loc='upper right') # ========== 图2: 平均注意力柱状图 ========== ax2 = fig.add_subplot(gs[1, :]) frames = np.arange(num_frames) colors = ['lightcoral' if gt_start_frame <= i <= gt_end_frame else 'lightblue' for i in range(num_frames)] bars = ax2.bar(frames, avg_attention, color=colors, edgecolor='black', linewidth=0.5) # 高亮目标帧 for i in range(gt_start_frame, gt_end_frame + 1): bars[i].set_edgecolor('red') bars[i].set_linewidth(2) ax2.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax2.set_ylabel('Average Attention Score', fontsize=12, fontweight='bold') ax2.set_title('Average Attention Distribution (All Layers & Heads)', fontsize=14, fontweight='bold', pad=15) ax2.grid(axis='y', alpha=0.3, linestyle='--') # 添加目标区域标注 ax2.axvspan(gt_start_frame-0.5, gt_end_frame+0.5, alpha=0.2, color='red', label=f'GT Frames [{gt_start_frame}, {gt_end_frame}]') ax2.legend(loc='upper right') # ========== 图3: 目标帧 vs 非目标帧的注意力对比 ========== ax3 = fig.add_subplot(gs[2, 0]) target_attention = avg_attention[gt_start_frame:gt_end_frame+1] non_target_mask = np.ones(num_frames, dtype=bool) non_target_mask[gt_start_frame:gt_end_frame+1] = False non_target_attention = avg_attention[non_target_mask] comparison_data = [target_attention, non_target_attention] box = ax3.boxplot(comparison_data, labels=['Target Frames', 'Non-Target Frames'], patch_artist=True, showmeans=True) box['boxes'][0].set_facecolor('lightcoral') box['boxes'][1].set_facecolor('lightblue') ax3.set_ylabel('Attention Score', fontsize=12, fontweight='bold') ax3.set_title('Target vs Non-Target Frames', fontsize=13, fontweight='bold', pad=15) ax3.grid(axis='y', alpha=0.3, linestyle='--') # 添加统计信息 target_mean = target_attention.mean() non_target_mean = non_target_attention.mean() ratio = target_mean / (non_target_mean + 1e-7) stats_text = f'Target Mean: {target_mean:.4f}\n' stats_text += f'Non-Target Mean: {non_target_mean:.4f}\n' stats_text += f'Ratio: {ratio:.2f}x' ax3.text(0.02, 0.98, stats_text, transform=ax3.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) # ========== 图4: 逐层注意力趋势 ========== ax4 = fig.add_subplot(gs[2, 1]) layer_target_mean = [] layer_non_target_mean = [] for layer_idx in range(num_layers): target_mean = layer_frame_attention[layer_idx, gt_start_frame:gt_end_frame+1].mean() non_target_mean = layer_frame_attention[layer_idx, non_target_mask].mean() layer_target_mean.append(target_mean) layer_non_target_mean.append(non_target_mean) layers = np.arange(num_layers) ax4.plot(layers, layer_target_mean, 'o-', color='red', linewidth=2, markersize=6, label='Target Frames') ax4.plot(layers, layer_non_target_mean, 's-', color='blue', linewidth=2, markersize=6, label='Non-Target Frames') ax4.set_xlabel('Layer Index', fontsize=12, fontweight='bold') ax4.set_ylabel('Mean Attention Score', fontsize=12, fontweight='bold') ax4.set_title('Layer-wise Attention Trend', fontsize=13, fontweight='bold', pad=15) ax4.legend(loc='best') ax4.grid(alpha=0.3, linestyle='--') else: # 简化版:只显示平均注意力 fig, ax = plt.subplots(figsize=(12, 6)) frames = np.arange(num_frames) colors = ['lightcoral' if gt_start_frame <= i <= gt_end_frame else 'lightblue' for i in range(num_frames)] bars = ax.bar(frames, avg_attention, color=colors, edgecolor='black', linewidth=0.5) for i in range(gt_start_frame, gt_end_frame + 1): bars[i].set_edgecolor('red') bars[i].set_linewidth(2) ax.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax.set_ylabel('Average Attention Score', fontsize=12, fontweight='bold') ax.set_title(f'Attention Distribution\nQuery: "{query}"', fontsize=14, fontweight='bold', pad=20) ax.grid(axis='y', alpha=0.3, linestyle='--') ax.axvspan(gt_start_frame-0.5, gt_end_frame+0.5, alpha=0.2, color='red', label=f'GT Frames [{gt_start_frame}, {gt_end_frame}]') ax.legend(loc='upper right') # 7. 保存图表 save_path = os.path.join(save_dir, f"{video_id}_attention_distribution.png") plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Saved attention visualization to: {save_path}") plt.close() # 8. 保存数值数据(CSV) save_data_path = os.path.join(save_dir, f"{video_id}_attention_data.npz") np.savez( save_data_path, layer_frame_attention=layer_frame_attention, avg_attention=avg_attention, gt_start_frame=gt_start_frame, gt_end_frame=gt_end_frame, query=query ) print(f"Saved attention data to: {save_data_path}") # 9. 返回统计信息 target_attention = avg_attention[gt_start_frame:gt_end_frame+1] non_target_mask = np.ones(num_frames, dtype=bool) non_target_mask[gt_start_frame:gt_end_frame+1] = False non_target_attention = avg_attention[non_target_mask] stats = { 'video_id': video_id, 'query': query, 'num_frames': num_frames, 'num_layers': num_layers, 'gt_range': (gt_start_frame, gt_end_frame), 'target_attention_mean': float(target_attention.mean()), 'target_attention_std': float(target_attention.std()), 'non_target_attention_mean': float(non_target_attention.mean()), 'non_target_attention_std': float(non_target_attention.std()), 'attention_ratio': float(target_attention.mean() / (non_target_attention.mean() + 1e-7)), 'attention_concentration': float(target_attention.sum() / avg_attention.sum()), } return stats def batch_visualize_attention( model, processor, data_list: List[dict], save_dir: str = "/home/share/svmd5vm0/home/scut_czy1/attn_map", device: str = "cuda", ): """ 批量处理多个视频的注意力可视化 Args: model: 模型 processor: processor data_list: 数据列表,每个元素包含: - video_path: 视频路径 - query: 查询文本 - start_frame: 起始帧 - end_frame: 结束帧 - video_id: 视频ID save_dir: 保存目录 device: 设备 """ model.eval() all_stats = [] for data in data_list: print(f"\nProcessing video: {data['video_id']}") # 准备输入 messages = [ { "role": "user", "content": [ { "type": "video", "video": data['video_path'], "fps": 1 }, {"type": "text", "text": data['query']}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) inputs = inputs.to(device) # 前向传播(获取注意力) with torch.no_grad(): outputs = model(**inputs, output_attentions=True) # 可视化 stats = visualize_attention_distribution( attentions=outputs.attentions, input_ids=inputs['input_ids'], processor=processor, gt_start_frame=data['start_frame'], gt_end_frame=data['end_frame'], query_text=data['query'], video_id=data['video_id'], save_dir=save_dir, ) all_stats.append(stats) # 保存所有统计信息 import json stats_path = os.path.join(save_dir, "all_stats.json") with open(stats_path, 'w') as f: json.dump(all_stats, f, indent=4) print(f"\nSaved all statistics to: {stats_path}") return all_stats # ========== 使用示例 ========== if __name__ == "__main__": """ 使用示例 """ # 示例1: 单个视频可视化 from transformers import Qwen3VLForConditionalGeneration, AutoProcessor model = Qwen3VLForConditionalGeneration.from_pretrained( "/home/share/svmd5vm0/home/scut_czy1/Qwen3-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager" ) processor = AutoProcessor.from_pretrained("/home/share/svmd5vm0/home/scut_czy1/Qwen3-VL-2B-Instruct") query_text = "A person is reading a book" # 准备输入 messages = [{ "role": "user", "content": [ {"type": "video", "video": "/home/share/svmd5vm0/home/scut_czy1/datasets/Charadesfps/videos_1FPS/0A8CF.mp4", "fps": 1}, {"type": "text", "text": query_text}, ], }] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to("cuda") # 获取输出(带注意力) with torch.no_grad(): outputs = model(**inputs, output_attentions=True) # 可视化 stats = visualize_attention_distribution( attentions=outputs.attentions, input_ids=inputs['input_ids'], processor=processor, gt_start_frame=5, gt_end_frame=9, query_text= query_text, video_id="video_001", save_dir="/home/share/svmd5vm0/home/scut_czy1/attn_map" ) print("Statistics:", stats) # 示例2: 批量处理 """ data_list = [ { 'video_path': 'video1.mp4', 'query': 'person drinking water', 'start_frame': 5, 'end_frame': 9, 'video_id': 'video_001' }, { 'video_path': 'video2.mp4', 'query': 'person opening door', 'start_frame': 10, 'end_frame': 15, 'video_id': 'video_002' }, ] all_stats = batch_visualize_attention( model=model, processor=processor, data_list=data_list, save_dir="./visualizations" ) """ print("可视化工具已准备就绪!")
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/21 9:49:11

YOLOv12架构突破:通过IAFF注意力融合机制实现多尺度特征自适应优化**

购买即可解锁300+YOLO优化文章,并且还有海量深度学习复现项目,价格仅需两杯奶茶的钱,别人有的本专栏也有! 文章目录 **YOLOv12架构突破:通过IAFF注意力融合机制实现多尺度特征自适应优化** **IAFF核心模块完整实现** 代码链接与详细流程 YOLOv12架构突破:通过IAFF注意力…

作者头像 李华
网站建设 2026/6/21 9:28:01

XLeRobot强化学习训练:5步掌握ManiSkill仿真平台实战技巧

XLeRobot强化学习训练&#xff1a;5步掌握ManiSkill仿真平台实战技巧 【免费下载链接】XLeRobot XLeRobot: Practical Household Dual-Arm Mobile Robot for ~$660 项目地址: https://gitcode.com/GitHub_Trending/xl/XLeRobot 还在为实体机器人训练的高成本和复杂调试而…

作者头像 李华
网站建设 2026/6/21 8:16:01

从零构建Q#-Python同步系统:手把手教你搭建可靠数据通道

第一章&#xff1a;Q#-Python 变量同步概述在量子计算与经典计算混合编程的场景中&#xff0c;Q# 与 Python 的协同工作成为实现高效算法设计的关键。变量同步是这一协作模式中的核心环节&#xff0c;它确保量子操作的结果能够被经典程序正确读取和处理&#xff0c;同时允许经典…

作者头像 李华
网站建设 2026/6/20 9:36:36

39、Linux系统编程知识全解析

Linux系统编程知识全解析 1. 相关书籍推荐 在学习Linux系统编程时,有不少优秀的书籍可供参考: | 书名 | 作者 | 出版信息 | 简介 | | — | — | — | — | | Managing Projects with GNU Make, 3rd ed. | Robert Mecklenburg | O’Reilly Media, 2004 | 对GNU Make这一在…

作者头像 李华
网站建设 2026/6/21 8:15:57

21、Linux 系统实用软件与游戏全攻略

Linux 系统实用软件与游戏全攻略 1. 系统自带小游戏 Linux 系统中可能预装了许多小游戏,以下是一些从标准 Linux 发行版 CD 安装的示例: | 游戏名称 | 游戏类型 | 运行方式 | 备注 | | ---- | ---- | ---- | ---- | | kpat | 耐心纸牌游戏 | 在 X 终端运行 | sol(快速)…

作者头像 李华
网站建设 2026/6/20 20:20:04

Wan2.2 AI视频生成终极指南:从入门到精通

想象一下&#xff0c;只需几句描述&#xff0c;AI就能为你创作出专业级的720P视频&#xff0c;这不再是科幻电影中的场景。Wan2.2-TI2V-5B作为业界领先的开源视频生成模型&#xff0c;将这一梦想变为现实。本指南将带你从零开始&#xff0c;掌握这一革命性技术的完整应用流程。…

作者头像 李华