Pytorch

Module

1
2
Module作为模块封装的父类,可以是一段逻辑,也可以是模型的一个块「block」或一层。
Pytorch中自定义模型只需要继承Module,保存好Param并提供forward方法,backward被tensor的自动微分自动完成。

以一个简单的MLP的代码示例:

1
2
3
4
5
6
7
8
9
10
11
class MLP(nn.Module):
def __init__(self):
# Call the constructor of the parent class nn.Module to perform
# the necessary initialization
super().__init__()
self.hidden = nn.LazyLinear(256)
self.out = nn.LazyLinear(10)
# Define the forward propagation of the model, that is, how to return the
# required model output based on the input X
def forward(self, X):
return self.out(F.relu(self.hidden(X)))

1. 回调「Hook」函数介绍

hook 函数机制:不改变主体,实现额外功能,像一个挂件、挂钩 ➡️ hook

1.1 为什么会有 hook 函数这个机制:参考文章1

这与 PyTorch 动态图运行机制有关:

在动态图运行机制中,当运算结束后,一些中间变量是会被释放掉的,比如特征图、非叶子节点的梯度。但有时候我们又想要继续关注这些中间变量,那么就可以使用 hook 函数在主体代码中提取中间变量。

主体代码主要是模型的前向传播「forward」和反向传播「backward」,额外的功能就是对模型的中间变量进行操作如:

  1. 提取/修改张量梯度
  2. 提取/保留非叶子张量的梯度
  3. 查看模型的层与层之间的数据传递情况(数据维度、数据大小等)
  4. 在不修改原始模型代码的基础上可视化各个卷积特征图
  5. ……
1.2 演示Hook的作用参考文章2

一般来说,“hook”是在特定事件之后自动执行的函数。

PyTorch 为nn.Module 对象 / 每个张量注册 hook。hook 由对象的向前或向后传播触发。它们具有以下函数签名:

1
2
3
4
5
6
7
8
from torch import nn, Tensor

# For nn.Module objects only.
def module_hook(module: nn.Module, input: Tensor, output: Tensor):

# For Tensor objects only.
# Only executed during the *backward* pass!
def tensor_hook(grad: Tensor):
  • 例子1:假如你想知道每个层输出的形状。我们可以创建一个简单的 wrapper,使用 hook 打印输出形状:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    import torch
    from torch import nn, Tensor
    from torchvision.models import resnet50
    import warnings

    warnings.filterwarnings('ignore')

    class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
    super().__init__()
    # 传入resnet50模型
    self.model = model

    for name, layer in self.model.named_children():
    layer.__name__ = name
    # Register a hook for each layer
    layer.register_forward_hook(
    lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
    )

    def forward(self, x: Tensor) -> Tensor:
    return self.model(x)


    verbose_resnet = VerboseExecution(resnet50())
    dummy_input = torch.ones(10, 3, 224, 224)

    _ = verbose_resnet(dummy_input)

    # --------输出
    # conv1: torch.Size([10, 64, 112, 112])
    # bn1: torch.Size([10, 64, 112, 112])
    # relu: torch.Size([10, 64, 112, 112])
    # maxpool: torch.Size([10, 64, 56, 56])
    # layer1: torch.Size([10, 256, 56, 56])
    # layer2: torch.Size([10, 512, 28, 28])
    # layer3: torch.Size([10, 1024, 14, 14])
    # layer4: torch.Size([10, 2048, 7, 7])
    # avgpool: torch.Size([10, 2048, 1, 1])
    # fc: torch.Size([10, 1000])
  • 例子2特征提取:通常,我们希望从一个预先训练好的网络中生成特性,然后用它们来完成另一个任务(例如分类等)。使用 hook,我们可以提取特征,而不需要重新创建现有模型或以任何方式修改它。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    import torch
    from torch import nn, Tensor
    from torchvision.models import resnet50
    import warnings

    warnings.filterwarnings('ignore')

    from typing import Dict, Iterable, Callable

    class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
    super().__init__()
    self.model = model
    self.layers = layers
    self._features = {layer: torch.empty(0) for layer in layers}

    for layer_id in layers:
    layer = dict([*self.model.named_modules()])[layer_id]
    # Register a hook
    layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
    def fn(_, __, output):
    self._features[layer_id] = output
    return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
    _ = self.model(x)
    return self._features

    resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
    dummy_input = torch.ones(10, 3, 224, 224)
    features = resnet_features(dummy_input)

    print({name: output.shape for name, output in features.items()})

    # 输出
    # {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}
  • 例子3:梯度裁剪:梯度裁剪是处理梯度爆炸的一种著名方法。PyTorch 已经提供了梯度裁剪的工具方法,但是我们也可以很容易地使用 hook 来实现它。其他任何用于梯度裁剪/归一化/修改的方法都可以用同样的方式实现。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
    for parameter in model.parameters():
    # Register a hook for each parameter

    # register_hook 方法的主要作用是允许用户在张量的梯度计算中注册一个自定义函数
    # 以便在反向传播期间对梯度进行操作或记录信息。
    # 这对于实现自定义梯度处理、梯度剪裁、可视化梯度信息以及梯度的修改等任务非常有用。
    parameter.register_hook(lambda grad: grad.clamp_(-val, val))

    return model

    clipped_resnet = gradient_clipper(resnet50(), 0.01)
    pred = clipped_resnet(dummy_input)
    loss = pred.log().mean()
    loss.backward()

    print(clipped_resnet.fc.bias.grad[:25])

    # 输出
    # tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015, 0.0027, 0.0017, -0.0023,
    # 0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018, 0.0062,
    # 0.0034, -0.0010, 0.0052, 0.0021, 0.0010, 0.0017, -0.0100, 0.0021,
    # 0.0020])
