Spatiotemporal Modeling Encounters 3D Medical Image Analysis: Slice-Shift UNet with Multi-View Fusion

As a fundamental part of computational healthcare, Computer Tomography (CT) and Magnetic Resonance Imaging (MRI) provide volumetric data, making the development of algorithms for 3D image analysis a necessity. Despite being computationally cheap, 2D Convolutional Neural Networks can only extract spatial information. In contrast, 3D CNNs can extract three-dimensional features, but they have higher computational costs and latency, which is a limitation for clinical practice that requires fast and efficient models. Inspired by the field of video action recognition we propose a new 2D-based model dubbed Slice SHift UNet (SSH-UNet) which encodes three-dimensional features at 2D CNN's complexity. More precisely multi-view features are collaboratively learned by performing 2D convolutions along the three orthogonal planes of a volume and imposing a weights-sharing mechanism. The third dimension, which is neglected by the 2D convolution, is reincorporated by shifting a portion of the feature maps along the slices' axis. The effectiveness of our approach is validated in Multi-Modality Abdominal Multi-Organ Segmentation (AMOS) and Multi-Atlas Labeling Beyond the Cranial Vault (BTCV) datasets, showing that SSH-UNet is more efficient while on par in performance with state-of-the-art architectures.


Introduction
Identifying organs through semantic segmentation is a crucial step in several clinical workflows, including diagnosis, intervention, therapy planning, treatment delivery, and tumour growth monitoring.However, the volumetric data generated by medical acquisition systems, such as Computer Tomography (CT), Magnetic Resonance Imaging (MRI), or Ultrasound, can make the segmentation task labour-intensive and time-consuming.For instance, a single 3D CT scan can contain hundreds of 2D slices (images).Therefore, developing robust and accurate automatic segmentation tools is a fundamental necessity in medical image analysis [22,23].
Figure 1.Overview of the proposed framework.An SSH-UNet's layer is a Residual Block receiving an input tensor with dimension (3B, Cin, S, H, W ), where 3B is the concatenation on the batch B of the features from the three orthogonal planes of the CT volume (Ixy, Iyz, Ixz).Spatial features are extracted by a 2D convolution from (H, W ), and are then shifted forward and backwards along slices' axis S. The operation is performed for each tensor Ixy, Iyz, Ixz in the batch.Since we are interested in how features are mixed, we represented the shift with the three axes explicitly depicting the slices and channels dimensions while the spatial dimensions H and W are condensed on a single axis.
With the advent of deep learning, Convolutional Neural Networks (CNNs) have proved to be extremely effective at solving vision tasks due to their powerful representation learning capabilities.In particular, "U-shaped" encoder-decoder architectures have achieved state-of-the-art results in various medical semantic segmentation tasks [4,6,8].More recently, Vision Transformers (ViT) [5] have achieved comparable results to CNN-based methods, and as a result, many transformer-based models have been proposed for both 2D and 3D medical image segmentation [2,3,26].Although 3D CNNs are designed to learn three-dimensional features, they require higher computation costs, resulting in higher inference latency compared to 2D CNNs.Besides, the large number of parameters may result in a higher risk of overfitting, especially when encountering small datasets [28].This is very common in the medical field as it is challenging to collect 3D medical datasets due to accessibility issues for ethical reasons, and limited time and budget for annotations.To process volumetric data more efficiently, two main strategies can be used.The first one is cutting the volume into slices and training 2D CNNs to segment each slice separately [2,3].Despite the computational efficiency, as the information between adjacent slices is neglected, it leads to segmentation results that are prone to discontinuity in 3D space [28].The second is using 2.5D segmentation methods (or pseudo-3D methods).A very common 2.5D strategy is "multi-view fusion" where three 2D CNNs are trained on the sagittal, coronal, and axial planes separately [27], after that, the segmentation results from each plane are fused to get the final result.
In this work, we propose a bi-dimensional UNet, for segmentation on volumetric medical data that extracts multiview and multi-slice information thanks to a Slice SHift mechanism (SSH-UNet).To extract multi-view features as in [12] we impose weight sharing between the 2D convolution that processes the slices from the three orthogonal planes.While shifting is a well-established technique in video processing, we wondered if it could also be transferred to volumetric data since there is no inherent preferential direction like time in videos.As shown in Figure 1 intraslice features are extracted by shifting a portion of the feature maps along the slices' axis following the work in [13].SSH-UNet is evaluated on two publicly available benchmark datasets the Multi-Modality Abdominal Multi-Organ Segmentation (AMOS) [10] and Multi-Atlas Labeling Beyond the Cranial Vault (BTCV) [11].To the extent of our knowledge, no previous work in the medical field has explored the combination of shifting and shared weights across multiple views within a single model.
To be more specific, the contributions of our work are as follows: • We propose the first network that repurposes the spatiotemporal modelling in video tasks to segment medical data.By interpreting the slices' axis as the time, we solve the problem of 2D CNNs that neglect information between adjacent slices by shifting a portion of the feature maps along the slices' axis.
• We revisit and extend the 2.5D multi-view fusion method by processing slices from the three orthogonal planes of a volume using a 2D UNet with shared weights rather than three separate networks, allowing multi-view features to be learned collaboratively while maintaining a light computational cost.
• We instantiate these ideas into the Slice-Shift (SSH) layer, a 2D convolution layer operating on 3D tensors.We validate the effectiveness of the proposed frame-work by training a UNet built of SSH layers on two publicly available benchmark datasets, AMOS and BTCV, showing that our approach with the same model complexity as 2D CNNs achieves the same performance as a fully 3D network with a similar architecture and can achieve comparable results with other popular state-ofthe-art approaches with less than 1/5 of parameters.
Our code will be released to facilitate follow-up research.

