Multi-Modal Masked Autoencoders for Medical Vision-and-Language Pre-Training
Notes
- Here are some useful links: repo, submission process
Highlights
- The objective of this paper is to develop a self-supervised learning paradigm which can learn cross-modal domain knowledge (vision and language) from medical data.
- The self-supervised strategy is based on the reconstruction of missing pixels (vision) and text labels (language) from randomly masked images and texts.
- The evaluation is based on a medical vision-and-language benchmark which includes three tasks.
Motivations
-
Medical vision-and-language pre-training (Med-VLP) aims to learn generic representations from large-scale medical image-text data
-
This representation can be transferred to various tasks relevant for medical vision-and-language analysis, such as visual question answering, image-text classification, image-text retrieval (the corresponding definitions are given below)
Method
Architecture
Key aspects
-
Use of transformers to encode image and language features
-
Use of transformers to perform multi-modal fusion
-
Use of a transformer to decode the image and a simple MLP to decode the text
-
pre-training is performed using medical image-text pairs
- Masks random patches of the input image and random tokens of the input text and reconstructs the missing pixels and tokens
this makes pre-training a self-supervised process
- Uses different masking rates for input images and text due to the different information densities of vision and language
Formalism
Loss function
θ∗,θ∗1,θ∗2=argmin-
L_s are the loss functions of pretext tasks, i.e MSE between the reconstructed and original images and the negative log-likelihood for the masked tokens
-
D_{\theta_s} are the decoders with their parameters \theta_1, \theta_2
-
M_{\theta} is the backbone model with its parameters \theta.
Vision encoder
X^{\nu} \in \mathbb{R}^{(N+1) \times D} \,=\, \left[ p_I; p_1 E^{\nu}; \cdots; p_N E^{\nu} \right]\,+\,E^{\nu}_{pos}-
Each image I \in \mathbb{R}^{H \times W \times C} is divided into N patches \{ p_1,\cdots,p_N \}
-
E^{\nu} \in \mathbb{R}^{P^2 \times D} is the projection matrix into the patch embeddings
-
p_I \in \mathbb{R}^{D} is used for the aggregation of visual information
-
X^{\nu} is fed into a transformer model with N_{\nu} transformer blocks to obtain the contextualized image representation H^{\nu} \in \mathbb{R}^{(N+1) \times D} \,=\, \left[ h^{\nu}_I; h^{\nu}_1; \cdots; h^{\nu}_N \right]
Language encoder
X^{l} \in \mathbb{R}^{(M+2) \times D} \,=\, \left[ w_T; w_1 E^{l}; \cdots; w_M E^{l}; w_{SEP} \right]\,+\,E^{l}_{pos}-
Each input text is tokenized to subword tokens {w_1,\cdots;w_M} by WordPiece, where tokens w_m \in \mathbb{R}^{V} are represented in one-hot form and V is the vocabulary size
-
E^{l} \in \mathbb{R}^{V \times D} is the projection matrix into the text embeddings
-
w_T \in \mathbb{R}^{D} and w_{SEP} \in \mathbb{R}^{D} correspond to a start-of-sequence token embedding and a special boundary token embedding, respectively
-
X^{l} is fed into a transformer model with N_{l} transformer blocks to obtain the contextualized text representation H^{l} \in \mathbb{R}^{(M+2) \times D} \,=\, \left[ h^{l}_T; h^{l}_1; \cdots; h^{l}_M; h^{l}_{SEP} \right]
Masking scheme
- the authors used random sampling with a much greater masking ratio for images (i.e. 75\%) than for texts (i.e. 15\%). This is justified by the fact that images are redundant while languages are information-dense
Representation selection for reconstruction
-
Images and texts are abstracted at different levels, with pixels having a lower semantic level than text tokens.
-
The outputs from the k-th transformer block (Z^{\nu k}) are used to compute the reconstruction loss (red part in the figure of the architecture)
-
The final output Z^{l} is used for the prediction of text tokens since predicting missing words requires richer semantic information
Decoder designs
-
A transformer model is used to perform the reconstruction task from Z^{\nu k}
-
A simple MLP is used to retrieve the missing text tokens
Results
ROCO dataset - repo
- 81,000 medical images with their captions and the corresponding UMLS Semantic Types useful for classification purposes
UMLS (Unified Medical Language System): provides a standardized way of categorizing biomedical concepts based on their semantic characteristics.
-
Contains several medical imaging modalities with the corresponding text automatically extracted from PubMed Central Open Access FTP mirror
-
There are 16 times more radiological images than the others modalities
- Randomly split the dataset into 80/10/10.
MedICaT dataset - repo
-
217,000 medical images from with their captions and inline textual references for 74% of figures
-
Contains several medical imaging modalities with the corresponding text automatically extracted from PubMed Central Open Access FTP mirror
-
Randomly sample 1,000 images for validation, 1,000 images for testing, and the remaining images for training
Implementation details
-
Vision encoder: CLIP-ViT-B
-
Language encoder: RoBERTa-base
-
N_m=6 transformer blocks for the multi-modal fusion module with a number of heads of 12 per block
-
AdamW optimizer during pre-training for 100,000 iterations
-
Center-crop to resize each image into the size of 288x288
Downstream tasks
-
Medical Visual Question Answering (Med-VQA) - Answering natural language questions about medical images. VQA-RAD, SLAKE and VQA-2019 dataset were used for evaluation
-
Medical Image-Text Classification - Produce the label given an image-text pair. The MELINDA dataset was used for evaluation
-
Medical Image-Caption Retrieval - Two subtasks: image-to-text (I2T) retrieval requires retrieving the most relevant texts from a large pool of texts given an image and vice versa for text-to-image (T2I). The ROCO dataset was used for evaluation
-
Accuracy is used as metric for the Med-VQA and medical Image-Text classification tasks
-
Recall@K (K=1, 5, 10) is used for the Medical Image-Caption Retrieval task
Unfortunately, nothing is said concerning the fine-tuning of the pretrained methods for the different downstream tasks :(
Results for Med-VQA task
Results for Medical Image-Text Classification
Results for Medical Image-Caption Retrieval
(ZS) means zero-shot and (FT) means fine-tuning
Ablation study
(MIM) stands for Masked Image Modeling and (MLM) stands for Masked Language Modeling
Qualitative results
Conclusions
-
Image/text coupling for medical data analysis looks like a promising way forward
-
Pre-training in a self-supervised way using a masking strategy appears to be relevant
-
Exploiting embeddings at different levels of abstraction for images and text would seem to be a good approach