news 2026/7/4 9:29:17

DAY 39 早停策略和模型权重的保存

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DAY 39 早停策略和模型权重的保存

一、过拟合的判断

在机器学习中,过拟合(Overfitting) 的核心定义是:模型在训练集上表现极佳(损失极低、准确率极高),但在未见过的测试集 / 验证集上表现大幅下降,本质是模型 “死记硬背” 了训练数据的噪声和细节,而非学习到数据的通用规律。

核心判断依据:通过训练损失(Train Loss) 和测试损失(Test Loss) 的曲线对比,是判断过拟合的核心手段。

二、模型的保存和加载

1.仅保存模型参数

- 原理:保存模型的权重参数,不保存模型结构代码。加载时需提前定义与训练时一致的模型类。

- 优点:文件体积小(仅含参数),跨框架兼容性强(需自行定义模型结构)。

# 保存模型参数 torch.save(model.state_dict(), "model_weights.pth") # 加载参数(需先定义模型结构) model = MLP() # 初始化与训练时相同的模型结构 model.load_state_dict(torch.load("model_weights.pth")) # model.eval() # 切换至推理模式(可选)

2.保存权重和模型

- 原理:保存模型结构及参数

- 优点:加载时无需提前定义模型类

- 缺点:文件体积大,依赖训练时的代码环境(如自定义层可能报错)。

# 保存整个模型 torch.save(model, "full_model.pth") # 加载模型(无需提前定义类,但需确保环境一致) model = torch.load("full_model.pth") model.eval() # 切换至推理模式(可选)

3.保存全部信息checkpoint,还包含训练状态

- 原理:保存模型参数、优化器状态(学习率、动量)、训练轮次、损失值等完整训练状态,用于中断后继续训练。

- 适用场景:长时间训练任务(如分布式训练、算力中断)。

# 保存训练状态 checkpoint = { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "loss": best_loss, } torch.save(checkpoint, "checkpoint.pth") # 加载并续训 model = MLP() optimizer = torch.optim.Adam(model.parameters()) checkpoint = torch.load("checkpoint.pth") model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 # 从下一轮开始训练 best_loss = checkpoint["loss"] # 继续训练循环 for epoch in range(start_epoch, num_epochs): train(model, optimizer, ...)

三、早停法

早停法是缓解过拟合最常用、最简洁的策略,核心逻辑是:训练过程中持续监控「验证集 / 测试集损失(或准确率)」,当模型在未见过的数据上的性能不再提升(甚至开始下降)时,提前终止训练,避免模型过度拟合训练集的噪声;同时保存训练过程中 “验证集表现最好” 的模型参数,保证最终使用的是泛化能力最优的模型。

早停法的核心要素

要素作用
监控指标优先选「验证集损失」(损失越低越好),也可选「验证集准确率」(越高越好)
耐心值(Patience)允许 “验证集性能不提升” 的最大轮数(比如 patience=50:连续 50 轮没提升就停)
最小改进值(Min_delta)忽略微小波动(比如 min_delta=0.0001:损失下降小于这个值,视为 “无提升”)
最优模型保存训练中实时保存 “验证集性能最好” 的模型参数,避免停在最后一轮的差模型

作业:对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略