Related Work
2.1.Segmentation on medical data with U-Net U-Net [19] was proposed for biomedical image segmentation back in 2015.Afterwards, a new class of models was developed based on U-Net-like architectures which established the state-of-the-art in segmentation.One promising approach was proposed by Isensee et al. in [8], where nnU-Net was introduced.nnU-Net is a deep learning-based segmentation method that automatically configures itself for any new task.Its performance is not attained through a new architecture (thus the name nnU-Net, 'no new net'), as it only comprises minor modifications to the original U-Net.Rather, it automates the complicated process of manually configuring the method.Hatamizadeh et al. reformulate in [7] the task of volumetric medical image segmentation as a sequence-to-sequence prediction problem by leveraging the power of self-attention and transformers architectures.They introduce a novel architecture, dubbed as UNEt TRansformers (UNETR), that utilizes a transformer as the encoder.The extracted representations are merged with a CNN-based decoder via skip connections at multiple resolutions.The ensemble of UNETR models has shown promising results on the BTCV dataset.Tang et al. introduced in [24] a novel 3D transformer-based model dubbed Swin UNEt TRansformers (Swin UNETR).Swin UNETR comprises a Swin Transformer [14] encoder and a CNN-based decoder.The transformer encoder is pre-trained with tailored, selfsupervised tasks over 5,050 images.Overall, the ensemble of 20 Swin UNETR models achieved at the time of publication the top-ranking performance on the BTCV challenge, showing distinct improvements for the segmentation of organs that are smaller in size.

Video action recognition
Spatio-temporal representation learning refers to the process of learning meaningful representations of both spatial and temporal information in a given dataset.In computer vision, this is particularly important for tasks such as video analysis and action recognition, where the goal is to accurately model the spatial and temporal evolution of objects and subjects over time.In particular, video action recognition has received increasing attention due to its potential applications such as video surveillance, human-computer interaction, and social video recommendation.This field presents however a fundamental challenge due to the space-time nature of the data.For years many efforts were made to trade off between temporal modelling and computation ( [25], [15], [13], [21]).Conventional 2D CNNs are computationally cheap but cannot capture temporal relationships.Since a video can be seen as a temporally dense sampled sequence of images, expanding the 2D convolution operation to 3D convolution is an intuitive approach to spatiotemporal feature learning.While 3D CNN-based methods can achieve strong results, they require significant computational resources.Lin et al. proposed in their work [13] a Temporal Shift Module (TSM) that can achieve the performance of 3D CNN but maintain 2D CNN's complexity.TSM shifts a fixed amount of the channels along the temporal dimension, facilitating information exchange among neighbouring frames yielding a 2D CNN that can learn spatiotemporal features.Li et al. [12] propose an operation that encodes spatiotemporal features by imposing a weight-sharing constraint.In particular, they perform 2D convolution by sharing the convolution kernels of three orthogonal views of a video, allowing multi-view features to be learned collaboratively.

