Interpretable Sparsification of Brain Graphs: Better Practices and Effective Designs for Graph Neural Networks

Brain graphs, which model the structural and functional relationships between brain regions, are crucial in neuroscientific and clinical applications that can be formulated as graph classification tasks. However, dense brain graphs pose computational challenges such as large time and memory consumption and poor model interpretability. In this paper, we investigate effective designs in Graph Neural Networks (GNNs) to sparsify brain graphs by eliminating noisy edges. Many prior works select noisy edges based on explainability or task-irrelevant properties, but this does not guarantee performance improvement when using the sparsified graphs. Additionally, the selection of noisy edges is often tailored to each individual graph, making it challenging to sparsify multiple graphs collectively using the same approach. To address the issues above, we first introduce an iterative framework to analyze the effectiveness of different sparsification models. By utilizing this framework, we find that (i) methods that prioritize interpretability may not be suitable for graph sparsification, as the sparsified graphs may degenerate the performance of GNN models; (ii) it is beneficial to learn the edge selection during the training of the GNN, rather than after the GNN has converged; (iii) learning a joint edge selection shared across all graphs achieves higher performance than generating separate edge selection for each graph; and (iv) gradient information, which is task-relevant, helps with edge selection. Based on these insights, we propose a new model, Interpretable Graph Sparsification (IGS), which improves the graph classification performance by up to 5.1% with 55.0% fewer edges than the original graphs. The retained edges identified by IGS provide neuroscientific interpretations and are supported by well-established literature.


