欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch單元測試的實現(xiàn)示例

 更新時間:2024年04月18日 10:57:40   作者:Hi20240217  
單元測試是一種軟件測試方法,本文主要介紹了pytorch單元測試的實現(xiàn)示例,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧

希望測試pytorch各種算子、block、網(wǎng)絡(luò)等在不同硬件平臺,不同軟件版本下的計算誤差、耗時、內(nèi)存占用等指標.

本文基于torch.testing._internal

一.公共模塊[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as dist

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" 

device="cpu"
device_type="cpu"
device_name="cpu"

try:
    if torch.cuda.is_available():     
        device_name=torch.cuda.get_device_name().replace(" ","")
        device="cuda:0"
        device_type="cuda"
        ccl_backend='nccl'
except:
    pass

host_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #從環(huán)境變量中獲取sdk版本號
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目錄
device_count=torch.cuda.device_count()

if not os.path.exists(metric_data_root):
    os.makedirs(metric_data_root)

def device_warmup(device):
    '''設(shè)備warmup,確保設(shè)備已經(jīng)正常工作,排除設(shè)備初始化的耗時'''
    left = torch.rand([128,512], dtype = torch.float16).to(device)
    right = torch.rand([512,128], dtype = torch.float16).to(device)
    out=torch.matmul(left,right)
    torch.cuda.synchronize()

torch.manual_seed(1) 
np.random.seed(1)

def loop_decorator(loops,rank=0):
    '''循環(huán)裝飾器,用于統(tǒng)計函數(shù)的執(zhí)行時間,內(nèi)存占用等'''
    def decorator(func):
        def wrapper(*args,**kwargs):
            latency=[]
            memory_allocated_t0=torch.cuda.memory_allocated(rank)
            for _ in range(loops):
                input_copy=[x.clone() for x in args]
                beg= datetime.now().timestamp() * 1e6
                pred= func(*input_copy)
                gt=kwargs["golden"]
                torch.cuda.synchronize()
                end=datetime.now().timestamp() * 1e6
                mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
                latency.append(end-beg)
            memory_allocated_t1=torch.cuda.memory_allocated(rank)
            avg_latency=np.mean(latency[len(latency)//2:]).round(3)
            first_latency=latency[0]
            return { "first_latency":first_latency,"avg_latency":avg_latency,
                      "memory_allocated":memory_allocated_t1-memory_allocated_t0,
                      "mse":mse}
        return wrapper
    return decorator

class TorchUtMetrics:
    '''用于統(tǒng)計測試結(jié)果,比較之前的最小值'''
    def __init__(self,ut_name,thresold=0.2,rank=0):
        self.ut_name=f"{ut_name}_{rank}"
        self.thresold=thresold
        self.rank=rank
        self.data={"ut_name":self.ut_name,"metrics":[]}
        self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
        try:
            with open(self.metrics_path,"r") as f:
                self.data=json.loads(f.read())
        except:
            pass

    def __enter__(self):
        self.beg= datetime.now().timestamp() * 1e6
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):        
        self.report()
        self.save_data()

    def save_data(self):
        with open(self.metrics_path,"w") as f:
            f.write(json.dumps(self.data,indent=4))

    def set_metrics(self,metrics):
        self.end=datetime.now().timestamp() * 1e6
        item=collections.OrderedDict()
        item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        item["sdk_version"]=sdk_version
        item["device_name"]=device_name
        item["host_name"]=host_name
        item["metrics"]=metrics
        item["metrics"]["e2e_time"]=self.end-self.beg
        self.cur_item=item
        self.data["metrics"].append(self.cur_item)

    def get_metric_names(self):
        return self.data["metrics"][0]["metrics"].keys()

    def get_min_metric(self,metric_name,devicename=None):
        min_value=0
        min_value_index=-1
        for idx,item in enumerate(self.data["metrics"]):
            if devicename and (devicename!=item['device_name']):                
                continue            
            val=float(item["metrics"][metric_name])
            if min_value_index==-1 or val<min_value:
                min_value=val
                min_value_index=idx
        return min_value,min_value_index

    def get_metric_info(self,index):
        metrics=self.data["metrics"][index]
        return f'{metrics["device_name"]}@{metrics["sdk_version"]}'

    def report(self):
        assert len(self.data["metrics"])>0
        for metric_name in self.get_metric_names():
            min_value,min_value_index=self.get_min_metric(metric_name)
            min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
            cur_value=float(self.cur_item["metrics"][metric_name])
            print(f"-------------------------------{metric_name}-------------------------------")
            print(f"{cur_value}#{device_name}@{sdk_version}")
            if min_value_index_same_dev>=0:
                print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
            if min_value_index>=0:
                print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子測試[test_clone.py]

from common import *
class TestCaseClone(TestCase):
    #如果不滿足條件,則跳過這個測試
    @unittest.skipIf(device_count>1, "Not enough devices") 
    def test_todo(self):
        print(".TODO")

    #框架會自動遍歷以下參數(shù)組合
    @parametrize("shape", [(10240,20480),(128,256)])
    @parametrize("dtype", [torch.float16,torch.float32])
    def test_clone(self,shape,dtype):
        
        #讓這個函數(shù)循環(huán)執(zhí)行l(wèi)oops次,統(tǒng)計第一次執(zhí)行的耗時、后半段的平均時間、整個執(zhí)行過程總的GPU內(nèi)存使用量
        @loop_decorator(loops=5)
        def run(input_dev):
            output=input_dev.clone()
            return output
        
        #記錄整個測試的總耗時,保存統(tǒng)計量,輸出摘要(self._testMethodName:測試方法,result:函數(shù)返回值,metrics:統(tǒng)計量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
            input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=input_host.cpu())
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        
instantiate_parametrized_tests(TestCaseClone)

if __name__ == "__main__":
    run_tests()

三.集合通信測試[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):
    '''CCL測試用例'''
    def _create_process_group_vccl(self, world_size, store):
        dist.init_process_group(
            ccl_backend, world_size=world_size, rank=self.rank, store=store
        )        
        pg = dist.distributed_c10d._get_default_group()
        return pg

    def setUp(self):
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    @property
    def world_size(self):
        return 4
      
    #框架會自動遍歷以下參數(shù)組合
    @unittest.skipIf(device_count<4, "Not enough devices") 
    @parametrize("op",[dist.ReduceOp.SUM])
    @parametrize("shape", [(1024,8192)])
    @parametrize("dtype", [torch.int64])
    def test_allreduce(self,op,shape,dtype):
        if self.rank >= self.world_size:
            return
        
        store = dist.FileStore(self.file_name, self.world_size)
        pg = self._create_process_group_vccl(self.world_size, store)
        if not torch.distributed.is_initialized():
            return
    
        torch.cuda.set_device(self.rank)
        device = torch.device(device_type,self.rank)
        device_warmup(device)
        #讓這個函數(shù)循環(huán)執(zhí)行l(wèi)oops次,統(tǒng)計第一次執(zhí)行的耗時、后半段的平均時間、整個執(zhí)行過程總的GPU內(nèi)存使用量
        @loop_decorator(loops=5,rank=self.rank)
        def run(input_dev):
            dist.all_reduce(input_dev, op=op)
            return input_dev
        
        #記錄整個測試的總耗時,保存統(tǒng)計量,輸出摘要(self._testMethodName:測試方法,result:函數(shù)返回值,metrics:統(tǒng)計量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
            input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
            gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
            gt_=gt[0]
            for i in range(1,self.world_size):
                gt_=gt_+gt[i]
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=gt_)
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        dist.destroy_process_group(pg)
    
