news 2026/5/25 20:12:34

Day 39 模型可视化与推理

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 39 模型可视化与推理

@浙大疏锦行

一、nn.Module核心自带方法

nn.Module封装了模型的核心逻辑,以下是高频使用的自带方法,按功能分类:

1. 模型状态控制(训练 / 评估模式)

方法作用
model.train()切换为训练模式:启用 Dropout、BatchNorm 等层的训练行为(默认模式)
model.eval()切换为评估模式:关闭 Dropout、固定 BatchNorm 均值 / 方差,用于推理 / 验证
model.training属性,返回布尔值:True= 训练模式,False= 评估模式

示例

import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 16, 3) self.dropout = nn.Dropout(0.5) # 训练时随机失活,评估时关闭 def forward(self, x): x = self.conv(x) x = self.dropout(x) return x model = SimpleCNN() print(model.training) # True(默认训练模式) model.eval() print(model.training) # False(评估模式,dropout失效) model.train() print(model.training) # True(切回训练模式)

2. 设备迁移(CPU/GPU)

方法作用
model.to(device)将模型所有参数 / 缓冲区移到指定设备(cuda/cpu/mps),返回模型实例
model.cuda()快捷方式:移到默认 GPU(等价于model.to('cuda')
model.cpu()快捷方式:移到 CPU(等价于model.to('cpu')

示例

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 模型移到GPU/CPU # 验证设备 print(next(model.parameters()).device) # 输出:cuda:0 或 cpu

3. 参数管理(查看 / 遍历参数)

方法作用
model.parameters()返回生成器:包含所有可训练参数(nn.Parameter类型)
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.state_dict()返回字典:{参数名:参数值},用于保存模型参数
model.load_state_dict()加载参数字典,用于恢复模型

示例

# 查看所有参数名称和形状 for name, param in model.named_parameters(): print(f"参数名:{name},形状:{param.shape},设备:{param.device}") # 统计总参数量(手动实现,无第三方库时用) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"总参数:{total_params},可训练参数:{trainable_params}")

4. 结构遍历(查看模型层)

方法作用
model.children()返回生成器:仅包含直接子层(如 Sequential 内的第一层),不递归
model.named_children()返回生成器:(层名,子层),仅直接子层
model.modules()返回生成器:递归包含所有层(包括嵌套层)
model.named_modules()返回生成器:(层名,层),递归所有层

示例

# 定义嵌套模型 class NestedModel(nn.Module): def __init__(self): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU() ) self.block2 = nn.Linear(16*30*30, 10) model = NestedModel() # children():仅直接子层(block1、block2) print("=== children() ===") for name, layer in model.named_children(): print(name, layer) # modules():递归所有层(包括Sequential内的Conv2d、ReLU) print("\n=== modules() ===") for name, layer in model.named_modules(): print(name, layer)

5. 前向传播与梯度

方法作用
model.forward(x)手动调用前向传播(不推荐),建议直接model(x)(调用__call__
model(x)等价于model.__call__(x),自动执行 forward + 钩子(hook)逻辑
model.zero_grad()清空所有参数的梯度(训练时反向传播前必须调用)

示例

x = torch.randn(1, 3, 32, 32).to(device) output = model(x) # 推荐:调用__call__,等价于model.forward(x) + 钩子 model.zero_grad() # 清空梯度 output.sum().backward() # 反向传播计算梯度

二、torchsummary库的summary方法

torchsummary是早期轻量库,核心功能是快速打印模型层结构、输出形状、总参数量,仅支持单输入模型,对嵌套模型 / 多输入支持差,维护较少。

1. 安装与基本用法

pip install torchsummary
from torchsummary import summary # 定义模型(输入:3通道32×32图像) class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32 * 8 * 8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = x.view(-1, 32 * 8 * 8) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x # 设备配置 + 模型初始化 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 调用summary:参数(模型,输入形状(通道,高,宽),batch_size可选) summary(model, input_size=(3, 32, 32), batch_size=1)

2. 输出解读

---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [1, 16, 32, 32] 448 MaxPool2d-2 [1, 16, 16, 16] 0 Conv2d-3 [1, 32, 16, 16] 4,640 MaxPool2d-4 [1, 32, 8, 8] 0 Linear-5 [1, 128] 262,272 Linear-6 [1, 10] 1,290 ================================================================ Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ----------------------------------------------------------------

3. 优缺点

优点缺点
极简、无多余依赖仅支持单输入模型
输出简洁、易理解对嵌套模型 / 多分支模型支持差
快速查看参数量 / 形状无批次维度、无内存占用细分
支持 GPU/CPU维护停滞,仅兼容 PyTorch 旧版本

三、torchinfo库的summary方法(推荐)

torchinfotorchsummary的升级版(原torchsummaryX),解决了多输入、嵌套模型、维度展示不清晰的问题,功能更全面,是当前 PyTorch 模型可视化的首选。

1. 安装与基本用法

pip install torchinfo
from torchinfo import summary # 复用上面的SimpleCNN模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 核心参数:model, input_size, batch_dim, device, col_width等 summary( model, input_size=(1, 3, 32, 32), # (batch_size, 通道, 高, 宽) batch_dim=0, # 批次维度的位置(默认0) device=device, # 模型设备 col_width=20, # 列宽 col_names=["input_size", "output_size", "num_params", "trainable"], # 显示列 row_settings=["var_names"] # 显示层变量名 )

2. 输出解读

========================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ========================================================================================== SimpleCNN (SimpleCNN) [1, 3, 32, 32] [1, 10] -- -- ├─Conv2d (conv1) [1, 3, 32, 32] [1, 16, 32, 32] 448 True ├─MaxPool2d (pool) [1, 16, 32, 32] [1, 16, 16, 16] -- -- ├─Conv2d (conv2) [1, 16, 16, 16] [1, 32, 16, 16] 4,640 True ├─MaxPool2d (pool) [1, 32, 16, 16] [1, 32, 8, 8] -- -- ├─Linear (fc1) [1, 2048] [1, 128] 262,272 True ├─Linear (fc2) [1, 128] [1, 10] 1,290 True ========================================================================================== Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 Total mult-adds (M): 2.15 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ==========================================================================================

四、推理的写法:评估模式

def evaluate_classification(model, dataloader, device): """ 分类模型评估:计算准确率、F1-score(宏平均)、混淆矩阵 """ # 1. 切换到评估模式(必须!) model.eval() # 2. 初始化指标容器 all_preds = [] all_labels = [] # 3. 关闭梯度计算(加速+省显存) with torch.no_grad(): for batch_idx, (x, y) in enumerate(dataloader): # 数据移到设备 x = x.to(device, dtype=torch.float32) y = y.to(device, dtype=torch.long) # 4. 推理(前向传播) outputs = model(x) # 输出:(batch_size, num_classes) preds = torch.argmax(outputs, dim=1) # 取概率最大的类别 # 5. 收集预测结果和真实标签(转回CPU便于计算指标) all_preds.extend(preds.cpu().numpy()) all_labels.extend(y.cpu().numpy()) # 可选:打印进度 if (batch_idx + 1) % 10 == 0: print(f"Batch [{batch_idx+1}/{len(dataloader)}] 完成") # 6. 计算评估指标 accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average="macro") # 宏平均F1(适合类别均衡) f1_weighted = f1_score(all_labels, all_preds, average="weighted") # 加权F1(适合类别不均衡) # 7. 打印结果 print("="*50) print(f"分类模型评估结果:") print(f"准确率 (Accuracy): {accuracy:.4f}") print(f"宏平均F1-score: {f1_macro:.4f}") print(f"加权F1-score: {f1_weighted:.4f}") print("="*50) return { "accuracy": accuracy, "f1_macro": f1_macro, "f1_weighted": f1_weighted, "preds": all_preds, "labels": all_labels } # 执行评估 eval_results = evaluate_classification(model, test_loader, device)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/25 12:55:15

计算机Java毕设实战-基于SpringBoot+Vue的智能驿站系统设计与实现基于Java Web的校园菜鸟驿站管理系统【完整源码+LW+部署说明+演示视频,全bao一条龙等】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/5/25 10:23:51

动态删除表外键依赖

这是一个用于 Liquibase 的 SQL 脚本,它的核心功能是动态查找并删除指向某个特定表字段的所有外键约束。它通常用在数据库重构中,当你需要删除一个有外键引用的表或字段时,必须先解除这些依赖。 下面我将对脚本进行逐行详解,并举例…

作者头像 李华
网站建设 2026/5/26 7:13:12

openFuyao 容器平台快速入门:Nginx 应用部署全流程实操

这里写目录标题一、引言“核心扩展”轻量化设计,从基础编排到异构算力调度可插拔架构:自由定义您的容器平台二、环境准备与安装部署测试环境准备(一)前提条件确认(二)版本下载与安装脚本获取(三…

作者头像 李华
网站建设 2026/5/25 16:17:43

警惕Vibe Coding ,Agentic Coding认知升级与实践避坑指南

首先需要说明的一点是,我本身不认为自己是 AI 编程的资深专家,所谓的实践完全是基于自己使用了多款 AI 编程产品的切身感受,以及跟 Qoder 研发同学、其他互联网公司 AI IDE 研发同学的交流,如果分享中的观点或者认知有跟你违背的地…

作者头像 李华
网站建设 2026/5/25 8:53:06

Qoder 实战:AI 驱动的研发效率与质量提升

大家好,我是迎天下网络科技有限公司的技术负责人李芳。作为一名一线的 Java 后端开发工程师,今天想和大家分享一下我在实际项目中使用 Qoder 的一些经验。通过几个真实的小案例,我会展示 Qoder 是如何帮助我们提升开发效率、优化代码质量的。…

作者头像 李华
网站建设 2026/5/25 21:17:41

国产期刊被EI收录!首个影响因子12分,录用率67%,国人友好~

Carbon Neutralization《碳中和》近日正式被国际权威数据库EI Compendex收录。该刊2022年创刊,每年出版6期,由温州大学与Wiley联合出版,集高影响因子、高录用率、对国人友好、出版速度快等优势于一身,具有高起点、高包容性和高亲和…

作者头像 李华