Banner Banner

Tackling Data Heterogeneity in Federated Learning

FLOCO leverages the concept of linear mode connectivity of deep neural networks

In the traditional centralized machine learning setting, data collected by clients are gathered on a central server to train a machine learning model. In contrast, the Federated Learning (FL) setting assumes that clients do not share their data with the central server, usually accounting for data privacy issues. Although many methods have been proposed for training models in a privacy-preserving manner, significant challenges still remain in FL. A persistent challenge lies in handling statistical heterogeneity—namely, if the clients’ distributions are different from each other.
For example, consider training a shared global model for diagnosing a viral disease across internationally geo-distributed hospitals. The dominating viral strain in each country could possibly differ. As a consequence, typical symptoms would differ as well and depend on the location of the hospital. Thus the conditional distributions of relevant symptoms for diagnosis could be diverse. A single global model would fail to accurately learn such diverse local distributions, while independently trained models could fail to generalize globally.
The heterogeneity causes two major issues: slow global training due to conflicting gradient signals and the need for personalization to adjust to local distributions. Previous personalized FL approaches cope with the latter issue by training local model on each client distribution after a well-performing global model is trained.  However, the conflicting gradient issue can prevent us from obtaining a well-trained global model under strong heterogeneous client assumption.
Shinichi Nakajima, BIFOLD research Grouplead and his team propose FLOCO (Federated Learning over Connected Modes), to tackle those issues by leveraging the recently explored concept of linear mode connectivity of deep neural networks.

Training of neural networks is usually performed by finding a single point in the network parameter space that gives minimum training loss. In contrast, Simplex learning in the context of neural network training refers to a method that learns a set of model parameters using the geometric concept of a simplex - a structure consisting of points connected in a multi-dimensional space. Namely, it aims to find the optimal locations of the vertices of a simplex so that the training loss is low everywhere within the simplex. “It has been shown that simplex learning improves the performance of the trained network in terms of robustness and uncertainty estimation. Our FLOCO adapts this learning method for the FL setting by assigning subregions in the solution simplex so that similar clients use (possibly overlapping) subregions close to each other”, explains Shinichi Nakajima. Each client uses the assigned subregion both for training and inference by uniformly sampled models in the subregion. This way the clients share the single global solution simplex, while heterogeneity among the clients is captured within the degrees of freedom of the simplex. This approach also remedies the conflict of gradient signals in optimizing the simplex vertices during training.  
Assuming the cross-silo FL scenario, the researchers  evaluated FLOCO using multiple synthetic non-IID split schemes for CIFAR-10 and the natural non-IID split in FEMNIST, across various neural network architectures, and with random and pre-trained initialization settings. They observed that FLOCO consistently outperformed standard FL methods such as FedAvg and FedProx, as well as state-of-the-art personalized FL techniques.
In summary, FLOCO improves global and local optimization in FL, with minimal computational overhead. The applicability of simplex learning would extend beyond FL, potentially benefiting multi-task learning, continual learning, and domain generalization, which will be  pursued in the future work of the team.


NeurIPS’24 Publication
Open-source implementation (GitHub)