Pytorch
Module
1 |
|
以一个简单的MLP的代码示例:
1 |
|
1. 回调「Hook」函数介绍
hook 函数机制:不改变主体,实现额外功能,像一个挂件、挂钩 ➡️ hook
1.1 为什么会有 hook 函数这个机制:参考文章1
这与 PyTorch 动态图运行机制有关:
在动态图运行机制中,当运算结束后,一些中间变量是会被释放掉的,比如特征图、非叶子节点的梯度。但有时候我们又想要继续关注这些中间变量,那么就可以使用 hook 函数在主体代码中提取中间变量。
主体代码主要是模型的前向传播「forward」和反向传播「backward」,额外的功能就是对模型的中间变量进行操作如:
- 提取/修改张量梯度
- 提取/保留非叶子张量的梯度
- 查看模型的层与层之间的数据传递情况(数据维度、数据大小等)
- 在不修改原始模型代码的基础上可视化各个卷积特征图
- ……
1.2 演示Hook的作用:参考文章2
一般来说,“hook”是在特定事件之后自动执行的函数。
PyTorch 为nn.Module 对象 / 每个张量注册 hook。hook 由对象的向前或向后传播触发。它们具有以下函数签名:
1
2
3
4
5
6
7
8
from torch import nn, Tensor
# For nn.Module objects only.
def module_hook(module: nn.Module, input: Tensor, output: Tensor):
# For Tensor objects only.
# Only executed during the *backward* pass!
def tensor_hook(grad: Tensor):
例子1:假如你想知道每个层输出的形状。我们可以创建一个简单的 wrapper,使用 hook 打印输出形状:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40import torch
from torch import nn, Tensor
from torchvision.models import resnet50
import warnings
warnings.filterwarnings('ignore')
class VerboseExecution(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
# 传入resnet50模型
self.model = model
for name, layer in self.model.named_children():
layer.__name__ = name
# Register a hook for each layer
layer.register_forward_hook(
lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
)
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)
_ = verbose_resnet(dummy_input)
# --------输出
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])例子2:特征提取:通常,我们希望从一个预先训练好的网络中生成特性,然后用它们来完成另一个任务(例如分类等)。使用 hook,我们可以提取特征,而不需要重新创建现有模型或以任何方式修改它。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38import torch
from torch import nn, Tensor
from torchvision.models import resnet50
import warnings
warnings.filterwarnings('ignore')
from typing import Dict, Iterable, Callable
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self._features = {layer: torch.empty(0) for layer in layers}
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
# Register a hook
layer.register_forward_hook(self.save_outputs_hook(layer_id))
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, __, output):
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
return self._features
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
dummy_input = torch.ones(10, 3, 224, 224)
features = resnet_features(dummy_input)
print({name: output.shape for name, output in features.items()})
# 输出
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}例子3:梯度裁剪:梯度裁剪是处理梯度爆炸的一种著名方法。PyTorch 已经提供了梯度裁剪的工具方法,但是我们也可以很容易地使用 hook 来实现它。其他任何用于梯度裁剪/归一化/修改的方法都可以用同样的方式实现。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
for parameter in model.parameters():
# Register a hook for each parameter
# register_hook 方法的主要作用是允许用户在张量的梯度计算中注册一个自定义函数
# 以便在反向传播期间对梯度进行操作或记录信息。
# 这对于实现自定义梯度处理、梯度剪裁、可视化梯度信息以及梯度的修改等任务非常有用。
parameter.register_hook(lambda grad: grad.clamp_(-val, val))
return model
clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()
print(clipped_resnet.fc.bias.grad[:25])
# 输出
# tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015, 0.0027, 0.0017, -0.0023,
# 0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018, 0.0062,
# 0.0034, -0.0010, 0.0052, 0.0021, 0.0010, 0.0017, -0.0100, 0.0021,
# 0.0020])
1.3一些常见的Hook函数:
register_forward_hook
是 PyTorch 中用于在神经网络的前向传播过程中注册钩子的一个函数。这个钩子函数会在模块执行其forward
方法时被调用,可以用来检查或修改中间输出。module.register_forward_hook(hook)
:register_hook
是 PyTorch 中用于在神经网络的反向传播过程中注册钩子的函数。这个钩子函数会在张量的梯度计算过程中被调用,主要用于调试和修改梯度。tensor.register_hook(hook)
:回调「Hook」函数注册
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31# 三个全局变量,dict类型,存储回调函数(即hook),用于net中的所有module
# 用于输入输出tensor
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
# 用于module定义
_global_module_registration_hooks: Dict[int, Callable] = OrderedDict()
# 用于模型参数
_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()
"""
This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling
purposes
"""
# 用于在module的forward和backward接口前后注册回调函数,例如dump出每个op的输入输出结果
_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
_global_is_full_backward_hook: Optional[bool] = None
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
# 提供reg接口在完成回调函数注册
register_module_buffer_registration_hook()
register_module_module_registration_hook()
register_module_parameter_registration_hook()
register_module_forward_pre_hook()
register_module_forward_hook()
register_module_backward_hook()
register_module_full_backward_pre_hook()
register_module_full_backward_hook()
1.4 Module成员变量分析
1 |
|