COMET: Learning Cardinality Constrained Mixture of Experts with Trees and Local Search

The sparse Mixture-of-Experts (Sparse-MoE) framework efficiently scales up model capacity in various domains, such as natural language processing and vision. Sparse-MoEs select a subset of the"experts"(thus, only a portion of the overall network) for each input sample using a sparse, trainable gate. Existing sparse gates are prone to convergence and performance issues when training with first-order optimization methods. In this paper, we introduce two improvements to current MoE approaches. First, we propose a new sparse gate: COMET, which relies on a novel tree-based mechanism. COMET is differentiable, can exploit sparsity to speed up computation, and outperforms state-of-the-art gates. Second, due to the challenging combinatorial nature of sparse expert selection, first-order methods are typically prone to low-quality solutions. To deal with this challenge, we propose a novel, permutation-based local search method that can complement first-order methods in training any sparse gate, e.g., Hash routing, Top-k, DSelect-k, and COMET. We show that local search can help networks escape bad initializations or solutions. We performed large-scale experiments on various domains, including recommender systems, vision, and natural language processing. On standard vision and recommender systems benchmarks, COMET+ (COMET with local search) achieves up to 13% improvement in ROC AUC over popular gates, e.g., Hash routing and Top-k, and up to 9% over prior differentiable gates e.g., DSelect-k. When Top-k and Hash gates are combined with local search, we see up to $100\times$ reduction in the budget needed for hyperparameter tuning. Moreover, for language modeling, our approach improves over the state-of-the-art MoEBERT model for distilling BERT on 5/7 GLUE benchmarks as well as SQuAD dataset.

The literature on Sparse-MoE has traditionally focused on Topk gating, which selects  out of  experts using a Top-k operation [14,47,58].Top-k gating is simple and efficient because it allows sparse training.However, as highlighted by prior literature [14,20,58], the non-continuous nature of Top-k makes it susceptible to stability and convergence issues.Alternative gating strategies exist in the literature, based on reinforcement learning [4] or postprocessing via linear assignment [8,34].However, these strategies also face challenges in terms of efficiency and interpretability; see related work in Section 2 for more details.Random routing strategies [44,59] alternatively bypass learning of the gating function altogether.Although computationally efficient, these strategies lead to performance degradation [8].Recent work [20] demonstrates that differentiable gating in Sparse-MoE can improve stability and performance compared to popular non-differentiable gates.However, it suffers from expert collapse in some cases as we observed in our experiments.
In this paper, we propose two new approaches for improving routing in Sparse-MoE.First, we introduce a novel differentiable sparse gate COMET 1 that improves over existing state-of-the-art sparse gates [14,20,44,47,58].Second, we argue that the combinatorial nature of expert selection in Sparse-MoE presents a serious challenge for first-order methods.In particular, the performance of these methods is highly dependent on initialization, and they can get stuck in low-quality routing solutions.Thus, we propose a new permutation-based local search method for Sparse-MoEs, which can help first-order methods escape low-quality initializations or solutions.Our local search approach is general and can be applied to any sparse gate, including Top-k [47], Hash routing [44], DSelect-k [20], and our proposed gate COMET.
COMET.Our proposed COMET gate is the first decision-treebased selection mechanism for sparse expert selection -decision trees naturally perform per-sample routing (i.e., each sample follows a root-to-leaf path).Our gate has several advantages: (i) it is differentiable and can be optimized using first-order optimization methods e.g., stochastic gradient descent; (ii) it allows (partially) conditional training, i.e., dense-to-sparse training; (iii) it enforces a cardinality constraint, i.e., selects (at most) k out of the n experts; (iv) it has superior predictive performance over state-of-the-art gates such as Hash routing, Top-k, and DSelect-k.
Local Search.The learning problem underlying Sparse-MoEs is of combinatorial nature, which poses additional challenges compared to non-MoE machine learning models.Popularly used optimization methods, such as SGD, may lead to low-quality solutions in Sparse-MoE, as we demonstrate in our numerical experiments in Section 5. To this end, we propose a permutation-based local search method, which can help first-order methods escape bad initializations and lead to better sample routing for any sparse gate e.g., Top-k, Hash routing, DSelect-k and even COMET.To the best of our knowledge, we are the first to explore local search methods in the context of Sparse-MoE.We provide empirical evidence through ablation studies and large-scale experiments to demonstrate permutationbased local search (i) pushes learning towards better gate/expert initializations in early optimization stages (see Section 4.4); (ii) effectively reduces the budget needed for hyperparameter tuning by up to 100× for some popular gates e.g., Hash Routing and Topk (see Section 5.1.3);(iii) leads to SOTA performance in terms of prediction and expert selection when combined with COMET, across various applications (see Section 5).
Contributions.As discussed earlier, it is well-known in the literature that popular sparse gates are challenging to train and may suffer from stability and performance issues.In this context, our contributions can be summarized as follows: • We propose COMET, a novel tree-based sparse gate that simultaneously has the following desirable properties: (a) differentiable, (b) allows (partially) conditional training i.e., dense-to-sparse training, and sparse inference, (c) satisfies per-sample cardinality constraint (selects at most  out of the  experts per-sample, where  is a user-specified parameter).

