欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch backward報錯2次訪問計算圖需要retain_graph=True的情況詳解

 更新時間:2024年02月20日 09:47:04   作者:培之  
這篇文章主要介紹了Pytorch backward報錯2次訪問計算圖需要retain_graph=True的情況,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教

backward報錯2次訪問計算圖需要 retain_graph=True 的一種情況

錯誤代碼

錯誤的原因在于

y1 = 0.5*x*2-1.2*x
y2 = x**3

沒有放到循環(huán)里面,沒有隨著 x 的優(yōu)化而相應變化。

import torch
import numpy as np
import torch.optim as optim

torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.tensor([1.0, 2.0, 3.0,4.5], dtype=torch.float32, requires_grad=True, device=device)


y_GT= torch.tensor([10, -20, -30,45], dtype=torch.float32,  device=device)

print(f'x{x}')


optimizer = optim.Adam([x], lr=1)
y1 = 0.5*x*2-1.2*x
y2 = x**3

for i in range(10):

    print(f'{i}: x{x}')
    optimizer.zero_grad()


    loss = (y1+y2-y_GT).mean()
    loss.backward()
    optimizer.step()
    print(f'{i}: x{x}')

正確代碼

import torch
import numpy as np
import torch.optim as optim

torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.tensor([1.0, 2.0, 3.0,4.5], dtype=torch.float32, requires_grad=True, device=device)


y_GT= torch.tensor([10, -20, -30,45], dtype=torch.float32,  device=device)

print(f'x{x}')


optimizer = optim.Adam([x], lr=1)


for i in range(10):

    print(f'{i}: x{x}')
    optimizer.zero_grad()
    y1 = 0.5*x*2-1.2*x
    y2 = x**3

    loss = (y1+y2-y_GT).mean()
    loss.backward()
    optimizer.step()
    print(f'{i}: x{x}')

總結(jié)

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

最新評論