Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions MRnet-MultiTask-Approach/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
*.npy
*.DS_Store

*.pth
*.csv

*.pyc

.vscode

runs/

*__pycache__*
21 changes: 21 additions & 0 deletions MRnet-MultiTask-Approach/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2020 Big Vision LLC

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
98 changes: 98 additions & 0 deletions MRnet-MultiTask-Approach/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<div align="center">
<img src="content/logo.jpg" width ="600" height="300"/>

# Stanford MRnet Challenge

**This repo contains code for the MRNet Challenge (A Multi-Task Approach)**


For more details refer to https://stanfordmlgroup.github.io/competitions/mrnet/

</div>

# Install dependencies
1. `pip install git+https://github.com/ncullen93/torchsample`
2. `pip install nibabel`
3. `pip install sklearn`
4. `pip install pandas`

Install other dependencies as per requirement

# Instructions to run the training
1. Clone the repository.

2. Download the dataset (~5.7 GB), and put `train` and `valid` folders along with all the the `.csv` files inside `images` folder at root directory.
```Shell
images/
train/
axial/
sagittal/
coronal/
val/
axial/
sagittal/
coronal/
train-abnormal.csv
train-acl.csv
train-meniscus.csv
valid-abnormal.csv
valid-acl.csv
valid-meniscus.csv
```

3. Make a new folder called `weights` at root directory, and inside the `weights` folder create three more folders namely `acl`, `abnormal` and `meniscus`.

4. All the hyperparameters are defined in `config.py` file. Feel free to play around those.

5. Now finally run the training using `python train.py`. All the logs for tensorboard will be stored in the `runs` directory at the root of the project.

# Understanding the Dataset

<div align="center">

<img src="content/mri_scan.png" width ="650" height="600"/>

</div>

The dataset contains MRIs of different people. Each MRI consists of multiple images.
Each MRI has data in 3 perpendicular planes. And each plane as variable number of slices.

Each slice is an `256x256` image

For example:

For `MRI 1` we will have 3 planes:

Plane 1- with 35 slices

Plane 2- with 34 slices

Place 3 with 35 slices

Each MRI has to be classisifed against 3 diseases

Major challenge with while selecting the model structure was the inconsistency in the data. Although the image size remains constant , the number of slices per plane are variable within a single MRI and varies across all MRIs.

# Model Specifications

<div align="center">

<img src="content/model.png" width ="700" height="490"/>

</div>

In the last attempt to MRNet challenge, we used 3 different models for each disease, but instead we can leverage the information that the model learns for each of the disease and make inferencing for other disease better.

We used Hard Parameter sharing in this approach.

We will be using 3 Alexnet pretrained as 3 feature extractors for each of the plane. We then combine these feature extractor layers as an input to a `global` fully connected layer for the final classification.

# Contributors
<p >
-- Neelabh Madan
<a href = https://github.com/neelabh17 target='blank'> <img src=https://github.com/edent/SuperTinyIcons/blob/master/images/svg/github.svg height='30' weight='30'/></a>
<br>

-- Jatin Prakash <a href = https://github.com/bicycleman15 target='blank'> <img src=https://github.com/edent/SuperTinyIcons/blob/master/images/svg/github.svg height='30' weight='30'/></a>


13 changes: 13 additions & 0 deletions MRnet-MultiTask-Approach/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
config = {
'max_epoch' : 50,
'log_train' : 50,
'lr' : 1e-5,
'starting_epoch' : 0,
'batch_size' : 1,
'log_val' : 10,
'task' : 'combined',
'weight_decay' : 0.01,
'patience' : 5,
'save_model' : 1,
'exp_name' : 'test'
}
Binary file added MRnet-MultiTask-Approach/content/logo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MRnet-MultiTask-Approach/content/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MRnet-MultiTask-Approach/content/mri_scan.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions MRnet-MultiTask-Approach/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dataset import MRData, load_data
153 changes: 153 additions & 0 deletions MRnet-MultiTask-Approach/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import pandas as pd
import numpy as np

import torch
import torch.utils.data as data

from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine
from torchvision import transforms

INPUT_DIM = 224
MAX_PIXEL_VAL = 255
MEAN = 58.09
STDDEV = 49.73

class MRData():
"""This class used to load MRnet dataset from `./images` dir
"""

def __init__(self,task = 'acl', train = True, transform = None, weights = None):
"""Initialize the dataset

Args:
plane : along which plane to load the data
task : for which task to load the labels
train : whether to load the train or val data
transform : which transforms to apply
weights (Tensor) : Give wieghted loss to postive class eg. `weights=torch.tensor([2.223])`
"""
self.planes=['axial', 'coronal', 'sagittal']
self.diseases = ['abnormal','acl','meniscus']
self.records = {'abnormal' : None, 'acl' : None, 'meniscus' : None}
# an empty dictionary
self.image_path={}

if train:
for disease in self.diseases:
self.records[disease] = pd.read_csv('./images/train-{}.csv'.format(disease),header=None, names=['id', 'label'])

