How can I only send trainable parameters to the client?

Hi all, I’m currently running prompt tuning on a a Flower system and having trouble finding a way to send only the trainable parameters(8 virtual tokens) to be trained in each client.

Here is my current Client class:
class Client(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        print("Training Started...")
        print("Trianing Finished.")
        return self.get_parameters(config={}), len(train_dataloader), {}
    def evaluate(self, parameters, config):
        loss, accuracy = test_loop(model,
        return float(loss), len(valid_dataloader), {"accuracy": float(accuracy)}

fl.client.start_numpy_client(server_address="", client=Client())



Hey @devmonkeyy, interesting question!

As of flower 1.7 (the current version) what I’d recommend doing is to change the get_parameters() method so it returns just the trainable parameters in the model. As you are using PyTorch I assume that trainable means that requires_grad=True.

The adjusted method would look as follows:

def get_parameters(model):
    return [p.cpu().detach().numpy() for p in model.parameters() if p.requires_grad]

You might need to adjust the code a bit to make it work with your client definition. But below I show a simple example where I take a standard ResNet18 model and then i freeze everything except the output fully connected layer.

import torch
from torchvision.models import resnet18

model = resnet18()
# Freeze everything
model = model.requires_grad_(False)
# Unfreeze output head (so it can be finetuned)
model.fc = model.fc.requires_grad_(True)

# Let's check that indeed on the the FC layer are _trainable_
for name, p in model.named_parameters():
    if p.requires_grad:
        print(f"{name} ---> shape: {p.shape}")

# prints
# fc.weight ---> shape: torch.Size([1000, 512])
# fc.bias ---> shape: torch.Size([1000])

# Let's now call the new `get_parameters` function
ndarrays = get_parameters(model)
for ndarray in ndarrays:
# prints
# (1000, 512)
# (1000,)

As you can see with the new function, only the trainable parameters will be extracted from the model.

Thanks for the response!

I changed the get_parameters function and froze layers in the model. When I run the code I receive the error:

size mismatch for base_model.bert.embeddings.word_embeddings.weight: copying a param with shape torch.Size([8, 768]) from checkpoint, the shape in current model is torch.Size([30522, 768])

This is from the line in setParameters()

model.load_state_dict(state_dict, strict=True)

Does this mean that I have to add the frozen layers back when I set the parameters of the model?

If you are not modifying the rest of the model during training (e.g. because maybe you initialized it with some pre-trained weights), you could then update the set_parameters() function to only load the state_dict of those layers whose weights are updated.

For example, in the ViT-finetuning example, we only finetune the last layers of the model, while keeping the rest intact. You can see here how we updated the get_parameters().

def set_parameters(model, parameters):
    # `parameters` contain updated weights for the model heads
    # so let's just load those into the state_dict.
    finetune_layers = model.heads
    params_dict = zip(finetune_layers.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    finetune_layers.load_state_dict(state_dict, strict=True)

So, I’d recommend doing something like this. You’ll need to find which are the layers that are trainable and then construct the state_dict. Does this work for you?

Thanks for the explanation. I was able to get the code to work by only loading the state dict of the model’s prompt_encoder. Thanks a lot for the help!


This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.