instantiate_parametrized_tests(TestCCL)

if __name__ == "__main__":
    run_tests()

四.測試命令

# 運行所有的測試
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./

# 運行某一個測試
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.測試報告

在這里插入圖片描述

到此這篇關(guān)于pytorch單元測試的實現(xiàn)示例的文章就介紹到這了,更多相關(guān)pytorch單元測試內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家! 

相關(guān)文章

  • Python Pydantic進行數(shù)據(jù)驗證的方法詳解

    Python Pydantic進行數(shù)據(jù)驗證的方法詳解

    在 Python 中,有許多庫可用于數(shù)據(jù)驗證和處理,其中一個流行的選擇是 Pydantic,下面就跟隨小編一起學習一下Pydantic 的基本概念和用法吧
    2024-01-01
  • PyTorch實現(xiàn)聯(lián)邦學習的基本算法FedAvg

    PyTorch實現(xiàn)聯(lián)邦學習的基本算法FedAvg

    這篇文章主要為大家介紹了PyTorch實現(xiàn)聯(lián)邦學習的基本算法FedAvg,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2022-05-05
  • Python輕松管理與操作文件的技巧分享

    Python輕松管理與操作文件的技巧分享

    在日常開發(fā)中,我們經(jīng)常會遇到需要對文件進行操作的場景,如讀寫文件、文件夾操作等。本文將為大家介紹一些 Python 中處理文件的實用技巧,讓你的工作更高效
    2023-05-05
  • python matplotlib如何給圖中的點加標簽

    python matplotlib如何給圖中的點加標簽

    這篇文章主要介紹了python matplotlib給圖中的點加標簽,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2019-11-11
  • Python實現(xiàn)數(shù)據(jù)庫與Excel文件之間的數(shù)據(jù)自動化導(dǎo)入與導(dǎo)出

    Python實現(xiàn)數(shù)據(jù)庫與Excel文件之間的數(shù)據(jù)自動化導(dǎo)入與導(dǎo)出

    數(shù)據(jù)庫和Excel文件是兩種常見且重要的數(shù)據(jù)存儲方式,本文將介紹如何使用Python有效地實現(xiàn)數(shù)據(jù)庫與Excel文件之間的數(shù)據(jù)自動化導(dǎo)入與導(dǎo)出,以SQLite數(shù)據(jù)庫為例,需要的朋友可以參考下
    2024-06-06
  • numpy矩陣乘法中的multiply,matmul和dot的使用

    numpy矩陣乘法中的multiply,matmul和dot的使用

    本文主要介紹了numpy矩陣乘法中的multiply,matmul和dot的使用,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2023-02-02
  • python利用faker庫批量生成測試數(shù)據(jù)

    python利用faker庫批量生成測試數(shù)據(jù)

    小編經(jīng)常需要批量測試一些數(shù)據(jù),有時候測試環(huán)境又暫時沒數(shù)據(jù),特意找了一下,發(fā)現(xiàn)有一個可批量生成數(shù)據(jù)的python庫—-faker,現(xiàn)在就介紹一下它的使用方法,如果你不想一行一行輸入代碼,小編提供了完整測試代碼,見文末代碼章節(jié)。
    2020-10-10
  • 深入理解Javascript中的this關(guān)鍵字

    深入理解Javascript中的this關(guān)鍵字

    這篇文章主要介紹了深入理解Javascript中的this關(guān)鍵字,本文講解了方法調(diào)用模式、函數(shù)調(diào)用模式、構(gòu)造器調(diào)用模式、apply調(diào)用模式 中this的不同之處,需要的朋友可以參考下
    2015-03-03
  • 用Python編寫一個每天都在系統(tǒng)下新建一個文件夾的腳本

    用Python編寫一個每天都在系統(tǒng)下新建一個文件夾的腳本

    這篇文章主要介紹了用Python編寫一個每天都在系統(tǒng)下新建一個文件夾的腳本,雖然這個實現(xiàn)聽起來有點無聊...但卻是學習os和time模塊的一個小實踐,需要的朋友可以參考下
    2015-05-05
  • 在python3中pyqt5和mayavi不兼容問題的解決方法

    在python3中pyqt5和mayavi不兼容問題的解決方法

    今天小編就為大家分享一篇在python3中pyqt5和mayavi不兼容問題的解決方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01

最新評論