CNNs with Multi-Level Attention for Domain Generalization

In the past decade, deep convolutional neural networks have achieved significant success in image classification and ranking and have therefore found numerous applications in multimedia content retrieval. Still, these models suffer from performance degradation when neural networks are tested on out-of-distribution scenarios or on data originating from previously unseen data Domains. In the present work, we focus on this problem of Domain Generalization and propose an alternative neural network architecture for robust, out-of-distribution image classification. We attempt to produce a model that focuses on the causal features of the depicted class for robust image classification in the Domain Generalization setting. To achieve this, we propose attending to multiple-levels of information throughout a Convolutional Neural Network and leveraging the most important attributes of an image by employing trainable attention mechanisms. To validate our method, we evaluate our model on four widely accepted Domain Generalization benchmarks, on which our model is able to surpass previously reported baselines in three out of four datasets and achieve the second best score in the fourth one.


INTRODUCTION
One of the most fundamental prerequisites for training robust and generalizable machine learning (ML) models, is the ability to learn representations which adequately encapsulate the underlying generating processes of a data distribution [3,30,31].One way of approaching the above problem, is to guide a model to learn disentangled representations from the training data and uncover the Permission to make digital or hard copies of all or part of this work for personal or classroom use is granted without fee provided that copies are not made or distributed for profit or commercial advantage and that copies bear this notice and the full citation on the first page.Copyrights for components of this work owned by others than ACM must be honored.Abstracting with credit is permitted.To copy otherwise, or republish, to post on servers or to redistribute to lists, requires prior specific permission and/or a fee.Request permissions from permissions@acm.org.ICMR '23, June 12-15, 2023, Thessaloniki, Greece © 2023 Association for Computing Machinery.ACM ISBN 978-1-4503-XXXX-X/18/06. . .$15.00 https://doi.org/XXXXXXX.XXXXXXX ones which remain invariant [1] under distribution shift.For example, a photograph of a dog shares similar traits with an image of a cartoon dog, or even a sketch of a dog.A generalizable model should be able to recognize the same class despite it being found in separate Domains.ML models are trained under the assumption that the training and test data distributions are independent and identically distributed.In practice however, Deep Learning (DL) models are expected to mitigate, or not be affected by, the distribution shift between their training data and data they have not been presented with before.This is often not the case, as DL models often learn representations which entangle class-discriminative attributes with correlated, though irrelevant, features of images.They therefore fail to produce informative features and to generalize to unseen data domains [28].To this end, Domain Generalization (DG) [39,49] methods aim at developing robust models which can generalize to unseen test data domains.Such methods attempt to address this problem by leveraging multiple domains in the training set, simulating biases found in real-world settings, synthesizing samples through augmentation and learning invariant representations through self-supervision (see Section 2).
In 2017, Transformer networks [36] were proposed as a model for Natural Language Processing (NLP).Transformers introduced a self-attention mechanism for providing additional contextual information to each word in an embedded sentence.Since their outstanding success in several NLP tasks, Transformers and self-attention mechanisms have slowly but steadily gained ground in the Computer Vision community [13], achieving significant advances in the field [17,42,44].In this work, we argue that by attending to features extracted from multiple layers of a convolutional neural network via multi-head self-attention mechanisms, a model can be trained to learn representations which reflect class-specific, domaininvariant attributes of an image.As a result, the trained model will be less affected by out-of-distribution data samples as it will base its predictions on the causal characteristics of the depicted class.Our contributions can be summarized in the following points: • We introduce a novel neural network architecture that utilizes self-attention [36] to attend to representations extracted throughout a Convolutional Neural Network (CNN), for robust out-of-distribution image classification in the DG setting • We evaluate our proposed method on the widely adopted DG benchmarks of [22], VLCS [34], Terra Incognita [2] and Office-Home [37] and demonstrate its effectiveness • We provide qualitative visual results of our model's inference process and its ability to focus on the invariant and causal features of a class via saliency maps In the next section we briefly present the most important contributions in DG, along with relative previous work in visual attention, from which we drew inspiration for our proposed algorithm.