import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder from sklearn.impute import SimpleImputer from sklearn.neural_network import MLPRegressor from sklearn.metrics import mean_squared_error import joblib # 1. 数据加载与预处理 df = pd.read_csv(r"D:\Study\PythonStudy\housing.csv") #信贷数据集路径 target_col = 'median_house_value' #目标变量名 # 缺失值填充 imputer = SimpleImputer(strategy='most_frequent') df_imputed = pd.DataFrame(imputer.fit_transform(df), columns=df.columns) # 独热编码 ocean_proximity 列(唯一的类别型特征) categorical_col = 'ocean_proximity' ohe = OneHotEncoder(sparse_output=False) cat_ohe = ohe.fit_transform(df_imputed[[categorical_col]]) cat_df = pd.DataFrame(cat_ohe, columns=ohe.get_feature_names_out([categorical_col]), index=df_imputed.index) # 拼接回去,删掉原来的 ocean_proximity df_encoded = pd.concat([df_imputed.drop(columns=[categorical_col]), cat_df], axis=1) # 分割数据 X = df_encoded.drop(columns=[target_col]) y = df_encoded[target_col] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 2. 初始训练、保存 mlp = MLPRegressor(hidden_layer_sizes=(64, 32), max_iter=100, random_state=42, warm_start=True) mlp.fit(X_train, y_train) joblib.dump(mlp, 'mlp_credit_model.pkl') print("初始模型保存完毕。") # 3. 加载权重 mlp2 = joblib.load('mlp_credit_model.pkl') # 4. 继续训练50轮,早停 best_loss = np.inf patience = 5 wait = 0 for i in range(50): mlp2.max_iter += 1 # 每次多训练一轮 mlp2.fit(X_train, y_train) y_pred = mlp2.predict(X_test) loss = mean_squared_error(y_test, y_pred) print(f"第{i+1}轮,测试集MSE: {loss:.4f}") if loss < best_loss: best_loss = loss wait = 0 joblib.dump(mlp2, 'mlp_credit_model_best.pkl') else: wait += 1 if wait >= patience: print("早停触发,训练提前终止。") break # 5. 加载最优模型 best_model = joblib.load('mlp_credit_model_best.pkl') print("最优模型已加载。")

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/7/3 4:00:17

终极指南:如何在VMware中免费解锁macOS虚拟机支持

终极指南&#xff1a;如何在VMware中免费解锁macOS虚拟机支持 【免费下载链接】unlocker VMware Workstation macOS 项目地址: https://gitcode.com/gh_mirrors/un/unlocker 你是否曾经想在Windows或Linux系统上体验macOS的流畅操作&#xff0c;却发现VMware Workstati…

作者头像 李华
网站建设 2026/7/4 4:54:19

Linux网络层核心技术揭秘: 从IP协议到内核实现深度剖析

Linux网络层核心技术揭秘: 从IP协议到内核实现深度剖析 在当今的互联网世界中, Linux凭借其稳定、高效的网络协议栈实现, 成为服务器、云计算和网络设备领域的基石. 理解Linux网络层的核心原理不仅有助于我们优化网络应用性能, 更能深入掌握现代网络通信的本质 1. 网络层的基础…

作者头像 李华
网站建设 2026/7/2 21:23:05

简单线程池实现(单例模式)

1.概念 基本概念 线程池是一种多线程处理形式&#xff0c;它预先创建一组线程并管理它们&#xff0c;避免频繁创建和销毁线程带来的性能开销。 在 Linux 环境下&#xff0c;线程池&#xff08;Thread Pool&#xff09;是一种常用的并发编程模型&#xff0c;用于复用线程资源&…

作者头像 李华
网站建设 2026/7/3 0:24:36

类与对象三大核心函数:构造、析构、拷贝构造详解

类与对象三大核心函数&#xff1a;构造、析构、拷贝构造详解 一、引言 在C面向对象编程中&#xff0c;构造函数、析构函数和拷贝构造函数被称为"三大件"&#xff08;Rule of Three&#xff09;。它们是类设计的基石&#xff0c;决定了对象的创建、拷贝和销毁行为。…

作者头像 李华
网站建设 2026/7/3 12:06:26

UiPath在金融行业的5个高价值应用案例

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容&#xff1a; 开发一个UiPath自动化流程&#xff0c;模拟银行对账单处理场景。流程应包括&#xff1a;1)自动登录网银系统下载对账单&#xff1b;2)使用OCR技术识别对账单内容&#xff1b;3)与内…

作者头像 李华
网站建设 2026/7/4 9:21:02

docker安装Qwen3-32B容器化方案提升运维效率

Docker安装Qwen3-32B容器化方案提升运维效率 在AI基础设施快速演进的今天&#xff0c;一个典型的技术团队可能正面临这样的困境&#xff1a;开发环境里流畅运行的大模型服务&#xff0c;一旦部署到生产集群就频频崩溃&#xff1b;不同版本的PyTorch、CUDA驱动和Python库相互冲突…

作者头像 李华