pytorch版本PSEnet訓(xùn)練并部署方式
概述
torch版本
訓(xùn)練環(huán)境沒(méi)有按照torch的readme一樣的環(huán)境,自己部署環(huán)境為:
torch==1.9.1 torchvision==0.10.1 python==3.8.0 cuda==10.2 mmcv==0.2.12 editdistance==0.5.3 Polygon3==3.0.9.1 pyclipper==1.3.0 opencv-python==3.4.2.17 Cython==0.29.24
./compile.sh
制作數(shù)據(jù)集
1、訓(xùn)練的數(shù)據(jù)集
采用的是rolabelimg進(jìn)行標(biāo)注,需要轉(zhuǎn)換為ic2015格式的數(shù)據(jù)。
轉(zhuǎn)換代碼:
import os from lxml import etree import numpy as np import math src_xml = "ANN" txt_dir = "gt" xml_listdir = os.listdir(src_xml) xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir] def xml_out(xml_path): gt_lines = [] ET = etree.parse(xml_path) objs = ET.findall("object") for ix,obj in enumerate(objs): name = obj.find("name").text robox = obj.find("robndbox") cx = int(float(robox.find("cx").text)) cy = int(float(robox.find("cy").text)) w = int(float(robox.find("w").text)) h = int(float(robox.find("h").text)) angle = float(robox.find("angle").text) # angle = math.degrees(angle1) wx1 = cx - int(0.5 * w) wy1 = cy - int(0.5 * h) wx2 = cx + int(0.5 * w) wy2 = cy - int(0.5 * h) wx3 = cx - int(0.5 * w) wy3 = cy + int(0.5 * h) wx4 = cx + int(0.5 * w) wy4 = cy + int(0.5 * h) x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx) y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy) x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx) y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy) x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx) y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy) x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx) y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy) lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\ str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n" gt_lines.append(lines) return gt_lines def main(): count = 0 for xml_dir in xml_listdir: gt_lines = xml_out(os.path.join(src_xml,xml_dir)) txt_path = "gt_" + xml_dir[:-4] + ".txt" with open(os.path.join(txt_dir,txt_path),"a+") as fd: fd.writelines(gt_lines) count +=1 print("Write file %s" % str(count)) if __name__ == "__main__": main()
rolabelimg標(biāo)注后的xml文件和labelimg的xml有些區(qū)別,根據(jù)不同的標(biāo)注軟件,轉(zhuǎn)換代碼略有區(qū)別。
轉(zhuǎn)換后的格式為x1,y1,x2,y2,x3,y3,x4,y4,"classes"
,此處classes為檢測(cè)的類別,如果是模糊訓(xùn)練的話,classes為“###”。
但是重點(diǎn),這個(gè)源代碼對(duì)于模糊訓(xùn)練,loss一直為1。
2、將數(shù)據(jù)集分成訓(xùn)練集和測(cè)試集
這里可以按照源碼路徑存放數(shù)據(jù)集,也可以修改源碼存放位置。
PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代碼為自己文件夾
3、訓(xùn)練
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py
其中根據(jù)源碼中的readme,
可以根據(jù)自己的需要,自行選擇配置文件。
4、部署測(cè)試
import torch import numpy as np import argparse import os import os.path as osp import sys import time import json from mmcv import Config import cv2 from torchvision import transforms from dataset import build_data_loader from models import build_model from models.utils import fuse_module from utils import ResultFormat, AverageMeter def prepare_image(image, target_size): """Do image preprocessing before prediction on any data. :param image: original image :param target_size: target image size :return: preprocessed image """ #assert os.path.exists(img), 'file is not exists' #img = cv2.imread(img) img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # h, w = image.shape[:2] # scale = long_size / max(h, w) img = cv2.resize(img, target_size) # 將圖片由(w,h)變?yōu)?1,img_channel,h,w) tensor = transforms.ToTensor()(img) tensor = tensor.unsqueeze_(0) tensor = tensor.to(torch.device("cuda:0")) return tensor def report_speed(outputs, speed_meters): total_time = 0 for key in outputs: if 'time' in key: total_time += outputs[key] speed_meters[key].update(outputs[key]) print('%s: %.4f' % (key, speed_meters[key].avg)) speed_meters['total_time'].update(total_time) print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg)) def load_model(cfg): model = build_model(cfg.model) model = model.cuda() model.eval() checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar" if checkpoint is not None: if os.path.isfile(checkpoint): print("Loading model and optimizer from checkpoint '{}'".format(checkpoint)) sys.stdout.flush() checkpoint = torch.load(checkpoint) d = dict() for key, value in checkpoint['state_dict'].items(): tmp = key[7:] d[tmp] = value model.load_state_dict(d) else: print("No checkpoint found at") raise # fuse conv and bn model = fuse_module(model) return model if __name__ == '__main__': src_dir = "testimg/" save_dir = "test_save/" if not os.path.exists(save_dir): os.makedirs(save_dir) cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py") for d in [cfg, cfg.data.test]: d.update(dict( report_speed=False )) if cfg.report_speed: speed_meters = dict( backbone_time=AverageMeter(500), neck_time=AverageMeter(500), det_head_time=AverageMeter(500), det_pse_time=AverageMeter(500), rec_time=AverageMeter(500), total_time=AverageMeter(500) ) model = load_model(cfg) model.eval() count = 0 for img_name in os.listdir(src_dir): img = cv2.imread(src_dir + img_name) tensor = prepare_image(img, target_size=(1376, 1024)) data = dict() img_metas = dict() data['imgs'] = tensor img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]]) img_metas['img_size'] = torch.tensor([[1376, 1024]]) data['img_metas'] = img_metas data.update(dict( cfg=cfg )) with torch.no_grad(): outputs = model(**data) if cfg.report_speed: report_speed(outputs, speed_meters) for bboxes in outputs['bboxes']: x1 = bboxes[0] y1 = bboxes[1] x2 = bboxes[4] y2 = bboxes[5] cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3) count = count + 1 cv2.imwrite(save_dir + img_name, img) print("img test:", count)
from dataset import build_data_loader from models import build_model from models.utils import fuse_module from utils import ResultFormat, AverageMeter
訓(xùn)練代碼里含有。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)將視頻按照時(shí)間維度剪切
這篇文章主要為大家詳細(xì)介紹了如何利用Python實(shí)現(xiàn)將視頻按照時(shí)間維度進(jìn)行剪切,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起了解一下2022-12-12Python利用pynimate實(shí)現(xiàn)制作動(dòng)態(tài)排序圖
這篇文章主要為大家詳細(xì)介紹了Python如何利用pynimate實(shí)現(xiàn)制作動(dòng)態(tài)排序圖,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-02-02Python內(nèi)置random模塊生成隨機(jī)數(shù)的方法
這篇文章主要介紹了Python內(nèi)置random模塊生成隨機(jī)數(shù)的方法,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-05-05Python實(shí)現(xiàn)滑塊拼圖驗(yàn)證碼詳解
驗(yàn)證碼作為一種自然人的機(jī)器人的判別工具,被廣泛的用于各種防止程序做自動(dòng)化的場(chǎng)景中。傳統(tǒng)的字符型驗(yàn)證安全性已經(jīng)名存實(shí)亡的情況下,各種新型的驗(yàn)證碼如雨后春筍般涌現(xiàn),今天給大家分享一篇Python實(shí)現(xiàn)滑塊驗(yàn)證碼2022-05-05Python中搜索和替換文件中的文本的實(shí)現(xiàn)(四種)
本文主要介紹了Python中搜索和替換文件中的文本的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-10-10pytorch環(huán)境配置及安裝圖文詳解(包括anaconda的安裝)
這篇文章主要介紹了pytorch環(huán)境配置及安裝圖文詳解(包括anaconda的安裝),本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-04-04python登錄WeChat 實(shí)現(xiàn)自動(dòng)回復(fù)實(shí)例詳解
在本篇內(nèi)容里小編給大家整理的是關(guān)于python登錄WeChat 實(shí)現(xiàn)自動(dòng)回復(fù)的相關(guān)實(shí)例內(nèi)容以及知識(shí)點(diǎn)總結(jié),有興趣的朋友們參考下。2019-05-05