Python tensorflow與pytorch的浮點(diǎn)運(yùn)算數(shù)如何計(jì)算
1. 引言
FLOPs 是 floating point operations 的縮寫,指浮點(diǎn)運(yùn)算數(shù),可以用來衡量模型/算法的計(jì)算復(fù)雜度。本文主要討論如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相關(guān)工具計(jì)算對(duì)應(yīng)模型的 FLOPs。
2. 模型結(jié)構(gòu)
為了說明方便,先搭建一個(gè)簡單的神經(jīng)網(wǎng)絡(luò)模型,其模型結(jié)構(gòu)以及主要參數(shù)如表1 所示。
表 1 模型結(jié)構(gòu)及主要參數(shù)
Layers | channels | Kernels | Strides | Units | Activation |
---|---|---|---|---|---|
Conv2D | 32 | (4,4) | (1,2) | \ | relu |
GRU | \ | \ | \ | 96 | \ |
Dense | \ | \ | \ | 256 | sigmoid |
用 tensorflow(實(shí)際使用 tensorflow 中的 keras 模塊)實(shí)現(xiàn)該模型的代碼為:
from tensorflow.keras.layers import * from tensorflow.keras.models import load_model, Model def test_model_tf(Input_shape): # shape: [B, C, T, F] main_input = Input(batch_shape=Input_shape, name='main_inputs') conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input) # shape: [B, T, FC] gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv) gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru) output = Dense(256, activation='sigmoid', name='output')(gru) model = Model(inputs=[main_input], outputs=[output]) return model
用 pytorch 實(shí)現(xiàn)該模型的代碼為:
import torch import torch.nn as nn class test_model_torch(nn.Module): def __init__(self): super(test_model_torch, self).__init__() self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2)) self.relu = nn.ReLU() self.gru = nn.GRU(input_size=4064, hidden_size=96) self.fc = nn.Linear(96, 256) self.sigmoid = nn.Sigmoid() def forward(self, inputs): # shape: [B, C, T, F] out = self.conv2d(inputs) out = self.relu(out) # shape: [B, T, FC] batch, channel, frame, freq = out.size() out = torch.reshape(out, (batch, frame, freq*channel)) out, _ = self.gru(out) out = self.fc(out) out = self.sigmoid(out) return out
3. 計(jì)算模型的 FLOPs
本節(jié)討論的版本具體為:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。
3.1. tensorflow 1.12.0
在 tensorflow 1.12.0 環(huán)境中,可以使用以下代碼計(jì)算模型的 FLOPs:
import tensorflow as tf import tensorflow.keras.backend as K def get_flops(model): run_meta = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(graph=K.get_session().graph, run_meta=run_meta, cmd='op', options=opts) return flops.total_float_ops if __name__ == "__main__": x = K.random_normal(shape=(1, 1, 100, 256)) model = test_model_tf(x.shape) print('FLOPs of tensorflow 1.12.0:', get_flops(model))
3.2. tensorflow 2.3.1
在 tensorflow 2.3.1 環(huán)境中,可以使用以下代碼計(jì)算模型的 FLOPs :
import tensorflow.compat.v1 as tf import tensorflow.compat.v1.keras.backend as K tf.disable_eager_execution() def get_flops(model): run_meta = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(graph=K.get_session().graph, run_meta=run_meta, cmd='op', options=opts) return flops.total_float_ops if __name__ == "__main__": x = K.random_normal(shape=(1, 1, 100, 256)) model = test_model_tf(x.shape) print('FLOPs of tensorflow 2.3.1:', get_flops(model))
3.3. pytorch 1.10.1+cu102
在 pytorch 1.10.1+cu102 環(huán)境中,可以使用以下代碼計(jì)算模型的 FLOPs(需要安裝 thop):
import thop x = torch.randn(1, 1, 100, 256) model = test_model_torch() flops, _ = thop.profile(model, inputs=(x,)) print('FLOPs of pytorch 1.10.1:', flops * 2)
需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代碼有乘 2 2 2 操作。
3.4. 結(jié)果對(duì)比
三者計(jì)算出的 FLOPs 分別為:
tensorflow 1.12.0:
tensorflow 2.3.1:
pytorch 1.10.1:
可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的結(jié)果基本在同一個(gè)量級(jí),而與 pytorch 1.10.1 計(jì)算出來的相差甚遠(yuǎn)。但如果將上述模型結(jié)構(gòu)改為只包含第一層 Conv2D,三者計(jì)算出來的 FLOPs 卻又是一致的。所以推斷差異主要來自于 GRU 的 FLOPs。如讀者知道其中詳情,還請(qǐng)不吝賜教。
4. 總結(jié)
本文給出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相關(guān)工具計(jì)算模型 FLOPs 的方法,但從本文所使用的測(cè)試模型來看, tensorflow 與 pytorch 統(tǒng)計(jì)出的結(jié)果相差甚遠(yuǎn)。當(dāng)然,也可以根據(jù)網(wǎng)絡(luò)層的類型及其對(duì)應(yīng)的參數(shù),推導(dǎo)計(jì)算出每個(gè)網(wǎng)絡(luò)層所需的 FLOPs。
到此這篇關(guān)于Python tensorflow與pytorch的浮點(diǎn)運(yùn)算數(shù)如何計(jì)算的文章就介紹到這了,更多相關(guān)Python tensorflow與pytorch內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- python深度學(xué)習(xí)tensorflow實(shí)例數(shù)據(jù)下載與讀取
- 使用Python、TensorFlow和Keras來進(jìn)行垃圾分類的操作方法
- Python基于TensorFlow接口實(shí)現(xiàn)深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)回歸
- Python基于Tensorflow2.X實(shí)現(xiàn)汽車油耗預(yù)測(cè)
- python深度學(xué)習(xí)tensorflow訓(xùn)練好的模型進(jìn)行圖像分類
- python深度學(xué)習(xí)tensorflow1.0參數(shù)和特征提取
- python如何下載指定版本TensorFlow
相關(guān)文章
Python常見字符串操作函數(shù)小結(jié)【split()、join()、strip()】
這篇文章主要介紹了Python常見字符串操作函數(shù),結(jié)合實(shí)例形式總結(jié)分析了split()、join()及strip()的常見使用技巧與注意事項(xiàng),需要的朋友可以參考下2018-02-02python基礎(chǔ)入門詳解(文件輸入/輸出 內(nèi)建類型 字典操作使用方法)
這篇文章主要介紹了python基礎(chǔ)入門,包括文件輸入/輸出、內(nèi)建類型、字典操作等使用方法2013-12-12python3格式化字符串 f-string的高級(jí)用法(推薦)
從Python 3.6開始,f-string是格式化字符串的一種很好的新方法。與其他格式化方式相比,它們不僅更易讀,更簡潔,不易出錯(cuò),而且速度更快!本文重點(diǎn)給大家介紹python3格式化字符串 f-string的高級(jí)用法,一起看看吧2020-03-03Python編碼時(shí)應(yīng)該注意的幾個(gè)情況
對(duì)于Python程序員,你需要注意一下本文所提到的這些事情。你也可以看看Zen of Python(Python之禪),這里面提到了一些注意事項(xiàng),并配以示例,可以幫助你快速提高2013-03-03python中resample函數(shù)實(shí)現(xiàn)重采樣和降采樣代碼
今天小編就為大家分享一篇python中resample函數(shù)實(shí)現(xiàn)重采樣和降采樣代碼,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02用Python做個(gè)自動(dòng)化彈鋼琴腳本實(shí)現(xiàn)天空之城彈奏
突然靈機(jī)一動(dòng),能不能用Python自動(dòng)化腳本彈奏一曲美妙的鋼琴曲呢?今天就一起帶大家如何用Python實(shí)現(xiàn)自動(dòng)化彈出一首《天空之城》有需要的朋友可以借鑒參考下2021-09-09Python+tkinter使用40行代碼實(shí)現(xiàn)計(jì)算器功能
這篇文章主要為大家詳細(xì)介紹了Python+tkinter使用40行代碼實(shí)現(xiàn)計(jì)算器功能,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-01-01