Notes

Link to Tutorial VAE

No Github repository.

Figures to help understand examples of Multi-modal VAE, one with a Product of Experts loss and the other with a Mixture of Experts loss.

  • Multi-modal VAE (Product of Experts loss):

Figure adapted from A. Salvador

  • Multi-modal VAE (Mixture of Experts loss):

Figure adapted from A. Salvador

 

Highlights

In this paper, the authors introduce two contributions:

  • They present two multi-modal normative modelling frameworks (MoE-normVAE, gPoE-normVAE).
  • They use a deviation metric that is based on the latent space.

 

Introduction

  • Authors study heterogeneous brain disorders and use normative models. These models assume that disease cohorts are located at the extremes of the healthy population distribution.

  • However, it is often unclear which imaging modality will be the most sensitive in detecting deviations from the norm caused by brain disorders. Hence, they choose to develop normative models that are suitable for multiple modalities.

  • Multi-modal VAE frameworks usually learn separate encoder and decoder networks for each modality and aggregate the encoding distributions to learn a joint latent representation (cf. figure in Notes). One approach is the Product of Expert (PoE) method, which considers all experts to be equally credible and assigns a uniform contribution from each modality. Nevertheless the joint distribution can be biased due to overconfident experts.

Fig 1. (b) Example PoE and gPoE joint distributions.

  • Authors propose a generalised Product-of-Experts (gPoE) by adding a weight to each modality and each latent dimension. They also use the Mixture of Expert (MoE) model and compare it with other methods.

  • Finally, to exploit this joint latent space, they develop a deviation metric from the latent space instead of the feature space.

 

Method

Product of Experts

  • M : number of modalities
  • XX={xxm}Mm=1 : Observations

  • p(z) : prior

  • pθ(XX,zz)=p(zz)Mm=1pθm(xxm|zz) : likelihood distribution
  • θ={θ1,...,θM} : Decoder parameters
L=Eqϕ(zz|XX)[Mm=1log pθ(xxm|zz)]DKL(qϕ(zz|XX)p(zz))
  • qϕ(zz|XX)=1KMm=1qϕm(zz|xxm) : probability density function
  • ϕ={ϕ1,...,ϕM} : Encoder parameters

They assume that each encoder follows a gaussian distribution:

q(zz|xxm)=N(μμm,σσ2mII)

Therefore,

μμ=Mm=1μμm/σσ2mMm=11/σσ2m σσ2=1Mm=11/σσ2m

 

Mixture of Experts

In the case of MoE, the probability density function becomes:

qϕ(zz|XX)=1KMm=11Mqϕm(zz|xxm)

and the loss:

L=Mm=1[Eqϕ(zz|XX)[Mm=1log pθ(xxm|zz)]DKL(qϕ(zz|xxm)p(zz))]
  • Disadvantage: the model only considers each uni-modal encoding distribution independently and does not explicitly combine information from multiple modalities in the latent representations.

 

Generalised Product-of-Experts joint posterior

To overcome the problem of overconfident experts, they added a weighted term for each modality and each latent dimension on the joint posterior distribution.

qϕ(zz|XX)=1KMm=11Mqαmϕm(zz|xxm)

With: Mm=1αm=1 and 0<αm<1 (α is learned during training)

Exemple of α:

Just like the PoE approach, the parameters of the joint posterior distribution can be calculated:

μμ=Mm=1μμmααm/σσ2mMm=1ααm/σσ2m σσ2=1Mm=1ααm/σσ2m

 

Multi-modal latent deviation metric

  • Previous work used the following distance (a univariate feature space metric) to highlight subjects that are out of distribution:
Duf=dijμnorm(dnormij)σnorm(dnormij)

μnorm(dnormij) and σnorm(dnormij) represent the mean and standard deviation of the holdout healthy control cohort.

  • The authors suggest that using latent space deviation metrics would more accurately capture deviations from normative behavior across multiple modalities. They measure the Mahalanobis distance from the encoding distribution of the training cohort:
Dml=(zjμ(znorm))T Σ(znorm)1 (zjμ(znorm))

where zjq(zzj|XXj) is a sample from the joint posterior distribution for subject j, μ(znorm) and Σ(znorm) are respectively the mean and the covariance of the healthy control cohort latent position.

  • Finally, for closer comparaison with Dml, they derive it to the multivariate feature space:
Dmf=(djμ(dnorm))T Σ(dnorm)1 (djμ(dnorm))

where dj={dij,...dIj} is the reconstruction error for subject j for brain regions (i=1,...,I).

 

Assessing deviation metric performance

To evaluate the performance of their models, they use the significance ratio:

significance ratio=True positive rateFalse positive rate=Ndisease(outliers)NdiseaseNholdout(outliers)Nholdout

Ideally, we want a model which correctly identifies pathological individuals as outliers and healthy individuals as sitting within the normative distribution.

 

Architecture

  • Dataset used: UK Biobank
  • 10,276 healthy subject to train their neural networks
  • At test time:
    • 2,568 healty controls from holdout cohort
    • 122 individuals with one of several neurodegenerative disorders
  • Also tried on another dataset: Alzheimer’s Disease Neuroimaging Initiative (ADNI) with 213 subjects
  • (same image modality were extracted (T1 and DTI features) for both datasets)

 

Results

For the UK Biobank dataset:

For the ADNI dataset:

 

Conclusions

  • Their models provide a better joint representation compared to baseline methods.
  • They proposed a latent deviation metric to detect deviations in the multivariate latent space.