Brain age regression with deep learning

Predict age from brain grey matter (regression) using Deep Learning (DL). Aging is associated with grey matter (GM) atrophy. Each year, an adult lose 0.1% of GM. We will try to learn a predictor of the chronological age (true age) using GM measurements on a population of healthy control participants.

Such a predictor provides the expected brain age of a subject. Deviation from this expected brain age indicates acceleration or slowdown of the aging process which may be associated with a pathological neurobiological process or protective factor of aging.

Dataset

There are 357 samples in the training set and 90 samples in the test set.

Input data

Voxel-based_morphometry VBM using cat12 software which provides:

  • Regions Of Interest (rois) of Grey Matter (GM) scaled for the Total Intracranial Volume (TIV): [train|test]_rois.csv 284 features.

  • VBM GM 3D maps or images (vbm3d) of voxels in the MNI space: [train|test]_vbm.npz contains 3D images of shapes (121, 145, 121). This npz contains also the 3D mask and the affine transformation to MNI referential. Masking the brain provides flat 331 695 input features (masked voxels) for each participant.

By default problem.get_[train|test]_data() return the concatenation of 284 ROIs of Grey Matter (GM) features with 331 695 features (voxels) within a brain mask. Those two blocks are higly redundant. To select only rois features do:

X[:, :284]

To select only vbm features do:

X[:, 284:]

Target

The target can be found in [test|train]_participants.csv files, selecting the age column for regression problem.

Evaluation metrics

sklearn metrics

The main Evaluation metrics is the Root-mean-square deviation RMSE. We will also look at the R-squared R2.

Installation

This starting kit requires Python 3 and the following dependencies:

  • numpy
  • scipy
  • pandas
  • scikit-learn
  • matplolib
  • seaborn
  • jupyter
  • torch
  • ramp-workflow

You can install the dependencies with the following command-line:

pip install -U -r requirements.txt

If you are using conda, we provide an environment.yml file for similar usage.

conda env create -n ramp-brainage -f environment.yml

Then, you can activate/desactivate the conda environment using:

conda activate ramp-brainage
conda deactivate

Getting started

  1. Download the data
python download_data.py

The train/test data will be available in the data directory.

  1. Execute the jupyter notebook
jupyter notebook brain_age_deep_starting_kit.ipynb

Play with this notebook to create your new model.

  1. Test Submission

The submissions need to be located in the submissions folder. For instance to create a linear_regression_rois submission, start by copying the starting kit

