Enhancing Heterogeneous Federated Learning with Knowledge Extraction and Multi-Model Fusion

Concerned with user data privacy, this paper presents a new federated learning (FL) method that trains machine learning models on edge devices without accessing sensitive data. Traditional FL methods, although privacy-protective, fail to manage model heterogeneity and incur high communication costs due to their reliance on aggregation methods. To address this limitation, we propose a resource-aware FL method that aggregates local knowledge from edge models and distills it into robust global knowledge through knowledge distillation. This method allows efficient multi-model knowledge fusion and the deployment of resource-aware models while preserving model heterogeneity. Our method improves communication cost and performance in heterogeneous data and models compared to existing FL algorithms. Notably, it reduces the communication cost of ResNet-32 by up to 50% and VGG-11 by up to 10 × while delivering superior performance.


INTRODUCTION
Federated learning (FL) has emerged as a novel machine learning paradigm for distributed clients to participate in the collaborative training of a centralized model.FL brings model synchronous training at the network edge, where devices (e.g., mobile phones and IoT devices) extract the knowledge on the private-sensitive training data and then upload the learned models to the cloud for aggregation.FL stores user data locally and restricts direct access to it from cloud servers; thereby, this paradigm not only enhances privacy-preserving but also introduces several inherent advantages, including model accuracy, cost efficiency, and diversity.With the massive demand for data in today's machine learning models and the social considerations of artificial intelligence (e.g., privacy and security [40,4]), federated learning has great potential and a role in counterpoising this trade-off.
Federated learning has already shown its potential in practical applications, including health care [32], environment protection [14], and electrical vehicles [30] to name a few applications.Intel, Google, Apple, and NVIDIA are using FL for their applications nowadays (e.g., Intel's OpenFL [6], Google's Keyboard [41,3], Apple's Siri [7,29], NVIDIA's medical imaging [22]).In consequence, designing efficient FL models and deploying them effectively and fairly on edge devices is crucial for improving the performance of edge computing in the future.
Traditional FL, like FedAvg [25] in Figure 1 (b), broadcasts global model parameters to selected edge devices, then averages trained local model parameters based on local data to update the global model.This iterative process, while effective, faces limitations [26].Firstly, model weight sharing between network edges and servers introduces significant communication overhead.Secondly, increasing computational and memory demands of AI models, combined with the heterogenous computing power of edge devices, complicate model deployment on resource-constrained devices.Additionally, real-world local data often exhibits imbalance or non-independence (Non-IID), which can result in training failures in decentralized situations.Lastly, over-parameterized deep learning models can cause overfitting when aggregating heterogeneous local models [15,27], leading to high learning and prediction variance.Hence, existing FL methods may result in an unfair and ineffective global model.
We propose a resource-aware Federated Learning (FL) approach, Federated learning using Knowledge Extraction and Multi-model fusion (FedKEM) as illustrated in Figure 1 (a), to overcome previous FL limitations.In FedKEM, each local client is trained with a serverside knowledge network.After local training, each local model's knowledge is distilled into a tiny-size neural network via deep mutual learning, allowing robust global knowledge acquisition from the server.We then ensemble the tiny neural networks, distill this to global knowledge and transfer it to the client for further learning.This reduces communication costs as only the tiny network is exchanged between edge and cloud, during training and inference.
FedKEM offers several advantages: It prevents large edge models' over-parameterization by distilling client knowledge before server aggregation.It cuts communication costs significantly by exchanging distilled tiny-size networks instead of the original large models.By ensembling knowledge from edges, the global model is strengthened, reducing overfitting and variance risks, and improving FL generalization.Additionally, it considers resource limitations, using multi-model fusion to deploy models fairly on edge devices, making FL more realistic.
Our experiments on non-IID data settings and heterogeneous client models show FedKEM significantly reduces communication cost and achieves better performance using fewer communication rounds.

RELATED WORK 2.1 Federated Learning
FedAvg [25] is the original implementation for training decentralized data and preserving privacy in FL.Based on FedAvg, numerous variants have been proposed to optimize FL [21,38,16], especially to track the heterogeneity issue in FL [16,23,46].For example, FedProx [21] is proposed to improve local client training by adding a proximal term to the local loss, FedNova [38] introduces weight modification to avoid gradient bias by normalizing and scaling local updates, SCAFFOLD [16] corrects the update direction to prevent client drift problem by maintaining drift variates.It is worth mentioning that these FL methods aggregate single-model weights for server from edges and incurs extra communication overhead.Other works [34,5,48,47] optimize the communication cost but have no consideration of the computation power heterogeneity of edge devices.Unlike prior works, we aggregate knowledge from the edge and only communicate these tiny knowledge networks between the edge and the server.
Another line in FL is personalized FL, which focuses on the problem of statistical heterogeneity.Personalized FL aims to personalize the global model for each client in FL and find how to develop improved personalized models that can benefit a large majority of clients [18].SPATL [45] introduces a knowledge transfer local predictor that transfers the shared encoder to each client.It further leverages network pruning [44,43] to select salient parameters in communication, only communicating selected parameters between the server and clients.Although we have the same consideration of device heterogeneity (memory storage and computation power), data heterogeneity (non-IID data), and model heterogeneity (model structure and size), we focus on how to extract knowledge from different types of models and their corresponding training devices to build robust global knowledge.