Slice-Shift UNet
We based our model design on the UNet architecture proposed by Isensee et al. in [8] and optimized by Futrega et al. in [6].SSH-UNet, whose detailed illustration is found in Figure 2, is a CNN-based architecture designed to capture the global connections between multi-plane (axial, coronal, and sagittal) and multi-slice images.This is obtained through weight sharing and by shifting the feature maps along the slices' axis.The overall framework is characterized by: 1) 2D residual blocks used to extract spatial features from the slices of the input volume, 2) slice shifting to incorporate information between adjacent slices neglected by 2D convolutions, and 3) a multi-view fusion block to obtain the final segmentation predictions from the three orthogonal planes.

2D residual block
Let us assume that the input to the encoder is a subvolume V ∈ R Cin×S×H×W , with C in channels and patch resolution of (S, H, W ). V lies in the Euclidean space, thus it has three mutually perpendicular coordinate axes x, y and z and three mutually perpendicular coordinate planes: xy-plane, yz-plane and xz-plane.For clarity we use the following notation V ∈ R Cin×X×Y ×Z .We modify the input tensor by placing the plane of interest on the last two dimensions.More precisely from V we generate three volumes V xy , V yz and V xz , as below: (1) The three volumes are stored in the batch dimension obtaining the final input I: We apply 2D convolution with a kernel size of 1 × k × k extracting spatial features from the three orthogonal planes stored in I. Methods like [18,20] treat images from xy, yz, and xz planes as three channels of 2D images.This is empirically effective and memory efficient, but the weakness of the approach is that the three channels are not spatially aligned [9], which is why we chose to concatenate the three views in the batch leaving the network to learn multi-view features through weights shearing.Overall, our residual block is composed of two convolutional layers with kernel size 1 × 3 × 3 followed by instance normalization and LeakyReLu activation.A residual skip connects the input of the block with the output of the second convolution.

Slice shifting
Given a volume V ∈ R C×S×H×W perceived as a sequence of S images (or slices) with resolution (H, W ), when applying 2D convolution, we do not extract features between adjacent slices.In SSH-UNet we apply a shift operation to re-integrate the third dimension and mingle the information in neighbouring slices.The intuition behind the shift operation adapted from [13] is the following: if we consider a 1-D convolution with kernel size 3 and weights W = (w 1 , w 2 , w 3 ), and a 1D input tensor X, then the convolution operation can be written as The operation can be decoupled as a shift and multiply-accumulate, where X is shifted by -1, 0, +1 and multiplied by (w 1 , w 2 , w 3 ) respectively.The shift operation is: which can be conducted separately from multiplication.The multiply-accumulate operation is: that in our case is computed by the previously mentioned 2D convolution.The shift operation does not introduce any extra computational cost to the 2D CNN model.The overall framework is described in Figure 1 where an intermediate residual layer of SSH-Unet with C in input channels and C out output channels is depicted.The slices' axis S change based on the plane we are considering: axis X, Y , Z for the axial, sagittal, and coronal planes respectively.The feature maps of the different slices are denoted with different shades of colours in each row.Along the slices' axis, we shift part of the channels forward and backwards by +1 and -1 leaving the rest un-shifted.We shift a proportion of 1/4 of the channels forward and 1/4 backwards.

