PyTorch知识点

本文最后更新于 2025年8月7日 晚上

本文主要分享了PyTorch的基础知识。

管理模型的行为

  • 训练模式model.train()
  • 评估模式:model.eval()
  • 禁用梯度计算torch.no_grad()
  • 启用梯度计算torch.enable_grad()
  • 冻结或解冻参数model.requires_grad_(True/False)
  • 切换设备model.to()
  • 半精度模式model.half()torch.cuda.amp.autocast()
  • 分层模式切换:对特定层使用 train()eval()

保存模型

保存模型的命名

.pt.pth 都是 PyTorch 预训练模型常见的文件后缀,实际在 PyTorch 官方和社区中没有严格的技术区别,只是约定俗成的命名习惯:

  • .pt 后缀

    • 全称:PyTorch

    • 常见用法:通常用于保存完整的模型(包括结构和参数),也用于保存模型参数(state_dict)。

    • 示例torch.save(model, 'model.pt')torch.save(model.state_dict(), 'model.pt')

    • 官方文档:PyTorch 官方示例常用 .pt 作为后缀。

  • .pth 后缀

    • 全称:PyTorch

    • 常见用法:更早期 PyTorch 社区中流行的后缀,经常用于保存模型参数字典(state_dict)。

    • 示例torch.save(model.state_dict(), 'model.pth')

    • 很多第三方项目、开源项目常用 .pth

保存模型的位置

  1. 在训练完成后保存模型。当模型训练完成后,可以保存最终的完整模型。这样可以在推理或部署时直接加载完整的模型,而无需重新定义网络结构。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    # 假设 model 是你的神经网络
    model = MyModel()

    # 定义损失函数和优化器
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 模拟训练循环
    for epoch in range(num_epochs):
    for inputs, labels in train_loader:
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # 训练结束后保存完整模型
    torch.save(model, "model_complete.pt")
  2. 在中断训练前保存(检查点保存)。如果训练可能会中断(如在分布式训练中),你可以定期保存模型,以便从中断处恢复。

    1
    2
    # 保存模型检查点
    torch.save(model, f"checkpoint_epoch_{epoch}.pt")
  3. 在验证性能后保存最佳模型。如果你在训练过程中验证模型性能,可以设置一个条件,只保存性能最好的模型。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    best_accuracy = 0.0

    for epoch in range(num_epochs):
    for inputs, labels in train_loader:
    # 训练逻辑
    pass

    # 验证模型性能
    accuracy = validate(model, val_loader)

    # 如果当前模型表现更好,则保存
    if accuracy > best_accuracy:
    best_accuracy = accuracy
    torch.save(model, "best_model.pt")
  4. 在模型导出或部署前保存。训练完成后,如果需要将模型导出到生产环境,通常在推理或部署前保存完整模型。

    1
    2
    3
    4
    5
    # 切换到评估模式
    model.eval()

    # 保存模型用于生产部署
    torch.save(model, "model_for_deployment.pt")
  5. 等等。

保存模型的方式

PyTorch 保存模型通常以下几种方式:

