카오스 데이터셋은 Chaos 대회에서 제공하는 데이터이다. trainset은 20명의 환자로 구성되어 있으며, testset도 20명의 환자로 구성되어있다. 근데 다운로드를 받아보면 testset은 CT영상에 대한 마스크 영상이 없을 것이다. 이건 뭐 당연한 말이겠지만, 제공을 안해준다. trainset으로 학습을 시키고 test set을 predict 한 후에 사이트에 제출하면 채점을 해주는 방식이다.
무튼!!
오늘은 CHAOS 데이터셋을 Pytorch Dataset Class로 만드는 코드를 제공할 것이다.
DCM을 PNG로 만드는 방식도 여기 안에 다 들어있으니 참고하면 쉽게 이해할 수 있을 것이다.
전체적인 코드는 역시 가독성을 우선적으로 두고 하였고, 중복된 부분이 있을 수 있다.
1. chaos 데이터셋에서 받은 폴더가 'TrainSet'이 있을 것이다. 그것을 나는 이 코드 기준으로 "../../liver_dataset"이라는 폴더 안에 넣고 시작을 하였다.
2. 밑에 보면 mode = train/val/test가 있을 것이다. 이거는 데이터셋을 나눌라고 한 것이다. 임의적으로 7:2:1로 설정하였다.
3. 그리고 permute 라는 행과 열을 바꾸는 함수를 활용하였는 데, 이거는 각자 모델에 맞게끔 수정하면 될 것이다. 파이토치는 (N,C,H,W) 순을 따르지만, 우리가 읽어들일 때 256,256로 되기 때문에 np.newaxis로 채널을 추가한 후, 행과 열을 바꾸어 1,256,256로 넣게 하였다.
4. 여기 안에 들어간 전처리는 , resize 가 사용되었다. 알아서 바꾸시면 되고, arr[arr<-100]=-100 <= 이거는 HU 값을 임의로 설정한 것이고 데이터셋에서 아무런 설정도 건드리고 싶지 않으면 그 윗줄에 있는 주석을 풀고 arr[arr>max] = max / arr[arr<min] = min 을 사용하면 된다.
import os
import pydicom
import numpy as np
import cv2
from glob import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torch
class CHAOS(Dataset):
def __init__(self, root_dir, image_size = 512,mode='train'):
self.root_dir = root_dir
self.image_size = image_size
self.mode = mode
self.images , self.masks = self.get_paths(self.root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.get_image(self.images[idx],self.image_size)
mask = self.get_mask(self.masks[idx],self.image_size)
image = torch.FloatTensor(image)
mask = torch.FloatTensor(mask)
image =image.permute(2,0,1)
mask = mask.permute(2,0,1)
return [image, mask]
def get_paths(self,root_dir):
# get paths of files
dir_path = os.path.join(root_dir,'Train_Sets/CT/')
dir_list = os.listdir(dir_path)
# add paths into list
image_paths =[]
label_paths =[]
for dirs in dir_list :
if dirs =='notes.txt':
continue
# get image paths
image_folder_path = dir_path+dirs+'/DICOM_anon/'
images = glob(image_folder_path + '*.dcm')
if 'IMG' in images[0] :
images = sorted(images,key=lambda x : int(os.path.basename(x).split('.')[0].split('-')[2][2:]))
else :
images = sorted(images,key=lambda x : int(os.path.basename(x).split('.')[0].split(',')[0][2:]))
# get label paths
label_folder_path = dir_path+dirs+'/Ground/'
labels = glob(label_folder_path+'*.png')
labels = sorted(labels, key=lambda x: int(os.path.basename(x).split('_')[-1].split('.png')[0]))
for i in images:
image_paths.append(i)
for l in labels:
label_paths.append(l)
if self.mode == 'train' :
validation_ratio = int(len(image_paths) / 10 * 7)
image_paths = image_paths[:validation_ratio]
label_paths = label_paths[:validation_ratio]
elif self.mode == 'val' : # validation
validation_ratio = int(len(image_paths) / 10 * 7)
test_ratio = int(len(image_paths)/10 * 9)
image_paths = image_paths[validation_ratio:test_ratio]
label_paths = label_paths[validation_ratio:test_ratio]
elif self.mode == 'test' : # test
test_ratio = int(len(image_paths)/10 * 9)
image_paths = image_paths[test_ratio:]
label_paths = label_paths[test_ratio:]
print(self.mode + ' image length = ' , len(image_paths))
print(self.mode + ' label length = ' , len(label_paths))
return image_paths,label_paths
def get_image(self,path,image_size):
# get dcm image
dcm = pydicom.read_file(path)
arr = dcm.pixel_array
arr = arr*dcm.RescaleSlope + dcm.RescaleIntercept
min = int ( dcm.WindowCenter[0]) - int(dcm.WindowWidth[0]/2)
max = int(dcm.WindowCenter[0]) + int(dcm.WindowWidth[0]/2)
arr = cv2.resize(arr, dsize=(image_size, image_size), interpolation=cv2.INTER_AREA) # reshape image size
#arr [ arr #arr [ arr>max ] = max
arr [arr < -200] = -200
arr [arr > 250] = 250
#arr = arr[np.newaxis,...] # add axis
arr = arr[...,np.newaxis]
return arr
def get_mask(self,path,image_size):
# label
label_image = cv2.imread(path)
label_image = cv2.cvtColor(label_image,cv2.COLOR_BGR2GRAY)
label_image = cv2.resize(label_image,dsize=(image_size,image_size),interpolation=cv2.INTER_AREA) # reshape
label_image[ label_image > 0 ] = 1
#label_image = label_image[np.newaxis,...]
label_image = label_image[...,np.newaxis]
return label_image
train_set = CHAOS(root_dir='../../liver_dataset',image_size=256,mode='train')
val_set = CHAOS(root_dir='../../liver_dataset',image_size=256,mode='val')
test_set = CHAOS(root_dir='../../liver_dataset',image_size = 256, mode='test')