INTRODUCTION
Understanding how brain function emerges from the communication between neural elements remains a challenge in modern neuroscience [5].Over the years, researchers have used brain graphs to encode the correlations of brain activities and uncover interesting connectivity patterns between brain regions.They find that the topological properties of brain graphs are useful in predicting various phenotypes and understanding brain activities [8,13,14,26,49], which account for the wide usage of brain graphs in neuroscientific research [39,55,70].Adopting the graph representations (often termed "connectomes"), many neuroscientific problems can be cast as graph problems.In this paper, we focus on end-to-end brain graph classification tasks since many brain graph classification tasks have meaningful real-life clinical significance, such as providing a non-invasive neuroimaging biomarker for the identification of certain psychiatric/neurological disorders at an early stage (e.g.autism, Alzheimer's disease) [48].
Despite the benefits of modeling brain data as graphs, even wellpreprocessed brain graphs pose serious challenges.A functional MRI-based (fMRI) brain graph, which is usually computed as pairwise correlations of fMRI time-series data, is fully connected.The resulting dense graph causes two unavoidable problems.First, it inhibits the use of efficient sparse operations, which leads to large time and memory consumption when the graphs are large [17,70].Second, the dense graph suffers from fMRI-related noise, making it extremely hard to train a model that learns useful generalization rules and provides good interpretability [41].To this end, it is crucial to make brain graphs more sparse and less noisy.The common practice in neuroscience is to remove the "weak" edges, whose weights are below the predefined threshold [52].However, direct thresholding requires a wide search for the proper threshold [10], and the sparsified graphs may lack useful edges and preserve significant noise.To illustrate it, in Table 1, we show the performance on the original graphs and sparsified graphs obtained using direct thresholding in a classification task.It can be seen that direct thresholding may drop important edges and/or keep unimportant edges, which leads to a decrease in performance.
Prior work related to graph sparsification generally falls into two categories.The first line of work learns the relative importance of the edges, which can be used to remove unimportant edges in the graph sparsification process.These works usually focus on interpretability explicitly, oftentimes referred to as "explainable graph neural networks (explainable GNNs)" [74].The core idea embraced by this community is to identify small subgraphs that are most accountable for model predictions.The relevance of the edges to the final predictions is encoded into an edge importance mask, a matrix that reveals the relative importance of the edges and can be used to sparsify the graphs.These works show good interpretability under various measures [51].However, it remains unclear whether better interpretability indicates better performance.The other line of work tackles unsupervised graph sparsification [42], without employing any label information.Some methods reduce the number of edges by approximating pairwise distances [50], cuts [33], or eigenvalues [58].These task-irrelevant methods may discard useful task-specific edges for predictions.Fewer works are task-relevant, primarily focusing on node classification [43,77].Consequently, these works produce different edge importance masks for each graph.However, in graph classification, individual masks can lead to significantly longer training time and susceptibility to noise.Conversely, a joint mask emerges as the preferred choice, offering robustness against noise and greater interpretability.
This work.To assess the quality of the sparsified graphs obtained from interpretable models in the graph classification task, we propose to evaluate the effectiveness of the sparsification algorithms under an iterative framework.At each iteration, the sparsification algorithms decide which edges to remove and feed the sparsified graphs to the next iteration.We measure the effectiveness of a sparsification algorithm by computing the accuracy of the downstream graph classification task at each iteration.An effective sparsification algorithm should acquire the ability to identify and remove noisy edges, resulting in a performance boost in the graph classification task after several iterations (Section 4.2).
We utilize this iterative framework to evaluate two common practices used in graph sparsification and graph explainability: (1) obtaining the edge importance mask from a trained model and (2) learning an edge importance mask for each graph individually [74].For instance, GNNExplainer [72] learns a separate edge importance mask for each graph after the model is trained.Through our empirical analysis, we find that these practices are not helpful in graph sparsification, as the sparsified graphs may lead to lower classification accuracy.In contrast, we identify three key strategies that can improve the performance.Specifically, we find that (S1) learning a joint edge importance mask (S2) simultaneously with the training of the model helps improve the performance over the iterations, as it passes task-relevant information through back-propagation.Another strategy to incorporate the task-relevant information is to (S3) initialize the mask with the gradient information from the immediate previous iteration.This strategy is inspired by the evidence in the computer vision domain that gradient information may encode data and task-relevant information and may contribute to the explainability of the model [1,3,27].
Based on the identified strategies, we propose a new Interpretable model for brain Graph Sparsification, IGS.We evaluate our IGS model on real-world brain graphs under the iterative framework and find that it can benefit from iterative sparsification.IGS achieves up to 5.1% improvement on graph classification tasks with graphs of 55.0% fewer edges than the original compared to strong baselines.
Our main contributions are summarized as follows: • General framework.We propose a general iterative framework to analyze the effectiveness of different graph sparsification models.We find that edge importance masks generated from interpretable models may not be suitable for graph sparsification because they may not improve the performance of graph classification tasks.• New insights.We find that two practices commonly used in graph sparsification and graph explainability are not helpful under the iterative framework.Instead, we find that learning a joint edge importance mask along with the training of the model improves the classification performance during iterative graph sparsification.Furthermore, incorporating gradient information in mask learning also boosts the performance in iterative sparsification.

NOTATION AND PRELIMINARIES
In this section, we introduce key notations, provide a brief background on GNNs, and formally define the problem that we investigate.
Notations.We consider a set of graphs G.  to represent the node representations/output at the -th layer of a GNN.Given our emphasis on graph classification problems, we denote the number of classes as , the set of labels as Y, and associate each graph   with a corresponding label   ∈ Y.
We also leverage gradient information [56] in this work: ∇  (  ) denotes the gradients of the output in class  with respect to the input graph   .These gradients are obtained through backpropagation and are referred to as the gradient map. }), where the Pooling function operates on all node representations, including options like Global_mean, Global_max or other complex pooling functions [37,73].The loss is given by  = 1     ∈ G train CrossEntropy (Softmax(x   ),   ), where G train represents the set of training graphs and   = |G train |.Though our framework does not rely on specific GNNs, we illustrate the effectiveness of our framework using the GCN model proposed in [35].
The performance of GNN models heavily depends on the quality of the input graphs.Messages propagated through noisy edges can significantly affect the quality of the learned representations [70].Inspired by this observation, we focus on the following problem: OR a joint edge importance mask M ∈ {0, 1} × shared by all graphs, which can be used to remove the noisy edges and retain the most task-relevant ones.This should lead to enhanced classification performance on sparsified graphs.Edge masks that effectively identify task-relevant edges are considered to be interpretable.

PROPOSED METHOD: IGS
In this section, we introduce our proposed iterative framework for evaluating various sparsification methods.Furthermore, we introduce IGS, a novel and interpretable graph sparsification approach that incorporates three key strategies: (S1) joint mask learning, (S2) simultaneous learning with the GNN model, and (S3) utilization of gradient information.We provide detailed explanations of these strategies in the following subsections.

