본문 바로가기
AI 기초이론

Pytorch Dataloader에 대해 알아보자 - 2

by 피라냐콜라다 2023. 3. 18.

저번 포스팅에서 dataloader의 내장 변수들과 그 기본적인 활용 방법에 알아보았다. 이번에는 dataloader에 대한 설명에 들어가기 앞서 torchvision 에서 제공하는 몇가지 transform 함수를 먼저 알아보려 한다.

먼저 이번에 사용할 이미지를 가져와보자

url = 'https://images.unsplash.com/photo-1583160247711-2191776b4b91?ixid=MnwxMjA3fDB8MHxzZWFyY2h8MTN8fGdvbGRlbnJldHJpZXZlcnxlbnwwfHwwfHw%3D&ixlib=rb-1.2.1&auto=format&fit=crop&w=500&q=60'
im = Image.open(requests.get(url, stream=True).raw)  # torchvision은 항상 PIL 객체로 받아야합니다!
im

위 코드로 로딩된 이미지

1. Resize

해당 함수는 이미지의 크기를 재조정한다.

im.size

>>> (500, 333)

위 사진의 크기는 500 x 333이다. 이 사진을 200 x 200의 크기로 바꿔보자

transforms.Resize((200,200))(im)

resize된 사진

2. RandomCrop()

사진을 임의의 위치에서 주어진 크기만큼 자른다.

transforms.RandomCrop((100,100))(im)

crop된 사진

3. RandomRotation()

사진을 주어진 각도 내에서 임의로 회전시킨다.

transforms.RandomRotation(30)(im)

rotate된 사진

이런 transform을 한번에 적용해보자

def get_transforms_img(im):
    transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomVerticalFlip(p=1),
    transforms.CenterCrop((150,150))])
    im=transform(im)
    return im
def get_transforms_img(im):
    im=transforms.Resize((224,224))(im)
    im=transforms.RandomVerticalFlip(p=1)(im)
    im=transforms.CenterCrop((150,150))(im)
    return im

위의 코드처럼 transforms.Compose로 하나의 list로 묶은 다음 transform을 적용해도 되고, 아래 코드처럼 하나씩 적용해도 된다.

get_transforms_img(im)

4. ToTensor()

이미지의 각 픽셀의 값을 수치로 전환해 Tensor값으로 반환한다.

transforms.ToTensor()(im)
tensor([[[0.3176, 0.3176, 0.3176,  ..., 0.2667, 0.2667, 0.2667],
         [0.3216, 0.3216, 0.3216,  ..., 0.2667, 0.2667, 0.2667],
         [0.3255, 0.3255, 0.3255,  ..., 0.2667, 0.2667, 0.2667],
         ...,
         [0.1647, 0.1725, 0.1882,  ..., 0.1922, 0.1922, 0.1882],
         [0.1569, 0.1608, 0.1608,  ..., 0.1922, 0.1843, 0.1765],
         [0.1490, 0.1529, 0.1529,  ..., 0.1922, 0.1804, 0.1725]],

        [[0.4118, 0.4118, 0.4118,  ..., 0.2706, 0.2706, 0.2706],
         [0.4157, 0.4157, 0.4157,  ..., 0.2706, 0.2706, 0.2706],
         [0.4235, 0.4235, 0.4235,  ..., 0.2706, 0.2706, 0.2706],
         ...,
         [0.2314, 0.2392, 0.2431,  ..., 0.1843, 0.1843, 0.1804],
         [0.2118, 0.2157, 0.2157,  ..., 0.1843, 0.1765, 0.1686],
         [0.2000, 0.2039, 0.2000,  ..., 0.1843, 0.1725, 0.1647]],

        [[0.4196, 0.4196, 0.4196,  ..., 0.2471, 0.2471, 0.2471],
         [0.4235, 0.4235, 0.4235,  ..., 0.2471, 0.2471, 0.2471],
         [0.4392, 0.4392, 0.4392,  ..., 0.2471, 0.2471, 0.2471],
         ...,
         [0.1686, 0.1765, 0.1765,  ..., 0.1882, 0.1882, 0.1843],
         [0.1608, 0.1647, 0.1647,  ..., 0.1882, 0.1804, 0.1725],
         [0.1647, 0.1647, 0.1608,  ..., 0.1882, 0.1765, 0.1686]]])

이번에는 그 유명한 MNIST datase를 불러와보자. 혹시 처음 들어본다면 다음의 글을 읽어보길 바란다.

https://ko.wikipedia.org/wiki/MNIST_%EB%8D%B0%EC%9D%B4%ED%84%B0%EB%B2%A0%EC%9D%B4%EC%8A%A4

 

MNIST 데이터베이스 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. MNIST 데이터베이스 (Modified National Institute of Standards and Technology database)는 손으로 쓴 숫자들로 이루어진 대형 데이터베이스이며, 다양한 화상 처리 시스템을 트레

ko.wikipedia.org

MNIST dataset

dataset_train_MNIST = torchvision.datasets.MNIST('data/MNIST/', # 다운로드 경로 지정
                                                 train=True, # True를 지정하면 훈련 데이터로 다운로드
                                                 transform=transforms.ToTensor(), # 텐서로 변환
                                                 download=True, 
                                                )
dataset_train_MNIST

>>>
Dataset MNIST
    Number of datapoints: 60000
    Root location: data/MNIST/
    Split: Train
    StandardTransform
Transform: ToTensor()

len(dataset_train_MNIST)

>>> 60000

dataset_train_MNIST.classes

>>>
['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']
 
image, label = next(iter(dataset_train_MNIST))
image, label

>>>
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          ...
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]), 5)

위처럼 torchvision.dataset으로 MNIST dataset을 불러올수도 있다. 

이번에는 DataLoader를 이용해서 MNIST dataset을 불러보자

dataloader_train_MNIST = DataLoader(dataset=dataset_train_MNIST,
                                    batch_size=16,
                                    shuffle=True,
                                    num_workers=4,
                                   )
                                 
images, labels = next(iter(dataloader_train_MNIST))

plt.figure(figsize=(12,12))
for n, (image, label) in enumerate(zip(images, labels), start=1):
    plt.subplot(2,2,n)
    plt.imshow(image.numpy().squeeze(), cmap='gray')
    plt.title("{}".format(dataset_train_MNIST.classes[label]))
    plt.axis('off')
plt.tight_layout()
plt.show()

이 외에도 CIFAR-10 이나 AG_NEWS와 같은 여러 dataset이 있으니 관심이 있다면 직접 해보도록 하자

댓글