Hi @mistersunshine, excellent question!
it’s not a bug. The type returned as the last argument in a client’s fit()
or evaluate()
(such as that in examples/quickstart-pytorch) has to be any of the types in flwr.common.Scalar. You can see the types in there do not support lists… however!! you can see bytes
are supported. This means that you can send back to your ServerApp
(where the strategy lives) pretty much any data structure you want carrying metrics from training or evaluation.
This is how you can test things quickly by making some modifications to the aforementioned example.
- Update how the
evaluate()
method works and make it return aNumPy
array that’s serialized into bytes. For example, replacing the entireevaluate
method with this:
def evaluate(self, parameters, config):
"""Evaluate the model on the data this client has."""
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.valloader, self.device)
delta_c = np.array([1.1, 2.2, 3.3], dtype=np.float32) # dummy delta_c
delta_c_bytes = delta_c.tobytes() # serialize
metrics = {"accuracy": accuracy, 'delta_c_bytes': delta_c_bytes}
return loss, len(self.valloader.dataset), metrics
- Now, let’s say you want to use the
delta_c
arrays sent by eachClientApp
back to the strategy. Then (following the example) you could de-serialize them to recover their standard NumPy representation. What I’ve done is add a few extra lines to the weighted_average callback in the example to illustrate this. It now should look like this:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
for _, m in metrics:
# Get metric of interest
delta_c_bytes = m['delta_c_bytes']
# Deserialize
delta_c = np.frombuffer(delta_c_bytes, dtype=np.float32)
# Print
print(delta_c)
# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}
If you run the example you’ll see that after each round of federated evaluation the deserialized delta_c
sent back by each client gets printed:
INFO : [ROUND 3]
INFO : configure_fit: strategy sampled 10 clients (out of 10)
INFO : aggregate_fit: received 10 results and 0 failures
INFO : configure_evaluate: strategy sampled 5 clients (out of 10)
INFO : aggregate_evaluate: received 5 results and 0 failures
[1.1 2.2 3.3]
[1.1 2.2 3.3]
[1.1 2.2 3.3]
[1.1 2.2 3.3]
[1.1 2.2 3.3]
I hope this helps!!!