欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch實現(xiàn)常用乘法算子TensorRT的示例代碼

 更新時間:2022年06月01日 15:04:34   作者:極智視界  
pytorch 用于訓練,TensorRT用于推理是很多AI應用開發(fā)的標配。大家往往更加熟悉 pytorch 的算子,而不太熟悉TensorRT的算子。本文介紹了Pytorch中常用乘法的TensorRT實現(xiàn),感興趣的可以了解一下

本文介紹一下 Pytorch 中常用乘法的 TensorRT 實現(xiàn)。

pytorch 用于訓練,TensorRT 用于推理是很多 AI 應用開發(fā)的標配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,這里拿比較常用的乘法運算在兩種框架下的實現(xiàn)做一個對比,可能會有更加直觀一些的認識。

1.乘法運算總覽

先把 pytorch 中的一些常用的乘法運算進行一個總覽:

  • torch.mm:用于兩個矩陣 (不包括向量) 的乘法,如維度 (m, n) 的矩陣乘以維度 (n, p) 的矩陣;
  • torch.bmm:用于帶 batch 的三維向量的乘法,如維度 (b, m, n) 的矩陣乘以維度 (b, n, p) 的矩陣;
  • torch.mul:用于同維度矩陣的逐像素點相乘,也即點乘,如維度 (m, n) 的矩陣點乘維度 (m, n) 的矩陣。該方法支持廣播,也即支持矩陣和元素點乘;
  • torch.mv:用于矩陣和向量的乘法,矩陣在前,向量在后,如維度 (m, n) 的矩陣乘以維度為 (n) 的向量,輸出維度為 (m);
  • torch.matmul:用于兩個張量相乘,或矩陣與向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
  • @:作用相當于 torch.matmul;
  • *:作用相當于 torch.mul;

如上進行了一些具體羅列,可以歸納出,常用的乘法無非兩種:矩陣乘 和 點乘,所以下面分這兩類進行介紹。

2.乘法算子實現(xiàn)

2.1矩陣乘算子實現(xiàn)

先來看看矩陣乘法的 pytorch 的實現(xiàn) (以下實現(xiàn)在終端):

>>> import torch
>>> # torch.mm
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99, 88)
>>> c = torch.mm(a, b)
>>> c.shape
torch.size([66, 88])
>>>
>>> # torch.bmm
>>> a = torch.randn(3, 66, 99)
>>> b = torch.randn(3, 99, 77)
>>> c = torch.bmm(a, b)
>>> c.shape
torch.size([3, 66, 77])
>>>
>>> # torch.mv
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99)
>>> c = torch.mv(a, b)
>>> c.shape
torch.size([66])
>>>
>>> # torch.matmul
>>> a = torch.randn(32, 3, 66, 99)
>>> b = torch.randn(32, 3, 99, 55)
>>> c = torch.matmul(a, b)
>>> c.shape
torch.size([32, 3, 66, 55])
>>>
>>> # @
>>> d = a @ b
>>> d.shape
torch.size([32, 3, 66, 55])

來看 TensorRT 的實現(xiàn),以上乘法都可使用 addMatrixMultiply 方法覆蓋,對應 torch.matmul,先來看該方法的定義:

//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
  return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}

可以看到這個方法有四個傳參,對應兩個張量和其 operation。來看這個算子在 TensorRT 中怎么添加:

// 構(gòu)造張量 Tensor0
nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 構(gòu)造張量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);

// 添加矩陣乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);

// 獲取輸出
matmulOutput = Matmul_layer->getOputput(0);

2.2點乘算子實現(xiàn)

再來看看點乘的 pytorch 的實現(xiàn) (以下實現(xiàn)在終端):

>>> import torch
>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])

來看 TensorRT 的實現(xiàn),以上乘法都可使用 addScale 方法覆蓋,這在圖像預處理中十分常用,先來看該方法的定義:

//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//!              and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
  return mImpl->addScale(input, mode, shift, scale, power);
}

 可以看到有三個模式:

  • kUNIFORM:weights 為一個值,對應張量乘一個元素;
  • kCHANNEL:weights 維度和輸入張量通道的 c 維度對應,可以做一些以通道為基準的預處理;
  • kELEMENTWISE:weights 維度和輸入張量的 c、h、w 對應,不考慮 batch,所以是輸入的后三維;

再來看這個算子在 TensorRT 中怎么添加:

// 構(gòu)造張量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);

// scalemode選擇,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;

// 構(gòu)建 Weights 類型的 shift、scale、power,其中 volume 為元素數(shù)量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };

