Application of Vision Transformer for Medical Image Classification

2024-01-10

The focus of this article is the Vision Transformer (ViT) and its practical applications in real-life problems. It discusses the medical image classification task that I have already solved using Convolutional Neural Networks (CNN) and introduces a solution based on ViT. The Transformer architecture has become the de facto standard for natural language processing tasks. So, what is the Vision Transformer (ViT)? The ViT architecture is based on representing images as a set of patches. Image patches are non-overlapping image blocks of size 16x16 pixels. For example, in an image with a resolution of 224x224, there are (224/16) * (224/16) = 14 * 14 = 196 patches. Image patches are treated as tokens (words) in natural language processing (NLP) applications. ViT represents each patch as a flattened linear projection of its pixels and operates on patch embedding vectors of length 768 (16x16x3 = 768). The following image shows the complete architecture of ViT: The main parts of the Transformer are patch + position embedding preparation, encoder, and pooling (multi-head pooling heads). 1. Patch + position embedding is formed from the input image pixels as a matrix of size 196 x 768 (a vector of 768 values for each patch position, 196 patches for a 224 x 224 image). At position zero, a randomly initialized vector with 768 values is added, so the patch + position embedding is a matrix of size 197 x 768. 2. The encoder consists of a series of multi-head attention modules, followed by normalization layers and multi-layer pooling modules. The Transformer encoder is the main part of ViT, which trains the similarity between patches based on their class membership. It includes a series of linear, normalization, and activation layers. The embedding matrix of size 197 x 768 is transformed to express the interaction between patches and represent their class values. The zero position row of this matrix is a class token (a vector of 768 values), which is used as input to the pooling block. 3. The pooling block ultimately transforms the class token (a vector of 768 values) into an output vector embedding that contains the interested class. Linear and activation layers are also used in this module. Hugging Face's ViT: Understanding Implementation in Practice Let's use the following code block to examine the basic ViT model from Hugging Face: Installation: !pip install torchvision !pip install torchinfo !pip install -q git+https://github.com/huggingface/transformers.git Imports: from PIL import Image from torchinfo import summary import torch Google Drive Mounting (for Google Colab): from google.colab import drive drive.mount('/content/gdrive') Cuda Device Setup: device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') In the code below, I examine the basic ViT model: from transformers import ViTConfig, ViTModel configuration = ViTConfig() print(configuration) The default configuration of the base model is as follows: ViTConfig { "attention_probs_dropout_prob": 0.0, "encoder_stride": 16, "hidden_act": "gelu", "hidden_dropout_prob": 0.0, "hidden_size": 768, "image_size": 224, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "model_type": "vit", "num_attention_heads": 12, "num_channels": 3, "num_hidden_layers": 12, "patch_size": 16, "qkv_bias": true, "transformers_version": "4.37.0.dev0" } By changing the fields of the configuration, we can create a custom ViT model. Let's try the default base ViT model: model = ViTModel(configuration).to(device) model.eval() In the output, we see all the layers of the base ViT model: Model Overview: summary(model=model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'num_params', 'trainable']) The base ViT model has a large number of parameters - over 86 million. Let's see the structure of the model's output. I sent a randomly generated fake image to the model: x = torch.randn((3, 224, 224)) x = torch.unsqueeze(x, 0) y = model(x.to(device)) print(y.pooler_output.shape) print(y.last_hidden_state.shape) We can see in the output: torch.Size([1, 768]) torch.Size([1, 197, 768]) The final output of the base ViT model consists of two parts: the shape of last_hidden_state is (batch_size, 197, 768), which is the output before the pooling block of the model.embeddings + model.encoder + model.layernorm, and the shape of pooler_output is (batch_size, 768), which is the output of the model.pooler. In the input to the model.pooler module, there is a zero position row of the normalized last_hidden_state matrix, which was obtained in the previous step. The following image illustrates the equivalence of step-by-step calling blocks (as described above) and calling the entire model at once to obtain the model output: If we run the code on the left and right sides of the image with the same input tensor x, we will see the same output tensor when printing. Understanding the ViT blocks and their output structure is important for developing solutions using transfer learning based on ViT. The Model.pooler block is changed to a custom block that uses the ViT model's inference on the previous blocks as input for training. Hugging Face provides two pre-trained ViT image classification models: 1. Pre-trained on ImageNet-21k (a collection of 14 million images and 21k categories); 2. Fine-tuned on ImageNet (also known as ILSVRC 2012, a collection of 1.3 million images and 1000 categories). Fine-tuning is used on ImageNet for the architecture used for 1000 category classification (ViTForImageClassification), which includes the model.classifier block instead of the model.pooler block, consisting of only the following linear layer: (classifier): Linear(in_features=768, out_features=1000, bias=True) The input to this layer is the zero position row of the normalized last_hidden_state matrix. Comparison of ViT and CNN: - CNN models extract all local features from an image and consider the collection of features in the overall image as the basis for classifying the input image. It is trained to calculate the class label of the image based on all the features. ViT treats the image as a set of image blocks and considers the position of these blocks. It is trained to calculate the class label for "similar" blocks, i.e., the ViT architecture incorporates the concept of segmentation. - ViT models have a large number of parameters (as summarized above -> 86 million) and require large datasets to perform well. CNN models can adapt to datasets of different sizes and may require relatively fewer parameters to achieve good performance. If trained from scratch, ViT does not perform well on small custom datasets. One use case for small custom datasets is using ViT to infer pre-trained models on large datasets for transfer learning. ViT for X-ray Chest Image Classification: I am using the same dataset of X-ray chest images. This dataset consists of three classes of images: I am using uniformly cropped images that include the chest region. Examples of cropped images (from left to right: "Normal (No Pneumonia)", "Bacterial Pneumonia", "Viral Pneumonia"): The dataset is divided into a training set and a test set. The training set contains 3000 images - 1000 "Normal (No Pneumonia)" images, 1000 "Pneumonia-Bacterial" images, and 1000 "Pneumonia-Viral" images randomly selected from their respective categories. The remaining images make up the test set, so the test set contains 2908 images - 576 "Normal (No Pneumonia)" images, 1777 "Pneumonia-Bacterial" images, and 555 "Pneumonia-Viral" images. Comparison of CNN and ViT for a 2-class classifier of "Normal (No Pneumonia)" / "Pneumonia (Bacterial or Viral)": I am solving the task of creating a system that can determine whether the input X-ray chest image belongs to the category "Normal (No Pneumonia)" or "Pneumonia (Bacterial or Viral)", i.e., a 2-class classifier, using ViT. This model demonstrates the best results among all CNN models on this dataset. Here is a summary of three CNN models: The model contains 348,050 parameters, much fewer than the ViT model. Note that for the CNN models, I used images with a resolution of 256x256. Here, I am trying to use the base ViT model pre-trained on the ImageNet-21k dataset and fine-tuned for X-ray images. Model 1: ViT with a "Small" Linear Classifier after processing the input image: First, I try the simplest solution, a linear layer whose input is the zero position row of the last_hidden_state matrix, a vector of 768 values. This final fit is suitable for image classification of 1000 categories on the ImageNet dataset. Load the pre-trained ViT model + image processor: from transformers import ViTConfig, ViTModel from transformers import AutoImageProcessor image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") The following code shows the process of transforming a single input PIL image, initially in the form of a PIL Image, into a vector of 768 values representing the class label, which is the input to my linear classifier: img = inputs = image_processor(img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) img = outputs.last_hidden_state img = img[:, 0, :] A batch of processed input images with a shape of (batch_size, 1, 768) is sent to the following model: class ChestClassifier(nn.Module): def __init__(self, num_classes): super(ChestClassifier, self).__init__() self.num_classes = num_classes self.ln = nn.Linear(768, self.num_classes) def forward(self, x): x = nn.Flatten()(x) x = self.ln(x) return x model1 = ChestClassifier(2).to(device) The summary of this model is: summary(model=model1, input_size=(1, 1, 768), col_names=['input_size', 'output_size', 'num_params', 'trainable']) The small classifier model contains only 1,538 parameters. I use the Adam optimizer with a learning rate of 0.001. There are 3000 images for training and 2908 images for testing. I keep the training batches balanced (approximately 50% of images per category). The image below shows the comparison of the CNN architecture results (which have been obtained and presented here) and model1 mentioned above. For both models, I selected the best checkpoint: Comparison of ViT fine-tuned with a "Small" Linear Classifier shows inferior performance compared to the CNN architecture results. I believe the reasons for these results are as follows: medical images are fundamentally different from the ImageNet data used to train the ViT model, and the number of trainable parameters in my "Small" Linear Classifier is not sufficient to achieve better results than the CNN models. How can we improve the model? First, there is nothing stopping me from using the full pre-trained patch-positional state when fine-tuning the classifier - i.e., the complete last_hidden_state output by ViT. Secondly, I can try a more complex classifier model with more trainable parameters. Model 2: ViT with a "Large" Linear Classifier after processing the input image: Compared to model 1, I changed the input preprocessing of the PIL image to obtain the complete transposed last_hidden_state matrix of the pre-trained ViT model. This matrix forms the input to my classifier model: img = inputs = image_processor(img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) img = outputs.last_hidden_state.permute(0, 2, 1) img = img.squeeze() A batch of processed input images with a shape of (batch_size, 768, 197) is sent to the following model: class ChestClassifierL(nn.Module): def __init__(self, num_classes): super(ChestClassifierL, self).__init__() self.num_classes = num_classes self.ln1 = nn.Linear(197, 256) self.relu = nn.ReLU(inplace=True) self.ln2 = nn.Linear(768*256, self.num_classes) def forward(self, x): x = self.ln1(x) x = self.relu(x) x = nn.Flatten()(x) x = self.ln2(x) return x model2 = ChestClassifierL(2).to(device) The summary of this model is: summary(model=model2, input_size=(1, 768, 197), col_names=['input_size', 'output_size', 'num_params', 'trainable']) The "Large" classifier model contains 443,906 parameters. I use the Adam optimizer with a learning rate of 0.001. The image below shows the comparison of the CNN architecture results (which have been obtained and presented here) and model2 mentioned above. For both models, I selected the best checkpoint: Using the "Large" classifier for ViT fine-tuning shows better performance than the CNN architecture! The reason for this result is not only the increase in the number of trainable parameters but also the consideration of the entire patch position information. The concept of segmentation is important for medical images as they may contain regions that are particularly abnormal for specific problems. Below, I demonstrate the positive trend of using ViT on another classifier - a classifier for different types of pneumonia: "Pneumonia-Bacterial" and "Pneumonia-Viral". I have 1000 "Pneumonia-Bacterial" images and 1000 "Pneumonia-Viral" images in the training set, and 1777 "Pneumonia-Bacterial" images and 555 "Pneumonia-Viral" images in the test set. Therefore, the training set contains 2000 images, and the test set contains 2332 images. I compare the same CNN architecture, which has 3 convolutional blocks, and the combination of ViT with model2 mentioned above. In the results below, "Class 0" represents "Pneumonia-Bacterial" and "Class 1" represents "Pneumonia-Viral". For both models, I selected the best checkpoint: Model 3: Fine-tuning ViT for Custom Input Resolution: In all the examples discussed above, I compared the results of CNN models trained on 256x256 resolution input images with the results of fine-tuned ViT models, which require input images with a resolution of 224x224. In this article, I found a solution for transfer learning on higher resolutions: the output size of the pre-trained model should be changed according to the higher resolution of the embedding positions and then continued to be trained with this matrix. For fine-tuning ViT on input resolution 256x256, I need to adjust the resolution of the last_hidden_state matrix to 257x768 and continue training with this matrix. So, let's try it in practice. The input preprocessing of the PIL image will be the following steps: img = inputs = image_processor(img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) img = outputs.last_hidden_state.permute(0, 2, 1) # new patch-position embeddings resolution img = transforms.Resize((768, 257))(img) img = img.squeeze() The processed input image batch with a shape of (batch_size, 768, 257) is sent to the following model: class ChestClassifierL256(nn.Module): def __init__(self, num_classes): super(ChestClassifierL256, self).__init__() self.num_classes = num_classes self.ln1 = nn.Linear(257, 256) self.relu = nn.ReLU(inplace=True) self.ln2 = nn.Linear(768*256, self.num_classes) def forward(self, x): x = self.ln1(x) x = self.relu(x) x = nn.Flatten()(x) x = self.ln2(x) return x model3 = ChestClassifierL256(2).to(device) The summary of this model is: summary(model=model3, input_size=(1, 768, 257), col_names=['input_size', 'output_size', 'num_params', 'trainable']) I tried the model3 for a binary classification of "Normal (No Pneumonia)" / "Pneumonia (Bacterial or Viral)". The image below shows the results of model2 with an input resolution of 224x224 compared to the results of model3 with an input resolution of 256x256. In the results, "Class 0" represents "Normal (No Pneumonia)" and "Class 1" represents "Pneumonia (Bacterial or Viral)". For both models, I selected the best checkpoint: The results of resolution change are more significant when the resolution is far from 224x224. Conclusion: The proper combination of ViT (Vision Transformer) inference and fine-tuned models can improve the performance of classifiers, even on specific datasets like medical images.