Pytorch中如何調(diào)用forward()函數(shù)
Pytorch調(diào)用forward()函數(shù)
Module類是nn模塊里提供的一個(gè)模型構(gòu)造類,是所有神經(jīng)網(wǎng)絡(luò)模塊的基類,我們可以繼承它來(lái)定義我們想要的模型。
下面繼承Module類構(gòu)造本節(jié)開頭提到的多層感知機(jī)。
這里定義的MLP類重載了Module類的__init__函數(shù)和forward函數(shù)。
它們分別用于創(chuàng)建模型參數(shù)和定義前向計(jì)算。
前向計(jì)算也即正向傳播。
import torch from torch import nn ? class MLP(nn.Module): ? ? # 聲明帶有模型參數(shù)的層,這里聲明了兩個(gè)全連接層 ? ? def __init__(self, **kwargs): ? ? ? ? # 調(diào)用MLP父類Module的構(gòu)造函數(shù)來(lái)進(jìn)行必要的初始化。這樣在構(gòu)造實(shí)例時(shí)還可以指定其他函數(shù) ? ? ? ? # 參數(shù),如“模型參數(shù)的訪問(wèn)、初始化和共享”一節(jié)將介紹的模型參數(shù)params ? ? ? ? super(MLP, self).__init__(**kwargs) ? ? ? ? self.hidden = nn.Linear(784, 256) # 隱藏層 ? ? ? ? self.act = nn.ReLU() ? ? ? ? self.output = nn.Linear(256, 10) ?# 輸出層 ? ? ? ? # 定義模型的前向計(jì)算,即如何根據(jù)輸入x計(jì)算返回所需要的模型輸出 ? ? def forward(self, x): ? ? ? ? a = self.act(self.hidden(x)) ? ? ? ? return self.output(a) ?? X = torch.rand(2, 784) net = MLP() print(net) net(X)
輸出:
MLP( (hidden): Linear(in_features=784, out_features=256, bias=True) (act): ReLU() (output): Linear(in_features=256, out_features=10, bias=True) ) tensor([[-0.1798, -0.2253, 0.0206, -0.1067, -0.0889, 0.1818, -0.1474, 0.1845, -0.1870, 0.1970], [-0.1843, -0.1562, -0.0090, 0.0351, -0.1538, 0.0992, -0.0883, 0.0911, -0.2293, 0.2360]], grad_fn=<ThAddmmBackward>)
為什么會(huì)調(diào)用forward()呢,是因?yàn)镸odule中定義了__call__()函數(shù),該函數(shù)調(diào)用了forward()函數(shù),當(dāng)執(zhí)行net(x)的時(shí)候,會(huì)自動(dòng)調(diào)用__call__()函數(shù)
Pytorch函數(shù)調(diào)用的問(wèn)題和源碼解讀
最近用到 softmax 函數(shù),但是發(fā)現(xiàn) softmax 的寫法五花八門,記錄如下
# torch._C._VariableFunctions torch.softmax(x, dim=-1)
# class softmax = torch.nn.Softmax(dim=-1) x=softmax(x)
# function x = torch.nn.functional.softmax(x, dim=-1)
簡(jiǎn)單測(cè)試了一下,用 torch.nn.Softmax 類是最慢的,另外兩個(gè)差不多
torch.nn.Softmax 源碼如下,可以看到這是個(gè)類,而他這里的 return F.softmax(input, self.dim, _stacklevel=5) 調(diào)用的是 torch.nn.functional.softmax
class Softmax(Module): ? ? r"""Applies the Softmax function to an n-dimensional input Tensor ? ? rescaling them so that the elements of the n-dimensional output Tensor ? ? lie in the range [0,1] and sum to 1. ? ? Softmax is defined as: ? ? .. math:: ? ? ? ? \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} ? ? When the input Tensor is a sparse tensor then the unspecifed ? ? values are treated as ``-inf``. ? ? Shape: ? ? ? ? - Input: :math:`(*)` where `*` means, any number of additional ? ? ? ? ? dimensions ? ? ? ? - Output: :math:`(*)`, same shape as the input ? ? Returns: ? ? ? ? a Tensor of the same dimension and shape as the input with ? ? ? ? values in the range [0, 1] ? ? Args: ? ? ? ? dim (int): A dimension along which Softmax will be computed (so every slice ? ? ? ? ? ? along dim will sum to 1). ? ? .. note:: ? ? ? ? This module doesn't work directly with NLLLoss, ? ? ? ? which expects the Log to be computed between the Softmax and itself. ? ? ? ? Use `LogSoftmax` instead (it's faster and has better numerical properties). ? ? Examples:: ? ? ? ? >>> m = nn.Softmax(dim=1) ? ? ? ? >>> input = torch.randn(2, 3) ? ? ? ? >>> output = m(input) ? ? """ ? ? __constants__ = ['dim'] ? ? dim: Optional[int] ? ? def __init__(self, dim: Optional[int] = None) -> None: ? ? ? ? super(Softmax, self).__init__() ? ? ? ? self.dim = dim ? ? def __setstate__(self, state): ? ? ? ? self.__dict__.update(state) ? ? ? ? if not hasattr(self, 'dim'): ? ? ? ? ? ? self.dim = None ? ? def forward(self, input: Tensor) -> Tensor: ? ? ? ? return F.softmax(input, self.dim, _stacklevel=5) ? ? def extra_repr(self) -> str: ? ? ? ? return 'dim={dim}'.format(dim=self.dim)
torch.nn.functional.softmax 函數(shù)源碼如下,可以看到 ret = input.softmax(dim) 實(shí)際上調(diào)用了 torch._C._VariableFunctions 中的 softmax 函數(shù)
def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: ? ? r"""Applies a softmax function. ? ? Softmax is defined as: ? ? :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` ? ? It is applied to all slices along dim, and will re-scale them so that the elements ? ? lie in the range `[0, 1]` and sum to 1. ? ? See :class:`~torch.nn.Softmax` for more details. ? ? Args: ? ? ? ? input (Tensor): input ? ? ? ? dim (int): A dimension along which softmax will be computed. ? ? ? ? dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. ? ? ? ? ? If specified, the input tensor is casted to :attr:`dtype` before the operation ? ? ? ? ? is performed. This is useful for preventing data type overflows. Default: None. ? ? .. note:: ? ? ? ? This function doesn't work directly with NLLLoss, ? ? ? ? which expects the Log to be computed between the Softmax and itself. ? ? ? ? Use log_softmax instead (it's faster and has better numerical properties). ? ? """ ? ? if has_torch_function_unary(input): ? ? ? ? return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) ? ? if dim is None: ? ? ? ? dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) ? ? if dtype is None: ? ? ? ? ret = input.softmax(dim) ? ? else: ? ? ? ? ret = input.softmax(dim, dtype=dtype) ? ? return ret
那么不如直接調(diào)用 built-in C 的函數(shù)?
但是有個(gè)博客 A selective excursion into the internals of PyTorch 里說(shuō)
Note: That bilinear is exported as torch.bilinear is somewhat accidental. Do use the documented interfaces, here torch.nn.functional.bilinear whenever you can!
意思是說(shuō) built-in C 能被 torch.xxx 直接調(diào)用是意外的,強(qiáng)烈建議使用 torch.nn.functional.xxx 這樣的接口
看到最新的 transformer 官方代碼里也用的是 torch.nn.functional.softmax,還是和他們一致更好(雖然他們之前用的是類。。。)
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python cv2.resize函數(shù)high和width注意事項(xiàng)說(shuō)明
這篇文章主要介紹了python cv2.resize函數(shù)high和width注意事項(xiàng)說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-07-07python多進(jìn)程實(shí)現(xiàn)進(jìn)程間通信實(shí)例
這篇文章主要介紹了python多進(jìn)程實(shí)現(xiàn)進(jìn)程間通信實(shí)例,具有一定參考價(jià)值,需要的朋友可以了解下。2017-11-11python實(shí)現(xiàn)的生成隨機(jī)迷宮算法核心代碼分享(含游戲完整代碼)
這篇文章主要介紹了python實(shí)現(xiàn)的隨機(jī)迷宮生成算法核心代碼分享,本文包含一個(gè)簡(jiǎn)單迷宮游戲完整代碼,需要的朋友可以參考下2014-07-07python繪制詞云圖最全教程(自定義png形狀、指定字體、顏色)
詞云圖是一種直觀的方式來(lái)展示文本數(shù)據(jù),它易于理解,能展示出詞語(yǔ)的頻率使用情況,對(duì)于文本分析非常有用,這篇文章主要給大家介紹了python繪制詞云圖(自定義png形狀、指定字體、顏色)的相關(guān)資料,需要的朋友可以參考下2024-05-05Python+pytorch實(shí)現(xiàn)天氣識(shí)別
這篇文章主要為大家詳細(xì)介紹了如何利用Python+pytorch實(shí)現(xiàn)天氣識(shí)別功能,文中的示例代碼講解詳細(xì),具有一定的借鑒價(jià)值,需要的可以參考一下2022-10-10