Ensemble Learning and Knowledge Distillation in Federated Learning
Knowledge distillation is first introduced as a model compression technique for neural networks to transfer knowledge from a large teacher model to a small student model [2,12].Ensemble learning is a promising technique to combine several individual models for better generalization performance [9].In light of these two ideas, such methods [42,8,37,28,1] purpose efficient knowledge distillation with an ensemble of teachers to further improve the student performance.
Recently, ensemble learning and knowledge distillation have emerged as an effective approach to address heterogeneity issues, resource-constrained edge devices, and communication efficiency in FL [19,10,31,35,24].For example, Fed-ensemble [35] ensembles the prediction output of all client models; FedKD [39] proposes an adaptive mutual distillation to learn a student and a teacher model simultaneously on the client side; FedDF [24] distills the ensemble of client teacher models to a server student model.In contrast, our approach utilizes knowledge distillation to encode the ensemble knowledge from clients into global knowledge.The novelty of our approach is that FedKEM can extract the local model knowledge that is being learned from the global knowledge into a tiny-size neural network by deep mutual learning [49], and then ensembles the tiny-size neural networks for multi-model fusion.

METHODOLOGY
Figure 2 shows the local updates of FedKEM and Figure 3 shows the cloud updates of FedKEM.In local updates, the client first downloads the knowledge network from the server, then mutually trains the knowledge network with the local model to extract an updated knowledge network, and finally uploads it back to the server.In cloud updates, the server first collects local knowledge (tiny-size network) uploaded from clients, then ensembles all the tiny-size networks, distills them into global knowledge, and finally transfers it to clients.In this section, we will explain our approach in detail.

Knowledge Extraction using Deep Mutual Learning
Traditional FL and its variants [21,38,16] keep the model up-todate by sharing model weights/gradients between server and edge clients.Simply aggregating weights might raise unexpected training failures.We aim to fuse the local model's knowledge to keep the model updated.Hence, we use deep mutual learning [49], to extract the knowledge.
The key idea of deep mutual learning is to train multiple neural networks (NNs) synchronously while minimizing the Kullback Leibler (KL) divergence among the output of the networks.In other words, the KL divergence evaluates the similarity of two distributions.By minimizing the KL divergence among the NNs, NNs can learn knowledge from each other.Therefore, to extract knowledge from the local model, we introduce a knowledge network (a tiny size network compared to the local model) locally and optimize the knowledge network and local model simultaneously using deep mutual learning.
To intuitively explain the knowledge extraction process, we explain it in an image classification task.Formally, in an edge client, we have a local model  and a knowledge network (tiny size network)   .First, to update the local model  .For any input batch of data , we calculate the cross-entropy loss of predictions and ground truth as equation 1.  where  is the mini-batch size,  is the ground truth label, and  is the softmax function.
Then, we compute the KL divergence from  () to   () as equation 2.
where we do element-wise division on ( ( ) ) .We update  by using the total loss as equation 3.
Similar steps are followed to update   .

Local Updates Through Deep Mutual Learning
As depicted in Figure 2, we mutually train the local model and knowledge network, transmitting the knowledge network back and forcing it to update the global knowledge.Knowledge extraction and communication allow edge clients to deploy resource-aware models for the application while keeping communication efficiency and model heterogeneity.Algorithm 1 shows the local update process.

Multi-model Knowledge Fusion
In FedKEM, we provide two model fusion methods for server fusion of the knowledge from the edge.The first one is similar to the traditional FL in that we aggregate the weight.Second, inspired by FedDF, we ensemble all received client models and distill the ensemble knowledge into a global knowledge network.In this section and the experiments, we mainly focus on ensemble the client's knowledge.However, FedKEM can also use traditional fusion methods, such as FedAvg and SCAFFOLD to aggregate the model.
We define the ensemble model as Θ = {   }  ∈ , where the    is the  ℎ client's knowledge network and the  is the set of clients that communicate with the server in current communication round.Then we distillate the knowledge of ensemble Θ to a global knowledge network   by using unlabeled data, generative data, or public data in the server.The distillation loss for   is defined in equation 4.
The server update process is shown in the Algorithm 2.