Multi-view fusion
The ensuring that information isn't wrongly mixed when tensors are fused.After, O xy , O xz , and O yz are summed followed by two convolutions with kernel size 1 × 1 × 1 generating the final segmentation mask.

Datasets
AMOS: the Multi-Modal Abdominal Multi-Organ Segmentation dataset [10] was introduced as part of the MIC-CAI 2022 challenge.AMOS is a large-scale, diverse, clinical dataset for abdominal organ segmentation that provides 500 CT and 100 MRI scans accompanied by voxel-level annotations for 15 organs.The data were collected from Longgang District Central Hospital (SZ, China).With over 74k annotated slices AMOS is ×20 larger than BTCV [11] dataset (3.6K annotated slices).For our experiment, we use the AMOS-CT subset where all the 500 CT scans are interpolated into the isotropic voxel spacing of 1.0 × 1.0 × 1.0 mm 3 .Following [10] we first truncate the HU values between [−991, 362] and normalize to [0, 1].Data augmentation of random flip, rotation, intensities scaling, and shifting are used with probabilities set to 0.2, 0.2, 0.5, and 0.5 respectively.The multi-organ segmentation problem is formulated as a 16-class segmentation task with 1-channel input.
BTCV: For the ablation analysis (Section 6), we utilize the popular Multi-Atlas Labeling Beyond the Cranial Vault dataset [11].BTCV contains 30 subjects with abdominal CT scans where 13 organs are annotated by interpreters under the supervision of radiologists at Vanderbilt University Medical Center.All CT scans were interpolated into the isotropic voxel spacing of 1.0 × 1.0 × 1.0 mm 3 as a pre-processing step.The intensity was truncated between [−175, 250] and normalized to [0, 1].We used the same data augmentation implemented in AMOS.

Implementation details
The network architecture was created using as baseline DynUNet class from MONAI 1 .We extended the original class by inserting the slice shifting in its building blocks and by adding our Multi-View Creation step and Multi-View Fusion Block.For a fair comparison the results in Table 1 are obtained by training for 1000 epochs using SGD optimizer with a momentum of 0.99, warm-up cosine scheduler for 50 iterations, an initial learning rate of 0.01, and a batch size of 2, recreating the same training condition of the benchmark created in [10].Following the official AMOC-CT challenge data split we used 200 CT scans for training and 100 CT scans for the validation set.With the BTCV dataset, we trained for 5000 epochs and stopped the training after 1000 epochs if the validation accuracy did not improve.An AdamW optimizer with a warm-up cosine scheduler was used for 50 iterations, batch size 2, an initial learning rate of 4e-4, momentum of 0.9, and decay rate of 1e-5.We used 24 CT for training and 6 CT for testing.
Each training was conducted with a patch resolution of 96 × 96 × 96 on an NVIDIA A100.

Evaluation metric
We used the Dice Similarity Coefficient (DSC) and the Normalized Surface Dice (NSD) [17] metric to evaluate the segmentation accuracy in our experiments.While DSC measures the overlap between two volumes, the NSD score provides information on the segmentation quality for the boundaries.Given the ground truth Y and the prediction Ŷ for each voxel i the Dice score is defined as: Using the above two metrics, we calculate category-wise performance.The DSC used to gauge model performance, ranges from 0 to 1, where 1 corresponds to a pixel-perfect match between the deep learning model output Ŷ and ground truth annotation Y .The NSD is used to determine which fraction of a segmentation boundary is correctly predicted with values ranging between 0 and 1.

