How do I write a custom client selection protocol?

A frequently asked question we get is:

How do I write a custom client selection protocol? :thinking:

1 Like

Writing a custom client selection mechanism can be done by extending an existing strategy (e.g. FedAvg) and just modify its configure_fit() method. Let’s see an example.

Let’s write a strategy that in the first round it samples N clients randomly but that in subsequent rounds it randomly samples half of the clients that were sampled in the previous round and the remaining N/2 are sampled from the whole pool of clients. This means that half of the clients participate in round T will participate in round T+1.


from random import random

class RepeatHalfSamplingFedAvg(FedAvg):
    """Behaves just like FedAvg but with a modified sampling.
    """
    def __init__(self, keep_ratio: float=0.5, *args, **kwargs):
        self.keep_ratio = keep_ratio # ratio of clients to resample
                                     # in the following round
        super().__init__(*args, **kwargs)

    def configure_fit(self, server_round: int, parameters: Parameters,
                      client_manager: ClientManager):
        """Resamples some clients from the previous round except
        in the first round of the FL experiments. 
        """
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
        
        # construct instructions to be send to each client.
        fit_ins = FitIns(parameters, config)

        # interface with the client manager to get statistics about
        # clients that can be sampled in this given round.
        av_clients = client_manager.num_available()
        sample_size, min_num_clients = self.num_fit_clients(av_clients)
        
        # So far the code above is very similar to the one in the 
        # base `configure_fit()` method in FedAvg

        if server_round == 1:
            # first round, random uniform sampling (standard)
            clients = client_manager.sample(
                num_clients=sample_size, min_num_clients=min_num_clients)

        else:
            # stochastically drop clients used in previous round
            clients = [cli for cli in self.prev_clients if random() < self.keep_ratio]
            print(f"Will resample clients: {[client.cid for client in clients]}")
            
            # sample more clients
            extra_clients = client_manager.sample(
                                num_clients=sample_size - len(clients),
                                min_num_clients=min_num_clients
                                )
            # append to list of previous clients
            clients.extend(extra_clients)
            
        # record client proxies (so we can resample in next round)
        self.prev_clients = clients

        print(f"Round {server_round} sampled clients "
              f"with cid: {[client.cid for client in self.prev_clients]}"
        )

        # Return client/config tuples
        return [(client, fit_ins) for client in clients]

Another idea would be to change the sampling of clients such that different clients get different instructions FitIns. You might want to do that in order to, for example, tell some clients to use a different set of hyperparameters, or even send them different models. Just like how we did it in the example above, the easiest way to implement this would be by writing a custom strategy that modifies the default behaviour of an existing strategy. The same ideas would apply if you want to modify how the sampling of clients for an evaluate round is done. For this, you’ll need to customize the configure_evaluate() method.

2 Likes