Iterative Framework
Figure 1 illustrates the general iterative framework.At a high level, given a sparsification method, our framework iteratively removes unimportant edges based on the edge importance masks generated by the method at each iteration.In detail, the method can generate either a separate edge importance mask M  for each input graph   or a joint edge importance mask M shared by all input graphs G = { 1 ,  2 , • • • }.These edge importance masks indicate the relevance of edges to the task's labels.In our setting, we also allow training the masks simultaneously with the model.Ideal edge masks are binary, where zeros represent unimportant edges to be removed.In reality, many models (e.g.GNNs [72,76]) learn soft edge importance masks with values between [0,1].In each iteration, our framework removes either the edges with zero values in the masks (if binary) or a fixed percentage  of edges with the lowest importance scores in the masks.We present the framework of iterative sparsification in Algorithm 1, where G  denotes the set of sparsified graphs at iteration , and    denotes the -th graph in the set G  .Though existing works [28,51] have proposed different ways to define the "importance" of an edge and thus they generate different sparse graphs, we believe that a direct and effective way to evaluate these methods is to track the performance of these sparsified graphs under this iterative framework.The trend of the performance reveals the relevance of the remaining edges to the predicted labels.

Trained Mask (S1+S2).
We aim to learn a joint edge importance mask M ∈ {0, 1} × along with the training of a GNN model, as shown in Figure 2.Each entry in M represents if the corresponding edge in the original input graph should be kept (value 1) or not (value 0).Directly learning the discrete edge mask is hard as  At iteration , IGS takes a set of input graphs and initializes its joint edge importance mask using the joint gradient map from the previous iteration.It trains the GNN model and the edge importance mask together, followed by sparsifying all input graphs using the obtained mask.Normal training is then conducted on the sparsified graphs.The gradient information is later extracted by computing a joint gradient map.Finally, IGS feeds the sparsified graphs to the next iteration and uses the joint gradient map to initialize the subsequent joint edge importance mask.IGS is model-agnostic and can be seamlessly integrated with existing GNN models.it cannot generate gradients to propagate back.Thus, at each iteration, we learn a soft version of M, where each entry is within [0, 1] and reflects the relative importance of each edge.Considering the symmetric nature of the adjacency matrix for undirected brain graphs, we require the learned edge importance mask to be symmetric.We design the soft edge importance mask as  (Φ  + Φ), where Φ is a matrix to be learned and  is the Sigmoid function.A good initialization of Φ can boost the performance and accelerate the training speed.Thus, we initialize this matrix with the gradient map (Section 3.2.2) from the previous iteration (Step 5 in Figure 2).Furthermore, following [72], we regularize the training of Φ by requiring  (Φ  +Φ) to be sparse.Thus we apply a  1 regularization on  (Φ  + Φ).In summary, we have the following training objective: where ⊙ denotes the Hadamard product; L is the Cross-Entropy loss;  is the regularization coefficient.We optimize the joint mask The indicator matrix M can then be used to sparsify the input graph through an element-wise multiplication, e.g. ′  = M ⊙   .3.2.2Joint Gradient Information (S3).Inspired from the evidence in the computer vision domain that gradient information may encode data and task-relevant information and may contribute to the explainability of the model [1,3,27], we utilize the gradient information, i.e., gradient maps to initialize and guide the learning of the edge importance mask.
Step 4 in Figure 2 illustrates the general idea of generating a joint gradient map by combining gradient information from each training graph.Each training graph   has  gradient maps ∇  (  ),  = 1, 2, • • • , , each corresponding to the output in class  (Section 2).Instead of using the "saliency maps" [56], which consider only the gradient maps from the predicted class, we leverage all the gradient maps as they provide meaningful knowledge.For  1 , . . .,   ∈ G train , we compute the unified mask of class j as the sum of the absolute values of each gradient map, represented as By summing the unified masks of all classes, we generate the joint edge gradient map denoted as T =  =1   .

EMPIRICAL ANALYSIS
In this section, we aim to answer the following research questions using our iterative framework: (Q1) Is learning a joint edge importance mask better than learning a separate mask for each graph?(Q2) Does simultaneous training of the edge importance mask with the model yield better performance than training the mask separately from the trained model?(Q3) Does the gradient information help with graph sparsification?(Q4) Is our method IGS interpretable?