Ensemble Knowledge
In FedKEM, we investigate three ensemble strategies, i.e., max logits, average logits, and majority vote.We adopt the max logits as the ensemble strategy since the max logits get the best results in practice.For a given input instance , the ensemble model is obtained by the following equation: where the   compares all output vectors and returns a new vector containing the element-wise maxima.

EXPERIMENTS
We conduct comprehensive experiments to evaluate the performance of FedKEM.We separated our experiments into three sections: learning efficiency, communication cost, and multi-model federated learning.In addition, we performed an ablation study with different ensemble methods for FedKEM.

Learning Efficiency
In our evaluation, we assess FedKEM's learning efficiency and optimization, particularly analyzing the correlation between communication rounds and the target model's accuracy.The universal models trained for the baselines include VGG-11, ResNet-20, and ResNet-32, while ResNet-18 is used as the knowledge network for FedKEM, due to its status as a commonly utilized, minimal model in the ResNet family.Figure 4 reveals that FedKEM delivers superior results in most benchmark settings compared to robust FL baselines, exhibiting a stable training process.It significantly surpasses baselines in handling over-parameterized networks like VGG-11 and is particularly adept at managing heterogeneous settings.For example, with 30 clients, FedKEM achieves 70% accuracy after 110 rounds, while all baselines fail to reach this accuracy even after 200 rounds.This gap , FedNova [38], FedProx [21], SPATL [45]: the top-1 test accuracy vs. communication rounds. is the total number of clients,  is the client sample ratio.The model type is the network at local devices and also the network used for communication in the case of the baseline methods).The knowledge network is ResNet-20 in all cases for FedKEM.

N=50 | f=70%
ResNet-20 VGG-11 Figure 5: Comparison of FedKEM with FedAvg [25], Fed-Nova [38], FedProx [21], SPATL [45]: the convergence accuracy. is the total number of clients,  is the client sample ratio.The model type is the network at local devices and also the network used for communication in the case of the baseline methods).The knowledge network is ResNet-20 in all cases for FedKEM.The higher the better.
widens with 50 clients, where baselines fail to attain 50% accuracy after 20000 total local updates.These results demonstrate FedKEM's advantages in stability and consistency, particularly in high heterogeneity FL environments.However, SPATL, although performing well with 30 and 50 clients, does not compete effectively with other baselines when the number of clients increases to 100.
We observe that the knowledge model used in FedKEM doesn't significantly impact its performance.It consistently maintains stable training processes and high final convergence (over 70%) across all settings.The final converged accuracies for 30 and 50 clients are documented in Figure 5. Notably, FedKEM avoids gradient explosions during training, an issue observed in other methods [45].[25], Fed-Nova [38], FedProx [21], SPATL [45]: the communication rounds and overhead needed to read 60% and 30% top-1 accuracy. is the total number of clients,  is the client sample ratio.The model type is the network at local devices and also the network used for communication in the case of the baseline methods).The knowledge network is ResNet-20 in all cases for FedKEM.
Baseline methods primarily concentrate on parameter and gradient aggregation for model fusion in the cloud.However, their aggregation approach, such as FedAvg's weighted averaging, can introduce biases.This is primarily due to the contribution of individual edge models to the FL system being a black box.
On the contrary, FedKEM uses ensemble distillation for model fusion.This approach generalizes heterogeneous edge models effectively, guiding the model towards an optimal direction, resulting in a much more stable optimization process across various non-IID FL settings.

Communication Efficiency
In FedKEM, a key attribute is that we introduce a knowledge network independent of the edge models.Hence, it only communicates the tiny size knowledge network through training and inference, resulting in lower communication costs than SoTA FL algorithms that use models for aggregation and fusion.For example, FedNova can achieve stable training but costs double the average communication cost compared to FedAvg as a result of sharing the extra gradient information.We evaluate the communication cost in two ways.First, we train all the models to a target accuracy and calculate the communication cost.Second, we train all the models to converge and then compare the communication cost for each FL algorithm.The communication cost is represented by: where the round cost is 2× size of the exchanged model (for downloading and uploading the global model in case of baselines or the knowledge model in case of FedKEM). Figure 6 shows the results of communication rounds needed to reach target accuracy.FedKEM requires much fewer rounds to reach 60% accuracy compared to other baselines.For example, with  = 30 and ResNet-20 as the knowledge network, FedKEM needs 49 rounds while the next best, FedNova requires 123 rounds.
When factoring in the communication cost, the difference is even more renounced.With the model size of VGG-11, ResNet-20, ResNet-32 and ResNet-44 given as 21, 1.05, 1.6, and 2.7 Mb respectively, we can calculate the transmission overhead of each method using equation 6.For instance, with ResNet-32 as the knowledge network, the communication burden of SPATL is more than double that of FedKEM to reach the same 60% accuracy (4.76 GB compared to 2.06 GB).
The reduction in communication cost is even more notable when an over-parameter network like VGG-11 is concerned with higher heterogeneity (bigger  ).With  = 100 and ResNet-20 as the knowledge network, FedKEM needs 0.074 GB for transferring while FedProx needs 0.71 GB to get the same result, a one-order of-magnitude cutback.

