Hello, is there any way in the Flower framework to specify partial parameters that I pass to the model, meaning to freeze certain modules of the model in order to improve actual communication efficiency?
Hello @3249403788, this can be done by overriding the get_parameters()
and set_parameters()
methods of the client()
class, assuming the frozen network named ‘CNN’.
def get_parameters(self):
return [val.cpu().numpy() for key, val in self.model.state_dict().items() if key != 'CNN']
def set_parameters(self, parameters):
valid_keys = [k for k in self.model.state_dict().keys() if k != 'CNN']
params_dict = zip(valid_keys, parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=False)
2 Likes
This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.