pytorch單元測(cè)試的實(shí)現(xiàn)示例
希望測(cè)試pytorch各種算子、block、網(wǎng)絡(luò)等在不同硬件平臺(tái),不同軟件版本下的計(jì)算誤差、耗時(shí)、內(nèi)存占用等指標(biāo).
本文基于torch.testing._internal
一.公共模塊[common.py]
import torch from torch import nn import math import torch.nn.functional as F import time import os import socket import sys from datetime import datetime import numpy as np import collections import math import json import copy import traceback import subprocess import unittest import torch import inspect from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests from torch.testing._internal.common_distributed import MultiProcessTestCase import torch.distributed as dist os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" os.environ["RANDOM_SEED"] = "0" device="cpu" device_type="cpu" device_name="cpu" try: if torch.cuda.is_available(): device_name=torch.cuda.get_device_name().replace(" ","") device="cuda:0" device_type="cuda" ccl_backend='nccl' except: pass host_name=socket.gethostname() sdk_version=os.getenv("SDK_VERSION","") #從環(huán)境變量中獲取sdk版本號(hào) metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data") #日志存放的目錄 device_count=torch.cuda.device_count() if not os.path.exists(metric_data_root): os.makedirs(metric_data_root) def device_warmup(device): '''設(shè)備warmup,確保設(shè)備已經(jīng)正常工作,排除設(shè)備初始化的耗時(shí)''' left = torch.rand([128,512], dtype = torch.float16).to(device) right = torch.rand([512,128], dtype = torch.float16).to(device) out=torch.matmul(left,right) torch.cuda.synchronize() torch.manual_seed(1) np.random.seed(1) def loop_decorator(loops,rank=0): '''循環(huán)裝飾器,用于統(tǒng)計(jì)函數(shù)的執(zhí)行時(shí)間,內(nèi)存占用等''' def decorator(func): def wrapper(*args,**kwargs): latency=[] memory_allocated_t0=torch.cuda.memory_allocated(rank) for _ in range(loops): input_copy=[x.clone() for x in args] beg= datetime.now().timestamp() * 1e6 pred= func(*input_copy) gt=kwargs["golden"] torch.cuda.synchronize() end=datetime.now().timestamp() * 1e6 mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item() latency.append(end-beg) memory_allocated_t1=torch.cuda.memory_allocated(rank) avg_latency=np.mean(latency[len(latency)//2:]).round(3) first_latency=latency[0] return { "first_latency":first_latency,"avg_latency":avg_latency, "memory_allocated":memory_allocated_t1-memory_allocated_t0, "mse":mse} return wrapper return decorator class TorchUtMetrics: '''用于統(tǒng)計(jì)測(cè)試結(jié)果,比較之前的最小值''' def __init__(self,ut_name,thresold=0.2,rank=0): self.ut_name=f"{ut_name}_{rank}" self.thresold=thresold self.rank=rank self.data={"ut_name":self.ut_name,"metrics":[]} self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon") try: with open(self.metrics_path,"r") as f: self.data=json.loads(f.read()) except: pass def __enter__(self): self.beg= datetime.now().timestamp() * 1e6 return self def __exit__(self, exc_type, exc_val, exc_tb): self.report() self.save_data() def save_data(self): with open(self.metrics_path,"w") as f: f.write(json.dumps(self.data,indent=4)) def set_metrics(self,metrics): self.end=datetime.now().timestamp() * 1e6 item=collections.OrderedDict() item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') item["sdk_version"]=sdk_version item["device_name"]=device_name item["host_name"]=host_name item["metrics"]=metrics item["metrics"]["e2e_time"]=self.end-self.beg self.cur_item=item self.data["metrics"].append(self.cur_item) def get_metric_names(self): return self.data["metrics"][0]["metrics"].keys() def get_min_metric(self,metric_name,devicename=None): min_value=0 min_value_index=-1 for idx,item in enumerate(self.data["metrics"]): if devicename and (devicename!=item['device_name']): continue val=float(item["metrics"][metric_name]) if min_value_index==-1 or val<min_value: min_value=val min_value_index=idx return min_value,min_value_index def get_metric_info(self,index): metrics=self.data["metrics"][index] return f'{metrics["device_name"]}@{metrics["sdk_version"]}' def report(self): assert len(self.data["metrics"])>0 for metric_name in self.get_metric_names(): min_value,min_value_index=self.get_min_metric(metric_name) min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name) cur_value=float(self.cur_item["metrics"][metric_name]) print(f"-------------------------------{metric_name}-------------------------------") print(f"{cur_value}#{device_name}@{sdk_version}") if min_value_index_same_dev>=0: print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}") if min_value_index>=0: print(f"{min_value}#{self.get_metric_info(min_value_index)}")
二.普通算子測(cè)試[test_clone.py]
from common import * class TestCaseClone(TestCase): #如果不滿足條件,則跳過(guò)這個(gè)測(cè)試 @unittest.skipIf(device_count>1, "Not enough devices") def test_todo(self): print(".TODO") #框架會(huì)自動(dòng)遍歷以下參數(shù)組合 @parametrize("shape", [(10240,20480),(128,256)]) @parametrize("dtype", [torch.float16,torch.float32]) def test_clone(self,shape,dtype): #讓這個(gè)函數(shù)循環(huán)執(zhí)行l(wèi)oops次,統(tǒng)計(jì)第一次執(zhí)行的耗時(shí)、后半段的平均時(shí)間、整個(gè)執(zhí)行過(guò)程總的GPU內(nèi)存使用量 @loop_decorator(loops=5) def run(input_dev): output=input_dev.clone() return output #記錄整個(gè)測(cè)試的總耗時(shí),保存統(tǒng)計(jì)量,輸出摘要(self._testMethodName:測(cè)試方法,result:函數(shù)返回值,metrics:統(tǒng)計(jì)量) with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m: input_host=torch.ones(shape,dtype=dtype)*np.random.rand() input_dev=input_host.to(device) metrics=run(input_dev,golden=input_host.cpu()) m.set_metrics(metrics) assert(metrics["mse"]==0) instantiate_parametrized_tests(TestCaseClone) if __name__ == "__main__": run_tests()
三.集合通信測(cè)試[test_ccl.py]
from common import * class TestCCL(MultiProcessTestCase): '''CCL測(cè)試用例''' def _create_process_group_vccl(self, world_size, store): dist.init_process_group( ccl_backend, world_size=world_size, rank=self.rank, store=store ) pg = dist.distributed_c10d._get_default_group() return pg def setUp(self): super().setUp() self._spawn_processes() def tearDown(self): super().tearDown() try: os.remove(self.file_name) except OSError: pass @property def world_size(self): return 4 #框架會(huì)自動(dòng)遍歷以下參數(shù)組合 @unittest.skipIf(device_count<4, "Not enough devices") @parametrize("op",[dist.ReduceOp.SUM]) @parametrize("shape", [(1024,8192)]) @parametrize("dtype", [torch.int64]) def test_allreduce(self,op,shape,dtype): if self.rank >= self.world_size: return store = dist.FileStore(self.file_name, self.world_size) pg = self._create_process_group_vccl(self.world_size, store) if not torch.distributed.is_initialized(): return torch.cuda.set_device(self.rank) device = torch.device(device_type,self.rank) device_warmup(device) #讓這個(gè)函數(shù)循環(huán)執(zhí)行l(wèi)oops次,統(tǒng)計(jì)第一次執(zhí)行的耗時(shí)、后半段的平均時(shí)間、整個(gè)執(zhí)行過(guò)程總的GPU內(nèi)存使用量 @loop_decorator(loops=5,rank=self.rank) def run(input_dev): dist.all_reduce(input_dev, op=op) return input_dev #記錄整個(gè)測(cè)試的總耗時(shí),保存統(tǒng)計(jì)量,輸出摘要(self._testMethodName:測(cè)試方法,result:函數(shù)返回值,metrics:統(tǒng)計(jì)量) with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m: input_host=torch.ones(shape,dtype=dtype)*(100+self.rank) gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)] gt_=gt[0] for i in range(1,self.world_size): gt_=gt_+gt[i] input_dev=input_host.to(device) metrics=run(input_dev,golden=gt_) m.set_metrics(metrics) assert(metrics["mse"]==0) dist.destroy_process_group(pg) instantiate_parametrized_tests(TestCCL) if __name__ == "__main__": run_tests()
四.測(cè)試命令
# 運(yùn)行所有的測(cè)試 pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./ # 運(yùn)行某一個(gè)測(cè)試 python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"
五.測(cè)試報(bào)告
到此這篇關(guān)于pytorch單元測(cè)試的實(shí)現(xiàn)示例的文章就介紹到這了,更多相關(guān)pytorch單元測(cè)試內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python Pydantic進(jìn)行數(shù)據(jù)驗(yàn)證的方法詳解
在 Python 中,有許多庫(kù)可用于數(shù)據(jù)驗(yàn)證和處理,其中一個(gè)流行的選擇是 Pydantic,下面就跟隨小編一起學(xué)習(xí)一下Pydantic 的基本概念和用法吧2024-01-01PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg
這篇文章主要為大家介紹了PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05python matplotlib如何給圖中的點(diǎn)加標(biāo)簽
這篇文章主要介紹了python matplotlib給圖中的點(diǎn)加標(biāo)簽,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-11-11Python實(shí)現(xiàn)數(shù)據(jù)庫(kù)與Excel文件之間的數(shù)據(jù)自動(dòng)化導(dǎo)入與導(dǎo)出
數(shù)據(jù)庫(kù)和Excel文件是兩種常見(jiàn)且重要的數(shù)據(jù)存儲(chǔ)方式,本文將介紹如何使用Python有效地實(shí)現(xiàn)數(shù)據(jù)庫(kù)與Excel文件之間的數(shù)據(jù)自動(dòng)化導(dǎo)入與導(dǎo)出,以SQLite數(shù)據(jù)庫(kù)為例,需要的朋友可以參考下2024-06-06numpy矩陣乘法中的multiply,matmul和dot的使用
本文主要介紹了numpy矩陣乘法中的multiply,matmul和dot的使用,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02python利用faker庫(kù)批量生成測(cè)試數(shù)據(jù)
小編經(jīng)常需要批量測(cè)試一些數(shù)據(jù),有時(shí)候測(cè)試環(huán)境又暫時(shí)沒(méi)數(shù)據(jù),特意找了一下,發(fā)現(xiàn)有一個(gè)可批量生成數(shù)據(jù)的python庫(kù)—-faker,現(xiàn)在就介紹一下它的使用方法,如果你不想一行一行輸入代碼,小編提供了完整測(cè)試代碼,見(jiàn)文末代碼章節(jié)。2020-10-10用Python編寫(xiě)一個(gè)每天都在系統(tǒng)下新建一個(gè)文件夾的腳本
這篇文章主要介紹了用Python編寫(xiě)一個(gè)每天都在系統(tǒng)下新建一個(gè)文件夾的腳本,雖然這個(gè)實(shí)現(xiàn)聽(tīng)起來(lái)有點(diǎn)無(wú)聊...但卻是學(xué)習(xí)os和time模塊的一個(gè)小實(shí)踐,需要的朋友可以參考下2015-05-05在python3中pyqt5和mayavi不兼容問(wèn)題的解決方法
今天小編就為大家分享一篇在python3中pyqt5和mayavi不兼容問(wèn)題的解決方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01