RELATED WORK 2.1 Domain Generalization
There have been numerous efforts to address challenges related to domain shift in the past ( [7], [40], [41]), however DG methods are different in that the model does not have any samples from the target domain(s) during training.
DG problems can be broadly categorized into two main settings, namely multi-source and single-source DG [49].In multi-source DG, all algorithms assume that their training data originate from  (where  > 1) distinct but known data domains.These algorithms take advantage of domain labels in order to discover invariant representations among the separate marginal distributions.Most previously proposed methods fall under this category.The authors of [33] propose deep CORAL, a method which aligns the second-order statistics between source and target domains in order to minimize the domain shift among their distributions.In [26], Style-Agnostic networks, or SagNets, use an adversarial learning paradigm to disentangle the style encodings of each domain and reduce style-biased predictions.With a different approach, the authors of [50] investigate the usage of data augmentation and style-mixing techniques for producing robust models.Another popular approach in multisource DG is Meta-learning, which focuses on learning the optimal parameters for a source model from previous experiment metadata.[10,23] and Adaptive Risk Minimization (ARM) [46], all propose meta-learning algorithms for adapting to unseen domains.Finally, [9] uses episodic training in the meta-learning setting to extract invariant representations across source domains.On the other hand, single-source DG methods hold no information about the presence of separate domains in their training data, but assume that it originates from a single distribution.Therefore, all single-source DG algorithms, such as our own, operate in a domain-agnostic manner and do not take advantage of domain labels.In [5], the authors combine self-supervised learning with a jigsaw solving objective in order to reduce the model's proneness to learning semantic features.Additionally, in [47] the authors attempt to remove feature dependencies in their model via sample weighting.Finally, RSC [18] is a self-challenging training heuristic to discard representations associated with very high gradients, which forces the network to activate features correlated with the class and not the domain.

Visual Attention
Attention mechanisms have long been introduced in CV [19], inspired by the human visual system's ability to efficiently analyze complex scenes.More recently, attention mechanisms have been proposed for the interpretation of the output of Convolutional Neural Networks (CNNs), where they act as dynamic re-weighting processes which attend to the most important features of the input image.In [48], the authors propose CAM, a post-hoc model interpretation algorithm for estimating attention maps in classification CNNs.Methods incorporating attention mechanisms into CNNs for image classification have also been proposed in the past [6,29,38,45].In [20], the authors introduce an end-to-end trainable mechanism for CNNs, by computing compatibility scores between intermediate features of the network and a global feature map.In [42], the Convolutional Block Attention Module, or CBAM, leverages both spatial and channel attention modules for adaptive feature refinement.Recently, several methods have been proposed which replace CNNs with self-attention and multi-head attention mechanisms [36] applied directly on the image pixels [4,8,25], leading to transformer-based methods for CV [14].

METHODOLOGY
Information passed through popular Convolutional Neural Network architectures, such as ResNets [15], tends to get entangled with non-causal attributes of an image due to correlations in the data distribution [28].Our method is built around the hypothesis that this problem can be mitigated if we allow the network to select intermediate feature maps throughout a CNN for representation learning.We therefore extract feature maps at multiple network layers and pass them through a multi-head attention mechanism (Figure 2).In our implementation we consider self dot-product attention with 3 heads.Given an intermediate feature map M ∈ R ××ℎ× , where b is the batch size, c is the number of channels and h and w are the height and width of the feature map, we aim to attend to each of the channels.As a first step, we flatten the feature maps M into a dimension of (, , ℎ × ).We follow by linearly projecting the flattened feature maps into a (, ,   ) dimension Tensor, where   is the size of each channel's embedded feature map.Each channel can be thought of as the token in the classic Transformer architecture.Given the embedded feature maps X ∈ R ××  and trainable weight matrices W  , W  , W  ∈ R   ×  (  the inner self-attention layer dimension), we create the query, key and value vectors: Q = XW  , K = XW  , V = XW  , R ××  , which are fed to the multi-head attention block.The self-attention layer is defined as: while the multi-head attention is: where: ) After the extracted feature maps have been attended to and reweighted, we pass them through a Multi-Layer Perceptron (MLP) in order to allow our model to learn a mapping between the processed features.The MLP consists of two Linear layers, activated by the GELU function [16].Finally, the projected features are flattened, concatenated and passed through a fully connected classification layer for the final decision.Our proposed framework is visualized in Figure 1.We propose attending to each channel of the extracted feature maps.For the compatibility metric in the self-attention module, we select to use the Scaled-Dot product.

EXPERIMENTAL SETUP
In our experiments, we build our method on a vanilla ResNet-50 [15] model, pre-trained on ImageNet.For our method, we choose to extract intermediate feature maps from the 3rd, 7th and 13th bottleneck blocks of the backbone ResNet-50 model, as shown in Fig 1.We train our model with the SGD optimizer for 30 epochs and a batch size of 32 images.The learning rate is set at 0.001 and decays with a rate 0.1 at epoch 24.The proposed framework was implemented with the PyTorch library [27] and trained on a NVIDIA RTX A5000 GPU.
We evaluate our method against 8 previous state-of-the-art algorithms, which use a ResNet-50 as their base model.Specifically, the baseline models we select are: ERM [35], RSC [18], MIXUP [43], CORAL [33], MMD [24], SagNet [26], SelfReg [21] and ARM [46].The above algorithms are a mix of both multi-source and singlesource methods allowing us to demonstrate the effectiveness of our proposed method.The hyperparameters of each algorithm are set to reflect the ones in the original papers.All baselines are implemented and executed using the DomainBed [12] codebase for a fair comparison.The presented experimental results are averaged over 3 runs.

