How do I start from a pre-trained model

A frequently asked question we get is:

How do I start from a pre-trained model? :thinking:

Great question!!
Currently the easiest way to start from a pretrained model is by passing it to your strategy. All strategies in Flower accept an (optional) input argument called initial_parameters which, if passed, will be used to initialise the global model. Recall that if you don’t pass them, Flower samples one of the connected clients at random and uses those parameters as the initial state of the global model.

Let’s say you want to make use of FedAvg and initialise the global model with those of a model you have. This is how to do it:

from flwr.server.strategy import FedAvg
from flwr.common import ndarrays_to_parameters

model = # your normal PyTorch model

# Convert state_dict to list of NumPy arrays
# this should look familiar (clients do something 
# similar when sending parameters back to the server)
ndarrays = [val.cpu().numpy() for val in model.state_dict().values()]

# Now convert them to the Parameters type
initial_parameters = ndarrays_to_parameters(ndarrays)

# Now you can pass them to your strategy
strategy = FedAvg(..., initial_parameters=initial_parameters)

# then pass the strategy to start_server or start_simulation.

In Tensorflow the process is identical, but often to obtain the ndarrays representation of your model you can simply do:

model = # a TF/Keras model
ndarrays = model.get_weights()

# Now convert them to the Parameters type
initial_parameters = ndarrays_to_parameters(ndarrays)

# Then construct your strategy and launch the server/simulation
1 Like