PyTorch 打印网络模型结构
·
🤵 Author :Horizon John
✨ 编程技巧篇:各种操作小结
🎇 机器视觉篇:会变魔术 OpenCV
💥 深度学习篇:简单入门 PyTorch
🏆 神经网络篇:经典网络模型
💻 算法篇:再忙也别忘了 LeetCode
PyTorch 打印网络模型结构
使用 Print() 函数打印网络
我们在使用PyTorch打印模型结构时都是这样操作的:
model = simpleNet()
print(model)
打印结果:
simpleNet(
(layer1): Sequential(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): ReLU()
)
(layer2): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): ReLU()
)
(layer3): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): ReLU()
)
(dropout): Dropout(p=0.5, inplace=False)
(fc): Linear(in_features=1024, out_features=10, bias=True)
(out): Linear(in_features=10, out_features=10, bias=True)
)
可以很容易发现这样打印出来的网络结构 不清晰
,参数看起来都很 乱
!
如果是一个简单一点的网络可能影响不是很大,但当随着网络层数加深、结构复杂、参数量变大时,就会看得很难受 !
Tensorflow / Keras 打印网络
使用 model.summary() 函数打印出网络结构:
model = MyNet()
model.summary()
对比上面可以看到网络结构 很清晰
:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 256) 25856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 256) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256) 1024
_________________________________________________________________
dense_5 (Dense) (None, 512) 131584
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 512) 0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512) 2048
_________________________________________________________________
dense_6 (Dense) (None, 1024) 525312
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024) 4096
_________________________________________________________________
dense_7 (Dense) (None, 784) 803600
_________________________________________________________________
reshape_1 (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
_________________________________________________________________
PyTorch summary打印网络结构的方法
首先需要安装一个库文件 torchinfo
pip install torchinfo
conda install -c conda-forge torchinfo
然后使用 summary
函数打印网络结构:
model = simpleNet()
batch_size = 64
summary(model, input_size=(batch_size, 3, 32, 32))
网络结构输出结果如下:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
simpleNet -- --
├─Sequential: 1-1 [64, 16, 16, 16] --
│ └─Conv2d: 2-1 [64, 16, 32, 32] 448
│ └─BatchNorm2d: 2-2 [64, 16, 32, 32] 32
│ └─MaxPool2d: 2-3 [64, 16, 16, 16] --
│ └─ReLU: 2-4 [64, 16, 16, 16] --
├─Sequential: 1-2 [64, 32, 8, 8] --
│ └─Conv2d: 2-5 [64, 32, 16, 16] 4,640
│ └─BatchNorm2d: 2-6 [64, 32, 16, 16] 64
│ └─MaxPool2d: 2-7 [64, 32, 8, 8] --
│ └─ReLU: 2-8 [64, 32, 8, 8] --
├─Sequential: 1-3 [64, 64, 4, 4] --
│ └─Conv2d: 2-9 [64, 64, 8, 8] 18,496
│ └─BatchNorm2d: 2-10 [64, 64, 8, 8] 128
│ └─MaxPool2d: 2-11 [64, 64, 4, 4] --
│ └─ReLU: 2-12 [64, 64, 4, 4] --
├─Dropout: 1-4 [64, 1024] --
├─Linear: 1-5 [64, 10] 10,250
├─Linear: 1-6 [64, 10] 110
==========================================================================================
Total params: 34,168
Trainable params: 34,168
Non-trainable params: 0
Total mult-adds (M): 181.82
==========================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 29.37
Params size (MB): 0.14
Estimated Total Size (MB): 30.29
==========================================================================================
更多详情可以参考 github 源码:torchinfo
阅读全文
AI总结
更多推荐
相关推荐
查看更多
A2A

谷歌开源首个标准智能体交互协议Agent2Agent Protocol(A2A)
adk-python

一款开源、代码优先的Python工具包,用于构建、评估和部署灵活可控的复杂 AI agents
Second-Me

开源 AI 身份系统,通过本地训练和部署,模仿用户思维和学习风格,创建专属AI替身,保护隐私安全。
目录
所有评论(0)