Explanation on Paper "Processing Megapixel Images with Deep Attention-Sampling Models"

Welcome! Recently this paper, Processing Megapixel Images with Deep Attention-Sampling Models by Angelos Katharopoulos and Francois Fleuret, which was published on ICML (International Conference on Machine Learning) in 2019 has caught my eyes. I find it very interesting and can provide great help for deep learning tasks on large datasets with high resolution images, especially for small groups or individual researchers. It is also well written and thoroughly proved from mathematical aspect. I will present this paper and try to make it easier and more friendly to read (the mathematical formulations and proves will not be repeated here. In the meanwhile, for those who are interested, it is recommended to read the paper since they are very well presented in the paper). Presenting is also a way of learning myself.

1. Background

The motivation of this paper is to help with computer vision tasks with megapixel images. The traditional CNN architectures usually cannot operate directly on those high resolution images because of the limitation of computational and memory resources.

1.1 Down Sample

One common approach is to down sample the original images directly. But this may lead to loss of important information. For example, in the following picture, the traffic sign becomes unrecognizable in b after down sampling the original image a, compared to the traffic sign from the original scale c.

Illustration of down sample

1.2 Deep MIL

Another approach which is more relevant to this paper is to split the original image into patches with a fixed patch size first, and process them separately. This approach is called deep multiple instance learning or Deep MIL Ilse et al., 2018.

Illustration of DeepMIL

Each patch can be seen as a part of the input image and it is in original scale so the details can be retained. Another good aspect of this approach is that it does not need patch annotations or labels which are usually not available or very expensive to get. But in this way, it may require to operate on every part which may waste computational resources on not important patches.

Thus, in this paper, the aim is to propose a model that can handle large images more efficiently with much less computational resource requirements. More specifically, for this patches approach shown here, it would be more efficient if we can know which patches are more informative, and only process the few informative patches in the original scale instead of processing them all. Intuitively, it would be helpful if there is a way to pay more attention or weight on the patches of interest to select them.

2. Aim

Propose an end-to-end trainable model able to handle multi-megapixel images using a single GPU/CPU.

3. Method - Outline

Now I will introduce the proposed method in this paper. It is called attention sampling or ATS in short. The following shows the outline of the method.

The outline of the ATS method

Given a large image, as mentioned before, it is computational efficient to down sample it first. But don’t worry, it is not going to process the image now since we don’t want to lose any details. It is instead to compute an attention distribution which can help decide which locations are more interesting. It is done by a rather light convolutional neural network. It basically will sign a weight or a probability, since it is a distribution, to every location of the image. For example, it may look like the following image. If we are to select two interesting patches out of all locations, we can simply choose the top two.

The example of the attention distribution

Then, the locations selected will be used to extract patches from the original images. We get two patches in the original scale with all the details they can have. The two patches are then fitted into another convolutional neural network which is the feature network to do the downstream task, for example to classify.

For the attention distribution, the authors performed a Monte Carlo estimate. They sampled a small number of indices from the attention distribution. The sample was done by sampling without replacement to avoid sample the same indices. They also showed that feeding the sampled set to the second feature CNN is an unbiased estimator compared to using all indices to predict. More specifically, the sampling is proved to have the minimum variance by normalizing the features in terms of L2 norm. So that the network does not change because of the sampling. So for this, during implementation, an L2 normalization layer is to be added as the final layer in the feature networks.

The two CNNs are trained together. And the model is proved to be differentiable and trainable. The detailed formulas and proves can be seen in the paper if interested.

One thing to be mentioned is that the proposed model can be easily scaled up. Because the computational and memory requirements depend only on the pre-defined patch size and the number of selected patches.

4. Evaluation

To evaluate the proposed ATS model, a few baseline methods were used to compare with. They include:

  • Traditional CNNs with different scales of down sampling
  • Deep multiple instance learning or Deep MIL which will process every patch as we introduced previously
  • Uniform sampling method or U-XX where the attention network is a uniform distribution and XX represents the number of selected patches

They were compared with the proposed attention sampling or ATS-XX model where XX is the number of sampled patches selected.

The metrics used to compare the performance include test error and computational and memory cost.

As mentioned before, the computational and memory cost for the proposed model depend only on the pre-defined patch size and the number of selected patches. While the baseline models (the traditional CNNs and the DeepMIL) scale linearly with the size of the input images. This difference can be further verified in the later results.

5. Results

The evaluation of this attention sampling method was tested on three tasks.

5.1 Megapixel MNIST

The first task is MINST digit classification. The dataset was generated automatically. Each image contains 50 patches with size equal to an MINST digit, among which, 5 digits are randomly positioned. 3 of the 5 digits belong to the same class and the other 2 to a random class. The rest are random noise. The following picture shows an example of the data samples:

Data example - Megapixel MNIST

The goal is to identify the digit with the most occurrences. In the case above, it should be 7.

