获取中间结果

利用 IntermediateLayerGetter 方法可以非常方便的获取中间结果,例如获取 resnet 的中间结果——features,这一点很有用,因为很多模型都会使用 resnet 作为 backbone 来提取 features。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter

input = torch.rand(size=(1, 3, 224, 224))

model = resnet18()
output = model(input)
print(output.shape)
# torch.Size([1, 1000])

'''extract layer1, layer2, layer3 and layer4, giving as names feat1, feat2, feat3 and feat4 respectively'''
return_layers = {'layer1':'feat1', 'layer2':'feat2', 'layer3':'feat3', 'layer4':'feat4'}
new_model = IntermediateLayerGetter(model, return_layers)
output = new_model(input)
for k, v in output.items():
print(k, v.shape)
# feat1 torch.Size([1, 64, 56, 56])
# feat2 torch.Size([1, 128, 28, 28])
# feat3 torch.Size([1, 256, 14, 14])
# feat4 torch.Size([1, 512, 7, 7])

IntermediateLayerGetter 的原理

IntermediateLayerGetter 方法的源码如下

IntermediateLayerGetter(nn.ModuleDict) | pytorch | github

IntermediateLayerGetter 的代码实现中,它会按照 model.named_children() 的迭代顺序重新组织 model 的 OrderedDict,直到 return_layers 中全部 key 对应的 module 都被放入新的 OrderedDict,而原 model 的 OrderedDict 中剩余的 module 都会被舍弃。在 forward 的过程中new_model 的计算顺序也是按照 OrderedDict 中 module 的顺序来进行的。所以在 IntermediateLayerGetter 的介绍中有强调:

It has a strong assumption that the modules have been registered into the model in the same order as they are used.

它的返回值是一个字典。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model

It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.

Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.

Args:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).

Examples::

>>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""

_version = 2
__annotations__ = {
"return_layers": Dict[str, str],
}

def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break

super().__init__(layers)
self.return_layers = orig_return_layers

def forward(self, x):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out

注意

需要额外注意的一点是 IntermediateLayerGetter 会抛弃原 model 的 OrderedDict 中后缀的不需要的 module,所以加载模型权重时需要额外修改 pretrained state_dict 中的内容才能顺利使用 model.load_state_dict() 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter

input = torch.rand(size=(1, 3, 224, 224))

model = resnet18()
output = model(input)
print(output.shape)
# torch.Size([1, 1000])

'''extract layer1, layer2, layer3 and layer4, giving as names feat1, feat2, feat3, feat4 respectively'''
return_layers = {'layer1':'feat1', 'layer2':'feat2', 'layer3':'feat3', 'layer4':'feat4'}
new_model = IntermediateLayerGetter(model, return_layers)
output = new_model(input)
for k, v in output.items():
print(k, v.shape)
# feat1 torch.Size([1, 64, 56, 56])
# feat2 torch.Size([1, 128, 28, 28])
# feat3 torch.Size([1, 256, 14, 14])
# feat4 torch.Size([1, 512, 7, 7])

from torch.utils import model_zoo

resnet18_url = "https://download.pytorch.org/models/resnet18-f37072fd.pth"
pretrained_dict = model_zoo.load_url(url=resnet18_url)
del pretrained_dict["fc.weight"], pretrained_dict["fc.bias"]
new_model.load_state_dict(pretrained_dict)

当然最便捷的形式还是先进行 load_state_dict() 再进行 IntermediateLayerGetter,这样可以免去一些 pretrained_state_dictnew_model.state_dict() 不匹配的问题。