How to replace Parameter with Gradient? Instead of weights, use aggregated gradients.

My understanding is that the communication process between the Flower framework client (C) and server (S) currently is as follows:

  • In the first round, S sends the initial model parameters to C, C receives the parameters and sets them in the model, then after training, returns the new model parameters obtained from training to S.

  • S aggregates all the model parameters sent back by C through some strategy (such as FedAVG) and then sends the new model parameters to each C.

My current questions are:

  1. Gradient Replacement

I don’t want to use model parameters in the communication and aggregation process. Instead, I want to replace them with “model gradients”. I want the local Client to not update parameters and clear gradients during the training process (even if it’s multiple epochs), extract the accumulated gradients after training to update the model, and send the gradients back to the Server.
The Server will aggregate the gradients received from each Client and send them back. In the new round, the Client will receive the global gradients and update the local model accordingly. How can I achieve this?

  1. Code Issue

I have attempted to modify some of the code (as shown below), and after debugging, it can run, but the Accuracy performance is extremely poor. Why is this?

class CustomClient(NumPyClient):
    def __init__(self, model, trainloader, valloader, local_epochs, lr, beta,context:Context):
        self.model = model
        self.trainloader = trainloader
        self.valloader = valloader
        self.local_epochs = local_epochs
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.lr = lr
        self.beta = beta
        self.round = 0

        self.client_state = (context.state)
        if "model_state" not in self.client_state.parameters_records:
            self.client_state.parameters_records["model_state"] =ParametersRecord()
            
            
    def fit(self, g_gradient, config):
        model_state = self.client_state.parameters_records["model_state"]
        
        # when num_nound=0 the len(g_gradient) is 0. init the model.
        if len(g_gradient) == 0:
            ndarrays = get_weights(Net())
            # parameters = ndarrays_to_parameters(ndarrays)
            set_weights(self.model, ndarrays)
        else:
            # read model state
            state_dict ={}
            for k,v in model_state.items():
                state_dict[k] = torch.from_numpy(v.numpy())
            self.model.load_state_dict(state_dict,strict=True)

        zero_grad(self.model)
        train_loss = train(  
            self.model,
            self.trainloader,
            self.local_epochs,
            self.device,
        )
        grad = get_gradients(self.model)  # get grad
        if len(g_gradient) == 0:
            g_gradient = [np.zeros_like(g) for g in grad]
            
        # calculate the grad_update through "grad - g_gradient"
        grad_update = [x - y for x, y in zip(grad, g_gradient)]
        
        ########################################
        # Some operations for grad_update such as LDP

        
        ########################################

        _update = [
            -self.lr * x for x in [
                y+z for y,z in zip(grad_update,grad)]
        ]
        
        
        def update_model(model, _update):
    		with torch.no_grad():
        		for param,update in zip(model.parameters(),_update):
            		update_tensor = torch.from_numpy(update).to(param.device)
            		param.data.add_(update_tensor)
        # set params
        update_model(self.model, _update)  
		
        # save model_state for next round
        for k,v in self.model.state_dict().items():
            model_state[k] = Array(v.detach().cpu().numpy())
        self.client_state.parameters_records["model_state"] = model_state

        return (
            grad_update,
            len(self.trainloader.dataset),
            {"train_loss": train_loss},
        )
        
        
    def evaluate(self, parameters, config):
        # read model state
        model_state = self.client_state.parameters_records["model_state"]
        state_dict = {}
        for k, v in model_state.items():
            state_dict[k] = torch.from_numpy(v.numpy())
        self.model.load_state_dict(state_dict, strict=True)

        loss, accuracy = test(self.model, self.valloader, self.device)
        return loss, len(self.valloader.dataset), {"accuracy": accuracy}