RELATED WORK
Sparse-Mixture-of-Experts. The MoE framework was introduced by [27], and since then has been extensively studied -see e.g., [26,28,29].More recently, [47] proposed a Sparse-MoE framework, based on the Top-k gate, and showed good performance on natural language processing tasks.It was further improved upon by [14,56,58].However, Top-k gate does not optimize the core expert selection problem as pointed out by [8].Additionally, as highlighted by prior literature [14,20,58], the non-continuous nature of Top-k makes it vulnerable to training stability and convergence issues.
With BASE Layers, [8,34] formulate Sparse-MoE as an assignment problem where they post-process the gate output for balanced expert selection.[4] formulates the expert selection as a reinforcement learning problem.Others [44,59] proposed random routing strategies that do not learn the gating function during training.These methods are also promising as they have been shown to outperform models that learn routing through Top-k, e.g., in Switch Transformers [14].Lastly, [20] introduced DSelect-k, a differentiable gate based on binary encodings, which improves over Top-k in terms of stability and statistical performance.
Conditional Computation.In addition to the Sparse-MoE framework, there are other related works that also study conditional computation, i.e., the setup where only some parts of neural network are activated based on the input -see e.g., [4,5,23,50].These works rely on heuristics where the training and inference models are different.More recently, [19] introduced conditional computation in differentiable (a.k.a.soft) trees [15,19,21,22,24,27].Their proposal allows routing samples through small parts of the tree; thus allowing for conditional computation with customized algorithms.Our work builds upon this approach to solve the cardinality-constrained expert selection problem in Sparse-MoE.Note that [19] does not address sparse expert selection in Sparse-MoE.
Local Search and Permutation Learning.There is an extensive optimization literature on local search, e.g., [3,18].However, such methods have not been used in Sparse-MoE.Here, we survey permutation learning methods that are most relevant to our proposal.This work uses differentiable relaxations of permutation via Sinkhorn operators [1,37].These earlier works use these relaxations in other contexts e.g., ranking in [1] and sorting in [37].We use permutation learning as a local search to complement first-order optimization methods to improve sample routing in Sparse-MoE.

LEARNING SPARSE MIXTURE OF EXPERTS WITH DECISION TREES
Problem Setup of Sparse-MoE.We first review the Sparse-MoE objective.We assume that the task has an input space X ⊆ R  and an output space Y ⊆ R  .Denote the -dimensional simplex by The goal of Sparse-MoE paradigm is to develop a gate that selects a convex combination of at most  out of the  experts.The output of the gate can be thought of as a probability vector  with at most  nonzero entries, where (•)  is the weight assigned to the expert   .The underlying optimization problem (also in [20]) is: ) ∥(•) ∥ 0 denotes the number of nonzero entries in the vector (•), ℓ (•, •) is the associated loss function such that ℓ : Y × X → R, and  denotes the size of training samples D = {(  ,   ) ∈ X × Y}  =1 .The cardinality constraint in (1b) ensures that the gate selects at most  experts.Some popular gates e.g., Top-k impose exact cardinality constraint instead of an inequality constraint in (1b).However, the inequality constraint can allow for sparser expert selection as observed in prior work [20] and in our experiments (Section 5).Problem (1) is a combinatorial optimization problem that is not amenable to stochastic gradient descent due to the cardinality constraint in (1b).In the next sections, we discuss our formulation that ensures the cardinality constraint and the simplex constraints are satisfied despite optimization with gradient-based methods.
The rest of this section is organized as follows.In Section 3.1, we discuss a high-level overview of our novel tree-based framework, which equivalently sets up the cardinality-constrained objective in problem (1) as a weighted sum of decision trees.Next in Section 3.2, we provide background on a single decision tree that selects a single expert per-sample while (i) allowing for smooth optimization, and (ii) conditional computation support -routing samples to a single leaf.Later in Section 3.3, we dive deeper into our novel treebased framework that combines such trees to satisfy the cardinality constraint for  ≥ 1 without violating the simplex constraint.We additionally highlight important aspects regarding leaf parameterization and regularization.Next, in Section 3.4, we discuss how our method handles settings where experts are non-powers of 2. We then discuss in Section 3.5 an implementation of COMET for numerically stable training.

Sparse-MoE with 𝑘 decision trees
The cardinality constrained MoE objective (1) can be formulated equivalently using a set of decision trees.Classical decision trees are naturally suited to route each sample to a single leaf with a chain of hierarchical decisions.In the case of  = 1, we propose a single decision tree to route samples, where each leaf node is associated with an expert.In cases where  > 1, we instantiate  different decision trees and combine their output in a way that enforces the cardinality and simplex constraints in (1b).
Given that classical decision trees are not amenable to differentiable training with first-order methods, we use a variant [19] of differentiable (a.k.a.soft) decision trees [19,21,22,27,30].We build upon this work to solve the cardinality-constrained problem (1).We first provide a summary of a single soft tree (with conditional computation support) in Section 3.2.This serves as a building block for selecting a single expert per-sample.

