I am doing research on model-heterogeneous FL and trying to implement my idea (based on knowledge distillation) using Flower. I wonder how can I return the parameters without aggregation? I have implemented the knowledge distillation parts in the aggregate_fit function and I want to return the updated parameters without aggregating, like return a list of parameters (with different dimensions) instead of a global one.
Hi @boyufan24, if you have implemented a custom strategy with your own aggregate_fit()
then you have full control on what happens with the parameters the clients have returned to the server/strategy. You can decide whether to aggregate them or not.
The signature of aggregate_fit()
in the current version of Flower (1.7) looks like this (see the full code here for FedAvg. Note that the results
contain the model parameters sent by the clients that participated in the fit round:
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
# your knowledge distillation logic
# normally you'd do aggregation, but you can skip it
# let's say you want to return a list of numpy arrays
ndarrays = [np.random.randn(1,2,3) for _ in range(10)]
# You'll need to convert them to `common.Parameters`
parameters = ndarrays_to_parameters(ndarrays)
return parameters
Note I used ndarrays_to_parameters to serialise those arrays. This is also a stage done in the strategies (they normally work with NumPy array to do all the math part, but before returning, those objects need to be serialized as common.Parameters)
But then take into account that your client will need to understand how to process what arrives to the fit()
method’s parameters
argument. It no longer is the global model so the usual logic in set_parameters()
will need to be adjusted.
Hope this helps!!
Thank you @javier ! It really helps. So the return values of aggregate_fit()
are the same one with the input parameters
of fit()
right? And it can be both a list or single value?
Besides, if the fit()
method’s parameters
argument is always from the aggregate_fit()
, how to conduct the first round local training with heterogeneous models? In my setting, there is no global model. But if we do not initilize the global model, Flower will randomly select a client model as the global one. The fit()
will use this parameters
to conduct the first local update. Therefore, there is some conflict between the initlization (first round) with the following rounds.
What I want to achieve is the clients use their own models (pass in client_fn()
) for the local training (first round), then after the aggregate_fit()
returning the model parameters, the fit()
can catch them and choose the right one to update (following rounds).
Here is the client class I define:
class FlowerClientKD(fl.client.NumPyClient):
def __init__(self, model, cid, trainloader, testloader) -> None:
super().__init__()
self.model = model
self.cid = cid
self.trainloader = trainloader
self.testloader = testloader
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"the current client is {self.cid}")
def set_parameters(self, parameters):
# assume we got the parameters list from the aggregate_fit(), which consists of model parameters for heterogeneous models
# then we need to use cid to choose the corresponding one
index = int(self.cid) - 1
parameter = parameters[index]
params_dict = zip(self.model.state_dict().keys(), parameter)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def fit(self, parameters, config):
self.set_parameters(parameters)
train(self.model, self.trainloader, epochs=1)
print("local train finished!")
return self.get_parameters({}), len(self.trainloader), {}
and this is the part of aggregate_fit()
, I didn’t show the defination of _kd_aggregate()
:
def aggregate_fit(self, server_round, results, failures):
results = sorted(results, key=lambda x: int(x[0].cid))
parameters = [fit_res.parameters for _, fit_res in results]
parameters_in_ndarrays = parameters_to_ndarrays(parameters)
kd_parameters = self._kd_aggregate(parameters_in_ndarrays, self.hetero_net)
kd_parameters = ndarrays_to_parameters(kd_parameters)
return kd_parameters
In the aggregate_fit()
code earlier I showed it returning just the aggregated parameters. But strategies return another argument: a dictionary of aggregated metrics (which will remain empty {}
if the client doesn’t return any of if no function to aggregate them is provided when instantiating the strategy). You can see this in the FedAvg strategy.
The aggregated parameters that aggregate_fit
returns is of type flwr.common.Parameters
which is a serialized version of List[NDArray]
(a list of standard NumPy arrays). If you have a single NumPy array, still it has to be inside a list. Serialized objects are easier to integrate in communication stacks (like gRCP).
What a flwr.client.NumPyClient
receives to its fit()
method is the serialized parameters object, which is exactly that List[NDArray]
that got serialized before returning from aggregate_fit()
. Sounds a bit complex when writing it…
To do this, then i’d suggest:
- initialize the global model with some dummy structure. Ensure your clients are aware of which round they are participating in, you can do this by passing a function to
on_fit_config_fn
arg in the strategy. For example passing an empty list:
def fit_config(server_round: int): # signature must be like this
config = {
"round": server_round,
"lr": 0.1 # you can pass anything else you want
}
return config
# something trivial
initial_parameters = ndarrays_to_parameters([np.array([0])])
# Now you can pass them to your strategy
strategy = FedAvg(..., initial_parameters=initial_parameters,
on_fit_config_fn=fit_config)
- Make your clients aware on which round they are:
def fit(self, parameters, config):
if config['round'] == 0:
# it's the first round, instantiate model, don't expect the server to send anything
model = client_fn()
else:
model = # use parameters sent by server
# local training
Does something like this look useful?
Thank you so much for your kind explaination and this is exactly what I want! I appreciate your time and hopufully I can also contribute to this lovely Flower community in the future.
This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.