pytorch获取模型的参数量和运算量
定义模型
net = nn.Sequential(
nn.Linear(28 * 28, 400),
nn.ReLU(),
nn.Linear(400, 200),
nn.ReLU(),
nn.Linear(200, 100),
nn.ReLU(),
nn.Linear(100, 10)
).cuda()
1. torch自身方法获取参数量
total = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
2. torchsummary库获取模型的参数量
from torchsummary import summary
summary(net, input_size=(784,))
3. thop库获取模型参数量和计算量
from thop import profile, clever_format
myinput = torch.zeros((1, 1, 784)).cuda()
flops, params = profile(net, inputs=myinput)
flops, params = clever_format([flops, params], "%.3f")
print(flops, params)
分析:
这个模型参数量为415.31K,但是浮点计算量为414.6K,因为ReLU等层不需要浮点计算