1.3一些常见的Hook函数
  • register_forward_hook 是 PyTorch 中用于在神经网络的前向传播过程中注册钩子的一个函数。这个钩子函数会在模块执行其 forward 方法时被调用,可以用来检查或修改中间输出。

    module.register_forward_hook(hook):

    register_forward_hook

  • register_hook 是 PyTorch 中用于在神经网络的反向传播过程中注册钩子的函数。这个钩子函数会在张量的梯度计算过程中被调用,主要用于调试和修改梯度。

    tensor.register_hook(hook):

    register_hook

  • 回调「Hook」函数注册

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    # 三个全局变量,dict类型,存储回调函数(即hook),用于net中的所有module
    # 用于输入输出tensor
    _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
    # 用于module定义
    _global_module_registration_hooks: Dict[int, Callable] = OrderedDict()
    # 用于模型参数
    _global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()


    """
    This tracks hooks common to all modules that are executed before/after
    calling forward and backward. This is global state used for debugging/profiling
    purposes
    """
    # 用于在module的forward和backward接口前后注册回调函数,例如dump出每个op的输入输出结果
    _global_backward_pre_hooks: Dict[int, Callable] = OrderedDict()
    _global_backward_hooks: Dict[int, Callable] = OrderedDict()
    _global_is_full_backward_hook: Optional[bool] = None
    _global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
    _global_forward_hooks: Dict[int, Callable] = OrderedDict()

    # 提供reg接口在完成回调函数注册
    register_module_buffer_registration_hook()
    register_module_module_registration_hook()
    register_module_parameter_registration_hook()

    register_module_forward_pre_hook()
    register_module_forward_hook()
    register_module_backward_hook()
    register_module_full_backward_pre_hook()
    register_module_full_backward_hook()
1.4 Module成员变量分析
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 版本号,一个内部使用的属性,用于跟踪模块的版本。这一机制主要用于在序列化(serialization)和反序列化(deserialization)过程中管理模型的兼容性。
_version: int = 1
# 一个布尔值,表示模块是否处于训练模式。可以通过 model.train() 和 model.eval() 方法切换。
training: bool
# 存储模块的所有参数(Parameter 对象),类型为 OrderedDict,如conv的weight、bias等
_parameters: Dict[str, Optional[Parameter]]
# 存储模块中的所有缓冲区(Tensor 对象),类型为 OrderedDict。缓冲区是模型状态的一部分,但不是参数,比如 BatchNorm 的running mean 和 running variance。
_buffers: Dict[str, Optional[Tensor]]
# 存储模块的子模块,类型为 OrderedDict。每个子模块在模型中都有一个唯一的名称。
_modules: Dict[str, Optional['Module']]


# 存储反向传播前的钩子,类型为 OrderedDict。这些钩子在反向传播前的过程中被调用。
_backward_pre_hooks: Dict[int, Callable]
# 存储反向传播钩子,类型为 OrderedDict。这些钩子在反向传播过程中被调用。
_backward_hooks: Dict[int, Callable]
# 存储前向传播钩子,类型为 OrderedDict。这些钩子在前向传播过程中被调用。
_forward_hooks: Dict[int, Callable]


# 存储 state_dict 钩子,类型为 OrderedDict。这些钩子在调用 state_dict 时被调用。
_state_dict_hooks: Dict[int, Callable] # 模型加载时,op的参数加载相关的回调函数
_load_state_dict_pre_hooks: Dict[int, Callable]
_state_dict_pre_hooks: Dict[int, Callable]
_load_state_dict_post_hooks: Dict[int, Callable]

1.5 Module方法分析

Pytorch
https://adzuki23.github.io/2024/05/22/Pytorch/
作者
Hongyu Li
发布于
2024年5月22日
更新于
2024年8月3日
许可协议