Preliminaries: Differential Decision Tree with Conditional Computation
In this section, we provide a brief summary of a variant [19] of a differentiable (a.k.a.soft) tree [21,24,27,29,30], which we use to enable single-expert selection in Sparse-MoEs.We extend it in the next section to solve the cardinality constrained problem for a general case  ≥ 1. Differentiable decision trees are similar to classical decision trees with hyperplane splits [38].However, they route each sample to left and right with different proportions, i.e., each sample reaches all leaves.Traditionally, differentiable decision trees have been unamenable to conditional computation as they cannot route a sample exclusively to the left or to the right.Recent work [19] introduced a variant of the differentiable tree model that supports conditional computation.Here, we discuss a brief summary of this variant.
We denote a single tree by  : X → Δ  , which maps an input sample  ∈ X to a probability vector  over Δ  .Here,  corresponds to the number of root-to-leaf paths (also equal to number of experts in the MoE paradigm).Let  be a binary tree with depth  -note our framework can naturally support cases where number of experts is non-powers of 2, see Section 3.4 for more details.Let I and L denote sets of the internal (split) nodes and the leaves of the tree, respectively.For any node  ∈ I ∪ L, we define  () as its set of ancestors.Let { → } denote that a sample  ∈ R  reaches .Sample Routing.Following prior work [19,21,30], we will discuss sample routing using a probabilistic model.While sample routing is discussed using probability, differentiable trees are deterministic.Differentiable trees are based on hyperplane splits [38], where a linear combination of the features is used in making routing decisions.In particular, we assign a trainable weight vector   ∈ R  with each internal node, which parameterizes the node's hyperplane split.Let ℎ : R → [0, 1] be an activation function.Given a sample  ∈ R  , the probability that internal node  routes  to the left is defined by ℎ(  • ).
Now we summarize how to model the probability that  reaches a certain leaf  [19,21,30].Let [ ] (resp.[ ]) denote the event that leaf  belongs to the left (resp.right) subtree of node  ∈ I.The probability that  reaches  is given by: Pr({ →  }) =  ∈ ( )  , (), where  , () is the probability of node  routing  towards the subtree containing leaf , i.e.,  , () . Note that the vector  () given by defines a per-sample probability distribution over the  leaves (or experts).Next, we discuss how the split probabilities {ℎ(  • ), 1 −ℎ(  • )} can achieve binary state with a particular choice of activation function ℎ -this is crucial for achieving sparse expert selection (and conditional computation) in the Sparse-MoE paradigm.

Smooth-
Step Activation Function.The common choice for activation function ℎ in soft tree literature is a logistic function [15,21,29,30].However, it can not perform hard routing i.e., output exact zeros.This implies that any sample  will reach every node in the tree with a positive probability, leading to a dense .[19] proposed a smooth-step activation function for a variant of soft trees -see Appendix A for details.Despite being continuously differentiable, smooth-step activation function can produce a sparse  (after an initial warm-up period of soft routing) for hard routing.This is crucial for a sparse expert selection in Sparse-MoE paradigm.Additionally, this choice of activation function also allows for (partially) conditional training with customized sparse backpropagation algorithms in soft trees (as shown in [19]), which is an important consideration for training large-scale Sparse-MoE models.
For cardinality-constrained Sparse-MoE learning with trees (not studied in [19]), the goal for each tree is to perform hard routing for all samples.Therefore, we add additional regularization on {ℎ(  • ), 1 − ℎ(  • )} to encourage convergence of  to a onehot state (discussed in more detail in Section 3.3).

Cardinality constraint with 𝑘 trees
Next, we discuss how to achieve the cardinality constraint ( ≥ 1) in Sparse-MoE with decision trees in the presence of simplex constraint.This key ideas are given as follows: • We consider  decision trees, where each tree  selects a single expert via  (  ) (•) as defined in (2).
• With the experts selected as above, we need to decide the relative weights assigned to each expert.This is done through auxiliary functions  (  ) (•), where reflects a linear weighting function (in the log space) for -th expert (or leaf) in -th tree.See Figure 1 as an example.Next, we define the prediction function for Sparse-MoE with  decision trees to form COMET.
COMET Prediction with  Out of  Experts.The prediction function for Sparse-MoE with  ≥ 1 is a weighted sum of the predictions of -th expert (or leaf) across  trees.To this end, we define the weight for -th expert as follows where () is the probability that a sample  will reach expert   in the -th tree.Using (3), the prediction function for Sparse-MoE with  ≥ 1 is given by ŷ =  ∈ []   ()(; , )  .
We present the following proposition (proof in Appendix B): Proposition 3.1.For any , if  (  ) outputs a binary vector for every , the function (; , ) satisfies the cardinality and simplex constraints in (1b).
Accelerating Convergence of  (  ) to One-Hot Encoding with Entropic Regularization.In the Sparse-MoE setup, the goal is to achieve a one-hot vector state for  (  ) quickly -this ensures the cardinality constraint (i.e., to select at most  experts) is respected by the  trees.To encourage faster convergence towards a one-hot vector, we add a per-tree entropy regularizer, Ω( (  ) ()) to the loss objective, where Ω( (  ) ()).Entropy regularizers are used in [20,37] to get binary representations .
Dense-to-Sparse Learning.COMET supports conditional training only partially.At the start of training, it uses all the available experts as  (  ) is completely dense, so conditional training is not possible.As training proceeds,  (  ) becomes sparser due to smoothstep activation function and entropic regularization, eventually achieving binary state.From this stage onwards, the gate satisfies the cardinality constraint per-sample, i.e, each sample gets routed to at most  experts.Hence, sparse training can proceed to refine the solution quality.Empirically, we observe that a small number of epochs are sufficient for the optimizer to reach the sparse training phase.

