MLX send / get model parameters of complex model

Hello, I’m using MLX with Flower to do federated learning. I found the example from the Flower GitHub repository very helpful to get started :smile:

The Problem I have is that instead of the MNIST Dataset (like in the example), I wanted to use CIFAR10 and a more complex model like Resnet. MLX provides an example for cifar (mlx-examples/cifar at main · ml-explore/mlx-examples · GitHub) with a resnet model. The model however is implemented different to the simple flower MLP example and I can’t figure out how to transfer the MNIST example to my use case :sweat_smile:

I found out that the model is structured differently:

print(resnet_model.parameters().keys())
print(mnist_model.parameters().keys())

# Outputs:
dict_keys(['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'linear'])
dict_keys(['layers'])

As you can see the mnist_model stores its parameters like:

{ "layers": [ {"weight": mlx.core.array, "bias": mlx.core.array}, {"weight": mlx.core.array, "bias": mlx.core.array}, ..., {"weight": mlx.core.array, "bias": mlx.core.array} ] }

But it seems that the resnet model has different layer types like conv1, bn1 and so on.
And the structure of them is different.

print(resnet_model.parameters()["conv1"].keys())

# Prints:
dict_keys(['weight'])

print(resnet_model.parameters()["bn1"].keys())

# Prints:
dict_keys(['weight', 'bias', 'running_mean', 'running_var'])

Even each layer is different. The conv1 layer only has ‘weight’ while bn1 has more dict_keys.

print(mnist_model.parameters()["layers"].keys())

# Prints:
AttributeError: 'list' object has no attribute 'keys'

I don’t know how to transfer it into a NPArray Format for flower to understand it. Hopefully someone can help me here :grin:

Thanks in advance ^^

3 Likes

Hey @FNG, It seems that a feature close to the PyTorch’s state_dict() can is provided by MLX via mlx.utils.tree_flatten. For example, if i take the MLX-ResNet model you point to and I do:

from mlx.utils import tree_flatten
model = resnet()
pp_flat = tree_flatten(model.parameters()) 

# each element in the list is [name, tensor], just like a PyTorch state dict.

print([p[0] for p  in pp_flat])

# prints the following
['conv1.weight',
 'bn1.weight',
 'bn1.bias',
 'bn1.running_mean',
 'bn1.running_var',
 'layer1.layers.0.conv1.weight',
 'layer1.layers.0.bn1.weight',
 'layer1.layers.0.bn1.bias',
 'layer1.layers.0.bn1.running_mean',
 'layer1.layers.0.bn1.running_var',
...
]

I believe this is what you want ? let me know if this helps !

2 Likes

Hey @Javier, thank you so much! I didn’t knew about tree_flatten and tree_unflatten. At first I tried it to make it into a NumpyDArray for get_parameters and then doing the reverse operation with tree_unflatten for set_parameters. Right now it seems to work quite well. At least I can train rounds and get a reasonable accuracy :smile:

Here’s the code I’m using right now:

def get_parameters(self, config):
    pp_flat = tree_flatten(model.parameters()) 

    return [np.array(val) for _, val in pp_flat]

def set_parameters(self, parameters):
    pp_flat = tree_flatten(model.parameters()) 
    keys = [p[0] for p  in pp_flat]
    params_dict = zip(keys, parameters)
    state_dict = [(k, mx.array(v)) for k, v in params_dict]

    pp_unflatten = tree_unflatten(state_dict)
    self.model.update(pp_unflatten)
2 Likes

Amazing! while working on my answer i noted we haven’t updated the MLX example in a while (I had for instance to update to a more recent version of MLX to make use of the tree_flatten feature). Would you be interesting in helping us update the example? Maybe even making it work with ResNet or something more exciting that the tiny tiny model we have?

3 Likes

Sure. Currently I’m doing MLX with Flower and other frameworks with Flower for my bachelor thesis. I still have a few things to figure out with MLX, but when I have a working example I’ll be more than happy to contribute it to Flower and the community. Will open a pull request if I have something in the next weeks / next month, if that’s ok ^^

3 Likes

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.