方法 优点 缺点 使用场景
保存权重 (state_dict) 文件小,灵活性高 需要重新定义网络结构 训练、部署、迁移
保存整个模型 加载简单,包含网络结构和权重 文件较大,可能因环境变化导致加载失败 快速加载模型进行推理
TorchScript 独立于 Python 环境,适合 C++ 部署 不支持某些动态操作 部署到生产环境或 C++ 使用
保存优化器状态 可恢复训练 仅适用于训练场景 训练中断后继续训练
  1. 只保存模型的状态字典 (state_dict)

    1
    2
    3
    4
    5
    6
    7
    # 保存
    torch.save(model.state_dict(), "model_weights.pth") # 这种方式只保存了模型的权重参数,而不包含网络结构。加载时需要重新定义网络结构。

    # 加载
    model = MyModel() # 重新定义网络结构
    model.load_state_dict(torch.load("model_weights.pth"))
    model.eval() # 切换到推理模式

    (可选)在C++中加载:

    1
    2
    3
    #include <torch/torch.h>

    torch::load(model, "model_weights.pth");
  2. 保存整个模型

    1
    2
    3
    4
    5
    6
    # 保存
    torch.save(model, "model_complete.pt")

    # 加载
    model = torch.load("model_complete.pt")
    model.eval() # 切换到推理模式

    保存完整模型时,由于网络结构直接序列化为 Python 对象,跨环境加载时可能会出现问题(如类定义路径不同)。

  3. 使用 TorchScript 脚本化模型。TorchScript 是 PyTorch 提供的一种将模型转换为静态图的方法,适用于部署场景。

    1
    2
    3
    4
    5
    6
    7
    # 保存
    scripted_model = torch.jit.script(model) # 或 torch.jit.trace(model, example_input)
    scripted_model.save("model_scripted.pt")

    # 加载
    model = torch.jit.load("model_scripted.pt")
    model.eval() # 切换到推理模式

    (可选)在C++中加载:

    1
    2
    3
    4
    #include <torch/torch.h>
    #include <torch/script.h> // 用于 torch::jit::load

    torch::jit::script::Module module = torch::jit::load("model_scripted.pt");

    torch.jit.script()torch.jit.trace() 是 PyTorch 中将模型转换为 TorchScript 的两种方法。两者在原理和适用场景上有很大的区别。

    特性 torch.jit.script() torch.jit.trace()
    工作原理 基于代码解析,捕获模型的完整逻辑 基于跟踪,通过示例输入记录计算图
    动态控制流支持 支持(包括 if-else、循环等动态逻辑) 不支持(只能捕获固定的计算逻辑)
    依赖输入数据 不依赖 依赖具体的示例输入
    模型适用性 适用于动态和静态模型 仅适用于静态模型
    复杂模型支持 支持复杂的动态模型 对动态模型支持有限
    使用难度 需要检查代码是否可解析 简单,只需提供示例输入
    输出 TorchScript 的行为 静态图,完全保留模型的动态行为 静态图,固定行为,不能动态调整
  4. 保存优化器状态。在训练过程中,可以同时保存优化器的状态,以便在中断后继续训练。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    # 保存
    torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    }, "checkpoint.pth")

    # 加载
    checkpoint = torch.load("checkpoint.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  5. 转ONNX

  6. 等等。

保存模型的内容

  1. 使用 state_dict 保存的权重文件

    • 模型参数: 权重文件中保存的是一个字典(state_dict),其键为模型中每一层的名称(如 layer1.weight),值为相应的张量(torch.Tensor)。

    • 格式

      • 基于 PyTorch 的序列化格式,底层使用 Python 的 pickle 序列化。
      • 文件本质是一个二进制文件。
    • 示例:

      1
      2
      3
      4
      5
      6
      {
      "conv1.weight": tensor([[...], [...]]), # 卷积层的权重
      "conv1.bias": tensor([...]), # 卷积层的偏置
      "fc.weight": tensor([[...], [...]]), # 全连接层的权重
      "fc.bias": tensor([...]) # 全连接层的偏置
      }
    • 等等。

  2. 保存完整模型

    • 网络结构: 包括模型的类(nn.Module)定义和网络的具体拓扑结构。
    • 模型参数: 包括网络中每一层的权重和偏置。
    • 格式
      • 同样基于 PyTorch 的 pickle 序列化。
      • 文件本质是一个二进制文件。
  3. 使用 TorchScript 保存的模型

    • 网络结构: 使用 TorchScript 表示网络结构为静态计算图。
    • 模型参数: 包括网络中每一层的权重和偏置。
    • 格式
      • 文件以 TorchScript 的序列化格式保存。
      • 支持跨环境加载(如在 C++ 中使用)。
  4. 等等。

推理

常见的推理流程

在 PyTorch 中,模型的权重(如 .pt 文件)通常只包含网络的参数(权重和偏置),而不包含模型的网络结构定义。因此,在加载权重并进行推理之前,必须重新定义模型的架构,以便 PyTorch 知道如何将权重加载到对应的网络结构中。

  1. 定义模型的网络结构。在 PyTorch 中,网络结构通常是一个继承自 torch.nn.Module 的类。例如:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    import torch
    import torch.nn as nn

    class MyModel(nn.Module):
    def __init__(self):
    super(MyModel, self).__init__()
    self.fc = nn.Linear(10, 2)

    def forward(self, x):
    return self.fc(x)
  2. 加载预训练的权重。使用 torch.loadmodel.load_state_dict 方法加载保存的权重到定义好的模型中:

    1
    2
    3
    model = MyModel()  # 定义网络结构
    model.load_state_dict(torch.load("model_weights.pt")) # 加载权重
    model.eval() # 设置为推理模式
  3. 进行推理。将输入传入模型,进行前向传播以获得推理结果:

    1
    2
    3
    input_tensor = torch.randn(1, 10)  # 示例输入
    output = model(input_tensor)
    print(output)
  4. 完成。

判断.pt文件中是否包含网络结构

要判断 .pt 文件中是否包含网络结构,可以尝试加载文件并观察行为:

  1. 使用 torch.load 尝试直接加载模型

    1
    2
    3
    4
    5
    6
    7
    8
    import torch

    model = torch.load("model.pt") # 尝试直接加载
    print(model) # 如果文件中包含网络结构,torch.load 将直接返回一个完整的模型对象。
    if isinstance(model, dict):
    print("该文件只包含权重参数(state_dict)")
    else:
    print("该文件包含完整的模型(网络结构和权重)")
    • 如果成功加载并可以直接推理,说明文件中包含了网络结构。
    • 如果报错或返回的是一个字典(state_dict),说明文件中只包含权重而没有网络结构。
  2. 使用 model.load_state_dict 尝试加载权重

    如果你知道模型的网络结构,可以先定义模型,然后尝试加载权重:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    import torch
    import torch.nn as nn

    class MyModel(nn.Module):
    def __init__(self):
    super(MyModel, self).__init__()
    self.fc = nn.Linear(10, 2)

    def forward(self, x):
    return self.fc(x)

    model = MyModel() # 定义网络结构
    state_dict = torch.load("model.pt") # 尝试加载
    model.load_state_dict(state_dict) # 尝试将权重加载到定义的模型中
    • 如果 .pt 文件中只包含权重,这种方式会正常工作。
  3. 等等。

如果没有网络结构定义怎么办?

  1. 使用 PyTorch 的脚本化模型或 TorchScript。如果模型在保存时已经被脚本化(TorchScript),那么 .pt 文件中不仅包含权重,还包含网络结构。可以直接加载脚本化模型并进行推理:

    1
    2
    3
    4
    scripted_model = torch.jit.load("scripted_model.pt")
    scripted_model.eval()
    output = scripted_model(torch.randn(1, 10))
    print(output)

    如果模型不是脚本化的,你需要找到模型的代码定义(通常是一个 Python 文件或类)。如果没有现成定义,可以尝试从文档或原始开发者处获取。

  2. 使用 ONNX。如果模型已经被导出为 ONNX 格式,则可以绕过 PyTorch,直接使用 ONNX Runtime 进行推理,无需重新定义网络结构。

只有(/使用).pt文件中的权重并自定义网络结构

基本流程

  1. 加载权重文件。即使 .pt 文件中包含完整的模型(网络结构和权重),你仍然可以加载文件并从中提取权重(state_dict)。

    1
    2
    3
    4
    5
    6
    7
    import torch

    # 加载包含网络结构和权重的模型
    saved_model = torch.load("model.pt")

    # 提取权重(state_dict)
    state_dict = saved_model.state_dict()
  2. 自定义网络结构。你需要在代码中定义你希望使用的自定义网络结构。确保网络结构的参数名称和形状与提取的权重匹配。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import torch.nn as nn

    class MyCustomModel(nn.Module):
    def __init__(self):
    super(MyCustomModel, self).__init__()
    self.fc = nn.Linear(10, 2) # 根据你的需求定义结构

    def forward(self, x):
    return self.fc(x)

    # 实例化你的自定义模型
    custom_model = MyCustomModel()
  3. 加载权重到自定义网络结构。使用 load_state_dict 方法将提取的权重加载到自定义的网络结构中。如果权重名称和形状不完全匹配,你可以设置 strict=False 来跳过不匹配的部分(通常会打印相关警告)。

    1
    2
    # 加载权重
    custom_model.load_state_dict(state_dict, strict=False)
  4. 切换到推理模式并进行测试。在加载权重后,切换模型到推理模式,并进行测试。

    1
    2
    3
    4
    5
    6
    7
    8
    custom_model.eval()  # 切换到推理模式

    # 创建测试输入
    input_tensor = torch.randn(1, 10) # 假设输入形状是 [1, 10]

    # 推理
    output = custom_model(input_tensor)
    print("推理输出:", output)
  5. (可选)保存模型

  6. 完成。

注意事项

  1. 权重匹配问题:自定义网络结构的参数名称和形状必须与原始模型的权重一致。如果不一致,你需要手动修改网络结构或转换权重的名称。例如,如果原始模型的层名称是 layer1.weight,但你的自定义模型中是 fc.weight,需要映射它们。示例:

    1
    2
    3
    4
    5
    6
    7
    8
    # 修改 state_dict 的键名
    new_state_dict = {}
    for key, value in state_dict.items():
    new_key = key.replace("layer1", "fc") # 根据实际情况调整
    new_state_dict[new_key] = value

    # 加载修改后的权重
    custom_model.load_state_dict(new_state_dict, strict=False)
  2. 参数形状问题:如果自定义网络结构的形状和原始模型不一致(例如,层的输入/输出大小不同),需要手动调整网络代码或重新训练部分参数。

  3. 删除原始模型的网络结构:如果你希望 .pt 文件中只包含权重,可以重新保存提取的 state_dict

    1
    2
    # 保存权重为新的文件
    torch.save(state_dict, "model_weights_only.pt")
  4. 等等。

网络结构可以有区别的情况

虽然权重文件中的参数是固定的,但在以下情况下,自定义网络结构可以与原始网络结构有一定的区别:

  1. 网络中新增的部分不加载权重。如果你在网络中新增了某些层,而这些层不需要加载预训练权重,PyTorch 会忽略这些层,并只加载权重文件中匹配的部分。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    class CustomModelWithExtraLayer(nn.Module):
    def __init__(self):
    super(CustomModelWithExtraLayer, self).__init__()
    self.fc = nn.Linear(10, 2) # 与原始网络一致
    self.extra_layer = nn.Linear(2, 3) # 新增的层。extra_layer 的权重不会从文件中加载,而是会被随机初始化。

    def forward(self, x):
    x = self.fc(x)
    return self.extra_layer(x) # 使用新层

    # 加载权重(strict=False 会忽略不匹配的部分)
    model = CustomModelWithExtraLayer()
    model.load_state_dict(torch.load("model_weights.pt"), strict=False)
  2. 对部分参数重命名或调整。如果你需要加载的网络结构与训练时的网络结构有一些参数名称不同,可以手动修改权重文件的 state_dict,以适配自定义网络结构。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    # 假设原始权重文件的参数名称是 "fc.weight" 和 "fc.bias"
    state_dict = torch.load("model_weights.pt")

    # 修改参数名称为自定义网络结构的名称
    new_state_dict = {}
    for key, value in state_dict.items():
    new_key = key.replace("fc", "custom_fc") # 假设新网络结构的名称是 "custom_fc"
    new_state_dict[new_key] = value

    # 自定义网络结构
    class CustomModel(nn.Module):
    def __init__(self):
    super(CustomModel, self).__init__()
    self.custom_fc = nn.Linear(10, 2)

    def forward(self, x):
    return self.custom_fc(x)

    # 加载权重
    model = CustomModel()
    model.load_state_dict(new_state_dict)
  3. 部分层的形状可以不匹配(微调)。如果你希望调整某些层(例如改变输出类别数目),可以通过加载匹配的部分权重,然后对新层进行随机初始化或微调。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    class CustomModel(nn.Module):
    def __init__(self):
    super(CustomModel, self).__init__()
    self.fc = nn.Linear(10, 5) # 假设原始网络输出为 2,这里调整为 5

    def forward(self, x):
    return self.fc(x)

    # 加载权重时忽略形状不匹配的部分
    model = CustomModel()
    pretrained_dict = torch.load("model_weights.pt")

    # 过滤掉不匹配的参数
    # fc 层(模型最后一层)的权重会被随机初始化,而其他层的权重会从文件中加载。
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}

    # 更新现有模型的权重
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
  4. 等等。


PyTorch知识点
http://zeyulong.com/posts/ea0c57f4/
作者
龙泽雨
发布于
2025年8月7日
许可协议