Neural Network/CNN

CIFAR100 - CNN으로 학습하기 (1) 데이터 로드 및 증강

힐안 2024. 1. 28. 01:55

 

1. 데이터 준비 - CIFAR100

https://www.cs.toronto.edu/~kriz/cifar.html

 

CIFAR-10 and CIFAR-100 datasets

< Back to Alex Krizhevsky's home page The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. The CIFAR-10 dataset The CIFAR-10 dataset consists of 60000

www.cs.toronto.edu

 

데이터셋을 직접 다운 받거나,

 

torchvision 라이브러리에서 CIFAR100 데이터셋 다운 받을 수 있습니다.

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as tt

import matplotlib.pyplot as plt
import numpy as np

train_set1 = torchivsion.datasets.CIFAR100("./", train=True, download=True, transform = tt.Compose([
	tt.ToTensor(),]))

train_loader1 = torch.utils.data.DataLoader(train_set1, batch_size = 16)

 

* 이미지 크기 : 3*32*32 , RGB 3채널의 컬러 이미지

* 데이터 분할 : 총 60,000개의 이미지

  - 훈련 데이터 : 50,000개

  - 테스트 데이터 : 10,000개

* 클래스 : 100개, 100개의 클래스에 각각 600장의 이미지가 포함됨 => 총 60,000개

 

 

2. Augmentation - 데이터 증강

 

: 주어진 이미지 데이터셋에서 새로운 변형된 버전의 이미지를 생성하여 다양성을 높이는 것

 

 

(1) 원본 이미지 출력해보기

 

train_set1 = torchvision.datasets.CIFAR100("./", train=True, download=False, transform = tt.Compose([
    tt.ToTensor(),
]))

train_loader1 = torch.utils.data.DataLoader(train_set1, batch_size=16)

for i in train_loader1 :
    s = 1
    plt.figure(figsize=(16,10))
    for img in i[0][:4] :
        plt.subplot(1, int(len(i[0])/4),s)
        plt.imshow(np.transpose(img, (1,2,0)))
        s+=1
        
    break
    
plt.show()

 

(2) Augmentation 적용

 

- torchvision.transforms

 

train_set = torchvision.datasets.CIFAR100("./", train=True, download=False, transform = tt.Compose([
    tt.ToTensor(),
    tt.RandomCrop(32, padding=4, padding_mode='reflect'),
    tt.RandomHorizontalFlip()
]))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=16)

for i in train_loader :
    s = 1
    plt.figure(figsize=(16,10))
    for img in i[0][:4] :
        plt.subplot(1, int(len(i[0])/4),s)
        plt.imshow(np.transpose(img, (1,2,0)))
        s+=1
        
    break
    
plt.show()

 

 

- albumentation

 

: albumentation은 이미지 데이터 증강을 위한 라이브러리입니다.

 

import albumentation as A
import cv2

images, _ = next(iter(train_loader1))
image = np.array(images[0])

transform = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=1, contrast_limit=1, p=1.0),
    A.HorizontalFlip(p=0.5),
])

transformed = transform(image=image)
transformed_image = transformed['image']

plt.imshow(np.transpose(transformed_image, (1, 2, 0)))
plt.show()

 

https://albumentations.ai/docs/api_reference/augmentations/transforms/

위 링크를 참고하여 albumentation 기법들을 학습해봅시다!

 

Albumentations Documentation - Transforms (augmentations.transforms)

Albumentations: fast and flexible image augmentations

albumentations.ai

 

이미지 데이터셋 준비를 마쳤으니,

이 다음에는 CNN의 구조를 알아보겠습니다 :)