Results
We compare our model with six state-of-the-art medical segmentation methods present in the benchmark in [10] where Yuanfeng and his colleagues, for the training stage, randomly cropped sub-volumes of size 64 × 160 × 160; we rather cropped sub-volumes of size 96 × 96 × 96 as input for our network, due to the multi-view creation, described in Section 3.1, that requires an isotropic volume size.The implementation of the state-of-the-art methods can be found in: UNet2 , VNet3 , CoTr4 , nnFormer 5 , UNetr 3 , Swin-UNetr 3 .
The class-wise Dice scores on the AMOS-CT validation set are shown in Table 1.By training with 96 × 96 × 96 patches, we achieve an overall accuracy of 87.28% gaining the second position in the benchmark right after UNet [8], trained with 64 × 160 × 160, that indeed outperforms SSH-UNet with +1.6% gain in accuracy.However, our model has almost -80% of parameters.Comparing SSH-UNet with Swin UNETR [24] (previously ranked first on MSD [1] and BTCV leaderboards) our model offers a substantial improvement in segmenting: right kidney +2.2%, gallbladder +5.8%, liver +1.7%, stomach +3.4%, and prostate/uterus +4.2%.In Table 3 the overall results from the AMOS-CT test benchmark are shown.SSH-UNet also confirmed its second position in the test set with an average DSC of 87.75% and NSD of 77.16%.The class-wise DSC and NSD can be found in Table 2, while Figure 3 shows some representative samples of our predictions.
In Table 4 we can see the results of 5-fold cross-validation on the BTCV dataset.On average our model is able to reach 84.35% of accuracy without the help of any ensemble.From the table, we can observe that the fourth-fold segmentation of the spleen shows a significant drop in performance.The gallbladder and adrenal glands are segmented poorly by the first and second folds compared to the others.The first fold also led to a bad segmentation mask for the esophagus, liver, and stomach.We want to highlight that the official BTCV webpage emphasizes that some patients may not have the right kidney or gallbladder and thus are not labelled; however, our network is capable of segmenting the right kidney independently of the folds, while the drop in performance in the second fold in the gallbladder may be related to the lack of annotated data.Models CT-Test mDSC(%) mNSD(%) UNet [8] 89.04 78.32 VNet [16] 82.92 67.56 CoTr [26] 80.86 66.31 nnFormer [29] 85.61 72.48 UNETR [7] 79.43 60.84 Swin-UNETR [24] 86.32 73.83 SSH-UNet 87.75 77.16 Table 3. Overall results of six state-of-the-art methods taken from the official AMOS-CT test benchmark in [10] and SSH-UNet.

Ablation study 6.1. Model components
We perform an ablation study to validate the effectiveness of the individual components of our model.As shown in Tables 5, we can see the results of the different configurations trained with the BTCV and AMOS datasets.A UNet with only 2D convolution resulted in the lowest mDSC score.By introducing only the shift operation, referred to as "shift" in the table, performance improved compared to the simple 2D case.With less than half of the parameters by combining multi-view with the shift operation (m.v.+ shift) we are able to achieve comparable results of fully 3D UNet with the same architecture.In Figure 4 we can see qualitative results on the BTCV validation set.

Shift operation
We investigate the impact on the performance of the proportion of shifted channels.In Table 6 we can see that by shifting 1/4 of the feature maps forwards and 1/4 backwards (meaning we are shifting in total half of the channels) we have the best result.In the last column, we have the 2D case without shifting.

Model complexity
In this section, we examine the model complexity.In   The first row highlights the segmentation results for the portal and splenic veins.Fully 3D UNet achieves qualitatively the best result, while we can observe that the last three columns miss the segmentation of a small left portion.The second row focuses on the segmentation of the pancreas (yellow) and stomach (green).We can see that the 2D implementation of UNet and the 2D UNet with the shift operation (last two columns) are not able to segment a portion of the stomach, while our network (third column) and 3D UNet can perfectly segment it.In the last row, the pancreas is pointed again.In this case, it is segmented properly by both 3D UNet and our implementation while a portion is completely missed by the fully 2D model even if integrated with the shift operation.7. Overall results of six state-of-the-art methods taken from the official AMOS-CT validation benchmark in [10] and SSH-UNet.with other state-of-the-art models (on average less than 1/5 of parameters) while maintaining the second-highest DSC score of 87.28%.

