본문 바로가기

카테고리 없음

CHAOS Challenge Dataset Preprocessing ( Pytorch Dataset Class) 카오스 데이터 셋 전처리 및 파이토치 데이터셋 클라스 설정

카오스 데이터셋은 Chaos 대회에서 제공하는 데이터이다. trainset은  20명의 환자로 구성되어 있으며, testset도 20명의 환자로 구성되어있다. 근데 다운로드를 받아보면 testset은 CT영상에 대한 마스크 영상이 없을 것이다. 이건 뭐 당연한 말이겠지만, 제공을 안해준다. trainset으로 학습을 시키고 test set을 predict 한 후에 사이트에 제출하면 채점을 해주는 방식이다.

 

무튼!!

 

오늘은 CHAOS 데이터셋을 Pytorch Dataset Class로 만드는 코드를 제공할 것이다.

 

DCM을 PNG로 만드는 방식도 여기 안에 다 들어있으니 참고하면 쉽게 이해할 수 있을 것이다.

 

전체적인 코드는 역시 가독성을 우선적으로 두고 하였고, 중복된 부분이 있을 수 있다.

 

 

1. chaos 데이터셋에서 받은 폴더가 'TrainSet'이 있을 것이다. 그것을 나는 이 코드 기준으로 "../../liver_dataset"이라는 폴더 안에 넣고 시작을 하였다.

2. 밑에 보면 mode = train/val/test가 있을 것이다. 이거는 데이터셋을 나눌라고 한 것이다. 임의적으로 7:2:1로 설정하였다.

3. 그리고 permute 라는 행과 열을 바꾸는 함수를 활용하였는 데, 이거는 각자 모델에 맞게끔 수정하면 될 것이다. 파이토치는 (N,C,H,W) 순을 따르지만, 우리가 읽어들일 때 256,256로 되기 때문에 np.newaxis로 채널을 추가한 후, 행과 열을 바꾸어 1,256,256로 넣게 하였다.

 

4. 여기 안에 들어간 전처리는 , resize 가 사용되었다. 알아서 바꾸시면 되고, arr[arr<-100]=-100  <= 이거는 HU 값을 임의로 설정한 것이고 데이터셋에서 아무런 설정도 건드리고 싶지 않으면 그 윗줄에 있는 주석을 풀고 arr[arr>max] = max / arr[arr<min] = min 을 사용하면 된다.

 

 

import os
import pydicom
import numpy as np
import cv2
from glob import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torch

class CHAOS(Dataset):
    def __init__(self, root_dir, image_size = 512,mode='train'):
        self.root_dir = root_dir
        self.image_size = image_size
        self.mode = mode
        self.images , self.masks = self.get_paths(self.root_dir)
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.get_image(self.images[idx],self.image_size)
        mask = self.get_mask(self.masks[idx],self.image_size)
        
        image = torch.FloatTensor(image)
        mask = torch.FloatTensor(mask)
        image =image.permute(2,0,1)
        mask = mask.permute(2,0,1)
        return [image, mask]
    
    def get_paths(self,root_dir):
        # get paths of files
        dir_path = os.path.join(root_dir,'Train_Sets/CT/')
        dir_list = os.listdir(dir_path)

        # add paths into list
        image_paths =[]
        label_paths =[]
        for dirs in dir_list :
            if dirs =='notes.txt':
                continue
            # get image paths
            image_folder_path = dir_path+dirs+'/DICOM_anon/'
            images = glob(image_folder_path + '*.dcm')

            if 'IMG' in images[0] :
                images = sorted(images,key=lambda x : int(os.path.basename(x).split('.')[0].split('-')[2][2:]))
            else :
                images = sorted(images,key=lambda x : int(os.path.basename(x).split('.')[0].split(',')[0][2:]))

            # get label paths
            label_folder_path = dir_path+dirs+'/Ground/'
            labels = glob(label_folder_path+'*.png')
            labels = sorted(labels, key=lambda x: int(os.path.basename(x).split('_')[-1].split('.png')[0]))

            for i in images:
                image_paths.append(i)
            for l in labels:
                label_paths.append(l)
        
        if self.mode == 'train' :
            validation_ratio = int(len(image_paths) / 10 * 7)
            image_paths = image_paths[:validation_ratio]
            label_paths = label_paths[:validation_ratio]
        elif self.mode == 'val' : # validation
            validation_ratio = int(len(image_paths) / 10 * 7)
            test_ratio = int(len(image_paths)/10 * 9)
            image_paths = image_paths[validation_ratio:test_ratio]
            label_paths = label_paths[validation_ratio:test_ratio]
        elif self.mode == 'test' : # test
            test_ratio = int(len(image_paths)/10 * 9)
            image_paths = image_paths[test_ratio:]
            label_paths = label_paths[test_ratio:]

        print(self.mode + ' image length = ' , len(image_paths))
        print(self.mode + ' label length = ' , len(label_paths))
        return image_paths,label_paths
    
    def get_image(self,path,image_size):
        # get dcm image
        dcm = pydicom.read_file(path)
        arr = dcm.pixel_array

        arr = arr*dcm.RescaleSlope + dcm.RescaleIntercept
        min = int ( dcm.WindowCenter[0]) - int(dcm.WindowWidth[0]/2)
        max = int(dcm.WindowCenter[0]) + int(dcm.WindowWidth[0]/2)
        
        arr = cv2.resize(arr, dsize=(image_size, image_size), interpolation=cv2.INTER_AREA) # reshape image size
        #arr [ arr        #arr [ arr>max ] = max
        arr [arr < -200] = -200
        arr [arr > 250] = 250
        #arr = arr[np.newaxis,...] #  add axis
        arr = arr[...,np.newaxis]
        return arr
    
    def get_mask(self,path,image_size):
        # label
        label_image = cv2.imread(path)
        label_image = cv2.cvtColor(label_image,cv2.COLOR_BGR2GRAY)
        label_image = cv2.resize(label_image,dsize=(image_size,image_size),interpolation=cv2.INTER_AREA) # reshape
        label_image[ label_image > 0 ] = 1
        #label_image = label_image[np.newaxis,...]
        label_image = label_image[...,np.newaxis]
        return label_image

train_set = CHAOS(root_dir='../../liver_dataset',image_size=256,mode='train')
val_set = CHAOS(root_dir='../../liver_dataset',image_size=256,mode='val')
test_set = CHAOS(root_dir='../../liver_dataset',image_size = 256, mode='test')