Datasets
To evaluate the robustness of our method we experiment on four well-known and publicly available DG benchmark datasets, namely PACS [22], VLCS [34], TerraIncognita [2] and Office-Home [37].Specifically: For each respective dataset we follow the standard leave-one-domainout cross-validation DG protocol, as described in [11,22].In this setting, a target domain is selected and held out from the model's training data split.The generalizability of the trained model is then measured by its accuracy on the unseen data originating from the target domain.For example, in the first experiment with the PACS dataset, the domains of Photo, Cartoon and Sketch are selected as Source domains while the Art Painting domain is held out as the Target.Therefore, the model is trained on data from the source domains and evaluated on previously unseen art images.

Results
The results of our experiments are presented in Table 1.The effectiveness of our method is demonstrated in the experimental outcome, as our model is able to surpass previously proposed stateof-the-art algorithms in the PACS, Terra Incognita and Office-Home datasets, while achieving the second best performance in VLCS.In PACS, our model surpasses the previous best model by 1.06%, while in TerraIncognita and Office-Home our implementation exceeds the baselines by 0.98% and 1.33% respectively.What's more, even though our algorithm is not able to achieve the top score in VLCS, it remains highly competitive and ranks as second best among its predecessors.
To further support our claims, we also provide visual examples of our model's inference process via saliency maps.Specifically, we select to implement the Image-Specific Class Saliency method as proposed in [32].In the above method, a visual map of the pixels contributing the most to the model's prediction is produced by computing and visualizing the gradient of the loss function with respect to the input image.As depicted in Figure 3, the darker a pixel, the more significant it is to the model.We choose to visualize 4 images of the "elephant class" from the four different domains in PACS.When compared to the baseline ERM model, our method seems to base its decisions on features of the depicted object (e.g.tusk of the elephant in the Art image) and pay less attention to irrelevant attributes, such as the noisy backgrounds (e.g.tree leaves in the Photo domain).This visual evidence proves promising towards researching alternative architectures containing both convolutional and attention layers for the DG setting.

CONCLUSIONS
In this paper, we introduced a novel approach for image classification in the Domain Generalization setting.The basic idea behind our implementation was to allow the model to select the most classdiscriminative and domain-invariant representations via multi-head self-attention mechanisms which attend to intermediate feature maps extracted from multiple layers of a convolutional neural network.The generalization ability of our model is supported by extensive experimental results on four publicly available and well-known DG benchmarks, in which our model either surpasses previously proposed algorithms or remains highly competitive.In addition, we provide visual qualitative examples of our model's inference process through saliency maps.The visual results demonstrate the fact that our model tends to disregard spurious correlations in its input images, such as background noise, and is able to base its predictions on class-specific attributes.However, our method still has room for improvement.The employment of multiple multi-head attention mechanisms and concatenation of embedded feature maps adds a significant computation and memory overhead, which is reflected by the relatively small image batch size in our experiments.For future work, we aim to further research the intersection between visual attention and fully convolutional networks in order to propose mechanisms which will be able to explicitly pay attention to the causal features of a class.

Figure 1 :
Figure 1: Visualization of our proposed framework on a ResNet-50 model.The feature maps after the 3rd, 7th and 13th bottleneck blocks are passed through multi-head attention layers with 3 heads each.Each re-weighted channel of the extracted feature maps is embedded into a vector of length 32 via an Multi-Layer Perceptron and concatenated into a larger vector, along with the output of the last convolutional layer.

Figure 2 :
Figure 2: Visualization of the Multi-Head Attention mechanism.In our implementation, intermediate feature maps are extracted from a backbone ResNet-50 model and passed through a multi-head attention layer with 3 heads (ℎ = 3).We propose attending to each channel of the extracted feature maps.For the compatibility metric in the self-attention module, we select to use the Scaled-Dot product.
PACS contains images originating from the Photo, Art Painting, Cartoon and Sketch domains.It also contains a total of 9,991 images and 7 class labels.• VLCS incorporates 10,729 real-world images from the PAS-CAL VOC, LabelMe, Caltech 101 and SUN09 datasets (or domains) and depicts 5 classes in total.• Terra Incognita contains photographs of wild animals taken by trap cameras at 4 different locations (L100, L38, L43 and L46).This dataset contains 10 classes and 24,788 images in total.• Office-Home comprises four domains of Art, Clipart, Product and Real-World images.The dataset contains 15,588 examples and 65 classes in total.

Figure 3 :
Figure 3: Saliency map visualization of the "elephant" class in 4 different domains, produced by a baseline ERM model and our method.The selected images are from the PACS dataset.By producing saliency maps, one is able to gain intuition into a model's inference process.The darker the pixel in the map, the more important it is to the model's prediction.The background noise and spurious correlations in the above images tend to contribute to the baseline model's decisions.In contrast, our framework seems to pay attention to the invariant attributes of the class.

Table 1 :
Top-1% accuracy results, averaged over 3 runs, on the PACS, VLCS, Terra Incognita (denoted as Terra) and Office-Home (denoted as Office) datasets.The top results are highlighted in bold while the second best are underlined.