The methods used here are:

  • Traditional CNN on the full size images
  • The attention sampling method (ATS-XX) with 180x180 as the down sampled size (which results in 32400 patches in total) and each patch is 50x50

The uniform sampling was omitted because it should not perform good since the large sampling space (32400). The Deep MIL was also omitted because exceeding GPU memory.

The result can be seen as following. The ATS method was tested with the 5,10,25,50,100 patches selected to run on full scale.

Result - Megapixel MNIST

It shows that the ATS has lower error and required less time for all numbers of patch settings than the CNN model. For the ATS model itself, the more patches it selected, the less error it achieved and more time it needed. Overall, it shows that attention sampling processes high resolution images both faster and more accurately than the CNN baseline.

Another important benefit is the interpretability the ATS achieved. The following shows the evolution of the attention distribution during training phase.

The evolution of the attention distribution - Megapixel MNIST

The example image contains 3 digits and 3 noise. The three digits were marked red on the first image. Yellow means higher attention. At first, the attention distribution was like uniformly distributed as every position is equally colored. Then it learned to distinguish the digits and noise from empty space. And from epoch 20, the attention finds the three digits and keep following them afterwards.

5.2 Histopathology images

The second experiment was evaluated on the colon cancer dataset to detect whether epithelial cells exist in a hematoxylin and eosin stained image. It is a binary classification problem. It will be classified as positive if at least one cell belongs to the epithelial class.

Data example - Histopathology image

Attention distribution with uniform sampling, CNN, Deep MIL, and attention sampling were tested and compared. The results can be seen as the following table shows:

Result - Histopathology image

As it shows on the table, the uniform sampling with 10 and 50 patches are clearly better than random guessing with 15.6% and 12.4% error rate on the test dataset. This because it only needs one epithelial cell to be classified as positive and it probably has much more than one. With the same number of selected patches, the proposed method ATS achieved about 35% lower test error as expected since it should only focus on informative parts. Compare to the CNN and Deep MIL models, the ATS models perform equally well but much more faster and much less memory requirement.

The following shows a visualization of learned attention distribution on a test image. Image a is the raw image, b shows the ground truth positions of the epithelial cells, c and d are the learned attention distribution of Deep MIL and ATS models. It was computed by multiplying each patch with a normalized attention value w_i.

Learned attention distribution - Histopathology image

It can be seen that both methods identify epithelial cells without having access to per-patch annotations. We can also see that the ATS method matches less well than the Deep MIL model. But since it is enough to find at least one epithelial cell to classify correctly, this may not affect the classification performance. However, ATS may be less helpful to detect regions of interest.

5.3 Speed limit sign detection

The last task is to classify whether a image contains no speed limit sign or a limit sign of 50,70 or 80.

Data example - Speed limit sign detection

Similarly, attention distribution with uniform sampling, CNN, Deep MIL, and attention sampling were tested and compared. The results can be seen as below:

Result - Speed limit sign detection

With uniform sampling, because only one or very few patches are informative, it is expected that the uniform models do not perform well. For CNN models with down sampled images, they fitted on the training dataset very well but were not able to generalize on the test dataset. It is also expected since the down sampled images are hard to see the traffic signs. For the not scaled CNN, it also failed to generalize. The paper didn’t explain why for this one. But it could be that the neural network was overfitting. For the Deep MIL and ATS models, we can see that the test results are comparable. Besides, the ATS only used 5 and 10 instead of 192 patches (which the Deep MIL used). Regarding the memory and time cost, ATS is much more efficient compared to the Deep MIL.

The following picture shows the learned attention distribution of Deep MIL and ATS on a test image.

Learned attention distribution - Speed limit sign detection

The proposed ATS method located both signs with high probability, whereas the Deep MIL model locates both but only one was selected with high probability.

6. Conclusion

After going through what the paper has explained and achieved, the main conclusion of this paper can be draw as below:

  • The proposed attention sampling method can efficiently process megapixel images in a single CPU/GPU by only processing a few parts of the input images based on an attention distribution.
  • The sampling with the attention distribution was proved to be the optimal approximation with minimum variance.
  • The attention sampling model can also effectively identify the informative regions without patch-level annotation.

7. Personal Thoughts

It is a very popular topic to deal with the computational problems raising in the deep learning tasks. This paper provided an interesting and helpful view. For people who want to implement or reproduce their work, the authors also provided a git repository.

There are also something need to be considered more when solving more complicated real world problem. For example, for object detection tasks with more than one size of targets (traffic lights, pedestrian, closer cars, far away cars, etc.), it would need to put more thoughts on how to decide on the down sample scale to learn the attention distribution and the patch size if it is possible to get the results once, or if there is a way to design a separate attention network which can be more suitable for different sizes of targets. Because if scaling too much, the smaller targets may disappear from the down sampled image, and if scale too little, the attention network may select too many locations relating to the same region for the larger targets, which will perhaps hurt the intention of reducing computational costs.