cp -r submissions/starting_kit submissions/linear_regression_rois`.

Tune the estimator in thesubmissions/linear_regression_rois/estimator.py file. This file must contain a function get_estimator() that returns a scikit learn Pipeline.

Then, test your submission locally:

ramp-test --submission linear_regression_rois

Note that the weights of the model are expected in a file called weights.pth located in your submission folder.

  1. Submission on RAMP.studio

Connect to your RAMP.studio account, select the brain_age_deep event, and submit your estimator in the sandbox section. Note that no training will be performed on the server side. You need to join the weights of the model in a file called weights.pth.

The MLP estimator in details

Let's play with the data:

In [ ]:
%matplotlib inline
import time
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import sklearn.metrics as metrics
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_validate
from sklearn.base import BaseEstimator, TransformerMixin
from joblib.externals.loky.backend.context import get_context
import problem

X_train, y_train = problem.get_train_data()
X_test, y_test = problem.get_test_data()
assert X_train.shape[1] == 284 + 331695

Let's define here some utility functions to compute scores:

In [ ]:
def cv_train_test_scores(rmse_cv_test, rmse_cv_train, r2_cv_test, r2_cv_train,
                         y_train, y_pred_train, y_test, y_pred_test):
    """ Compute CV score, train and test score from a cv grid search model.

    Parameters
    ----------
    rmse_cv_test : array
        Test rmse across CV folds.
    rmse_cv_train : array
        Train rmse across CV folds.
    r2_cv_test : array
        Test R2 across CV folds.
    r2_cv_train : array
        Train R2 across CV folds.
    y_train : array
        True train values.
    y_pred_train : array
        Predicted train values.
    y_test : array
        True test values.
    y_pred_test : array
        Predicted test values.

    Returns
    -------
    info : TYPE
        DataFrame(r2_cv, r2_train, mae_train, mse_train).
    """
    # CV scores
    rmse_cv_test_mean, rmse_cv_test_sd = np.mean(rmse_cv_test), np.std(rmse_cv_test)
    rmse_cv_train_mean, rmse_cv_train_sd = np.mean(rmse_cv_train), np.std(rmse_cv_train)

    r2_cv_test_mean, r2_cv_test_sd = np.mean(r2_cv_test), np.std(r2_cv_test)
    r2_cv_train_mean, r2_cv_train_sd = np.mean(r2_cv_train), np.std(r2_cv_train)

    # Test scores
    rmse_test = np.sqrt(metrics.mean_squared_error(y_test, y_pred_test))
    r2_test = metrics.r2_score(y_test, y_pred_test)

    # Train scores
    rmse_train = np.sqrt(metrics.mean_squared_error(y_train, y_pred_train))
    r2_train = metrics.r2_score(y_train, y_pred_train)

    scores = pd.DataFrame(
        [[rmse_cv_test_mean, rmse_cv_test_sd, rmse_cv_train_mean, rmse_cv_train_sd,
          r2_cv_test_mean, rmse_cv_test_sd, r2_cv_train_mean, r2_cv_train_sd,
          rmse_test, r2_test, rmse_train, r2_train]],
        columns=('rmse_cv_test_mean', 'rmse_cv_test_sd', 'rmse_cv_train_mean', 'rmse_cv_train_sd',
                 'r2_cv_test_mean', 'rmse_cv_test_sd', 'r2_cv_train_mean', 'r2_cv_train_sd',
                 'rmse_test', 'r2_test', 'rmse_train', 'r2_train'))

    return scores

Let's select the input features of our analysis: rois or vbm. This can be achieved using the ROIsFeatureExtractor or VBMFeatureExtractor extractors.

In [ ]:
class ROIsFeatureExtractor(BaseEstimator, TransformerMixin):
    """Select only the 284 ROIs features:"""
    def fit(self, X, y):
        return self

    def transform(self, X):
        return X[:, :284]

    
class VBMFeatureExtractor(BaseEstimator, TransformerMixin):
    """Select only the 284 ROIs features:"""
    def fit(self, X, y):
        return self

    def transform(self, X):
        return X[:, 284:]


fe = ROIsFeatureExtractor()
print(fe.transform(X_train).shape)

fe = VBMFeatureExtractor()
print(fe.transform(X_train).shape)

Let's design a simple MLP age predictor. The framework is evaluated with a cross-validation approach. The metrics used are the root-mean-square error (RMSE). In this example the predictor is trained on the roi features.

In [ ]:
class MLP(nn.Module):
    """ Define a simple one hidden layer MLP.
    """
    def __init__(self, in_features):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
          nn.Linear(in_features, 120),
          nn.ReLU(),
          nn.Linear(120, 84),
          nn.ReLU(),
          nn.Linear(84, 1))

    def forward(self, x):
        return self.layers(x)


class Dataset(torch.utils.data.Dataset):
    """ A torch dataset for regression.
    """
    def __init__(self, X, y=None):
        """ Init class.

        Parameters
        ----------
        X: array-like (n_samples, n_features)
            training data.
        y: array-like (n_samples, ), default None
            target values.
        """
        self.X = torch.from_numpy(X)
        if y is not None:
            self.y = torch.from_numpy(y)
        else:
            self.y = None

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        if self.y is not None:
            return self.X[i], self.y[i]
        else:
            return self.X[i]


class RegressionModel(object):
    """ Base class for Regression models.
    """
    def __init__(self, model, batch_size=10, n_epochs=10, print_freq=40):
        """ Init class.

        Parameters
        ----------
        model: nn.Module
            the input model.
        batch_size:int, default 10
            the mini_batch size.
        n_epochs: int, default 5
            the number of epochs.
        print_freq: int, default 100
            the print frequency.
        """
        self.model = model
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.print_freq = print_freq

    def fit(self, X, y):
        """ Fit model.

        Parameters
        ----------
        X: array-like (n_samples, n_features)
            training data.
        y: array-like (n_samples, )
            target values.
        fold: int
            the fold index.
        """
        self.model.train()
        self.reset_weights()
        print("-- training model...")
        dataset = Dataset(X, y)
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=1, multiprocessing_context=get_context("loky"))
        loss_function = nn.L1Loss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        start_time = time.time()
        current_loss = 0.
        for epoch in range(self.n_epochs):
            for step, data in enumerate(loader, start=epoch * len(loader)):
                inputs, targets = data
                inputs, targets = inputs.float(), targets.float()
                targets = targets.reshape((targets.shape[0], 1))
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = loss_function(outputs, targets)
                loss.backward()
                optimizer.step()
                current_loss += loss.item()
                if step % self.print_freq == 0:
                    stats = dict(epoch=epoch, step=step,
                                 lr=optimizer.param_groups[0]["lr"],
                                 loss=loss.item(),
                                 time=int(time.time() - start_time))
                    print(json.dumps(stats))
        current_loss /= (len(loader) * self.n_epochs)

    def predict(self, X):
        """ Predict using the input model.

        Parameters
        ----------
        X: array-like (n_samples, n_features)
            samples.

        Returns
        -------
        C: array (n_samples, )
            returns predicted values.
        """
        self.model.eval()
        dataset = Dataset(X)
        testloader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=False, num_workers=1,
            multiprocessing_context=get_context("loky"))
        with torch.no_grad():
            C = []
            for i, inputs in enumerate(testloader):
                inputs = inputs.float() 
                C.append(self.model(inputs))
            C = torch.cat(C, dim=0)
        return C.numpy().squeeze()

    def reset_weights(self):
        """ Reset all the weights of the model.
        """
        def weight_reset(m):
            if hasattr(m, "reset_parameters"):
                m.reset_parameters()
        self.model.apply(weight_reset)


cv = problem.get_cv(X_train, y_train)
mlp = MLP(284)
estimator = make_pipeline(ROIsFeatureExtractor(), StandardScaler(), RegressionModel(mlp))

cv_results = cross_validate(
    estimator, X_train, y_train, scoring=['neg_root_mean_squared_error', 'r2'],
    cv=cv, verbose=1, return_train_score=True, n_jobs=5)

# Refit on all train
estimator.fit(X_train, y_train)

# Apply on test
y_pred_train = estimator.predict(X_train)
y_pred_test = estimator.predict(X_test)

print("Important scores are rmse_cv_test_mean and rmse_test:")
scores = cv_train_test_scores(
    rmse_cv_test=-cv_results['test_neg_root_mean_squared_error'],
    rmse_cv_train=-cv_results['train_neg_root_mean_squared_error'],
    r2_cv_test=cv_results['test_r2'],
    r2_cv_train=cv_results['train_r2'],
    y_train=y_train, y_pred_train=y_pred_train, y_test=y_test, y_pred_test=y_pred_test).T.round(3)
print(scores)