news 2026/5/26 2:07:18

day37简单的神经网络@浙大疏锦行

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
day37简单的神经网络@浙大疏锦行

day37简单的神经网络@浙大疏锦行

使用 sklearn 的 load_digits 数据集 (8x8 像素的手写数字) 进行 MLP 训练。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromsklearn.datasetsimportload_digitsfromsklearn.model_selectionimporttrain_test_splitfromsklearn.preprocessingimportMinMaxScalerimportnumpyasnpimportmatplotlib.pyplotasplt# 1. 加载数据digits=load_digits()X=digits.data y=digits.targetprint(f"数据形状:{X.shape}")print(f"标签形状:{y.shape}")# 查看一张图片plt.imshow(digits.images[0],cmap='gray')plt.title(f"Label:{y[0]}")plt.show()

数据形状: (1797, 64) 标签形状: (1797,)

# 2. 数据预处理# 划分训练集和测试集X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)# 归一化scaler=MinMaxScaler()X_train=scaler.fit_transform(X_train)X_test=scaler.transform(X_test)# 转换为 TensorX_train=torch.FloatTensor(X_train)y_train=torch.LongTensor(y_train)X_test=torch.FloatTensor(X_test)y_test=torch.LongTensor(y_test)print("训练集 Tensor 形状:",X_train.shape)print("测试集 Tensor 形状:",X_test.shape)

训练集 Tensor 形状: torch.Size([1437, 64])

测试集 Tensor 形状: torch.Size([360, 64])

# 3. 定义模型classMLP(nn.Module):def__init__(self):super(MLP,self).__init__()# 输入层 64 (8*8像素) -> 隐藏层 32 -> 输出层 10 (0-9数字)self.fc1=nn.Linear(64,32)self.relu=nn.ReLU()self.fc2=nn.Linear(32,10)defforward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)returnout model=MLP()print(model)

MLP(

(fc1): Linear(in_features=64, out_features=32, bias=True) (relu): ReLU()

(fc2): Linear(in_features=32, out_features=10, bias=True)

)

# 4. 定义损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.SGD(model.parameters(),lr=0.1)# 学习率稍微调大一点,或者增加epoch
# 5. 训练模型num_epochs=2000losses=[]forepochinrange(num_epochs):# 前向传播outputs=model(X_train)loss=criterion(outputs,y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if(epoch+1)%100==0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss:{loss.item():.4f}')

# 6. 可视化损失plt.plot(range(num_epochs),losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss')plt.show()

# 7. 模型评估withtorch.no_grad():# 训练集准确率outputs_train=model(X_train)_,predicted_train=torch.max(outputs_train,1)accuracy_train=(predicted_train==y_train).sum().item()/y_train.size(0)# 测试集准确率outputs_test=model(X_test)_,predicted_test=torch.max(outputs_test,1)accuracy_test=(predicted_test==y_test).sum().item()/y_test.size(0)print(f'训练集准确率:{accuracy_train:.4f}')print(f'测试集准确率:{accuracy_test:.4f}')

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/25 6:20:22

【time-rs】Duration 结构体详解

这是一个 Rust 时间库中的 Duration 结构体实现,提供高精度的时间跨度表示。 1. 主要特性 纳秒级精度:由整秒和纳秒部分组成支持负值:与标准库的 std::time::Duration 不同,支持负时间间隔安全边界检查:使用 RangedI32…

作者头像 李华
网站建设 2026/5/25 11:47:59

10398_基于SSM的教学评价管理系统

1、项目包含项目源码、项目文档、数据库脚本、软件工具等资料;带你从零开始部署运行本套系统。2、项目介绍教学评价系统是以Java平台作为开发环境,采用MySQL数据库作为后台,使用Eclipse作为开发工具进行设计。本系统主要实现了教学评价模块、…

作者头像 李华
网站建设 2026/5/26 4:17:43

Go语言变量

Go变量声明的核心机制 静态类型语言要求变量在使用前必须声明,明确内存边界。Go作为静态语言,通过变量声明实现这一机制: 变量绑定特定内存区域,类型信息确定操作边界声明形式为:var 变量名 类型 值未显式初始化时自动…

作者头像 李华
网站建设 2026/5/26 4:37:41

【高可用系统架构】

系统高可用实现手段 冗余与无单点设计 部署关键节点时避免单点故障,例如负载均衡采用双节点Keepalived方案(如Nginx/HAProxy/LVS),通过虚拟IP实现故障自动切换。网络通信配置多线路(如移动电信双线)&#x…

作者头像 李华
网站建设 2026/5/26 5:14:18

高频软件测试基础面试题

在软件测试的面试过程中,面试官会问些基础的软件测试知识,下面为大家整理了一些高频软件测试面试必备的基础题,拿走不谢~ 一、什么是软件测试 为了发现程序中的错误而执行程序的过程。 二、软件测试的原则 1、完全测试程序是不可能的 2、…

作者头像 李华
网站建设 2026/5/26 5:14:19

如何准确判断json文件并且拿到我想要的信息

写在前面,自从发现拿到json解析后的文件中有我们想要的信息后,我稍微有点迷上这种方法,但是拿到内容后要怎么拿到想要的信息呢,字典列表相互嵌套,我头都晕了方法:首先就是把json解析后的文本保存成.json的形…

作者头像 李华