Notes

 

Highlights

  • The objective of this paper is to develop a self-supervised learning paradigm which can learn cross-modal domain knowledge (vision and language) from medical data.
  • The self-supervised strategy is based on the reconstruction of missing pixels (vision) and text labels (language) from randomly masked images and texts.
  • The evaluation is based on a medical vision-and-language benchmark which includes three tasks.

 

Motivations

  • Medical vision-and-language pre-training (Med-VLP) aims to learn generic representations from large-scale medical image-text data

  • This representation can be transferred to various tasks relevant for medical vision-and-language analysis, such as visual question answering, image-text classification, image-text retrieval (the corresponding definitions are given below)

 

Method

Architecture

 

Key aspects

  • Use of transformers to encode image and language features

  • Use of transformers to perform multi-modal fusion

  • Use of a transformer to decode the image and a simple MLP to decode the text

  • pre-training is performed using medical image-text pairs

  • Masks random patches of the input image and random tokens of the input text and reconstructs the missing pixels and tokens

    this makes pre-training a self-supervised process

  • Uses different masking rates for input images and text due to the different information densities of vision and language

 

Formalism

Loss function
θ,θ1,θ2=argmin
  • L_s are the loss functions of pretext tasks, i.e MSE between the reconstructed and original images and the negative log-likelihood for the masked tokens

  • D_{\theta_s} are the decoders with their parameters \theta_1, \theta_2

  • M_{\theta} is the backbone model with its parameters \theta.

 

Vision encoder
X^{\nu} \in \mathbb{R}^{(N+1) \times D} \,=\, \left[ p_I; p_1 E^{\nu}; \cdots; p_N E^{\nu} \right]\,+\,E^{\nu}_{pos}
  • Each image I \in \mathbb{R}^{H \times W \times C} is divided into N patches \{ p_1,\cdots,p_N \}

  • E^{\nu} \in \mathbb{R}^{P^2 \times D} is the projection matrix into the patch embeddings

  • p_I \in \mathbb{R}^{D} is used for the aggregation of visual information

  • X^{\nu} is fed into a transformer model with N_{\nu} transformer blocks to obtain the contextualized image representation H^{\nu} \in \mathbb{R}^{(N+1) \times D} \,=\, \left[ h^{\nu}_I; h^{\nu}_1; \cdots; h^{\nu}_N \right]

 

Language encoder
X^{l} \in \mathbb{R}^{(M+2) \times D} \,=\, \left[ w_T; w_1 E^{l}; \cdots; w_M E^{l}; w_{SEP} \right]\,+\,E^{l}_{pos}
  • Each input text is tokenized to subword tokens {w_1,\cdots;w_M} by WordPiece, where tokens w_m \in \mathbb{R}^{V} are represented in one-hot form and V is the vocabulary size

  • E^{l} \in \mathbb{R}^{V \times D} is the projection matrix into the text embeddings

  • w_T \in \mathbb{R}^{D} and w_{SEP} \in \mathbb{R}^{D} correspond to a start-of-sequence token embedding and a special boundary token embedding, respectively

  • X^{l} is fed into a transformer model with N_{l} transformer blocks to obtain the contextualized text representation H^{l} \in \mathbb{R}^{(M+2) \times D} \,=\, \left[ h^{l}_T; h^{l}_1; \cdots; h^{l}_M; h^{l}_{SEP} \right]

 

Masking scheme
  • the authors used random sampling with a much greater masking ratio for images (i.e. 75\%) than for texts (i.e. 15\%). This is justified by the fact that images are redundant while languages are information-dense

 

Representation selection for reconstruction
  • Images and texts are abstracted at different levels, with pixels having a lower semantic level than text tokens.

  • The outputs from the k-th transformer block (Z^{\nu k}) are used to compute the reconstruction loss (red part in the figure of the architecture)

  • The final output Z^{l} is used for the prediction of text tokens since predicting missing words requires richer semantic information

 

Decoder designs
  • A transformer model is used to perform the reconstruction task from Z^{\nu k}

  • A simple MLP is used to retrieve the missing text tokens

 

Results

ROCO dataset - repo

  • 81,000 medical images with their captions and the corresponding UMLS Semantic Types useful for classification purposes

    UMLS (Unified Medical Language System): provides a standardized way of categorizing biomedical concepts based on their semantic characteristics.

  • Contains several medical imaging modalities with the corresponding text automatically extracted from PubMed Central Open Access FTP mirror

  • There are 16 times more radiological images than the others modalities

  • Randomly split the dataset into 80/10/10.

MedICaT dataset - repo

  • 217,000 medical images from with their captions and inline textual references for 74% of figures

  • Contains several medical imaging modalities with the corresponding text automatically extracted from PubMed Central Open Access FTP mirror

  • Randomly sample 1,000 images for validation, 1,000 images for testing, and the remaining images for training

 

Implementation details

  • Vision encoder: CLIP-ViT-B

  • Language encoder: RoBERTa-base

  • N_m=6 transformer blocks for the multi-modal fusion module with a number of heads of 12 per block

  • AdamW optimizer during pre-training for 100,000 iterations

  • Center-crop to resize each image into the size of 288x288

 

Downstream tasks

  • Medical Visual Question Answering (Med-VQA) - Answering natural language questions about medical images. VQA-RAD, SLAKE and VQA-2019 dataset were used for evaluation

  • Medical Image-Text Classification - Produce the label given an image-text pair. The MELINDA dataset was used for evaluation

  • Medical Image-Caption Retrieval - Two subtasks: image-to-text (I2T) retrieval requires retrieving the most relevant texts from a large pool of texts given an image and vice versa for text-to-image (T2I). The ROCO dataset was used for evaluation

  • Accuracy is used as metric for the Med-VQA and medical Image-Text classification tasks

  • Recall@K (K=1, 5, 10) is used for the Medical Image-Caption Retrieval task

Unfortunately, nothing is said concerning the fine-tuning of the pretrained methods for the different downstream tasks :(

 

Results for Med-VQA task

 

Results for Medical Image-Text Classification

 

Results for Medical Image-Caption Retrieval

(ZS) means zero-shot and (FT) means fine-tuning

 

Ablation study

(MIM) stands for Masked Image Modeling and (MLM) stands for Masked Language Modeling

 

Qualitative results

 

Conclusions

  • Image/text coupling for medical data analysis looks like a promising way forward

  • Pre-training in a self-supervised way using a masking strategy appears to be relevant

  • Exploiting embeddings at different levels of abstraction for images and text would seem to be a good approach