Hello, I’m using MLX with Flower to do federated learning. I found the example from the Flower GitHub repository very helpful to get started
The Problem I have is that instead of the MNIST Dataset (like in the example), I wanted to use CIFAR10 and a more complex model like Resnet. MLX provides an example for cifar (mlx-examples/cifar at main · ml-explore/mlx-examples · GitHub) with a resnet model. The model however is implemented different to the simple flower MLP example and I can’t figure out how to transfer the MNIST example to my use case
I found out that the model is structured differently:
print(resnet_model.parameters().keys())
print(mnist_model.parameters().keys())
# Outputs:
dict_keys(['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'linear'])
dict_keys(['layers'])
As you can see the mnist_model
stores its parameters like:
{ "layers": [ {"weight": mlx.core.array, "bias": mlx.core.array}, {"weight": mlx.core.array, "bias": mlx.core.array}, ..., {"weight": mlx.core.array, "bias": mlx.core.array} ] }
But it seems that the resnet model has different layer types like conv1
, bn1
and so on.
And the structure of them is different.
print(resnet_model.parameters()["conv1"].keys())
# Prints:
dict_keys(['weight'])
print(resnet_model.parameters()["bn1"].keys())
# Prints:
dict_keys(['weight', 'bias', 'running_mean', 'running_var'])
Even each layer is different. The conv1 layer only has ‘weight’ while bn1 has more dict_keys.
print(mnist_model.parameters()["layers"].keys())
# Prints:
AttributeError: 'list' object has no attribute 'keys'
I don’t know how to transfer it into a NPArray Format for flower to understand it. Hopefully someone can help me here
Thanks in advance ^^