Non-powers of 2
Typically, in Sparse-MoE, each expert is assigned to a separate machine for efficiency [14,58].This may mean that the number of experts could be defined by the number of machines -machines may not necessarily be available in powers of 2. Our gate naturally handles cases where the number of experts are not chosen to be powers of 2. We propose merging child nodes at the leaf level.In such instances, we have imperfect binary decision trees (Fig. 4 in Appendix) with  nodes, with 2  −  nodes in the ( − 1)-th level, and 2 − 2  nodes in the -th level.Additional details are in Appendix C. In contrast to other differentiable gates (e.g., DSelect-k [20]), our proposed gate COMET does not require any additional regularization to encourage the simplex constraint in (1b).

Stable numerical implementation
Next, we discuss a stable numerical implementation of COMET gate.COMET introduces additional exponential functions in the expert weights (or leaf nodes of the decision trees) -see (3).More exponential functions are known to cause instabilities in Sparse-MoE models.For example, [58] introduced router z-loss in Switch Transformers to encourage smaller logits.However, this may have a performance tradeoff.In our implementation of COMET, we can mitigate instability issues arising from additional exponential functions using the following approach: (i) convert root-to-leaf probabilities to the log-space, log (), (iii) subtract the maximum, i.e., max , ( ()) from each element, (iv) apply a two-way softmax operation to get ().

LOCAL SEARCH
Expert selection is a challenging combinatorial problem that is known to be NP-hard.Although first-order heuristics can usually provide fast solutions, they rely heavily on initialization and are sometimes prone to arriving at low-quality solutions.To this end, we propose a permutation-based local search method that complements first-order methods in optimizing Sparse-MoEs.In both large-scale experiments and ablation studies, we see that the incorporation of local search can improve the performance of any gating method and can significantly reduce the number of tuning trials.
Our approach derives inspiration from the local search methods commonly used along with the first-order methods to help escape local minima in sparse linear models [3,18].We note that this is the first attempt in the literature to incorporate local search methods in the context of Sparse-MoE.Moreover, unlike common local search methods in literature, our proposed search method is differentiable.We want to highlight that our local search method is useful for any existing sparse gate, e.g., Hash routing, Top-k, and our proposed COMET.We hypothesize that our permutation-based approach can help navigate the optimization loss surface for various gates.
The rest of the section is organized as follows.In section 4.1, we formulate a refined cardinality-constrained Sparse-MoE objective with additional binary variables to add support for permutationbased local search.Then, in section 4.2, we provide background on permutation and its differentiable relaxation.Next in section 4.3, we outline our differentiable optimization approach for the refined Sparse-MoE objective and some additional practical considerations for computational efficiency.Later, in Section 4.4, we provide an ablation study to support our hypothesis that the local search can help escape bad initializations.

Permutation-based Local Search
In this section, we formulate a refined objective for the cardinalityconstrained Sparse-MoE objective that adds support for permutationbased local search.
Let us denote by S  the set of all permutations of the set [].Given any permutation  ∈ S  , we permute the  experts accordingly and assign -th weight ()  to  ()-th expert instead of -th expert.With this permutation, the prediction for Sparse-MoE could be written as: ŷ =  ∈ []   ( ) ()()  .We note that due to symmetry between experts and weights, permuting the experts is essentially same as permuting the weights.To see this, we can write  ∈ []   ( ) ()()  =  ∈ []   ()()  −1 (  ) , where  −1 is the inverse map of , which is also a permutation.
For a permutation , we can define a corresponding permutation matrix   , by setting   [, ] = 1{ ( ) = }, where 1{•} is an indicator function.Then it is easy to see that  ∈ []   ()() where P local  is a localized set of permutations in the full set of permutations, which we denote by P  .For example, one may only allow for P local  = P 2 , which only allows interchanging (swapping) two columns similar to "swap" operations shown to be useful in the sparse regression literature [18].Besides optimizing the gates and experts, formulation (4) performs local search by optimizing over the permutation matrix.Specifically, the goal of local search here is to find a permutation  that leads to a better solution, i.e., one with a lower objective.Intuitively, if SGD is stuck at a lowquality solution, the permutation may be able to escape the solution by a better reordering of the experts.Standard local search, e.g., bruteforce search may be computationally expensive.Therefore, we resort to a differentiable method that can be optimized efficiently.

