*This question was migrated from Github Discussions.
Original Question:
I’m saving a model according to the guidelines here: [ old doc]
This saves the “aggregated_weights” into an “.npz” file. However, loading this file and its contents into a PyTorch model afterwards for inference is not clear to me. I’m having a bit of a hard time understanding the “aggregated_weights” variable as well as it appears to be a flwr.commons.Parameters class that is not printable (my code hangs whenever I try to print)
So my question is: How can I instead make these “aggregated_weights” into the standard PyTorch state_dict format and save it as “.pt”?
The internals of Flower including the strategies, work with either the Parameters or NDArrays types. NDArrays is simply a type used in Flower to denote a list of NumPy arrays and Parameters is a serialized version of such NDArray.
This is the reason why in several places (e.g. in the strategies) we make use of utility functions like parameters_to_ndarrays or ndarrays_to_parameters. You can find those defined here.
Going back to the question. The state_dict of a PyTorch model is an (ordered) Python dictionary whose values are of type torch.Tensor. We can therefore represent those values as an NDArray like this:
ndarrays = [val.cpu().numpy() for val in model.state_dict().values()]
# which can then be converted to the Parameters type with
parameters = flwr.common.ndarrays_to_parameters(ndarrays)
# and can be converted back to a list of NumPy arrays lie this:
ndarrays = flwr.common.parameters_to_ndarrays(parameters)
You can save those ndarrays to disk in any way you want (one way is using Python Pickle).
But how to load them back into a PyTorch model? Easy! You just need to construct the state_dict and use as values the content of the ndarrays.
from collections import OrderedDict
# let's initialize the model
my_model = Model()
# load the file containing the List of NumPy arrays
ndarrays = # load file and content
# Construct a generator of {key : values} using as keys those
# from the model state_dict and as values the elements
# you just loaded
p_dict = zip(my_model.state_dict().keys(), ndarrays)
# Now construct the state_dict converting the values
# into standard torch tensors
state_dict = OrderedDict({k: torch.tensor(v) for k, v in p_dict})
# Finally load the state_dict into your model
my_model.load_state_dict(state_dict, strict=True)
And that’s it!
The process above is essentially what Flower Clients do when working with PyTorch models. You can see this is the case in several of the PyTorch examples in the GitHub repo. In fact, the above snippet of code comes directly from the quickstart-pytorch/client.py code. The same idea is also done in examples that implement a global_evaluation stage, requiring the model (e.g. a standard PyTorch model) to use the parameters aggregated by the strategy.