딥러닝/Pytorch

[Deep Learning / Pytorch] Dataset & DataLoader

bluetag_boy 2024. 3. 12. 18:03
반응형

 파이토치(PyTorch)에서 Dataset과 DataLoader를 사용하는 주된 이유는 데이터를 효율적으로 처리하고, 모델 학습을 위한 데이터 로딩을 보다 유연하고 관리하기 쉽게 만들기 위해서 사용된다. 이러한 추상화가 불편해 보일 수 있지만, 실제로는 대규모 데이터셋을 다루거나 복잡한 데이터 전처리 과정을 통합할 때 매우 유용하다.


Dataset

  • 데이터 추상화: Dataset 클래스는 데이터셋에 대한 추상화를 제공하며, 사용자는 이를 상속받아 __getitem__과 __len__ 메서드를 구현함으로써 데이터의 로딩 방식을 커스터마이징할 수 있다. 이는 다양한 데이터 소스(이미지, 텍스트 파일, CSV 등)를 동일한 인터페이스로 처리할 수 있게 해준다.

 

  • 유연성: 사용자 정의 데이터셋을 만들 수 있어, 데이터 전처리나 샘플링 로직을 데이터셋 클래스 내에 캡슐화할 수 있다. 이로 인해 코드가 더 정돈되고 재사용성이 높아진다.

 

Dataset 클래스를 사용하기 위해서는 아래와 같이 3개의 메서드가 필수적으로 구현되어야 한다.

 

 

1) __init__

  • 데이터셋의 파일 경로, 변환(transforms) 등 데이터셋에 필요한 초기 설정을 수행
  • 데이터셋 객체가 생성될 때 초기화를 담당

 

2) __len__

  • 데이터셋의 전체 크기를 반환
  • len(dataset) 호출 시 자동으로 사용됨

 

3) __getitem__

  • 특정 인덱스에 해당하는 샘플을 데이터셋에서 불러오고 반환
  • 데이터셋의 특정 원소에 접근할 때 사용되며, 인덱스(index)를 받아 해당 인덱스의 샘플을 가공(예: 변환 적용) 후 반환
  • DataLoader를 통해 미니 배치를 구성할 때 내부적으로 호출됨

 

Example

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, transforms=None):
        self.data = data
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        if self.transforms:
            sample = self.transforms(sample)
        return sample
        
# 데이터셋 생성
train_dataset = CustomDataset(train_csv_file)
val_dataset = CustomDataset(val_csv_file)
test_dataset = CustomDataset(pred_csv_file)

 

 

DataLoader

  • 배치 처리: DataLoader는 데이터셋에서 미니 배치를 자동으로 생성해다. 이는 모델 학습 시 배치 단위로 데이터를 공급하는 일반적인 요구사항을 쉽게 충족시키는 방법을 제공한다.

 

  • 멀티스레딩/멀티프로세싱 데이터 로딩: DataLoader는 멀티스레드 또는 멀티프로세싱을 사용하여 데이터 로딩을 병렬화할 수 있어, 학습 과정 중 CPU/GPU가 데이터 로딩으로 인해 유휴 상태에 빠지는 것을 최소화한다. 이는 특히 큰 데이터셋을 사용할 때 학습 속도를 크게 향상시킬 수 있다.

 

  • 자동 셔플링과 샘플링 전략: 학습 데이터를 셔플링하여 모델이 특정 순서에 의존하지 않도록 하거나, 사용자 정의 샘플링 전략을 적용할 수 있다. 이는 모델의 일반화 능력을 향상시키는 데 도움이 된다.

 

Examples

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)