더 이상 tistory 블로그를 운영하지 않습니다. glanceyes.github.io에서 새롭게 시작합니다.

새소식

AI/AI 기본

PyTorch 프로젝트 구조와 클래스 속성 활용하기

  • -

 

2022년 1월 24일(월)부터 28일(금)까지 네이버 부스트캠프(boostcamp) AI Tech 강의를 들으면서 개인적으로 중요하다고 생각되거나 짚고 넘어가야 할 핵심 내용들만 간단하게 메모한 내용입니다. 틀리거나 설명이 부족한 내용이 있을 수 있으며, 이는 학습을 진행하면서 꾸준히 내용을 수정하거나 추가해 나갈 예정입니다.

 

 

PyTorch 프로젝트 구조

 

PyTorch Project Template

초기 단계에서는 학습과정과 디버깅 등을 지속적으로 확인할 수 있는 대화식 개발 과정이 유리하지만, 배포 및 공유 단계에서는 실행 순서가 꼬일 수 있는 등 여러 이유로 인해 notebook 파일로 공유하는 건 어려움이 있다.

DL(Deep Learning) 코드도 하나의 프로그램이므로 개발 용이성을 확보하고 유지보수를 향상시킬 필요가 있다.

 

OOP + 모듈 → 프로젝트

 

실행, 데이터, 모델, 설정, logging, 지표, 유틸리티 등 다양한 모듈을 분리하여 프로젝트를 템플릿화할 필요가 있다.

이는 사실 모든 프로그램을 개발하는 데 있어서 필요한 부분이 아닌가 하는 생각이 든다.

마치 Node.js에서처럼 코드를 refactoring하여 유지 보수가 용이하게끔 구조화시키는 것처럼 말이다.

 

아래는 강의에서 추천한 템플릿의 repo이다.

 

부스트캠프 Level 1에서의 마스크 분류 모델을 구현하는 프로젝트에서도 아래처럼 코드를 구조화하고 정리하는 데 집중했다.

유사한 기능을 가지거나 동일한 역할을 하는 코드를 한 폴더로 묶어서 정리하고, 파일 간의 의존성을 줄이고자 노력했다.

https://github.com/boostcampaitech3/level1-image-classification-level1-recsys-08

 

GitHub - boostcampaitech3/level1-image-classification-level1-recsys-08: level1-image-classification-level1-recsys-08 created by

level1-image-classification-level1-recsys-08 created by GitHub Classroom - GitHub - boostcampaitech3/level1-image-classification-level1-recsys-08: level1-image-classification-level1-recsys-08 creat...

github.com

 

code/
├─ dataset/             
│   ├─ augmentation/
│   │   ├─ BaseAugmentation.py
│   │   ├─ CustomAugmentation.py
│   │   └─ TrainAugmentation.py
│   ├─ BaseDataset.py
│   ├─ SplitByProfileDataset.py
│   └─ TestDataset.py
├─ inference/        
│   └─ Inferrer.py
├─ loss/    
│   └─ loss.py
├─ models/    
│   ├─ BaseModel.py
│   ├─ EfficientNetB3.py
│   ├─ EfficientNetB4.py
│   ├─ EfficientNetB4T.py
│   ├─ ResNet152.py
│   ├─ ResNet18.py
│   ├─ ResNet50.py
│   └─ VisionTransformer.py
├─ schedulers/     
│   ├─ CosineAnnealing.py        
│   ├─ CosineAnnealingWarmRestarts.py        
│   └─ StepLR.py        
├─ train/    
│   ├─ Trainer.py  
│   └─ Validator.py   
├─ utils/   
│   ├─ setConfig.py  
│   └─ util.py   
├─ config.ini     
└─ run.py

 

 

실행과 관련된 파일

  • train.py
  • test.py

 

 

설정과 관련된 파일

  • config.json
  • parse_config.py

 

Config.json은 data_loader, optimizer, epoch, loss function 등 다양한 학습 설정을 포함한다.

train.py에 argument로 config.json 파일에서 설정한 파라미터를 던저주면서 실행한다.

 

 

 

프로젝트에서 클래스 속성 활용하기

__getitem__() 함수를 정의하면 마치 dictionary처럼 class를 활용할 수 있다.

이는 PyTorch에서 Dataset을 정의할 때 유용한데, 몇 번째 아이템을 가져올 때 dictionary 또는 배열 접근 방식처럼 가져올 수 있어서다.

# __getitem__() 함수 예시
class Test(object):
    def __getitem__(self, items):
		print(type(items),items)

test = Test()
test[5] # <class 'int'> 5
test[5:10:2] # <class 'slice'> slice(5, 10, 2)
test['example'] # <class 'str'> example

 

__getitem__() 함수를 정의하지 않아도 getattr()를 통해 원하는 정보를 dictionary와 유사하게 가져올 수 있다. 

attribute가 있으면 해당 값을 반환하고 없으면 default 값을 반환한다.

class Point:
    def __init__(self):
        self.x = 10 # 인스턴스의 attribute인 x를 선언하고 10을 할당한다.
        self.y = 20 # 인스턴스의 attribute인 y를 선언하고 20을 할당한다.

p = Point()

print( getattr(p, 'x') )          # 10
print( getattr(p, 'z', 'Hello') ) # p의 attribute으로 z가 없는 대신에 default 값을 'Hello'로 설정했다.
print( getattr(p, 'a') )          # Error

 

 

Contents

글 주소를 복사했습니다

부족한 글 끝까지 읽어주셔서 감사합니다.
보충할 내용이 있으면 언제든지 댓글 남겨주세요.