Preliminaries: Permutation and a differentiable relaxation
In this section, we briefly summarize how the permutation learning problem is parameterized and later optimized.To parametrize the permutation matrix in the problem, a natural consideration is through the linear assignment problem [32].To illustrate this, consider  people are to complete  tasks and a matrix  ∈ R × ≥0 , the goal is to assign each task to one person so as to maximize the utility given that the utility of assigning task  to person  is    .This leads to the following optimization problem . ( The operator  here is called the Matching operator, which maps a nonnegative matrix  to a permutation matrix  .Problem ( 5) is a combinatorial optimization problem, which admits the following linear relaxation [6]: where B  denotes the set of double stochastic matrices which is a convex hull of the set of permutation matrices P  .However, this is still not a differentiable parametrization as problem (6) might end up with multiple solutions.To this end, Mena et al. [37] proposes a smooth version 2 of the permutation learning objective in (6): and solves it using Sinkhorn operator  (•) [1], defined by the following recursion: where T  ( ) =  ⊘ ( 1  1   ), and T  ( ) =  ⊘ (1  1    ) are the row and column-wise normalization operators of a matrix, with ⊘ denoting the element-wise division and 1  a column vector of ones.The sinkhorn procedure in (8) allows differentiable training with first-order methods, making it appealing as a local search method for Sparse-MoE.
As shown in [37],  ( ) can be obtained as lim →0 +  ( /), and thus lim →0 + , →∞   ( /).In practice, we set a max number of iterations  for normalization in (8b) as well as a small positive number  > 0, and use   ( /) to approximate the limit (8c).In this way, we are able to parametrize the permutation matrix  in (4) as a differentiable function   ( /) of learnable matrix  .However, additional considerations are needed to ensure that a hard permutation matrix can be achieved quickly in a few epochs -this is important in Sparse-MoE paradigm for computational reasons and a well-defined measure of sparsity.We discuss these in the next section.

Practical considerations for optimization
Next, we discuss some empirical considerations for the end-to-end learning approach that are important for Sparse-MoE.
Need for a hard permutation matrix.We would like to have a hard permutation matrix at inference time and ideally during the course of training, for exact sparsity and computational efficiency considerations.First, the gate does not perform sparse inference if the learnt permutation matrix is not a hard matrix.For example, even if (•) is sparse, the refined weights  • (•) are not a sparse vector if  is not a binary matrix.This would result in a dense mixture of experts.Second, some sparse gates perform dense-to sparse-training (partially conditional training), e.g., DSelect-k, COMET, or variants of Top-k [41].If the learnt permutation matrix is not hard, then sparse training cannot proceed in the later stages of optimization.To this end, we employ a two-stage optimization approach: (i) in the first stage, we simultaneously train the network (experts and gates) and the permutation for a small number of epochs.(ii) In the second phase, the permutation matrix is fixed and only the remaining network (experts and gate) is trained.Therefore, local search is only used in the early stages of training.Empirically, we observe that a small number of epochs (1 − 10) is sufficient to learn a good permutation in the first stage and improve solution quality.Since local search is restricted to the first stage, the computational efficiency of gates that perform dense-to-sparse training is not affected by much -please refer to Appendex D.3 for additional discussion.
In the two-stage approach outlined above, there is a transition from a soft to a hard matrix between the two stages.As we mentioned earlier, we use   ( /) to approximate  ( ) as a limit of  → ∞,  → 0 + .In practice, the transition could be not continuous, as this approximation does not always reach a hard permutation matrix given that  is finite and  is nonzero.Therefore, at the transition point, we propose to convert the "soft" permutation matrix   ( /) to a hard one via the linear assignment problem given in (5), by invoking  as   ( /).In addition, empirically, small  can lead to numerical instabilities for small  [37].Therefore, to decrease deviance of   ( /) from the closest hard permutation matrix, we introduce two schedulers on  and  that increase  for decreased : (i) Ramp up (linearly)  from 20 to 150, (ii) Ramp down (linearly in log-scale)  from 10 −3 to 10 −7 .
Although the above schedulers decrease the deviance between soft and its closest hard permutation matrix at the transition point, the method still appears to suffer from pseudo-convergence.In particular, we observed, some row-columns can converge to fractional entries i.e., a 2x2 sub-block having all entries with 0.5.Therefore, we introduce small separate row-wise and column-wise entropic regularizations to mitigate such degenerate cases:   ∈ [] (Ω(S  ( /)  ) + Ω(T  (S  ( /))  )), where  ≥ 0.
Implicit localization.In the spirit of common local search approaches, a potential optimization approach could alternate between optimization of network (experts and gates) and permutation matrix.However, this is unnecessary because the differentiable relaxation of permutation is also amenable to first-order methods.Therefore, our approach jointly optimizes both the network and the permutation matrix.We noted earlier that the search space for permutation is "localized" out of the full set of permutation matrices P  .This localization is implicitly imposed through the smooth optimization of the permutation matrix via Sinkhorn.The permutation matrix learning relies on the initialization for  and at each gradient step the  ( ) is naturally expected to not deviate drastically from  ( −1) .Since the permutation matrix is updated for a limited number of steps in first stage, intuitively it cannot deviate significantly from the initial permutation matrix.This also defines an implicit neighborhood.

Ablation study for local search
In this section, we provide an ablation study to provide evidence that the permutation-based local search can complement first-order optimization methods for routing in Sparse-MoE.The study highlights that local search can improve solution quality through escape out of bad initializations in the first stages of optimization for different types of routing strategies: (a) fixed gates, (b) trainable gates.We perform this study on a subsampled (200k) MovieLens dataset and use the same MoE architecture with 16 experts as the one described in Supplemental Section S1.2.We trained models for only 10 epochs without/with local search, where in the latter case we fixed the number of epochs for permutation learning to 5 epochs and  = 10 −5 .We used a batch size of 512 and learning rate of 2.5×10 −5 .We repeat the training with 100 different random initializations and compute averages along with their standard errors.
Fixed Gates.In fixed gating strategies e.g., random hash routing (Hash-r), the samples are pre-assigned to experts.For example, in natural language processing tasks, tokens or words in vocabulary are clustered randomly [44] before training begins into groups and each group of words are assigned to a random expert in the set of experts.In our experiments on recommender systems, we randomly pre-assigned samples to experts based on user index for Hash-r (and Hash-r+).It is possible that the same group of users could be better aligned with another expert based on expert and user embedding initializations.Permutation-based local search can potentially find better assignment of each group to a more suited expert.We provide empirical evidence to demonstrate that local search indeed can find better loss.We report the average out-of-sample loss achieved by both Hash-r and Hash-r+ in Table 1.Learning permutation appears to help map each pre-assigned cluster of users to a more suitable expert based on expert initialization for second stage of optimization.
Trainable Gates.For trainable gates, we also study the effect of local search on non-differentiable (Top-k) and differentiable gates (COMET ).We fixed  = 2 for both types of gates and followed the same training protocol for 10 epochs.For COMET (and COMET+), we fixed  = 0.01 (for smooth-step) and  = 1 (for entropic regularization).For Top-k+and COMET+, we fixed the number of epochs for permutation learning as 5.We repeated this exercise for 100 different random initializations of the experts and gates.We report the average out-of-sample objective achieved by both types of gates in Table 1.We can observe that local search appears to complement first-order optimization methods by learning better initializations in the first stage of Sparse-MoE optimization for later learning.
The practical significance of local search achieving a better test objective across many initializations for various gates can be seen in terms of reducing hyperparameter tuning overhead as discussed in Section 5.1.3.

EXPERIMENTS
We study the performance of COMET and COMET+ in recommender systems and image datasets in Section 5.1 and COMET-BERT in natural language processing tasks in 5.2.We also study the effect of local search for various gates.We denote our methods in italics.

Experiments on Recommender Systems and Image Datasets
We study the performance of COMET and COMET+ in recommender systems and image datasets.We compare with state-of-the-art gates and baselines including Softmax, Top-k, DSelect-k and Hash routing (Hash-r) on recommender systems (MovieLens [17], Jester [16], Books [57]) and image datasets (Digits [10,40], MultiMNIST [46], MultiFashionMNIST [20], CelebA [35]).We also include an ablation study in Section 5.1.2that shows that COMET achieves good performance with much less trials than existing popular gates e.g., Hash routing and Top-k.Additionally, in Section 5.1.3,we show that Hash-r+,Top-k+, and COMET+ with local search can potentially achieve good performance with much less trials than Hash-r, Top-k and COMET respectively.

Implementation.
We provide an open-source implementation of COMET and COMET+: https://github.com/mazumder-lab/COMET. Experimental Setup.Although our exposition in Section 3 was for a single-task setting, the same gate can also be used in multi-task learning -multi-task requires multi-gate MoE architecture [36], where each task has a separate trainable gate, but tasks have to select from a common set of experts.We briefly summarize the key aspects for each dataset.For MovieLens/Books/Jester we have two tasks: classification task predicts whether user watches/read/rates a particular movie/book/joke, regression problem predicts user's rating.Loss is the convex combination of the two binary crossentropy (for classification) and mean squared error (for regression) with task weights: {, 1 −  }.We separately present results for two different 's:  ∈ {0.1, 0.9}.For MultiMNIST/MultiFashionMNIST, there are two multi-class classification tasks, which are equally weighted.For CelebA, there are 10 binary classification problems, which are equally weighted.Lastly, for Digits dataset, we have a multi-class single-task classification cross-entropy objective.Full details about datasets and MoE architectures are in Supplement Section S1.We used Adam for optimization, and we tuned the key hyperparameters using random grid search.Note that for Hash-r+, COMET+ and Top-k+, we only allocate a very small portion of the epochs (1-10) for permutation learning.Full details about the hyperparameter tuning are given in Supplement Section S1.

Performance of COMET and COMET+.
In Tables 2 and 3, we report the (average) test loss and the average number of selected experts per sample across multiple recommender and vision datasets.The results indicate that COMET and COMET+ lead on many datasets, outperforming popular state-of-the-art gating methods e.g., Hash-r, Top-k and DSelect-k in test loss.Our proposed gate COMET can outperform standard routing techniques (without  7 in Appendix E. We observe COMET+ can improve AUC by up to 13% over Hash routing and Top-k, and 9% over DSelect-k.We observe that Top-k gate does not uniformly outperform the Softmax across multiple datasets.However, Top-k+ significantly improves the performance of Top-k across multiple datasets.In fact with the permutation module, Top-k+ outperforms Softmax in all cases, so sparsity in gating seems to be beneficial on all these datasets. Inference Sparsity.We see that COMET and COMET+ can sometimes lead to a smaller number of experts selected than that for Top-k.This leads to smaller number of FLOPs at inference time (see Appendix D.2).For some settings, DSelect-k appears to arrive at a sparser selection than COMET+; however, in these cases, DSelect-k loses significantly in terms of performance.We observed expert collapsing in DSelect-k in such cases.
Timing Discussion.For cost complexity of COMET, please see Appendix D.1.Additionally, we discuss the computational aspects of the local search in Appendix D.3.hyperparameter trials.This indicates that COMET is not too heavily dependent on a very restricted set of hyperparameter values.We visualize this for various datasets in Fig. 2. We see tuning reduction by a factor of 5×−100× for COMET over popular gates.

Effect of Local Search on Hyperparameter
Tuning.Here, we study how local search can be beneficial in terms of hyperparameter tuning.We study this effect for Hash-r, Top-k and COMET.We visualize this in Fig. 3 for MovieLens for both Hash-r+, Top-k+and COMET+.We observe that we can achieve comparable performance with much smaller number of trials.We see tuning reduction by a factor of 3×−100× for Hash-r+, 20×−100× for Top-k+ and 2×−5× for COMET+.This suggests that permutation-based local search helps escape out of bad initializations.Such favorable properties of local search in terms of reducing the hyperparameter tuning load for existing gates can be beneficial for Large Language Models.

Experiments on NLP Tasks
In this section, we consider a setting where a pretrained large model (non-MoE based) needs to be distilled for a more efficient inference while preserving or improving the best performance.Following [60], we study a distillation setting, where BERT [11] is distilled into its Sparse-MoE based variant.Specifically, the FFN layers are replaced with MoE layers -this can result in a ∼2× smaller number of (effective) parameters with per-sample sparse routing (for  = 1), thus allowing for more efficient inference.
MovieLens ( = 0.1) MovieLens ( = 0.9)  Following [60], we use an importance-weight guided distillation strategy: (i) Finetune BERT on a downstream task.(ii) Compute importance weights in FFN layers to construct an MoE-based variant of BERT.(iii) Distill BERT into MoE-based variant on the downstream task with a layer-wise discrepancy loss.[60] used Hash routing in their MoEBERT model.We propose COMET-BERT (MoE based BERT model with COMET / COMET+ gating) and evaluate the performance on the GLUE benchmarks [49] and SQuAD benchmark [42].More details about the benchmarks are given in Supplement Section S2.1.
Implementation.We implemented COMET-BERT in HuggingFace [54] and adapted the codebase of [60].Unlike Hash routing, our gates can also cater to  ≥ 1.However, for consistent comparison in terms of inference, we set  = 1.Tuning details are outlined in Supplement Section S2.2.Code for COMET-BERT is available at https://github.com/mazumder-lab/COMET-BERT.
Results.We report the performance metrics in Table 4 for 7 GLUE datasets and SQuAD dataset.COMET-BERT outperforms MoEBERT in 5/7 benchmarks on GLUE datasets.COMET-BERT also outperform MoEBERT significantly on SQuADv2.0.Notably, in 5 of these datasets (CoLA, MRPC, QNLI and MNLI, SQuAD v2.0), COMET-BERT achieves SOTA performance when distilling BERT, (when compared with all distillation methods in literature with same number of effective parameters for inference).We show that gates that learn sparse routing decisions per sample, e.g., Top-k, DSelect-k, COMET, significantly reduce the number of FLOPs (3×−6×) at inference time in comparison to dense gates e.g., Softmax.Additionally, we see that in all 4 cases, COMET has smaller number of FLOPs (1.1×−1.6×)than the highly popular Topk gate.We also outperform DSelect-k in some cases in number of FLOPs.While in some cases, we have larger number of FLOPs than DSelect-k, our AUC is higher (up to 9%) in these cases.

D.3 Effect of local search on computation
Inference.Note that the permutation matrix is global and not sample specific.At inference time, multiplying permutation matrix  with () amounts to a reordering of the expert indiceshence, additional cost for this permutation is negligible compared to evaluation of  () and ().
Training.In the first stage of COMET training (a few epochs ∼ 5), the training is dense (requiring all experts per sample).For COMET+, we also learn the permutation matrix during this stage.There is a small additional computational cost: (a) permutation matrix of size  ×, where  is the number of experts, e.g., 16; (b) cost of Sinkhorn operator which constitutes row/column sum normalizations.This cost is marginal compared to the cost of evaluating the experts  ′  , each of which is an MLP/CNN.In the second stage of training, where the samples are being routed to a small  (= 2) subset of experts per-sample, there is no additional cost for COMET vs COMET+.To show an example, for MovieLens 200k, where we learn permutation matrix in first 5 epochs, the total time for 50 epochs (on 4 GPUs) is given by: 494s for COMET and 496s for COMET+.Note 50 epochs were sufficient to achieve convergence for both gates.

E TASK-SPECIFIC METRICS CORRESPONDING TO TABLES 2 AND 3
We provide task-specific metrics for all recommender systems and image datasets in Table 7.We observe COMET+ can give superior AUC performance by up to 13% over Hash routing and Top-k, and 9% over DSelect-k.SQuAD.We evaluate our sparse routing approaches on question answering dataset: SQuAD v2.0 [42].This task is treated as a sequence labeling problem, where we predict the probability of each token being the start and end of the answer span.Statistics of the question answering dataset (SQuAD v2.0) are summarized in Table S2.Note that this step matched the performance numbers reported for BERT-base in Table 1 of [60].We used the best model (for each dataset) for the remaining steps below.• Compute importance weights in FFN layers to construct an MoEBERT/COMET-BERT model, where FFN layers are replaced with MoE layers with the weight assignment strategy in [60].• Distill BERT into MoEBERT or COMET-BERT on the downstream task with a layer-wise discrepancy loss.For MoEBERT, we used the optimal hyperparameters reported (based on ∼ 1000 trials per dataset) in Table 7 of Supplement in [60].For COMET-BERT, we performed 100 tuning trials via random search with each COMET and COMET+ and picked the best results based on development datasets.The hyperparameters were randomly selected from the following sets: -Learning Rate: Discrete uniform over the set {1 × 10 ≥ 0}.In the MoE framework, the prediction function has two components: (i) a set of  experts (parametrized by neural networks)   : X → R  for any  ∈ [] := {1, 2, . . ., }, and (ii) a gate  : X → Δ  that outputs weights in the probability simplex.Given a sample  ∈ X, the corresponding output of the MoE is a convex combination of the experts with weights ():  ∈ []   ()()  .

Figure 2 :
Figure 2: Sensitivity of COMET to hyperparameter tuning.COMET can achieve the same level of performance as popular gates (e.g., Hash-r and Top-k) with significantly lesser number of hyperparameter trials.We see tuning reduction by 5×−100× for COMET over Top-k and Hash routing.

Figure 3 :
Figure 3: Effect of local search on hyperparameter tuning.Comparison of Hash-r+ vs Hash-r, Top-k+ vs Top-k and COMET+ vs COMET on MovieLens with two different task weight settings.Local search appears to achieve the same level of performance with much lesser number of hyperparameter trials.We see tuning reduction by a factor of 3×−100× for Hash-r+, 20×−100× for Top-k+ and 2×−5× for COMET+.
Our local search method is general and can be applied to any gate, e.g., Hash routing, Top-k, and COMET.• We perform extensive experiments on recommender systems, vision and natural language processing tasks to highlight that COMET and COMET+ (COMET combined with local search) can give boosts in predictive performance.In particular, on recommender and image datasets, we observed that COMET+ can improve AUC performance by up to 13% over existing sparse gates e.g., Top-k and Hash routing.It can also reduce tuning by up to 100× over popular gates e.g., Hash routing, and Topk.Similarly, in natural language processing applications, our COMET-BERT model (MoE based variant of BERT with COMET / COMET+ gating) can outperform state-of-the-art Hash-routingbased MoEBERT model [60] on 5/7 GLUE benchmarks as well as SQuAD dataset for distilling pre-trained BERT model [11].
• Popular first-order methods used to optimize Sparse-MoEs are heavily influenced by expert and gate initializations, and may get stuck in low-quality solutions.Hence, we introduce a novel permutation-based local search method that can complement first-order methods by helping them escape bad initializations or solutions.

Table 1 :
Test loss (×10 −2 ) achieved for different gates without and with (marked with +) local search in early stages of optimization.Asterisk(*) indicates statistical significance (p-value<0.05) over the corresponding gate without permutation with a one-sided unpaired t-test.

Table 2 :
Tess Loss (×10 −2 , the smaller the better) and number of experts per sample (/) for COMET, COMET+ and benchmark gates across various recommender system datasets.Asterisk(*) indicates statistical significance (p-value<0.05) over the best existing gate, using a one-sided unpaired t-test.

Table 3 :
Tess Loss (×10 −2 , the smaller the better) and number of experts per sample (/) for COMET, COMET+ and benchmarks gates across various image datasets.Asterisk(*) indicates statistical significance (p-value<0.05) over the best existing gate, using a one-sided unpaired t-test.Even without local search, COMET is getting relatively good solutions.We hypothesize that the good performance of COMET is due to a combination of factors including differentiability, and -decision trees formulation.With local search, COMET+ can sometimes further enhance solution quality.We also provide taskspecific metrics (AUC/Accuracy/MSE) in Tables Search with COMET.Here, we study how our differentiable COMET gate (that performs dense-tosparse training) can be beneficial in terms of hyperparameter tuning over popular gates such as Hash routing and Top-k.We perform a large set of tuning trials and perform a bootstrapping procedure (discussed in Appendix F) to see whether COMET helps in reducing the hyperparameter tuning overload.COMET can achieve the same level of performance as popular gates with much lesser number of

Table 4 :
[14]ormance metrics on the GLUE and SQuAD development sets.Models are trained without data augmentation.Both models have 66M (effective) parameters for inference.thatamodel performs per sample -is a standard measure to evaluate the inference speed for Sparse-MoE e.g., in[14]etc.
and Hash routing.(iii)help COMET-BERT achieve state-of-the-art results for distilling BERT on GLUE and SQuAD benchmarks.operations

Table 7 :
Test AUC/Accuracy/MSE for COMET+ and benchmark gates on recommender systems and image datasets.Randomly sample  ( ∈ {1, 2, 5, 10, 15, • • • , 250}) trials from the bag of a larger set of 500 trials.•Findthetrial with the best validation loss.•Compute the test loss for that trial.•Repeatthis exercise for 1000 times.•Compute the average test loss across the best selected trials.

Table S1 :
Summary of GLUE benchmark.

Table S2 :
Summary of SQuAD benchmark.