// !! 注意這里還需要對 shift、scale、power 的 values 進行賦值,若只是乘法只需要對 scale 進行賦值就行

// 添加張量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);

// 獲取輸出
scaleOutput = Scale_layer->getOputput(0);

有一點你可能會比較疑惑,既然是點乘,那么輸入只需要兩個張量就可以了,為啥這里有 input、shift、scale、power 四個張量這么多呢。解釋一下,input 不用說,就是輸入張量,而 shift 表示加法參數(shù)、scale 表示乘法參數(shù)、power 表示指數(shù)參數(shù),說到這里,你應該能發(fā)現(xiàn),這個函數(shù)除了我們上面講的點乘外還有其他更加豐富的運算功能。

到此這篇關(guān)于Pytorch實現(xiàn)常用乘法算子TensorRT的示例代碼的文章就介紹到這了,更多相關(guān)Pytorch乘法算子TensorRT內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Django利用cookie保存用戶登錄信息的簡單實現(xiàn)方法

    Django利用cookie保存用戶登錄信息的簡單實現(xiàn)方法

    這篇文章主要介紹了Django利用cookie保存用戶登錄信息的簡單實現(xiàn)方法,結(jié)合實例形式分析了Django框架使用cookie保存用戶信息的相關(guān)操作技巧,需要的朋友可以參考下
    2019-05-05
  • Python中遍歷字典過程中更改元素導致異常的解決方法

    Python中遍歷字典過程中更改元素導致異常的解決方法

    這篇文章主要介紹了Python中遍歷字典過程中更改元素導致錯誤的解決方法,針對增刪元素后出現(xiàn)dictionary changed size during iteration的異常解決做出討論和解決,需要的朋友可以參考下
    2016-05-05
  • python PIL中ImageFilter模塊圖片濾波處理和使用方法

    python PIL中ImageFilter模塊圖片濾波處理和使用方法

    這篇文章主要介紹PIL中ImageFilter模塊幾種圖片濾波處理和使用方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2023-11-11
  • 如何用Python實現(xiàn)八數(shù)碼問題

    如何用Python實現(xiàn)八數(shù)碼問題

    這篇文章主要給大家介紹了關(guān)于如何用Python實現(xiàn)八數(shù)碼問題的相關(guān)資料,八數(shù)碼問題是一種經(jīng)典的搜索問題,它的目標是將一個亂序的八數(shù)碼序列變成一個有序的八數(shù)碼序列,通常使用 A* 算法來解決,需要的朋友可以參考下
    2023-10-10
  • 一文秒懂python讀寫csv xml json文件各種騷操作

    一文秒懂python讀寫csv xml json文件各種騷操作

    多年來,數(shù)據(jù)存儲的可能格式顯著增加,但是,在日常使用中,還是以 CSV 、 JSON 和 XML 占主導地位。 在本文中,我將與你分享在Python中使用這三種流行數(shù)據(jù)格式及其之間相互轉(zhuǎn)換的最簡單方法,需要的朋友可以參考下
    2019-07-07
  • ubuntu20.04運用startup application開機自啟動python程序的腳本寫法

    ubuntu20.04運用startup application開機自啟動python程序的腳本寫法

    這篇文章主要介紹了ubuntu20.04運用startup application開機自啟動python程序的腳本寫法,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2023-10-10
  • Python中import的用法陷阱解決盤點小結(jié)

    Python中import的用法陷阱解決盤點小結(jié)

    這篇文章主要為大家介紹了Python中import的用法陷阱解決盤點小結(jié),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-10-10
  • python生成密碼字典的方法

    python生成密碼字典的方法

    今天小編就為大家分享一篇python生成密碼字典的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-07-07
  • 淺析Python語言自帶的數(shù)據(jù)結(jié)構(gòu)有哪些

    淺析Python語言自帶的數(shù)據(jù)結(jié)構(gòu)有哪些

    Python已經(jīng)廣泛的應用于數(shù)據(jù)分析、數(shù)據(jù)挖掘、機器學習等眾多科學計算領(lǐng)域,這篇文章主要介紹了Python語言自帶的數(shù)據(jù)結(jié)構(gòu)有哪些?需要的朋友可以參考下
    2019-08-08
  • django實現(xiàn)用戶登陸功能詳解

    django實現(xiàn)用戶登陸功能詳解

    這篇文章主要介紹了django實現(xiàn)用戶登陸功能詳解,具有一定借鑒價值,需要的朋友可以參考下。
    2017-12-12

最新評論