MLX send / get model parameters of complex model

Hello, I’m using MLX with Flower to do federated learning. I found the example from the Flower GitHub repository very helpful to get started :smile:

The Problem I have is that instead of the MNIST Dataset (like in the example), I wanted to use CIFAR10 and a more complex model like Resnet. MLX provides an example for cifar (mlx-examples/cifar at main · ml-explore/mlx-examples · GitHub) with a resnet model. The model however is implemented different to the simple flower MLP example and I can’t figure out how to transfer the MNIST example to my use case :sweat_smile:

I found out that the model is structured differently:

print(resnet_model.parameters().keys())
print(mnist_model.parameters().keys())

# Outputs:
dict_keys(['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'linear'])
dict_keys(['layers'])

As you can see the mnist_model stores its parameters like:

{ "layers": [ {"weight": mlx.core.array, "bias": mlx.core.array}, {"weight": mlx.core.array, "bias": mlx.core.array}, ..., {"weight": mlx.core.array, "bias": mlx.core.array} ] }

But it seems that the resnet model has different layer types like conv1, bn1 and so on.
And the structure of them is different.

print(resnet_model.parameters()["conv1"].keys())

# Prints:
dict_keys(['weight'])

print(resnet_model.parameters()["bn1"].keys())

# Prints:
dict_keys(['weight', 'bias', 'running_mean', 'running_var'])

Even each layer is different. The conv1 layer only has ‘weight’ while bn1 has more dict_keys.

print(mnist_model.parameters()["layers"].keys())

# Prints:
AttributeError: 'list' object has no attribute 'keys'

I don’t know how to transfer it into a NPArray Format for flower to understand it. Hopefully someone can help me here :grin:

Thanks in advance ^^

Hey @FNG, It seems that a feature close to the PyTorch’s state_dict() can is provided by MLX via mlx.utils.tree_flatten. For example, if i take the MLX-ResNet model you point to and I do:

from mlx.utils import tree_flatten
model = resnet()
pp_flat = tree_flatten(model.parameters()) 

# each element in the list is [name, tensor], just like a PyTorch state dict.

print([p[0] for p  in pp_flat])

# prints the following
['conv1.weight',
 'bn1.weight',
 'bn1.bias',
 'bn1.running_mean',
 'bn1.running_var',
 'layer1.layers.0.conv1.weight',
 'layer1.layers.0.bn1.weight',
 'layer1.layers.0.bn1.bias',
 'layer1.layers.0.bn1.running_mean',
 'layer1.layers.0.bn1.running_var',
...
]

I believe this is what you want ? let me know if this helps !

Hey @javier, thank you so much! I didn’t knew about tree_flatten and tree_unflatten. At first I tried it to make it into a NumpyDArray for get_parameters and then doing the reverse operation with tree_unflatten for set_parameters. Right now it seems to work quite well. At least I can train rounds and get a reasonable accuracy :smile:

Here’s the code I’m using right now:

def get_parameters(self, config):
    pp_flat = tree_flatten(model.parameters()) 

    return [np.array(val) for _, val in pp_flat]

def set_parameters(self, parameters):
    pp_flat = tree_flatten(model.parameters()) 
    keys = [p[0] for p  in pp_flat]
    params_dict = zip(keys, parameters)
    state_dict = [(k, mx.array(v)) for k, v in params_dict]

    pp_unflatten = tree_unflatten(state_dict)
    self.model.update(pp_unflatten)

Amazing! while working on my answer i noted we haven’t updated the MLX example in a while (I had for instance to update to a more recent version of MLX to make use of the tree_flatten feature). Would you be interesting in helping us update the example? Maybe even making it work with ResNet or something more exciting that the tiny tiny model we have?

Sure. Currently I’m doing MLX with Flower and other frameworks with Flower for my bachelor thesis. I still have a few things to figure out with MLX, but when I have a working example I’ll be more than happy to contribute it to Flower and the community. Will open a pull request if I have something in the next weeks / next month, if that’s ok ^^