How do I save the global model after training?

A frequently asked question we get is:

How do I save the global model after training? :thinking::question:

There are different ways of saving the global model so you can make use of it once the training is completed.

I often implement this feature by extending the functionality of a Flower strategy. For example, you could extend the evaluate() method of a strategy (e.g. that in FedAvg) so a checkpoint of the model is saved each time the method is executed – which by default is executed at the end of each round. Let’s see how that looks in code.

Let’s create a new strategy that inherits from FedAvg (you can do this for any other Flower strategy) and then modify the logic in the evaluate() method. In this example I’m saving the parameters using Python Pickle.

from logging import INFO
import pickle
from pathlib import Path
from flwr.common.logger import log
from flwr.common import parameters_to_ndarrays

class FedAvgWithModelSaving(fl.server.strategy.FedAvg):
    """This is a custom strategy that behaves exactly like
    FedAvg with the difference of storing of the state of
    the global model to disk after each round.
    def __init__(self, save_path: str, *args, **kwargs):
        self.save_path = Path(save_path)
        # Create directory if needed
        self.save_path.mkdir(exist_ok=True, parents=True)
        super().__init__(*args, **kwargs)

    def _save_global_model(self, server_round: int, parameters):
        """A new method to save the parameters to disk."""

        # convert parameters to list of NumPy arrays
        # this will make things easy if you want to load them into a
        # PyTorch or TensorFlow model later
        ndarrays = parameters_to_ndarrays(parameters)
        data = {'globa_parameters': ndarrays}
        filename = str(self.save_path/f"parameters_round_{server_round}.pkl")
        with open(filename, 'wb') as h:
            pickle.dump(data, h, protocol=pickle.HIGHEST_PROTOCOL)
        log(INFO, f"Checkpoint saved to: {filename}")    

    def evaluate(self, server_round: int, parameters):
        """Evaluate model parameters using an evaluation function."""
        # save the parameters to disk using a custom method
        self._save_global_model(server_round, parameters)

        # call the parent method so evaluation is performed as
        # FedAvg normally does.
        return super().evaluate(server_round, parameters)

You can make use of this strategy directly in your project or any of the Flower Examples. For example if we take the quickstart-pytorch example, replace the strategy in with the one above.

[copy the strategy from above]

# strategy = fl.server.strategy.FedAvg(
#                     evaluate_metrics_aggregation_fn=weighted_average)
# use the new strategy and specify a path where to save your model
strategy = FedAvgWithModelSaving(save_path='my_checkpoints', 

# run `start_server` or `start_simulation`

# now you can load any of the stored checkpoints

Run the example and you’ll see how checkpoints start being appearing in the path you specify.

One of the great things about Flower strategies is that you can modify it to suit your needs. You can add small modifications to make it save a checkpoint every N rounds, or to keep the most recent M, or the best performing set of models based on some metric, etc.

1 Like