'''
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
image for that plane}
'''
for plane in self.planes:
self.image_path[plane] = './images/train/{}/'.format(plane)
else:
for disease in self.diseases:
self.records[disease] = pd.read_csv('./images/valid-{}.csv'.format(disease),header=None, names=['id', 'label'])

'''
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
image for that plane}
'''
for plane in self.planes:
self.image_path[plane] = './images/valid/{}/'.format(plane)


self.transform = transform

for disease in self.diseases:
self.records[disease]['id'] = self.records[disease]['id'].map(
lambda i: '0' * (4 - len(str(i))) + str(i))

# empty dictionary
self.paths={}
for plane in self.planes:
self.paths[plane] = [self.image_path[plane] + filename +
'.npy' for filename in self.records['acl']['id'].tolist()]

self.labels = {'abnormal' : None, 'acl' : None, 'meniscus' : None}
for disease in self.diseases:
self.labels[disease] = self.records[disease]['label'].tolist()

weights_ = []
for disease in self.diseases:
pos = sum(self.labels[disease])
neg = len(self.labels[disease]) - pos
weights_.append(neg/pos)

# Find the wieghts of pos and neg classes
if weights:
self.weights = torch.FloatTensor(weights)
else:
self.weights = torch.FloatTensor(weights_)

print('Weights for loss is : ', self.weights)

def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.records['acl'])

def __getitem__(self, index):
"""
Returns `(images,labels)` pair
where image is a list [imgsPlane1,imgsPlane2,imgsPlane3]
and labels is a list [gt,gt,gt]
"""
img_raw = {}

for plane in self.planes:
img_raw[plane] = np.load(self.paths[plane][index])
img_raw[plane] = self._resize_image(img_raw[plane])

label = []
for disease in self.diseases:
label.append(self.labels[disease][index])

label = torch.FloatTensor(label)

return [img_raw[plane] for plane in self.planes], label

def _resize_image(self, image):
"""Resize the image to `(3,224,224)` and apply
transforms if possible.
"""
# Resize the image
pad = int((image.shape[2] - INPUT_DIM)/2)
image = image[:,pad:-pad,pad:-pad]
image = (image-np.min(image))/(np.max(image)-np.min(image))*MAX_PIXEL_VAL
image = (image - MEAN) / STDDEV

if self.transform:
image = self.transform(image)
else:
image = np.stack((image,)*3, axis=1)

image = torch.FloatTensor(image)
return image

def load_data(task : str):

# Define the Augmentation here only
augments = Compose([
transforms.Lambda(lambda x: torch.Tensor(x)),
RandomRotate(25),
RandomTranslate([0.11, 0.11]),
RandomFlip(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
])

print('Loading Train Dataset of {} task...'.format(task))
train_data = MRData(task, train=True, transform=augments)
train_loader = data.DataLoader(
train_data, batch_size=1, num_workers=4, shuffle=True
)

print('Loading Validation Dataset of {} task...'.format(task))
val_data = MRData(task, train=False)
val_loader = data.DataLoader(
val_data, batch_size=1, num_workers=4, shuffle=False
)

return train_loader, val_loader, train_data.weights, val_data.weights
56 changes: 56 additions & 0 deletions MRnet-MultiTask-Approach/models/MRnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn
from torchvision import models
import os

class MRnet(nn.Module):
"""MRnet uses pretrained resnet50 as a backbone to extract features, this is multilabel classifying model
"""

def __init__(self): # add conf file

super(MRnet,self).__init__()

# init three backbones for three axis
self.axial = models.alexnet(pretrained=True).features
self.coronal = models.alexnet(pretrained=True).features
self.saggital = models.alexnet(pretrained=True).features

self.pool_axial = nn.AdaptiveAvgPool2d(1)
self.pool_coronal = nn.AdaptiveAvgPool2d(1)
self.pool_saggital = nn.AdaptiveAvgPool2d(1)

self.fc = nn.Sequential(
nn.Linear(in_features=3*256,out_features=3)
)

def forward(self,x):
""" Input is given in the form of `[image1, image2, image3]` where
`image1 = [1, slices, 3, 224, 224]`. Note that `1` is due to the
dataloader assigning it a single batch.
"""

# squeeze the first dimension as there
# is only one patient in each batch
images = [torch.squeeze(img, dim=0) for img in x]

image1 = self.axial(images[0])
image2 = self.coronal(images[1])
image3 = self.saggital(images[2])

image1 = self.pool_axial(image1).view(image1.size(0), -1)
image2 = self.pool_coronal(image2).view(image2.size(0), -1)
image3 = self.pool_saggital(image3).view(image3.size(0), -1)

image1 = torch.max(image1,dim=0,keepdim=True)[0]
image2 = torch.max(image2,dim=0,keepdim=True)[0]
image3 = torch.max(image3,dim=0,keepdim=True)[0]

output = torch.cat([image1,image2,image3], dim=1)

output = self.fc(output)
return output

def _load_wieghts(self):
"""load pretrained weights"""
pass
1 change: 1 addition & 0 deletions MRnet-MultiTask-Approach/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .MRnet import MRnet
1 change: 1 addition & 0 deletions MRnet-MultiTask-Approach/src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This contains the submission code for MRnet challenge.
Loading