Freeze certain modules of the model

Hello, is there any way in the Flower framework to specify partial parameters that I pass to the model, meaning to freeze certain modules of the model in order to improve actual communication efficiency?

Hello @3249403788, this can be done by overriding the get_parameters() and set_parameters() methods of the client() class, assuming the frozen network named ‘CNN’.

def get_parameters(self):
    return [val.cpu().numpy() for key, val in self.model.state_dict().items() if key != 'CNN']
def set_parameters(self, parameters):
      valid_keys = [k for k in self.model.state_dict().keys() if k != 'CNN']
      params_dict = zip(valid_keys, parameters)
      state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
      self.model.load_state_dict(state_dict, strict=False)
2 Likes

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.