Setup
4.1.1Dataset.We use the WU-Minn Human Connectome Project (HCP) 1200 Subjects Data Release as our benchmark dataset to evaluate our method and baselines [61].The pre-processed brain graphs can be obtained from ConnectomeDB [45].These brain graphs are derived from the resting-state functional magnetic resonance imaging (rs-fMRI) of 812 subjects, where no explicit task is being performed.Predictions using rs-fMRI are generally harder than task-based fMRI [44].The obtained brain graphs are fully connected, and the edge weights are computed from the correlation of the rs-fMRI time series between each pair of brain regions [57].The parcellation of the brain is completed using Group-ICA with 100 components [9,20,[22][23][24]54], which results in 100 brain regions comprising the nodes of our brain graphs.Additionally, a set of cognitive assessments were performed on each subject, which we utilized as cognitive labels in our prediction tasks.Specifically, we utilize the scores from the following cognitive domains as our labels, which incorporate age adjustment [45]: • PicVocab (Picture Vocabulary) assesses language/vocabulary comprehension.The respondent is presented with an audio recording of a word and four photographic images on the computer screen and is asked to select the picture that most closely matches the word's meaning.• ReadEng (Oral Reading Recognition) assesses language/reading decoding.The participant is asked to read and pronounce letters and words as accurately as possible.
The test administrator scores them as right or wrong.• PicSeq (Picture Sequence Memory) assesses the Open of episodic memory.It involves recalling an increasingly lengthy series of illustrated objects and activities presented in a particular order on the computer screen.• ListSort (List Sorting) assesses working memory and requires the participant to sequence different visually-and orallypresented stimuli.• CardSort (Dimensional Change Card Sort) assesses the cognitive flexibility.Participants are asked to match a series of bivalent test pictures (e.g., yellow balls and blue trucks) to the target pictures, according to color or shape.Scoring is based on a combination of accuracy and reaction time.• Flanker (Flanker Task) measures a participant's attention and inhibitory control.The test requires the participant to focus on a given stimulus while inhibiting attention to stimuli flanking it.Scoring is based on a combination of accuracy and reaction time.More details can be found in ConnectomeDB [45].These scores are continuous.In order to use them for graph classification, we assign the subjects achieving scores in the top third to the first class and the ones in the bottom third to the second class.

Baselines. We outline the baselines used in our experiments.
Grad-Indi [7].This method obtains the edge importance mask for each individual graph from a trained GNN model.In contrast to the gradient information (Strategy S3) proposed in Section 3.2.2, a gradient map of each sample is generated for the predicted class   : T  = ∇   (  ) ⊙ ∇   (  ) [7].Later, the edge importance mask M  for   is generated based on Equation (2).
Grad-Joint.We adapt Grad-Indi [7] to incorporate our proposed strategies (S1+S3) and learn an edge importance mask shared by all graphs from a trained GNN model.Specifically, we leverage the method described in Section 3.2.2 that generates the joint gradient map to obtain the joint importance mask.
Grad-Trained.We further modify Grad-Indi [7] to train the joint edge mask concurrently with the GNN training (S2).We also use the joint gradient map (Section 3.2.2) to initialize the edge importance mask (Strategies S1+S2+S3).The main differences of Grad-Trained from IGS are that: (1) it does not require symmetry of the edge mask; (2) it does not require edge mask sparsity (without  1 regularization).
GNNExplainer-Indi [72].This method trains an edge important mask for each individual graph after the GNN model is trained.We follow the code provided by [40].
GNNExplainer-Joint.Adapted from [72], this model trains a joint edge important mask for all graphs (Strategy S1).
GNNExplainer-Trained. Adapted from [72], this method simultaneously trains a joint edge important mask and the GNN model (Strategies S1+S2).Compared with IGS, this method does not use gradient information.
BrainNNExplainer [18].This method (also known as IBGNN) trains a joint edge important mask for all graphs after the GNN is trained.It is slightly different from GNNExplainer-Joint in terms of objective functions.We follow the original setup in [18].
BrainGNN [38].This method does not explicitly perform the graph sparsification task, but uses node pooling to identify important subgraphs.It learns to preserve important nodes and all the connections between them.We follow the original setup in [38].