Conclusions
Organ segmentation is a fundamental task in the medical field.The volumetric data that characterize CT and MRI acquisitions make, however, the segmentation task computationally expensive.On the one hand, 2D CNNs provide a low latency solution unable to capture inter-slice information, on the other hand, 3D CNNs extract three-dimensional features at the price of high computation costs and risk of overfitting.Moreover, popular 2.5D multi-view fusion methods train three separate networks where the features of the orthogonal planes are learned independently, despite being part of the same volume.In SSH-UNet this is addressed by imposing weight sharing between convolutions so that only one network needs to be trained and multi-view features are collaboratively learned.In this work, we introduced a novel approach for the segmentation of volumetric medical data.Inspired by works in the field of Video Action Recognition we interpret the slices of a volume as the frame of a video.Given a 2D backbone, to re-integrate the information between features belonging to adjacent slices we leverage the power of a shifting mechanism inspired by the TSM module.Spatio-temporal modeling, declined on pseudo-3D operators, despite being well-known in the Video Understanding field was never used before in the medical image analysis to extract and mingle multi-slice features.Our network, by using a 2D convolution with weight sharing mechanism and slice shift, can extract 3D features keeping low computational complexity.In comparison to other popular state-of-the-art methods, SSH-UNet achieves an accuracy of 87.28% on the AMOS validation providing the smallest model in terms of parameters (6.48M) compared to the best network which has +1.6% improve in accuracy but ×5 increase in parameters.
last Residual block of the decoder gives as output the tensor O = [O xy ; O yz ; O xz ].The final segmentation mask is obtained by fusing the three output tensors stored in the batch of O.In the first place, the operation computed in Eq. 1 is reversed

Figure 2 .
Figure 2. Detailed architecture and components of our proposed SSH-UNet.(a) Input block of our network that processes the concatenation in the batch of the three created views, (b) Residual block that can extract multi-view and multi-slice features thanks to 2D convolutions with shared weights and slice shift mechanism.(c) Output block where multi-view creation is reversed and multi-view fusion is performed to obtain the final segmentation mask.d) Overview of SSH-UNet architecture.

Figure 3 .
Figure 3.In this qualitative visualization we can see the prediction of SSH-UNet for three samples, identified by their ID number, from the AMOS-CT test set.

Figure 4 .
Figure 4. Qualitative results with representative samples from the BTCV dataset.The first row highlights the segmentation results for the portal and splenic veins.Fully 3D UNet achieves qualitatively the best result, while we can observe that the last three columns miss the segmentation of a small left portion.The second row focuses on the segmentation of the pancreas (yellow) and stomach (green).We can see that the 2D implementation of UNet and the 2D UNet with the shift operation (last two columns) are not able to segment a portion of the stomach, while our network (third column) and 3D UNet can perfectly segment it.In the last row, the pancreas is pointed again.In this case, it is segmented properly by both 3D UNet and our implementation while a portion is completely missed by the fully 2D model even if integrated with the shift operation.

Figure 5 .
Figure 5. Efficiency: FLOPs vs. DSC.We plot the average DSC on the validation set of AMOS-CT.The FLOPs and parameters are estimated using [1 × 128 × 128 × 128] as model input.The size of each circle indicates the number of parameters (Params.).We can observe that SSH-UNet has the lowest number of parameters and small FLOPs compared to other implementations while maintaining the second-highest DSC of 87.28%.

Table 2 .
The class-wise Dice score (DSC) and the Normalized Surface Distance (NSD) of SSH-UNet on the AMOS-CT test.

Table 5 .
Ablation analysis of the introduced components.The term shift stands for the slice shift operation, while the term m.v.stands for multi-view fusion.In the last two columns, we can see the average Dice score on the validation set for BTCV and AMOS datasets.

Table 6 .
Performance comparison on the proportion of channels shifted forward and backwards on the AMOS-CT validation set.Proportion 0 is the fully 2D case without shift.The proportion 1/2 is the case where all the channels are shifted, half forward and half backwards.andthenumber of parameters are presented for SSH-UNet and other baselines.A graphical representation of the Table can be seen in Figure5, where the efficiency plot shows that SSH-UNet is computationally more efficient compared