欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch中的torch.nn.Linear()方法用法解讀

 更新時間:2024年02月26日 10:09:14   作者:擁抱晨曦之溫暖  
這篇文章主要介紹了Pytorch中的torch.nn.Linear()方法用法,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教

Pytorch torch.nn.Linear()方法

torch.nn.Linear()作為深度學(xué)習(xí)中最簡單的線性變換方法,其主要作用是對輸入數(shù)據(jù)應(yīng)用線性轉(zhuǎn)換

看一下官方的解釋及介紹

class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor
 
    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
 
    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
 
    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)
 
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
 
 
# This class exists solely for Transformer; it has an annotation stating
# that bias is never None, which appeases TorchScript

這里我們主要看__init__()方法,很容易知道,當(dāng)我們使用這個方法時一般需要傳入2~3個參數(shù),分別是in_features: int, out_features: int, bias: bool = True,第三個參數(shù)是說是否加偏置(bias),簡單來講,這個函數(shù)其實(shí)就是一個'一次函數(shù)':y = xA^T + b,(T表示張量A的轉(zhuǎn)置),首先super(Linear, self).__init__()就是老生常談的方法,之后初始化in_features和out_features,接下來就是比較重要的weight的設(shè)置,我們可以很清晰的看到weight的shape是(out_features,in_features)的,而我們在做xA^T時,并不是x和A^T相乘的,而是x和A.weight^T相乘的,這里需要大大留意,也就是說先對A做轉(zhuǎn)置得到A.weight,然后在丟入y = xA^T + b中,得出結(jié)果。

接下來奉上一個小例子來實(shí)踐一下:

import torch
 
# 隨機(jī)初始化一個shape為(128,20)的Tensor
x = torch.randn(128,20)
# 構(gòu)造線性變換函數(shù)y = xA^T + b,且參數(shù)(20,30)指的是A的shape,則A.weight的shape就是(30,20)了
y= torch.nn.Linear(20,30)
output = y(x)
# 按照以上邏輯使用torch中的簡單乘法函數(shù)進(jìn)行檢驗(yàn),結(jié)果很顯然與上述符合
# 下面的y.weight可以理解為一個shape為(30,20)的一個可學(xué)習(xí)的矩陣,.t()表示轉(zhuǎn)置
# y.bias若為TRUE,則bias是一個Tensor,且其shape為out_features,在該程序中應(yīng)為30
# 更加細(xì)致的表達(dá)一下y = (128 * 20) * (30 * 20)^T + (if bias (1,30) ,else: 0)
ans = torch.mm(x,y.weight.t())+y.bias
print('ans.shape:\n',ans.shape)
print(torch.equal(ans,output))

對torch.nn.Linear的理解

torch.nn.Linear是pytorch的線性變換層

定義如下:

Linear(in_features: int, out_features: int, bias: bool = True, device: Any | None = None, dtype: Any | None = None)

全連接層 Fully Connect 一般就就用這個函數(shù)來實(shí)現(xiàn)。

因此在潛意識里,變換的輸入張量的 shape 為 (batchsize, in_features),輸出張量的 shape 為 (batchsize, out_features)。

當(dāng)然這是常用的方式,但是 Linear 的輸入張量的維度其實(shí)并不需要必須為上述的二維,多維也是完全可以的,Linear 僅是對輸入的最后一維做線性變換,不影響其他維。

可以看下官網(wǎng)的解釋

Linear — PyTorch 1.11.0 documentation

一個例子

如下:

import torch
input = torch.randn(30, 20, 10)  # [30, 20, 10]
linear = torch.nn.Linear(10, 15)  # (*, 10) --> (*, 15)
output = linear(input)
print(output.size()) # 輸出 [30, 20, 15]

總結(jié)

以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python3.5集合及其常見運(yùn)算實(shí)例詳解

    Python3.5集合及其常見運(yùn)算實(shí)例詳解

    這篇文章主要介紹了Python3.5集合及其常見運(yùn)算,結(jié)合實(shí)例形式分析了Python3.5集合的定義、功能、交集、并集、差集等常見操作技巧與相關(guān)注意事項(xiàng),需要的朋友可以參考下
    2019-05-05
  • Pytorch配置GPU環(huán)境方式

    Pytorch配置GPU環(huán)境方式

    這篇文章主要介紹了Pytorch配置GPU環(huán)境方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2024-02-02
  • Django項(xiàng)目如何配置Memcached和Redis緩存?選擇哪個更有優(yōu)勢?

    Django項(xiàng)目如何配置Memcached和Redis緩存?選擇哪個更有優(yōu)勢?

    這篇文章主要介紹了Django項(xiàng)目如何配置Memcached和Redis緩存,幫助大家更好的理解和學(xué)習(xí)使用django框架,感興趣的朋友可以了解下
    2021-04-04
  • Python內(nèi)置函數(shù)map()的具體使用

    Python內(nèi)置函數(shù)map()的具體使用

    Python中的map()函數(shù)是一個高效的內(nèi)置函數(shù),用于將指定函數(shù)應(yīng)用于序列的每個元素,通過接收一個函數(shù)和一個或多個序列,本文就來詳細(xì)的介紹一下如何使用,感興趣的可以了解一下
    2024-09-09
  • NumPy庫中np.mean的具體使用

    NumPy庫中np.mean的具體使用

    np.mean?是 NumPy 庫中的一個函數(shù),用于計算給定數(shù)組或數(shù)組元素的算術(shù)平均值,本文主要介紹了NumPy庫中np.mean的具體使用,具有一定的參考價值,感興趣的可以了解一下
    2025-04-04
  • Django項(xiàng)目中表的查詢的操作

    Django項(xiàng)目中表的查詢的操作

    這篇文章主要介紹了Django項(xiàng)目中表的查詢的操作,文中給大家提到了Django項(xiàng)目 ORM常用的十三種查詢方法,結(jié)合實(shí)例代碼給大家介紹的非常詳細(xì),需要的朋友可以參考下
    2022-09-09
  • Python運(yùn)行異常管理解決方案

    Python運(yùn)行異常管理解決方案

    這篇文章主要介紹了Python運(yùn)行異常管理解決方案,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2020-03-03
  • python 遍歷可迭代對象的實(shí)現(xiàn)方法

    python 遍歷可迭代對象的實(shí)現(xiàn)方法

    本文主要介紹了python 遍歷可迭代對象的實(shí)現(xiàn)方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02
  • 簡單實(shí)例帶你了解Python的編譯和執(zhí)行全過程

    簡單實(shí)例帶你了解Python的編譯和執(zhí)行全過程

    python 是一種解釋型的編程語言,所以不像編譯型語言那樣需要顯式的編譯過程。然而,在 Python 代碼執(zhí)行之前,它需要被解釋器轉(zhuǎn)換成字節(jié)碼,這個過程就是 Python 的編譯過程,還不知道的朋友快來看看吧
    2023-04-04
  • 詳解python連接telnet和ssh的兩種方式

    詳解python連接telnet和ssh的兩種方式

    本文主要介紹了python連接telnet和ssh的兩種方式,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2021-10-10

最新評論