개발/AI(python)

입문 - Dataset클래스 정의하기(__getitem__, __len__())

ebang 2022. 11. 30. 23:00
반응형

데이터셋 클래스 정의

파이토치로 신경망 모델을 구축하려면 데이터셋도 일정한 형식에 맞게 정의해줘야 한다.

먼저 데이터셋 생성에 필요한 라이브러리를 임포트한다.

import cv2
from torch.utils.data import Dataset

파이토치에서 제공하는 Dataset클래스를 활용해서 데이터셋 객체를 만들 수 있다.

Dataset은 추상 클래스이며, 우리는 Dataset을 상속받은 다음 특수 메서드인

__len__() , __getitem__()

재정의(오버라이딩)해야 한다.

 

 

💡 추상 클래스: 곧바로 객체를 생성할 수 없고 상속만 할 수 있는 클래스.       
    추상 클래스를 사용하는 이유는 상속받는 클래스들의 메서드를 규격화하기 위해서이다.
    상속을 강제해 메서드 시그니처를 일치시키기 위해서이다.                                                                                                                                                  

 

 

 

import cv2 # OpenCV 라이브러리
from torch.utils.data import Dataset # 데이터 생성을 위한 클래스

class ImageDataset(Dataset):
    # 초기화 메서드(생성자)
    def __init__(self, df, img_dir='./', transform=None):
        super().__init__() # 상속받은 Dataset의 생성자 호출
        # 전달받은 인수들 저장
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    # 데이터셋 크기 반환 메서드 
    def __len__(self):
        return len(self.df)

    # 인덱스(idx)에 해당하는 데이터 반환 메서드 
    def __getitem__(self, idx):
        img_id = self.df.iloc[idx, 0]    # 이미지 ID
        img_path = self.img_dir + img_id # 이미지 파일 경로 
        image = cv2.imread(img_path)     # 이미지 파일 읽기 
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 이미지 색상 보정
        label = self.df.iloc[idx, 1]     # 이미지 레이블(타깃값)

        if self.transform is not None:
            image = self.transform(image) # 변환기가 있다면 이미지 변환
        return image, label
  1. init
__init__() 

ImageDataset 클래스의 초기화 메서드이다. 

상속 받은 Dataset의 초기화 메서드를 호출한 후, 파라미터로 받은 인수들을 저장한다.

파라미터들의 역할은 다음과 같다.

 

(이 파라미터들은 상속 받은 Dataset내에 존재하는 멤버변수이다. )

 

    -df : DataFrame 객체. 앞서 labels를 train과 valid로 나누었는데,

train 혹은 valid를 df 파라미터로 전달한다.

   -img_dir : 이미지 데이터를 포함하는 경로

   -transform : 이미지 변환기. 이미지 데이터셋을 만들 때 기본적인 전처리를 할 수 있는데,
                            전처리를 하려면 이미지 변환기를 넘겨
주면 된다.

  1. len
__len__()

데이터셋의 크기를 반환하는 메서드. Dataset 클래스에 이미 정의되어 있는 메서드를 정의하는 것이다.

  1. getitem
__getitem__()

지정한 인덱스에 해당하는 데이터를 반환하는 메서드.

idx번째 이미지와 레이블(타깃값)을 반환한다.

초기화 메서드에서 받은 경로에 이미지 ID를 합쳐 이미지 위치를 알아내고

초기화 메서드에서 이미지 변환기를 받아두었다면 변환 작업까지 수행한 후 반환한다 .

 

 

 

💡 위 메서드 처럼 이름 앞 뒤로 이중 언더바가 붙는 메서드는 호출 방식이 일반적인 메서드와 다르다.
     len은 len(imageDataset)형태로 호출하고,getitem은 ImageDataset[idx]형태로 호출한다.                                                   파이토치로 딥러닝 모델링을 하려면 이 특별 메서드 형식에 맞게 데이터셋 생성 클래스를 정의해야 한다.    

 

💡 사용자 정의 데이터셋 클래스를 만들 때 Dataseet추상 클래스를 ‘꼭 상속받아야 할까?                                                
      필수는 아니다. len 과 getitem 메서드만 동일하게 정의했다면 상속받지 않아도 문제없이 동작한다.
      그렇지만 코드의 의도를 명확하게 하려면 Dataset를 상속받는게 바람직하다.
      공식문서에서도 상속받으라고 권고한다.                      

 

 

 

 

 

참고

*python 공식문서 - Dataset, Dataloader 활용하기(상속받은 클래스로 customize 하기)

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

 

 

* 이 게시물은 '머신러닝 딥러닝 문제해결 전략'을 읽고 쓴 리뷰입니다. 

반응형