pytorch版本PSEnet訓練并部署方式
概述
torch版本
訓練環(huán)境沒有按照torch的readme一樣的環(huán)境,自己部署環(huán)境為:
torch==1.9.1 torchvision==0.10.1 python==3.8.0 cuda==10.2 mmcv==0.2.12 editdistance==0.5.3 Polygon3==3.0.9.1 pyclipper==1.3.0 opencv-python==3.4.2.17 Cython==0.29.24
./compile.sh
制作數(shù)據(jù)集
1、訓練的數(shù)據(jù)集
采用的是rolabelimg進行標注,需要轉換為ic2015格式的數(shù)據(jù)。
轉換代碼:
import os from lxml import etree import numpy as np import math src_xml = "ANN" txt_dir = "gt" xml_listdir = os.listdir(src_xml) xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir] def xml_out(xml_path): gt_lines = [] ET = etree.parse(xml_path) objs = ET.findall("object") for ix,obj in enumerate(objs): name = obj.find("name").text robox = obj.find("robndbox") cx = int(float(robox.find("cx").text)) cy = int(float(robox.find("cy").text)) w = int(float(robox.find("w").text)) h = int(float(robox.find("h").text)) angle = float(robox.find("angle").text) # angle = math.degrees(angle1) wx1 = cx - int(0.5 * w) wy1 = cy - int(0.5 * h) wx2 = cx + int(0.5 * w) wy2 = cy - int(0.5 * h) wx3 = cx - int(0.5 * w) wy3 = cy + int(0.5 * h) wx4 = cx + int(0.5 * w) wy4 = cy + int(0.5 * h) x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx) y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy) x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx) y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy) x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx) y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy) x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx) y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy) lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\ str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n" gt_lines.append(lines) return gt_lines def main(): count = 0 for xml_dir in xml_listdir: gt_lines = xml_out(os.path.join(src_xml,xml_dir)) txt_path = "gt_" + xml_dir[:-4] + ".txt" with open(os.path.join(txt_dir,txt_path),"a+") as fd: fd.writelines(gt_lines) count +=1 print("Write file %s" % str(count)) if __name__ == "__main__": main()
rolabelimg標注后的xml文件和labelimg的xml有些區(qū)別,根據(jù)不同的標注軟件,轉換代碼略有區(qū)別。
轉換后的格式為x1,y1,x2,y2,x3,y3,x4,y4,"classes"
,此處classes為檢測的類別,如果是模糊訓練的話,classes為“###”。
但是重點,這個源代碼對于模糊訓練,loss一直為1。
2、將數(shù)據(jù)集分成訓練集和測試集
這里可以按照源碼路徑存放數(shù)據(jù)集,也可以修改源碼存放位置。
PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代碼為自己文件夾
3、訓練
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py
其中根據(jù)源碼中的readme,
可以根據(jù)自己的需要,自行選擇配置文件。
4、部署測試
import torch import numpy as np import argparse import os import os.path as osp import sys import time import json from mmcv import Config import cv2 from torchvision import transforms from dataset import build_data_loader from models import build_model from models.utils import fuse_module from utils import ResultFormat, AverageMeter def prepare_image(image, target_size): """Do image preprocessing before prediction on any data. :param image: original image :param target_size: target image size :return: preprocessed image """ #assert os.path.exists(img), 'file is not exists' #img = cv2.imread(img) img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # h, w = image.shape[:2] # scale = long_size / max(h, w) img = cv2.resize(img, target_size) # 將圖片由(w,h)變?yōu)?1,img_channel,h,w) tensor = transforms.ToTensor()(img) tensor = tensor.unsqueeze_(0) tensor = tensor.to(torch.device("cuda:0")) return tensor def report_speed(outputs, speed_meters): total_time = 0 for key in outputs: if 'time' in key: total_time += outputs[key] speed_meters[key].update(outputs[key]) print('%s: %.4f' % (key, speed_meters[key].avg)) speed_meters['total_time'].update(total_time) print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg)) def load_model(cfg): model = build_model(cfg.model) model = model.cuda() model.eval() checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar" if checkpoint is not None: if os.path.isfile(checkpoint): print("Loading model and optimizer from checkpoint '{}'".format(checkpoint)) sys.stdout.flush() checkpoint = torch.load(checkpoint) d = dict() for key, value in checkpoint['state_dict'].items(): tmp = key[7:] d[tmp] = value model.load_state_dict(d) else: print("No checkpoint found at") raise # fuse conv and bn model = fuse_module(model) return model if __name__ == '__main__': src_dir = "testimg/" save_dir = "test_save/" if not os.path.exists(save_dir): os.makedirs(save_dir) cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py") for d in [cfg, cfg.data.test]: d.update(dict( report_speed=False )) if cfg.report_speed: speed_meters = dict( backbone_time=AverageMeter(500), neck_time=AverageMeter(500), det_head_time=AverageMeter(500), det_pse_time=AverageMeter(500), rec_time=AverageMeter(500), total_time=AverageMeter(500) ) model = load_model(cfg) model.eval() count = 0 for img_name in os.listdir(src_dir): img = cv2.imread(src_dir + img_name) tensor = prepare_image(img, target_size=(1376, 1024)) data = dict() img_metas = dict() data['imgs'] = tensor img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]]) img_metas['img_size'] = torch.tensor([[1376, 1024]]) data['img_metas'] = img_metas data.update(dict( cfg=cfg )) with torch.no_grad(): outputs = model(**data) if cfg.report_speed: report_speed(outputs, speed_meters) for bboxes in outputs['bboxes']: x1 = bboxes[0] y1 = bboxes[1] x2 = bboxes[4] y2 = bboxes[5] cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3) count = count + 1 cv2.imwrite(save_dir + img_name, img) print("img test:", count)
from dataset import build_data_loader from models import build_model from models.utils import fuse_module from utils import ResultFormat, AverageMeter
訓練代碼里含有。
總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python利用pynimate實現(xiàn)制作動態(tài)排序圖
這篇文章主要為大家詳細介紹了Python如何利用pynimate實現(xiàn)制作動態(tài)排序圖,文中的示例代碼講解詳細,感興趣的小伙伴可以跟隨小編一起學習一下2023-02-02Python內(nèi)置random模塊生成隨機數(shù)的方法
這篇文章主要介紹了Python內(nèi)置random模塊生成隨機數(shù)的方法,本文給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下2019-05-05Python中搜索和替換文件中的文本的實現(xiàn)(四種)
本文主要介紹了Python中搜索和替換文件中的文本的實現(xiàn),文中通過示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下2021-10-10pytorch環(huán)境配置及安裝圖文詳解(包括anaconda的安裝)
這篇文章主要介紹了pytorch環(huán)境配置及安裝圖文詳解(包括anaconda的安裝),本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-04-04python登錄WeChat 實現(xiàn)自動回復實例詳解
在本篇內(nèi)容里小編給大家整理的是關于python登錄WeChat 實現(xiàn)自動回復的相關實例內(nèi)容以及知識點總結,有興趣的朋友們參考下。2019-05-05