Training Setup.
To fairly evaluate different methods under the iterative framework, we adopt the same GNN architecture [34], hyper-parameter settings, and training framework.We set the number of convolutional layers to four, the dimension of the hidden layers to 256, the dropout rate to 0.5, the batch size to 16, the optimizer to Adam, the learning rate to 0.001, and the regularization coefficient  to 0.0001.Note that though we use the GNN from [34], IGS is model-agnostic, and we provide the results of other backbone GNNs in Table 4.For each prediction task, we shuffle the data and take four different data splits.The train/val/test split is 0.7/0.15/0.15.To reduce the influence of imbalances, we manually ensure each split has equal labels.In each iteration, we adopt early stopping [53] and set the patience to 100 epochs.We stop training if we cannot observe a decrease in validation loss in the latest 100 epochs.We fix the removing ratio % to be 5% per iteration.In the iterative sparsification, we run a total of 55 iterations and use the validation loss of the sparsified graphs as the criterion to select the best iteration (Step 3 in Figure 2).We present the average and standard deviation of test accuracies over four splits, using the model obtained from the best iteration.The code is available at https://github.com/motivationss/IGS.git.

(Q1-Q3) Graph Classification under the Iterative Framework
In Table 2, we present the results of IGS with the eight baselines mentioned in section 4.1.2.The first row represents the prediction task we study; the second row represents the performance averaged across four different splits using the original graph; and the rest of the rows denote the performance of other baselines.Notably, for better comparison across different baselines, the last column shows the average rank of each method.Below we present our observations from Table 2: First, learning a joint mask contributes to a better performance than learning a mask for each graph separately.We can start by comparing the performance between GNNExplainer-Joint and GNNExplainer-Indi as well as Grad-Joint and Grad-Indi.The performance disparity between the methods in each pair is notable and consistent across all prediction tasks.Notably, Grad-Joint (rank: 4.33) outperforms Grad-Indi (rank: 7.67) by a considerable margin, while GNNExplainer-Joint (rank: 4.67) ranks significantly higher than GNNExplainer-Indi (rank: 8.33).Using a joint mask instead of individual masks can provide up to 6.7% boost in accuracy, validating our intuition in section 3.2.2 that a joint mask is more robust to sample-wise noise.
Table 2: Results of test accuracies of different approaches evaluated on six prediction tasks (PicVocab, ReadEng, PicSeq, ListSort, CardSort, and Flanker) across four data splits generated by different random seeds.We report the mean and standard deviation for each of them.The first row denotes the performance using the original graph trained by GCN [34]; the last column denotes the average rank of each method.The best result is marked in bold.Second, training the mask and the GNN model simultaneously yields better results than obtaining the from the trained model.We can see this by comparing the performance between the Trained and the Joint variants of Grad and GNNExplainer.Changing from post-training to joint-training can provide up to 3.4% performance improvements, as demonstrated in the ReadEng task by the two variants of GNNExplainer.Even though in some tasks the post-training approach may outperform the trained approach (e.g.Grad-Joint and Grad-Trained in the PicVocab task), the trained approach has a higher average rank than the posttraining approach (e.g.3.83 vs. 4.33 for Grad and 3.17 4.67 for GNNExplainer).addition, the better performance of IGS over BrainNNExplainer also demonstrates the effectiveness of obtaining the edge mask during training rather than after training.

