Pytorch中的torch.where函數(shù)使用
使用torch.where函數(shù)
首先我們看一下Pytorch中torch.where函數(shù)是怎樣定義的:
@overload def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
torch.where函數(shù)的功能如下:
torch.where(condition, x, y)
- condition:判斷條件
- x:若滿足條件,則取x中元素
- y:若不滿足條件,則取y中元素
以具體實(shí)例看一下torch.where函數(shù)的效果:
import torch # 條件 condition = torch.rand(3, 2) print(condition) # 滿足條件則取x中對(duì)應(yīng)元素 x = torch.ones(3, 2) print(x) # 不滿足條件則取y中對(duì)應(yīng)元素 y = torch.zeros(3, 2) print(y) # 條件判斷后的結(jié)果 result = torch.where(condition > 0.5, x, y) print(result)
結(jié)果如下:
tensor([[0.3224, 0.5789],
[0.8341, 0.1673],
[0.1668, 0.4933]])
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
tensor([[0., 1.],
[1., 0.],
[0., 0.]])
可以看到torch.where函數(shù)會(huì)對(duì)condition中的元素逐一進(jìn)行判斷,根據(jù)判斷的結(jié)果選取x或y中的值,所以要求x和y應(yīng)該與condition形狀相同。
torch.where(),np.where()兩種用法,及np.argwhere()尋找張量(tensor)和數(shù)組中為0的索引
1.torch.where()
torch.where()有兩種用法,
- 當(dāng)輸入?yún)?shù)為三個(gè)時(shí),即torch.where(condition, x, y),返回滿足 x if condition else y的tensor,注意x,y必須為tensor
- 當(dāng)輸入?yún)?shù)為一個(gè)時(shí),即torch.where(condition),返回滿足condition的tensor索引的元組(tuple)
代碼示例
torch.where(condition, x, y)
代碼
import torch import numpy as np # 初始化兩個(gè)tensor x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) # 尋找滿足x中大于3的元素,否則得到y(tǒng)對(duì)應(yīng)位置的元素 arr0 = torch.where(x>=3, x, y) #輸入?yún)?shù)為3個(gè) print(x, '\n', y) print(arr0, '\n', type(arr0))
結(jié)果
>>> x
tensor([[1, 2, 3, 0, 6],
[4, 6, 2, 1, 0],
[4, 3, 0, 1, 1]])
>>> y
tensor([[0, 5, 1, 4, 2],
[5, 7, 1, 2, 9],
[1, 3, 5, 6, 6]])
>>> arr0
tensor([[0, 5, 3, 4, 6],
[4, 6, 1, 2, 9],
[4, 3, 5, 6, 6]])
>>> type(arr0)
<class 'torch.Tensor'>
arr0的類型為<class 'torch.Tensor'>
torch.where(condition)
以尋找tensor中為0的索引為例
代碼
import torch import numpy as np x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) # 返回x中0元素的索引 index0 = torch.where(x==0) # 輸入?yún)?shù)為1個(gè) print(index0,'\n', type(index0))
結(jié)果
>>> index0
(tensor([0, 1, 2]), tensor([3, 4, 2]))
>>> type(index0)
<class 'tuple'>
其中[0, 1, 2]是0元素坐標(biāo)的行索引,[3, 4, 2]是0元素坐標(biāo)的列索引,注意,最終得到的是tuple類型的返回值,元組中包含了tensor
2.np.where()
np.where()用法與torch.where()用法類似,也包括兩種用法,但是不同的是輸入值類型和返回值的類型
代碼示例
np.where(condition, x, y)和np.where(condition),輸入x,y可以為非tensor
代碼
import torch import numpy as np x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) arr1 = np.where(x>=3, x, y) # 輸入?yún)?shù)為3個(gè) index0 = torch.where(x==0) # 輸入?yún)?shù)為1個(gè) print(arr1,'\n',type(arr1)) print(index1,'\n', type(index1))
結(jié)果
>>> arr1
[[0 5 3 4 6]
[4 6 1 2 9]
[4 3 5 6 6]]
>>> type(arr1)
<class 'numpy.ndarray'>
>>> index1
(array([0, 1, 2]), array([3, 4, 2]))
>>> type(index1)
<class 'tuple'>
注意,np.where()和torch.where()的返回值類型不同
3.np.argwhere(condition)
尋找符合contion的元素索引
代碼示例
代碼
import torch import numpy as np x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) index2 = np.argwhere(x==0) # 尋找元素為0的索引 print(index2,'\n', type(index2))
結(jié)果
>>> index2
tensor([[0, 1, 2],
[3, 4, 2]])
>>> type(index2)
<class 'torch.Tensor'>
注意返回值的類型
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- PyTorch中torch.load()的用法和應(yīng)用
- python中torch.load中的map_location參數(shù)使用
- Pytorch中的torch.nn.Linear()方法用法解讀
- python中的List sort()與torch.sort()
- 關(guān)于torch.scatter與torch_scatter庫(kù)的使用整理
- PyTorch函數(shù)torch.cat與torch.stac的區(qū)別小結(jié)
- pytorch.range()和pytorch.arange()的區(qū)別及說(shuō)明
- 使用with torch.no_grad():顯著減少測(cè)試時(shí)顯存占用
- PyTorch中torch.save()的用法和應(yīng)用小結(jié)
相關(guān)文章
10個(gè)python爬蟲入門實(shí)例(小結(jié))
這篇文章主要介紹了10個(gè)python爬蟲入門實(shí)例(小結(jié)),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11Python實(shí)現(xiàn)企業(yè)微信通知機(jī)器人的方法詳解
這篇文章主要為大家詳細(xì)介紹了如何使用Python實(shí)現(xiàn)對(duì)企業(yè)微信進(jìn)行群通知的功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-02-02pytorch實(shí)現(xiàn)好萊塢明星識(shí)別的示例代碼
本文主要介紹了pytorch實(shí)現(xiàn)好萊塢明星識(shí)別,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-01-01從np.random.normal()到正態(tài)分布的擬合操作
這篇文章主要介紹了從np.random.normal()到正態(tài)分布的擬合操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-06-06Python如何保留float類型小數(shù)點(diǎn)后3位
這篇文章主要介紹了Python如何保留float類型小數(shù)點(diǎn)后3位,具有很好的參考價(jià)值,希望對(duì)的大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-05-05