Custom client selection strategy

This question was migrated from Github Discussions.

Original questions:
“Hi! I would like to create a custom client selection strategy that at each round selects at random a number “m” of clients, each selected with probability proportional to the fraction of data owned by each client (in other words, clients with more data have more probability to be selected). For this reason, I am writing my own custom strategy and overriding the configure_fit method.
However, I do not find any way to get the data size of each client via ClientManager, I can only use “fit” and “evaluate”. Perhaps if I use the “sample” method in ClientManager with a custom Criterion my goal is achievable? Is there any way I can access the data size of each client, which I assume being potentially different at each round?”

Answer 1:
"Never done that, so I’ll provide an high level answer. You have to make your own ClientManager class, your own Criterion, and your own ClientProxy.

In the method sample(.) of the ClientManager you can communicate with the available clients (use SimpleClientManager as a starting point) that are instances of ClientProxy. The problem is that the ClientProxy does not have a method to ask for the data sample size, that’s why you have to do your own implementation. Once done that, you pass the sample size and clients’ cid to your own Criterion and that’s it.

I know, it’s not an high quality answer, but I hope it helps somehow"

Answer 2:
"Thanks for the info! I actually found a “workaround” to get this done, without having to make my own ClientManager etc. It’s a little “dirty”, but it seems to be working.
Basically, in my custom strategy I implemented the following custom configure_fit:

def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
    # Get list with all the available clients (K clients)
    available_clients = list(client_manager.clients.values())

    if self.p == []:
        data_sizes = {
            client.cid:, {"epochs": -1}), timeout=4).metrics.get("data_size", 0)
            for client in available_clients

        # Compute selection probabilities based on data size
        total_size = sum(data_sizes.values())
        self.p = [size / total_size for size in data_sizes.values()]

    # Sample d clients with a probability proportional to their data size
    candidate_clients = np.random.choice(
        size=min(self.d, len(available_clients)),

    # Request the candidate clients to compute their local losses and return them to the server
    local_losses = {
        client.cid:, {"epochs": 0}), timeout=4).metrics.get('local_loss', float('inf'))
        for client in candidate_clients

    # Select the top m clients with the highest local losses
    selected_clients_cids = sorted(local_losses, key=local_losses.get, reverse=True)[:self.m]
    selected_clients.append([key for key in selected_clients_cids])

    # Return the selected clients with the FitIns objects
    return [(client_manager.clients.get(cid), FitIns(parameters, {})) for cid in selected_clients_cids]`

In the client, I have the following custom fit:

def fit(self, parameters, config):
epochs = config.get("epochs", self.epochs) # Get the number of epochs from config
    if epochs == -1:
        return parameters, 0, {"data_size": self.data_size}  # Just return client's data size

    if epochs == 0:
        # Estimate local loss without training the model
        local_loss, _ = self.model.evaluate(self.x_test, self.y_test, verbose=0)
        local_loss += np.random.uniform(low=1e-10, high=1e-9) # Make sure that potential ties are broken at random
        return parameters, 0, {"local_loss": local_loss}  # Return the estimated local loss without updating the parameters

    # Train the model and return the updated parameters, self.y_train, validation_data=(self.x_test, self.y_test), epochs = self.epochs, verbose=0)
    return self.model.get_weights(), self.data_size, {}`

As you can see, I am calling from the server the only method I know will answer from each client, that is, the fit method. When I call fit, if it is only to get the data size or the local loss, I pass values that are lower than 1. These values are handled with a simple if statement.