PicVocab
Third, incorporating gradient information helps improve classification performance.We can see this by first comparing the performance of Grad-Joint and Grad-Trained against the original graphs.The use of gradient information can provide up to 5.1% higher accuracy, though the improvement depends on the task.Furthermore, since the main difference between GNNExplainer-Trained and IGS lies in the use of gradient information, the consistent superior performance of IGS strengthens this conclusion.
Fourth, we compare the performance of the baselines against the performance of the original graphs (second row).Grad-Indi [7] and GNNExplainer-Indi [72] are implementations that faithfully follow their original formulation or are provided directly by the authors.These two approaches fail to achieve any performance improvements through iterative sparsification, with the exception of Grad-Indi in the task of PicVocab and ReadEng.This raises the question of whether these existing instance-level approaches can identify the most meaningful edges in noisy graphs.These methods may be vulnerable to severe sample-wise noise.On the contrary, with our suggested modifications, the joint and trained versions can remove the noise and provide up to 5.1% performance boost compared to the base GCN method applied to the original graphs.However, the improvement is dataset-dependent.For instance, GNNExplainer-Trained provides decent performance boosts in PicVocab, ReadEng, and Flanker, but degrades in PicSeq, ListSort, and CardSort.
Finally, our proposed approach, IGS, achieves the best performance across all prediction tasks, demonstrated by its highest rank among all methods.Compared with the performance on the original graphs, IGS can provide consistent performance boost across all prediction tasks, with the exception of ListSort, which is a challenging task that no baseline surpasses the original performance.Furthermore, using the sparsified graph identified by IGS generally results in less variance in accuracy and leads to better stability when compared to the original graphs, with the exception on the PicSeq task.In addition, the superior performance of IGS over BrainGNN demonstrates the effectiveness of using edge importance masks as opposed to node pooling.
Graph Sparsity.In Table 3, we present the final average sparsity of the graphs obtained by IGS over four data splits.We observe that with significantly fewer edges retained, IGS can still achieve up to 5.1% performance boost.

(Q4) Interpretability of IGS
We now evaluate the interpretability of the edge masks derived for each of our prediction tasks.
Setup.We assign anatomical labels to each of the 100 components comprising the nodes of our brain networks by computing the largest overlap between regions identified in the Cole-Anticevic parcellation [30].We then obtained the edge masks from the bestperforming iteration of each prediction task and assessed the highest-weighted edges in each mask.Results.Since our IGS model performed best in the language-related prediction tasks, ReadEng and PicVocab, we focus our interpretability analysis on this domain.There is ample evidence in the neuroscience literature that supports the existence of an intrinsic language network that is perceptible during resting state [11,36,59]; thus, it is unsurprising that our rs-fMRI based brain networks are predictive of language task performance.It has also been well established for over a century that the language centers (including Broca's area, Wernicke's area, the angular gyrus, etc.) are characteristically left-lateralized in the brain [12,65].In both ReadEng and PicVocab, the majority of the highest weighted edges retained in the masks involved brain regions localized to the left hemisphere, falling in line with the expectations for a language task.
PicVocab.Figures 3 and 4 depict the progression of the edge masks at both the node and subnetwork level over the training iterations towards optimal edge mask in both the ReadEng and PicVocab tasks.Evaluating the edge masks at the subnetwork level offers valuable insights into which functional connections are most important for the prediction of each task.The PicVocab edge mask homed in on functional connections involving the Cingulo-Opercular (CO) network, specifically between CO and the Dorsal Attention (DA), Visual1 (V1), Visual2 (V2) and Frontoparietal (FP) networks.The CO network has been shown to be implicated in word recognition [60], and its synchrony with other brain networks identified here may represent the stream of neural processing related to the PicVocab task, in which subjects respond to an auditory stimulus of a word and are prompted to choose the image that best represents the word.Connectivity between the Auditory (AD) and V2 networks is also evident in the PicVocab edge mask, suggesting the upstream integration of auditory and visual stimuli involved in the PicVocab task are also predictive of task performance.
ReadEng.The IGS model also found edge mask connections between the V1 network and the CO, Language (LA) and DA networks, as well as CO-LA and CO-AD connections, to be most predictive of ReadEng performance.This task involves the subject reading aloud words presented on a screen.From our results, it follows that the ability of Vis1 to integrate with networks responsible for language processing (LA and CO) and attention (DA), as well as the capacity for functional synchrony between the language-related networks (CO-LA), would be predictive of overall ReadEng performance.The importance of the additional CO-AD connectivity identified by our model also suggests that the ability of the CO language network to integrate with auditory centers may be involved in the neural processes responsible for the proper pronunciation of the words given by visual cues.Key take-aways.Overall, in addition to the IGS model's superior classification performance, our results suggest that the iterative pruning of the IGS edge masks during training does indeed retain important and neurologically meaningful edges while removing noisy connections.While it has been shown in the literature that resting-state connectivity can be used to predict task performance [6,32,46], the ability of the IGS model to sparsify the resting state brain graph to clearly task-relevant edges for prediction of task performance further underscores the interpretability of the resultant edge masks.