Multi-Model Federated Learning
Federated learning systems often face challenges in data and resource heterogeneity.While existing FL works primarily focus on addressing data heterogeneity and reducing training performance overhead, resource heterogeneity remains a significant challenge.A uniform model deployed across all resource-heterogeneous edge clients may limit the system's computational overhead due to resource-poor clients.
FedKEM tackles these challenges through knowledge extraction, enabling multi-model deployment on heterogeneous edge devices.This allows models to be deployed more effectively to edge clients based on their computational resources, rather than sharing an identical model after optimization as in traditional FL methods.Depending on clients' memory and Multiply-Add accumulation (MACs), we can allocate suitable models for better utilization.
We evaluated FedKEM's performance on multi-model deployment using ResNet-20/32/44 in the same FL system, updating the multi-model edge clients with ResNet-20 as the knowledge model.This approach is beneficial when clients have different computational capabilities, as it enables local network customization.Fed-KEM showed stable training and quickly achieved high accuracy despite different client models.
In Figure 7, we compare FedKEM's multi-model performance with ResNet-20 and ResNet-32 as the knowledge network, choosing the smallest to minimize communication overhead.The results showed comparable accuracy (over 70%) to training a single universal model.The performance difference between ResNet-20 and ResNet-32 as knowledge networks became more pronounced as the heterogeneity (N) increased.

CONCLUSION
This paper presents FedKEM, a novel federated learning paradigm that addresses the challenges of limited resources and computing power heterogeneity in edge devices.FedKEM uses a compact network architecture that is trained locally to extract knowledge from the local model and integrate global knowledge.This "tiny-size" network is then sent to a central service for multi-model fusion and global knowledge distillation.Experiment results reveal FedKEM's superiority over other federated learning algorithms in accuracy, efficiency, and stability, underlining its potential as a scalable solution for federated learning issues.Future research will aim to enhance multi-model fusion efficiency and explore the approach's applicability in other machine learning areas.Ultimately, we strive to position FedKEM as a leading solution for federated learning problems.

Figure 1 :
Figure 1: (a) The proposed knowledge extraction and multi-model fusion FL method builds global knowledge by extracting local knowledge from different models using the corresponding computing power device and fuses it in, then transfers the local knowledge to the edge.It can thus be aware of resource constraints and serve as a robust general-purpose FL solution for practical applications.(b) In contrast, traditional FL produces aggregation of local model weights and distributes the global model to edges.It treats edge devices with the same model and computing power and has complex communication costs.

Figure 2 :
Figure 2: Local updates.The client downloads the knowledge network from the server, trains mutually with the local model, extracts the updated knowledge network, and uploads it to the server.

Figure 3 :
Figure 3: Cloud updates.The cloud collects local knowledge (tiny-size network) from clients, ensembles all tiny-size networks, distills into global knowledge and transfers it to clients.

Figure 4 :
Figure 4: Comparison of FedKEM with FedAvg[25], FedNova[38], FedProx[21], SPATL[45]: the top-1 test accuracy vs. communication rounds. is the total number of clients,  is the client sample ratio.The model type is the network at local devices and also the network used for communication in the case of the baseline methods).The knowledge network is ResNet-20 in all cases for FedKEM.

Figure 6 :
Figure 6: Comparison of FedKEM with FedAvg[25], Fed-Nova[38], FedProx[21], SPATL[45]: the communication rounds and overhead needed to read 60% and 30% top-1 accuracy. is the total number of clients,  is the client sample ratio.The model type is the network at local devices and also the network used for communication in the case of the baseline methods).The knowledge network is ResNet-20 in all cases for FedKEM.

Figure 7 :
Figure 7: Comparison of FedKEM with multi-model: the top-1 test accuracy vs. communication rounds. is the total number of clients,  is the client sample ratio, and the model type is the network used for communication in the baselines.ResNet-20 and ResNet-32 are used as the knowledge network.

Table 1 :
Models distribution at local clients in the multi-model setting.