欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch版本PSEnet訓(xùn)練并部署方式

 更新時(shí)間:2023年05月10日 08:36:39   作者:__JDM__  
這篇文章主要介紹了pytorch版本PSEnet訓(xùn)練并部署方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

概述

源碼地址

torch版本

訓(xùn)練環(huán)境沒(méi)有按照torch的readme一樣的環(huán)境,自己部署環(huán)境為:

torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh

制作數(shù)據(jù)集

1、訓(xùn)練的數(shù)據(jù)集

采用的是rolabelimg進(jìn)行標(biāo)注,需要轉(zhuǎn)換為ic2015格式的數(shù)據(jù)。

轉(zhuǎn)換代碼:

import os
from lxml import etree
import numpy as np
import math
src_xml = "ANN"
txt_dir = "gt"
xml_listdir = os.listdir(src_xml)
xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
def xml_out(xml_path):
    gt_lines = []
    ET = etree.parse(xml_path)
    objs = ET.findall("object")
    for ix,obj in enumerate(objs):
        name = obj.find("name").text
        robox = obj.find("robndbox")
        cx = int(float(robox.find("cx").text))
        cy = int(float(robox.find("cy").text))
        w = int(float(robox.find("w").text))
        h = int(float(robox.find("h").text))
        angle = float(robox.find("angle").text)
        # angle = math.degrees(angle1)
        wx1 = cx - int(0.5 * w)
        wy1 = cy - int(0.5 * h)
        wx2 = cx + int(0.5 * w)
        wy2 = cy - int(0.5 * h)
        wx3 = cx - int(0.5 * w)
        wy3 = cy + int(0.5 * h)
        wx4 = cx + int(0.5 * w)
        wy4 = cy + int(0.5 * h)
        x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
        y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
        x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
        y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
        x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
        y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
        x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
        y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
        lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
                str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
        gt_lines.append(lines)
        return gt_lines
def main():
    count = 0
    for xml_dir in xml_listdir:
        gt_lines = xml_out(os.path.join(src_xml,xml_dir))
        txt_path = "gt_" + xml_dir[:-4] + ".txt"
        with open(os.path.join(txt_dir,txt_path),"a+") as fd:
            fd.writelines(gt_lines)
        count +=1
        print("Write file %s" % str(count))
if __name__ == "__main__":
    main()

rolabelimg標(biāo)注后的xml文件和labelimg的xml有些區(qū)別,根據(jù)不同的標(biāo)注軟件,轉(zhuǎn)換代碼略有區(qū)別。

轉(zhuǎn)換后的格式為x1,y1,x2,y2,x3,y3,x4,y4,"classes",此處classes為檢測(cè)的類別,如果是模糊訓(xùn)練的話,classes為“###”。

但是重點(diǎn),這個(gè)源代碼對(duì)于模糊訓(xùn)練,loss一直為1。

2、將數(shù)據(jù)集分成訓(xùn)練集和測(cè)試集

數(shù)據(jù)集

這里可以按照源碼路徑存放數(shù)據(jù)集,也可以修改源碼存放位置。

PSENet-python3\dataset\psenet\psenet_ic15.py

修改下述代碼為自己文件夾

3、訓(xùn)練

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py

其中根據(jù)源碼中的readme,

可以根據(jù)自己的需要,自行選擇配置文件。

4、部署測(cè)試

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
    """Do image preprocessing before prediction on any data.
    :param image:       original image
    :param target_size: target image size
    :return:
                        preprocessed image
    """
    #assert os.path.exists(img), 'file is not exists'
    #img = cv2.imread(img)
    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # h, w = image.shape[:2]
    # scale = long_size / max(h, w)
    img = cv2.resize(img, target_size)
    # 將圖片由(w,h)變?yōu)?1,img_channel,h,w)
    tensor = transforms.ToTensor()(img)
    tensor = tensor.unsqueeze_(0)
    tensor = tensor.to(torch.device("cuda:0"))
    return tensor
def report_speed(outputs, speed_meters):
    total_time = 0
    for key in outputs:
        if 'time' in key:
            total_time += outputs[key]
            speed_meters[key].update(outputs[key])
            print('%s: %.4f' % (key, speed_meters[key].avg))
    speed_meters['total_time'].update(total_time)
    print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def load_model(cfg):
    model = build_model(cfg.model)
    model = model.cuda()
    model.eval()
    checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
    if checkpoint is not None:
        if os.path.isfile(checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
            sys.stdout.flush()
            checkpoint = torch.load(checkpoint)
            d = dict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
        else:
            print("No checkpoint found at")
            raise
        # fuse conv and bn
    model = fuse_module(model)
    return model
if __name__ == '__main__':
    src_dir = "testimg/"
    save_dir = "test_save/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
    for d in [cfg, cfg.data.test]:
        d.update(dict(
            report_speed=False
        ))
    if cfg.report_speed:
        speed_meters = dict(
            backbone_time=AverageMeter(500),
            neck_time=AverageMeter(500),
            det_head_time=AverageMeter(500),
            det_pse_time=AverageMeter(500),
            rec_time=AverageMeter(500),
            total_time=AverageMeter(500)
        )
    model = load_model(cfg)
    model.eval()
    count = 0
    for img_name in os.listdir(src_dir):
        img = cv2.imread(src_dir + img_name)
        tensor = prepare_image(img, target_size=(1376, 1024))
        data = dict()
        img_metas = dict()
        data['imgs'] = tensor
        img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
        img_metas['img_size'] = torch.tensor([[1376, 1024]])
        data['img_metas'] = img_metas
        data.update(dict(
            cfg=cfg
        ))
        with torch.no_grad():
            outputs = model(**data)
        if cfg.report_speed:
            report_speed(outputs, speed_meters)
        for bboxes in outputs['bboxes']:
            x1 = bboxes[0]
            y1 = bboxes[1]
            x2 = bboxes[4]
            y2 = bboxes[5]
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
        count = count + 1
        cv2.imwrite(save_dir + img_name, img)
        print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter

訓(xùn)練代碼里含有。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。 

相關(guān)文章

最新評(píng)論