Graph Explainability
Our work is related to explainable GNNs given that we identify important edges/subgraphs that account for the model predictions.Some explainable GNNs are "perturbation-based", where the goal is to investigate the relation between output and input variations.GNNExplainer [72] learns a soft mask for the nodes and edges, which explains the predictions of a well-trained GNN model.Sub-graphX [75] explains its predictions by efficiently exploring different subgraphs with a Monte Carlo tree search.Another approach for explainable GNNs is surrogate-based; the methods in this category generally construct a simple and interpretable surrogate model to approximate the output of the original model in certain neighborhoods [74].For instance, GraphLime [29] considers the N-hop neighboring nodes of the target node and then trains a nonlinear surrogate model to fit the local neighborhood predictions; RelEx [76] first uses a GNN to fit the BFS-generated datasets and then generates soft masks to explain the predictions; PGM-Explainer [64] generates local datasets based on the influence of randomly perturbing the node features, shrinks the size of the datasets via the Grow-Shrink algorithm, and employs a Bayesian network to fit the datasets.In general, most of these methods focus on the node classification task and make explanations for a single graph, which is not applicable to our setting.Others only apply to simple graphs, which cannot handle signed and weighted brain graphs [29,75].Additionally, most methods generate explanations after a GNN is trained.Though some methods achieve decent results in explainability-related metrics (e.g.fidelity scores [51]), it remains unclear whether their explanations can necessarily remove noise and retain the "important" part of the original graph, which improves the classification accuracy.

Graph Sparsification
Compared to the explainable GNN methods, graph sparsification methods explicitly aim to sparsify graphs.Most of the existing methods are unsupervised [77].Conventional methods reduce the size of the graph through approximating pairwise distances [50], preserving various kinds of graph cuts [33], node degree distributions [19,63], and using some graph-spectrum based approachse [2,15,16].These methods aim at preserving the structural information of the original input graph without using the label information, and they assume that the input graph is unweighted.Relatively fewer supervised works have been proposed.For example, NeuralSparse [77] builds a parametrized network to learn a k-neighbor subgraph by limiting each node to have at most  edges.On top of NeuralSparse, PTDNet [43] removes the k-neighbor assumption, and instead, it employs a low-rank constraint on the learned subgraph to discourage edges connecting multiple communities.Graph Condensation [31] proposes to parameterize the condensed graph structure as a function of condensed node features and optimizes a gradient-matching training objective.Despite the new insights offered by these methods, most of them focus exclusively on node classification, and their training objectives are built on top of that.A work that shares similarity to our proposed method, IGS, is BrainNNExplainer [18] (also known as IBGNN).It is inspired by GNNExplainer [72] and obtains the joint edge mask in a post-training fashion.On the other hand, our proposed method, IGS, trains a joint edge mask along with the backbone model and incorporates gradient information in an iterative manner.Another line of work leverages node pooling to identify important subgraphs, and learns to preserve important nodes and all the connections between them.One representative work is BrainGNN [38].However, the connections between preserved nodes are not necessarily all informative, and some may contain noise.

Saliency Maps
Saliency maps are first proposed to explain the deep convolutional neural network models in image classification tasks [56].Specifically, the method proposes to use the gradients backpropagated from the predicted class as the explanations.Recently, [7] introduces the concept of saliency maps to graph neural networks, employing squared gradients to explain the underlying model.Additionally, [4] suggests using graph saliency to identify regions of interest (ROIs).In general, the gradients backpropagated from the output logits can serve as the importance indicators for model predictions.In this work, inspired by the line of saliency-related works, we leverage the gradient information to guide our model.

CONCLUSIONS
In this paper, we studied neural-network-based graph sparsification for brain graphs.By introducing an iterative sparsification framework, we identified several effective strategies for GNNs to filter out noisy edges and improve the graph classification performance.We combined these strategies into a new interpretable graph classification model, IGS, which improves the graph classification performance by up to 5.1% with 55% fewer edges than the original graphs.The retained edges identified by IGS provide neuroscientific interpretations and are supported by well-established literature.

