1
- ## [ torch 参数更多 ] torch.nn.modules.module.register_module_forward_hook
1
+ ## [ 组合替代实现 ] torch.nn.modules.module.register_module_forward_hook
2
2
### [ torch.nn.modules.module.register_module_forward_hook] ( https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html )
3
3
4
4
``` python
5
- torch.nn.modules.module.register_module_forward_hook(hook, * , prepend = False , with_kwargs = False , always_call = False )
5
+ torch.nn.modules.module.register_module_forward_hook(hook, * , always_call = False )
6
6
```
7
7
8
8
### [ paddle.nn.Layer.register_forward_post_hook] ( https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/Layer_cn.html#register-forward-post-hook-hook )
@@ -11,12 +11,28 @@ torch.nn.modules.module.register_module_forward_hook(hook, *, prepend=False, wit
11
11
paddle.nn.Layer.register_forward_post_hook(hook)
12
12
```
13
13
14
- PyTorch 相比 Paddle 支持更多其他参数,具体如下:
14
+ 其中,PyTorch 为给全局所有 module 注册 hook,而 Paddle 为给单个 Layer 注册 hook。 PyTorch 相比 Paddle 支持更多其他参数,具体如下:
15
15
### 参数映射
16
16
17
17
| PyTorch | PaddlePaddle | 备注 |
18
18
| ------------- | ------------ | ------------------------------------------------------ |
19
- | hook | hook | 被注册为 forward pre-hook 的函数。 |
20
- | prepend | - | 钩子执行顺序控制,Paddle 无此参数,暂无转写方式。 |
21
- | with_kwargs | - | 是否传递关键字参数,Paddle 无此参数,暂无转写方式。 |
22
- | always_call | - | 是否强制调用钩子,Paddle 无此参数,暂无转写方式。 |
19
+ | hook | hook | 被注册为 forward post-hook 的函数。 |
20
+ | always_call | - | 是否强制调用钩子,Paddle 无此参数,一般对训练结果影响不大,可直接删除。 |
21
+
22
+ ### 转写示例
23
+
24
+ ``` python
25
+ # PyTorch 写法
26
+ Linear = torch.nn.Linear(2 , 4 )
27
+ Conv2d = torch.nn.Conv2d(3 , 16 , 3 )
28
+ Batch2d = torch.nn.BatchNorm2d(10 )
29
+ torch.nn.modules.module.register_module_forward_hook(hook)
30
+
31
+ # Paddle 写法
32
+ Linear = paddle.nn.Linear(2 , 4 )
33
+ Conv2d = paddle.nn.Conv2d(3 , 16 , 3 )
34
+ Batch2d = paddle.nn.BatchNorm2D(10 )
35
+ Linear.register_forward_post_hook(hook)
36
+ Conv2d.register_forward_post_hook(hook)
37
+ Batch2d.register_forward_post_hook(hook)
38
+ ```
0 commit comments