Python利用DNN實(shí)現(xiàn)寶石識別
任務(wù)描述
本次實(shí)踐是一個(gè)多分類任務(wù),需要將照片中的寶石分別進(jìn)行識別,完成寶石的識別
實(shí)踐平臺:百度AI實(shí)訓(xùn)平臺-AI Studio、PaddlePaddle1.8.0 動(dòng)態(tài)圖
深度神經(jīng)網(wǎng)絡(luò)(DNN)
深度神經(jīng)網(wǎng)絡(luò)(Deep Neural Networks,簡稱DNN)是深度學(xué)習(xí)的基礎(chǔ),其結(jié)構(gòu)為input、hidden(可有多層)、output,每層均為全連接。
數(shù)據(jù)集介紹
- 數(shù)據(jù)集文件名為archive_train.zip,archive_test.zip。
- 該數(shù)據(jù)集包含25個(gè)類別不同寶石的圖像。
- 這些類別已經(jīng)分為訓(xùn)練和測試數(shù)據(jù)。
- 圖像大小不一,格式為.jpeg。
# 查看當(dāng)前掛載的數(shù)據(jù)集目錄, 該目錄下的變更重啟環(huán)境后會(huì)自動(dòng)還原 # View dataset directory. This directory will be recovered automatically after resetting environment. !ls /home/aistudio/data
data55032 dataset
#導(dǎo)入需要的包 import os import zipfile import random import json import cv2 import numpy as np from PIL import Image import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import Linear import matplotlib.pyplot as plt
1.數(shù)據(jù)準(zhǔn)備
''' 參數(shù)配置 ''' train_parameters = { "input_size": [3, 64, 64], #輸入圖片的shape "class_dim": -1, #分類數(shù) 'augment_path' : '/home/aistudio/augment', #數(shù)據(jù)增強(qiáng)圖片目錄 "src_path":"data/data55032/archive_train.zip", #原始數(shù)據(jù)集路徑 "target_path":"/home/aistudio/data/dataset", #要解壓的路徑 "train_list_path": "./train_data.txt", #train_data.txt路徑 "eval_list_path": "./val_data.txt", #eval_data.txt路徑 "label_dict":{}, #標(biāo)簽字典 "readme_path": "/home/aistudio/data/readme.json", #readme.json路徑 "num_epochs": 20, #訓(xùn)練輪數(shù) "train_batch_size": 64, #批次的大小 "learning_strategy": { #優(yōu)化函數(shù)相關(guān)的配置 "lr": 0.001 #超參數(shù)學(xué)習(xí)率 } }
def unzip_data(src_path,target_path): ''' 解壓原始數(shù)據(jù)集,將src_path路徑下的zip包解壓至data/dataset目錄下 ''' if(not os.path.isdir(target_path)): z = zipfile.ZipFile(src_path, 'r') z.extractall(path=target_path) z.close() else: print("文件已解壓")
def get_data_list(target_path,train_list_path,eval_list_path, augment_path): ''' 生成數(shù)據(jù)列表 ''' #存放所有類別的信息 class_detail = [] #獲取所有類別保存的文件夾名稱 data_list_path=target_path class_dirs = os.listdir(data_list_path) if '__MACOSX' in class_dirs: class_dirs.remove('__MACOSX') # #總的圖像數(shù)量 all_class_images = 0 # #存放類別標(biāo)簽 class_label=0 # #存放類別數(shù)目 class_dim = 0 # #存儲(chǔ)要寫進(jìn)eval.txt和train.txt中的內(nèi)容 trainer_list=[] eval_list=[] #讀取每個(gè)類別 for class_dir in class_dirs: if class_dir != ".DS_Store": class_dim += 1 #每個(gè)類別的信息 class_detail_list = {} eval_sum = 0 trainer_sum = 0 #統(tǒng)計(jì)每個(gè)類別有多少張圖片 class_sum = 0 #獲取類別路徑 path = os.path.join(data_list_path,class_dir) # print(path) # 獲取所有圖片 img_paths = os.listdir(path) for img_path in img_paths: # 遍歷文件夾下的每個(gè)圖片 if img_path =='.DS_Store': continue name_path = os.path.join(path,img_path) # 每張圖片的路徑 if class_sum % 15 == 0: # 每10張圖片取一個(gè)做驗(yàn)證數(shù)據(jù) eval_sum += 1 # eval_sum為測試數(shù)據(jù)的數(shù)目 eval_list.append(name_path + "\t%d" % class_label + "\n") else: trainer_sum += 1 trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum測試數(shù)據(jù)的數(shù)目 class_sum += 1 #每類圖片的數(shù)目 all_class_images += 1 #所有類圖片的數(shù)目 # ----------------------------------數(shù)據(jù)增強(qiáng)---------------------------------- aug_path = os.path.join(augment_path, class_dir) for img_path in os.listdir(aug_path): # 遍歷文件夾下的每個(gè)圖片 name_path = os.path.join(aug_path,img_path) # 每張圖片的路徑 trainer_sum += 1 trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum測試數(shù)據(jù)的數(shù)目 all_class_images += 1 #所有類圖片的數(shù)目 # ---------------------------------------------------------------------------- # 說明的json文件的class_detail數(shù)據(jù) class_detail_list['class_name'] = class_dir #類別名稱 class_detail_list['class_label'] = class_label #類別標(biāo)簽 class_detail_list['class_eval_images'] = eval_sum #該類數(shù)據(jù)的測試集數(shù)目 class_detail_list['class_trainer_images'] = trainer_sum #該類數(shù)據(jù)的訓(xùn)練集數(shù)目 class_detail.append(class_detail_list) #初始化標(biāo)簽列表 train_parameters['label_dict'][str(class_label)] = class_dir class_label += 1 #初始化分類數(shù) train_parameters['class_dim'] = class_dim print(train_parameters) #亂序 random.shuffle(eval_list) with open(eval_list_path, 'a') as f: for eval_image in eval_list: f.write(eval_image) #亂序 random.shuffle(trainer_list) with open(train_list_path, 'a') as f2: for train_image in trainer_list: f2.write(train_image) # 說明的json文件信息 readjson = {} readjson['all_class_name'] = data_list_path #文件父目錄 readjson['all_class_images'] = all_class_images readjson['class_detail'] = class_detail jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': ')) with open(train_parameters['readme_path'],'w') as f: f.write(jsons) print ('生成數(shù)據(jù)列表完成!')
def data_reader(file_list): ''' 自定義data_reader ''' def reader(): with open(file_list, 'r') as f: lines = [line.strip() for line in f] for line in lines: img_path, lab = line.strip().split('\t') img = Image.open(img_path) if img.mode != 'RGB': img = img.convert('RGB') img = img.resize((64, 64), Image.BILINEAR) img = np.array(img).astype('float32') img = img.transpose((2, 0, 1)) # HWC to CHW img = img/255 # 像素值歸一化 yield img, int(lab) return reader
!pip install Augmentor
Looking in indexes: https://mirror.baidu.com/pypi/simple/ Requirement already satisfied: Augmentor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.2.8) Requirement already satisfied: tqdm>=4.9.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (4.36.1) Requirement already satisfied: future>=0.16.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (0.18.0) Requirement already satisfied: numpy>=1.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (1.16.4) Requirement already satisfied: Pillow>=5.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (7.1.2)
''' 參數(shù)初始化 ''' src_path=train_parameters['src_path'] target_path=train_parameters['target_path'] train_list_path=train_parameters['train_list_path'] eval_list_path=train_parameters['eval_list_path'] batch_size=train_parameters['train_batch_size'] augment_path = train_parameters['augment_path'] ''' 解壓原始數(shù)據(jù)到指定路徑 ''' unzip_data(src_path,target_path)
文件已解壓
def proc_img(src): for root, dirs, files in os.walk(src): if '__MACOSX' in root:continue for file in files: src=os.path.join(root,file) img=Image.open(src) if img.mode != 'RGB': img = img.convert('RGB') img.save(src) if __name__=='__main__': proc_img(r"data/dataset")
import os, Augmentor import shutil, glob if not os.path.exists(augment_path): # 控制不重復(fù)增強(qiáng)數(shù)據(jù) for root, dirs, files in os.walk("data/dataset", topdown=False): for name in dirs: path_ = os.path.join(root, name) if '__MACOSX' in path_:continue print('數(shù)據(jù)增強(qiáng):',os.path.join(root, name)) print('image:',os.path.join(root, name)) p = Augmentor.Pipeline(os.path.join(root, name),output_directory='output') p.rotate(probability=0.6, max_left_rotation=2, max_right_rotation=2) p.zoom(probability=0.6, min_factor=0.9, max_factor=1.1) p.random_distortion(probability=0.4, grid_height=2, grid_width=2, magnitude=1) count = 1000 - len(glob.glob(pathname=path_+'/*.jpg')) p.sample(count, multi_threaded=False) p.process() print('將生成的圖片拷貝到正確的目錄') for root, dirs, files in os.walk("data/dataset", topdown=False): for name in files: path_ = os.path.join(root, name) if path_.rsplit('/',3)[2] == 'output': type_ = path_.rsplit('/',3)[1] dest_dir = os.path.join(augment_path ,type_) if not os.path.exists(dest_dir):os.makedirs(dest_dir) dest_path_ = os.path.join(augment_path ,type_, name) shutil.move(path_, dest_path_) print('刪除所有output目錄') for root, dirs, files in os.walk("data/dataset", topdown=False): for name in dirs: if name == 'output': path_ = os.path.join(root, name) shutil.rmtree(path_) print('完成數(shù)據(jù)增強(qiáng)')
Processing kunzite_20.jpg: 1%| | 11/968 [00:00<00:14, 65.61 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Kunzite image: data/dataset/Kunzite Initialised with 32 image(s) found. Output directory set to data/dataset/Kunzite/output. Processing kunzite_14.jpg: 2%|▏ | 24/968 [00:00<00:17, 54.43 Samples/s]Processing kunzite_15.jpg: 100%|██████████| 968/968 [00:15<00:00, 61.57 Samples/s] Processing <PIL.Image.Image image mode=RGB size=350x366 at 0x7F7060EB06D0>: 100%|██████████| 32/32 [00:00<00:00, 269.33 Samples/s] Processing almandine_5.jpg: 1%| | 6/969 [00:00<00:20, 45.91 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Almandine image: data/dataset/Almandine Initialised with 31 image(s) found. Output directory set to data/dataset/Almandine/output. Processing almandine_2.jpg: 1%|▏ | 14/969 [00:00<00:27, 34.12 Samples/s] Processing almandine_25.jpg: 100%|██████████| 969/969 [00:22<00:00, 42.25 Samples/s] Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E020C90>: 100%|██████████| 31/31 [00:00<00:00, 173.21 Samples/s] Processing emerald_2.jpg: 1%| | 10/964 [00:00<00:16, 58.72 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Emerald image: data/dataset/Emerald Initialised with 36 image(s) found. Output directory set to data/dataset/Emerald/output. Processing emerald_36.jpg: 2%|▏ | 20/964 [00:00<00:17, 54.08 Samples/s]Processing emerald_15.jpg: 100%|██████████| 964/964 [00:26<00:00, 36.49 Samples/s] Processing <PIL.Image.Image image mode=RGB size=460x460 at 0x7F705DED0110>: 100%|██████████| 36/36 [00:00<00:00, 149.48 Samples/s] Processing sapphire blue_9.jpg: 1%| | 10/966 [00:00<00:13, 68.91 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Sapphire Blue image: data/dataset/Sapphire Blue Initialised with 34 image(s) found. Output directory set to data/dataset/Sapphire Blue/output. Processing sapphire blue_16.jpg: 2%|▏ | 22/966 [00:00<00:16, 56.52 Samples/s]Processing sapphire blue_30.jpg: 100%|██████████| 966/966 [00:18<00:00, 53.08 Samples/s] Processing <PIL.Image.Image image mode=RGB size=450x450 at 0x7F706885B810>: 100%|██████████| 34/34 [00:00<00:00, 177.29 Samples/s] Processing malachite_2.jpg: 1%| | 10/972 [00:00<00:20, 47.64 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Malachite image: data/dataset/Malachite Initialised with 28 image(s) found. Output directory set to data/dataset/Malachite/output. Processing malachite_16.jpg: 2%|▏ | 18/972 [00:00<00:20, 47.14 Samples/s]Processing malachite_22.jpg: 100%|██████████| 972/972 [00:18<00:00, 52.32 Samples/s] Processing <PIL.Image.Image image mode=RGB size=376x262 at 0x7F7060E93D10>: 100%|██████████| 28/28 [00:00<00:00, 173.34 Samples/s] Processing alexandrite_0.jpg: 1%| | 6/966 [00:00<00:24, 39.61 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Alexandrite image: data/dataset/Alexandrite Initialised with 34 image(s) found. Output directory set to data/dataset/Alexandrite/output. Processing alexandrite_23.jpg: 2%|▏ | 18/966 [00:00<00:21, 44.52 Samples/s]Processing alexandrite_20.jpg: 100%|██████████| 966/966 [00:20<00:00, 48.06 Samples/s] Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705E025B10>: 100%|██████████| 34/34 [00:00<00:00, 129.49 Samples/s] Processing zircon_8.jpg: 1%| | 5/967 [00:00<00:33, 28.43 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Zircon image: data/dataset/Zircon Initialised with 33 image(s) found. Output directory set to data/dataset/Zircon/output. Processing zircon_23.jpg: 1%| | 6/967 [00:00<00:33, 28.43 Samples/s]Processing zircon_24.jpg: 100%|██████████| 967/967 [00:24<00:00, 38.88 Samples/s] Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705DEAC3D0>: 100%|██████████| 33/33 [00:00<00:00, 134.76 Samples/s] Processing onyx black_16.jpg: 1%| | 8/972 [00:00<00:13, 69.17 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Onyx Black image: data/dataset/Onyx Black Initialised with 28 image(s) found. Output directory set to data/dataset/Onyx Black/output. Processing onyx black_6.jpg: 2%|▏ | 18/972 [00:00<00:18, 51.84 Samples/s] Processing onyx black_2.jpg: 100%|██████████| 972/972 [00:18<00:00, 53.19 Samples/s] Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DEE1910>: 100%|██████████| 28/28 [00:00<00:00, 131.50 Samples/s] Processing rhodochrosite_29.jpg: 1%| | 10/971 [00:00<00:18, 53.20 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Rhodochrosite image: data/dataset/Rhodochrosite Initialised with 29 image(s) found. Output directory set to data/dataset/Rhodochrosite/output. Processing rhodochrosite_21.jpg: 2%|▏ | 21/971 [00:00<00:16, 58.01 Samples/s]Processing rhodochrosite_15.jpg: 100%|██████████| 971/971 [00:20<00:00, 46.42 Samples/s] Processing <PIL.Image.Image image mode=RGB size=373x356 at 0x7F705E011910>: 100%|██████████| 29/29 [00:00<00:00, 243.76 Samples/s] Processing diamond_16.jpg: 1%| | 5/969 [00:00<00:28, 34.31 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Diamond image: data/dataset/Diamond Initialised with 31 image(s) found. Output directory set to data/dataset/Diamond/output. Processing diamond_6.jpg: 1%| | 11/969 [00:00<00:26, 35.79 Samples/s] Processing diamond_20.jpg: 100%|██████████| 969/969 [00:24<00:00, 40.22 Samples/s] Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE6CCD0>: 100%|██████████| 31/31 [00:00<00:00, 150.83 Samples/s] Processing benitoite_29.jpg: 1%| | 7/969 [00:00<00:15, 63.04 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Benitoite image: data/dataset/Benitoite Initialised with 31 image(s) found. Output directory set to data/dataset/Benitoite/output. Processing benitoite_2.jpg: 2%|▏ | 24/969 [00:00<00:16, 57.15 Samples/s] Processing benitoite_12.jpg: 100%|██████████| 969/969 [00:17<00:00, 55.09 Samples/s] Processing <PIL.Image.Image image mode=RGB size=472x433 at 0x7F705DFE9290>: 100%|██████████| 31/31 [00:00<00:00, 178.70 Samples/s] Processing pearl_0.jpg: 1%| | 6/967 [00:00<00:25, 38.13 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Pearl image: data/dataset/Pearl Initialised with 33 image(s) found. Output directory set to data/dataset/Pearl/output. Processing pearl_32.jpg: 2%|▏ | 21/967 [00:00<00:20, 47.09 Samples/s]Processing pearl_12.jpg: 100%|██████████| 967/967 [00:17<00:00, 54.49 Samples/s] Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020A50>: 100%|██████████| 33/33 [00:00<00:00, 205.47 Samples/s] Processing beryl golden_39.jpg: 1%| | 11/964 [00:00<00:12, 79.36 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Beryl Golden image: data/dataset/Beryl Golden Initialised with 36 image(s) found. Output directory set to data/dataset/Beryl Golden/output. Processing beryl golden_29.jpg: 2%|▏ | 22/964 [00:00<00:14, 63.92 Samples/s]Processing beryl golden_2.jpg: 100%|██████████| 964/964 [00:16<00:00, 58.61 Samples/s] Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE6F910>: 100%|██████████| 36/36 [00:00<00:00, 273.71 Samples/s] Processing labradorite_16.jpg: 1%| | 9/960 [00:00<00:17, 55.49 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Labradorite image: data/dataset/Labradorite Initialised with 40 image(s) found. Output directory set to data/dataset/Labradorite/output. Processing labradorite_17.jpg: 2%|▏ | 20/960 [00:00<00:18, 52.03 Samples/s]Processing labradorite_11.jpg: 100%|██████████| 960/960 [00:21<00:00, 45.63 Samples/s] Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE70F10>: 100%|██████████| 40/40 [00:00<00:00, 117.40 Samples/s] Processing fluorite_23.jpg: 1%| | 11/968 [00:00<00:14, 65.24 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Fluorite image: data/dataset/Fluorite Initialised with 32 image(s) found. Output directory set to data/dataset/Fluorite/output. Processing fluorite_4.jpg: 1%|▏ | 14/968 [00:00<00:19, 49.03 Samples/s] Processing fluorite_4.jpg: 100%|██████████| 968/968 [00:21<00:00, 44.39 Samples/s] Processing <PIL.Image.Image image mode=RGB size=500x442 at 0x7F705DE87CD0>: 100%|██████████| 32/32 [00:00<00:00, 169.43 Samples/s] Processing iolite_2.jpg: 1%| | 7/968 [00:00<00:24, 39.15 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Iolite image: data/dataset/Iolite Initialised with 32 image(s) found. Output directory set to data/dataset/Iolite/output. Processing iolite_35.jpg: 2%|▏ | 23/968 [00:00<00:18, 51.39 Samples/s]Processing iolite_23.jpg: 100%|██████████| 968/968 [00:16<00:00, 57.22 Samples/s] Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE764D0>: 100%|██████████| 32/32 [00:00<00:00, 373.16 Samples/s] Processing quartz beer_24.jpg: 1%| | 12/965 [00:00<00:16, 57.87 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Quartz Beer image: data/dataset/Quartz Beer Initialised with 35 image(s) found. Output directory set to data/dataset/Quartz Beer/output. Processing quartz beer_28.jpg: 2%|▏ | 24/965 [00:00<00:14, 65.30 Samples/s]Processing quartz beer_30.jpg: 100%|██████████| 965/965 [00:16<00:00, 59.48 Samples/s] Processing <PIL.Image.Image image mode=RGB size=300x300 at 0x7F705DE82DD0>: 100%|██████████| 35/35 [00:00<00:00, 173.58 Samples/s] Processing garnet red_21.jpg: 1%| | 7/964 [00:00<00:34, 27.76 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Garnet Red image: data/dataset/Garnet Red Initialised with 36 image(s) found. Output directory set to data/dataset/Garnet Red/output. Processing garnet red_2.jpg: 2%|▏ | 17/964 [00:00<00:28, 33.50 Samples/s] Processing garnet red_2.jpg: 100%|██████████| 964/964 [00:20<00:00, 46.97 Samples/s] Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020090>: 100%|██████████| 36/36 [00:00<00:00, 197.00 Samples/s] Processing danburite_35.jpg: 1%| | 8/968 [00:00<00:16, 58.65 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Danburite image: data/dataset/Danburite Initialised with 32 image(s) found. Output directory set to data/dataset/Danburite/output. Processing danburite_32.jpg: 2%|▏ | 17/968 [00:00<00:19, 49.88 Samples/s]Processing danburite_23.jpg: 100%|██████████| 968/968 [00:19<00:00, 50.58 Samples/s] Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE78390>: 100%|██████████| 32/32 [00:00<00:00, 144.25 Samples/s] Processing cats eye_7.jpg: 1%| | 8/969 [00:00<00:24, 39.01 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Cats Eye image: data/dataset/Cats Eye Initialised with 31 image(s) found. Output directory set to data/dataset/Cats Eye/output. Processing cats eye_26.jpg: 2%|▏ | 15/969 [00:00<00:23, 41.33 Samples/s]Processing cats eye_33.jpg: 100%|██████████| 969/969 [00:25<00:00, 38.19 Samples/s] Processing <PIL.Image.Image image mode=RGB size=401x401 at 0x7F706AF09510>: 100%|██████████| 31/31 [00:00<00:00, 214.03 Samples/s] Processing hessonite_1.jpg: 0%| | 3/970 [00:00<00:33, 28.84 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Hessonite image: data/dataset/Hessonite Initialised with 30 image(s) found. Output directory set to data/dataset/Hessonite/output. Processing hessonite_19.jpg: 1%|▏ | 13/970 [00:00<00:31, 30.34 Samples/s]Processing hessonite_33.jpg: 100%|██████████| 970/970 [00:20<00:00, 47.73 Samples/s] Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020610>: 100%|██████████| 30/30 [00:00<00:00, 162.33 Samples/s] Processing carnelian_12.jpg: 1%| | 5/967 [00:00<00:28, 34.19 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Carnelian image: data/dataset/Carnelian Initialised with 33 image(s) found. Output directory set to data/dataset/Carnelian/output. Processing carnelian_32.jpg: 1%| | 12/967 [00:00<00:29, 32.65 Samples/s]Processing carnelian_31.jpg: 100%|██████████| 967/967 [00:24<00:00, 39.93 Samples/s] Processing <PIL.Image.Image image mode=RGB size=425x425 at 0x7F705DE840D0>: 100%|██████████| 33/33 [00:00<00:00, 147.85 Samples/s] Processing jade_26.jpg: 1%| | 9/972 [00:00<00:25, 38.24 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Jade image: data/dataset/Jade Initialised with 28 image(s) found. Output directory set to data/dataset/Jade/output. Processing jade_20.jpg: 2%|▏ | 22/972 [00:00<00:19, 47.93 Samples/s]Processing jade_18.jpg: 100%|██████████| 972/972 [00:18<00:00, 51.18 Samples/s] Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE8B050>: 100%|██████████| 28/28 [00:00<00:00, 331.02 Samples/s] Processing variscite_22.jpg: 1%| | 5/970 [00:00<00:25, 37.31 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Variscite image: data/dataset/Variscite Initialised with 30 image(s) found. Output directory set to data/dataset/Variscite/output. Processing variscite_10.jpg: 1%|▏ | 13/970 [00:00<00:26, 35.70 Samples/s]Processing variscite_31.jpg: 100%|██████████| 970/970 [00:21<00:00, 45.58 Samples/s] Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE7BE50>: 100%|██████████| 30/30 [00:00<00:00, 157.22 Samples/s] Processing tanzanite_2.jpg: 1%| | 5/964 [00:00<00:31, 30.52 Samples/s] 數(shù)據(jù)增強(qiáng): data/dataset/Tanzanite image: data/dataset/Tanzanite Initialised with 36 image(s) found. Output directory set to data/dataset/Tanzanite/output. Processing tanzanite_15.jpg: 2%|▏ | 15/964 [00:00<00:25, 36.60 Samples/s]Processing tanzanite_37.jpg: 100%|██████████| 964/964 [00:25<00:00, 38.41 Samples/s] Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E00E4D0>: 100%|██████████| 36/36 [00:00<00:00, 144.18 Samples/s] 將生成的圖片拷貝到正確的目錄 刪除所有output目錄 完成數(shù)據(jù)增強(qiáng)
#每次生成數(shù)據(jù)列表前,首先清空train.txt和eval.txt with open(train_list_path, 'w') as f: f.seek(0) f.truncate() with open(eval_list_path, 'w') as f: f.seek(0) f.truncate() #生成數(shù)據(jù)列表 get_data_list(target_path,train_list_path,eval_list_path,augment_path) ''' 構(gòu)造數(shù)據(jù)提供器 ''' train_reader = paddle.batch(data_reader(train_list_path), batch_size=batch_size, drop_last=True) eval_reader = paddle.batch(data_reader(eval_list_path), batch_size=batch_size, drop_last=True)
{'input_size': [3, 64, 64], 'class_dim': 25, 'augment_path': '/home/aistudio/augment', 'src_path': 'data/data55032/archive_train.zip', 'target_path': '/home/aistudio/data/dataset', 'train_list_path': './train_data.txt', 'eval_list_path': './val_data.txt', 'label_dict': {'0': 'Kunzite', '1': 'Almandine', '2': 'Emerald', '3': 'Sapphire Blue', '4': 'Malachite', '5': 'Alexandrite', '6': 'Zircon', '7': 'Onyx Black', '8': 'Rhodochrosite', '9': 'Diamond', '10': 'Benitoite', '11': 'Pearl', '12': 'Beryl Golden', '13': 'Labradorite', '14': 'Fluorite', '15': 'Iolite', '16': 'Quartz Beer', '17': 'Garnet Red', '18': 'Danburite', '19': 'Cats Eye', '20': 'Hessonite', '21': 'Carnelian', '22': 'Jade', '23': 'Variscite', '24': 'Tanzanite'}, 'readme_path': '/home/aistudio/data/readme.json', 'num_epochs': 20, 'train_batch_size': 64, 'learning_strategy': {'lr': 0.001}} 生成數(shù)據(jù)列表完成!
Batch=0 Batchs=[] all_train_accs=[] def draw_train_acc(Batchs, train_accs): title="training accs" plt.title(title, fontsize=24) plt.xlabel("batch", fontsize=14) plt.ylabel("acc", fontsize=14) plt.plot(Batchs, train_accs, color='green', label='training accs') plt.legend() plt.grid() plt.show() all_train_loss=[] def draw_train_loss(Batchs, train_loss): title="training loss" plt.title(title, fontsize=24) plt.xlabel("batch", fontsize=14) plt.ylabel("loss", fontsize=14) plt.plot(Batchs, train_loss, color='red', label='training loss') plt.legend() plt.grid() plt.show()
2.定義模型
###在以下cell中完成DNN網(wǎng)絡(luò)的定義###
#定義網(wǎng)絡(luò) class MyDNN(fluid.dygraph.Layer): ''' 卷積神經(jīng)網(wǎng)絡(luò) ''' def __init__(self): super(MyDNN,self).__init__() self.hidden1=fluid.dygraph.Linear(3*64*64,1000, act='relu') self.hidden2=fluid.dygraph.Linear(1000,500, act='relu') self.hidden3=fluid.dygraph.Linear(500,100, act='relu') self.out = fluid.dygraph.Linear(input_dim=100, output_dim=25, act='softmax') def forward(self,input): x = fluid.layers.reshape(input,shape=[-1,3*64*64]) x = self.hidden1(x) x = self.hidden2(x) x = self.hidden3(x) x = self.out(x) return x
3.訓(xùn)練模型
with fluid.dygraph.guard(place = fluid.CUDAPlace(0)): print(train_parameters['class_dim']) print(train_parameters['label_dict']) model=MyDNN() #模型實(shí)例化 model.train() #訓(xùn)練模式 opt=fluid.optimizer.SGDOptimizer(learning_rate=train_parameters['learning_strategy']['lr'], parameter_list=model.parameters())#優(yōu)化器選用SGD隨機(jī)梯度下降,學(xué)習(xí)率為0.001. epochs_num=train_parameters['num_epochs'] #迭代次數(shù) for pass_num in range(epochs_num): for batch_id,data in enumerate(train_reader()): images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64) labels = np.array([x[1] for x in data]).astype('int64') labels = labels[:, np.newaxis] image=fluid.dygraph.to_variable(images) label=fluid.dygraph.to_variable(labels) predict=model(image) #數(shù)據(jù)傳入model loss=fluid.layers.cross_entropy(predict,label) avg_loss=fluid.layers.mean(loss)#獲取loss值 acc=fluid.layers.accuracy(predict,label)#計(jì)算精度 if batch_id!=0 and batch_id%5==0: Batch = Batch+5 Batchs.append(Batch) all_train_loss.append(avg_loss.numpy()[0]) all_train_accs.append(acc.numpy()[0]) print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num,batch_id,avg_loss.numpy(),acc.numpy())) avg_loss.backward() opt.minimize(avg_loss) #優(yōu)化器對象的minimize方法對參數(shù)進(jìn)行更新 model.clear_gradients() #model.clear_gradients()來重置梯度 fluid.save_dygraph(model.state_dict(),'MyDNN')#保存模型 draw_train_acc(Batchs,all_train_accs) draw_train_loss(Batchs,all_train_loss)
train_pass:19,batch_id:400,train_loss:[0.24890603],train_acc:[0.96875]
4.模型評估
#模型評估 with fluid.dygraph.guard(): accs = [] model_dict, _ = fluid.load_dygraph('MyDNN') model = MyDNN() model.load_dict(model_dict) #加載模型參數(shù) model.eval() #訓(xùn)練模式 for batch_id,data in enumerate(eval_reader()):#測試集 images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64) labels = np.array([x[1] for x in data]).astype('int64') labels = labels[:, np.newaxis] image=fluid.dygraph.to_variable(images) label=fluid.dygraph.to_variable(labels) predict=model(image) acc=fluid.layers.accuracy(predict,label) accs.append(acc.numpy()[0]) avg_acc = np.mean(accs) print(avg_acc)
0.96875
5.模型預(yù)測
import os import zipfile def unzip_infer_data(src_path,target_path): ''' 解壓預(yù)測數(shù)據(jù)集 ''' if(not os.path.isdir(target_path)): z = zipfile.ZipFile(src_path, 'r') z.extractall(path=target_path) z.close() def load_image(img_path): ''' 預(yù)測圖片預(yù)處理 ''' img = Image.open(img_path) if img.mode != 'RGB': img = img.convert('RGB') img = img.resize((64, 64), Image.BILINEAR) img = np.array(img).astype('float32') img = img.transpose((2, 0, 1)) # HWC to CHW img = img/255 # 像素值歸一化 return img infer_src_path = '/home/aistudio/data/data55032/archive_test.zip' infer_dst_path = '/home/aistudio/data/archive_test' unzip_infer_data(infer_src_path,infer_dst_path)
label_dic = train_parameters['label_dict'] ''' 模型預(yù)測 ''' with fluid.dygraph.guard(): model_dict, _ = fluid.load_dygraph('MyDNN') model = MyDNN() model.load_dict(model_dict) #加載模型參數(shù) model.eval() #訓(xùn)練模式 #展示預(yù)測圖片 infer_path='data/archive_test/alexandrite_3.jpg' img = Image.open(infer_path) plt.imshow(img) #根據(jù)數(shù)組繪制圖像 plt.show() #顯示圖像 #對預(yù)測圖片進(jìn)行預(yù)處理 infer_imgs = [] infer_imgs.append(load_image(infer_path)) infer_imgs = np.array(infer_imgs) for i in range(len(infer_imgs)): data = infer_imgs[i] dy_x_data = np.array(data).astype('float32') dy_x_data=dy_x_data[np.newaxis,:, : ,:] img = fluid.dygraph.to_variable(dy_x_data) out = model(img) lab = np.argmax(out.numpy()) #argmax():返回最大數(shù)的索引 print("第{}個(gè)樣本,被預(yù)測為:{},真實(shí)標(biāo)簽為:{}".format(i+1,label_dic[str(lab)],infer_path.split('/')[-1].split("_")[0])) print("結(jié)束")
第1個(gè)樣本,被預(yù)測為:Malachite,真實(shí)標(biāo)簽為:alexandrite 結(jié)束
以上就是Python利用DNN實(shí)現(xiàn)寶石識別的詳細(xì)內(nèi)容,更多關(guān)于Python DNN寶石識別的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python字符編碼轉(zhuǎn)碼之GBK,UTF8互轉(zhuǎn)
說到python的編碼,一句話總結(jié),說多了都是淚啊,這個(gè)在以后的python的開發(fā)中絕對是一件令人頭疼的事情。所以有必要輸入理解2020-02-02Python私有屬性私有方法應(yīng)用實(shí)例解析
這篇文章主要介紹了Python私有屬性私有方法應(yīng)用場景解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-09-09Python爬蟲實(shí)現(xiàn)抓取京東店鋪信息及下載圖片功能示例
這篇文章主要介紹了Python爬蟲實(shí)現(xiàn)抓取京東店鋪信息及下載圖片功能,涉及Python頁面請求、響應(yīng)、解析等相關(guān)操作技巧,需要的朋友可以參考下2018-08-08用python寫一個(gè)定時(shí)提醒程序的實(shí)現(xiàn)代碼
今天小編就為大家分享一篇用python寫一個(gè)定時(shí)提醒程序的實(shí)現(xiàn)代碼,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07python修改linux中文件(文件夾)的權(quán)限屬性操作
這篇文章主要介紹了python修改linux中文件(文件夾)的權(quán)限屬性操作,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03在tensorflow實(shí)現(xiàn)直接讀取網(wǎng)絡(luò)的參數(shù)(weight and bias)的值
這篇文章主要介紹了在tensorflow實(shí)現(xiàn)直接讀取網(wǎng)絡(luò)的參數(shù)(weight and bias)的值,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06Python實(shí)現(xiàn)中文數(shù)字轉(zhuǎn)換為阿拉伯?dāng)?shù)字的方法示例
這篇文章主要介紹了Python實(shí)現(xiàn)中文數(shù)字轉(zhuǎn)換為阿拉伯?dāng)?shù)字的方法,涉及Python字符串遍歷、轉(zhuǎn)換相關(guān)操作技巧,需要的朋友可以參考下2017-05-05