Figure 1 :
Figure 1: General iterative framework of sparsification.This framework progressively eliminates noisy edges from input brain graphs by learning an edge importance mask for each/all graph(s).The edge importance mask(s) can be generated from a well-trained GNN model or trained simultaneously with a GNN model.Important edges are depicted in orange, while noisy edges are shown in grey.Dashed lines with purple crosses represent the removed edges in the sparsified graphs.

Figure 2 :
Figure2: Training process of IGS.At iteration , IGS takes a set of input graphs and initializes its joint edge importance mask using the joint gradient map from the previous iteration.It trains the GNN model and the edge importance mask together, followed by sparsifying all input graphs using the obtained mask.Normal training is then conducted on the sparsified graphs.The gradient information is later extracted by computing a joint gradient map.Finally, IGS feeds the sparsified graphs to the next iteration and uses the joint gradient map to initialize the subsequent joint edge importance mask.IGS is model-agnostic and can be seamlessly integrated with existing GNN models.

3 :
Weighted brain network edge masks at both node (top row) and subnetwork level (bottom row -computed as the average of corresponding edges) for PicVocab task.Early, middle, and final phases of training are depicted from left to right, and high-importance subnetworks are highlighted in red.We find that IGS gradually removes noisy edges and its final edge importance mask can provide high-quality interpretations.Highlighted (Orange) label names represent the regions that are meaningful in this task.Brain network labels and abbreviations: Auditory (AD), Cingulo-Opercular (CO), Dorsal Attention (DA), Default (DF), Frontoparietal (FP), Language (LA), Somatomotor (SO), Visual 1 (V1), Visual 2 (V2).

Figure 4 :
Figure 4: Weighted brain network edge masks at both node (top row) and subnetwork level (bottom row) for the ReadEng task, following the same setup in Section 4.3.

Table 1 :
Brain graph classification performance (accuracy) on the original graphs (Original) and sparsified graphs (Direct threholding).Direct thresholding may keep unimportant edges.Details about the data and experimental setup can be found in Section 4.1.
Each graph   (V, E  ) ∈ G in this set has  nodes, and the corresponding node set and edge set are denoted as V and E  , respectively.The graphs share the same set of nodes.The set of neighboring nodes of node  is denoted as N  .We focus on the setting where the input graphs are weighted, and we represent the weighted adjacency matrix of each input graph   as A  ∈ R × .The node features in   are represented by a matrix X  ∈ R × , where its -th row X  [ , :] represents the features of the -th node, and  refers to the dimensionality of the node features.For conciseness, we use X Algorithm 1 Iterative Sparsification Framework INPUT: Sparsification Method , Input Graph Set G 1 , Graph Labels Y, Training Set Index 1 Train , Validation Set Index 1 Val , 14: M = .MaskTrain (GNN, G  , Y, 1 Train ) 16: Validation loss   = Train&Val (GNN, 17: G +1 , Y, 1 Train , 1 Val ) OUTPUT: G  with smallest   across all training samples in a batch-training fashion to achieve our objective of learning a shared mask.Subsequently, we convert this soft mask into an indicator matrix by assigning zero values to the lowest  percentage of elements: 3.2.3Algorithm.We incorporate these three strategies into IGS and outline our method in Algorithm 2: Algorithm 2 Interpretable Graph Sparsification: IGS INPUT: Input Graph Dataset G 1 , Training Set Index 1 Train , Validation Set Index 1 Val , Removing Percentage , Number of Iterations  , GNN model, Regularization Coeffient  for i = 1, . . ., N do // Step 1: GNN Training with Edge Importance Mask if  == 1 then Initialize Φ using Xavier normal initiation.else Initialize Φ using the previous joint gradient map T ( )  (Φ  + Φ) ← Train (GNN, G  , Y, 1 train , ).(Equation (1)) Step 3: Normal Training with Sparsified Graphs Validation loss   , GNN_Trained = Train&Val (GNN, G +1 , Y, 1 Train , 1 Val ) // Step 4: Leveraging Gradient Information

Table 3 :
Final sparsity of the sparsified brain graphs identified by IGS averaged over different splits.The initial sparsity is 50% by thresholding.IGS can remove more than half of the edges while achieving up to 5.1% performance boost.

Table 4 :
Performance of IGS with different GNN backbones, following the same setup in Section 4.1.The performance improvements achieved by IGS are model-agnostic.