-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
35 lines (26 loc) · 882 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Utility functions for model analysis."""
import torch
def param_matrix(model):
"""Print parameter matrix.
Args:
model: neural network built with pytorch.
"""
for i in range(len(list(model.parameters()))):
print(list(model.parameters())[i].size())
def total_num_param(model):
"""Print total number of parameters.
Args:
model: neural network built with pytorch.
"""
print(sum(map(torch.numel, model.parameters())))
def param_trainable(model):
"""Print all trainable parameters and layer information.
Args:
model: neural network built with pytorch.
"""
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
print(param.data)
print(param.size())
print("===============================================")