Not sure how to implement SecAgg(+) in this FL

Hi there,
I am a student doing research on different privacy-preserving method in FL. I am trying to implement SecAgg(+) in this simple example, for doing some testing/simulations. I found adding the clientMod pretty straight forward, but I cannot figure out how/where to implement the SecAggPlusWorkflow on the ServerApp in this example. I could not figure it out using the examples and documentation. Does anyone have a suggestion? You can find my code below:

import logging
import warnings

from typing import Dict

import flwr as fl
import matplotlib.pyplot as plt
import numpy as np

from flwr.client import ClientApp
from flwr.client.mod import secaggplus_mod
from flwr.common import Context
from flwr.common import NDArrays
from flwr.server import LegacyContext
from flwr.server import ServerApp
from flwr.server import ServerAppComponents
from flwr.server import ServerConfig
from flwr.server.workflow import DefaultWorkflow
from flwr.server.workflow import SecAggPlusWorkflow
from flwr.simulation.run_simulation import run_simulation
from flwr_datasets import FederatedDataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split


N_CLIENTS = 6
N_GPUS = 0


class MnistClient(fl.client.NumPyClient):
    def __init__(self, cid: str):
        partition_id = int(cid)

        # Load the MNSIT dataset, partition in number of clients
        fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})

        # get the partition for the client
        dataset = fds.load_partition(partition_id, "train").with_format("numpy")
        X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

        # split partition dataset in 80% train, 20% test
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X,
            y,
            test_size=0.2,
            random_state=None,
        )

        # Create LogisticRegression Model
        self.model = LogisticRegression(
            penalty="l2",
            max_iter=1,  # local epoch
            warm_start=True,  # prevent refreshing weights when fitting
            solver="lbfgs",
            fit_intercept=True,
            class_weight=None,
        )

        # Setting initial parameters, akin to model.compile for keras models
        MnistClient.set_initial_params(self.model)

    def get_parameters(self, config):  # type: ignore
        return MnistClient.get_model_parameters(self.model)

    def fit(self, parameters, config):  # type: ignore
        MnistClient.set_model_params(self.model, parameters)
        # Ignore convergence failure due to low local epochs
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.model.fit(self.X_train, self.y_train)
        print(f"Training finished for round {config['server_round']}")
        return MnistClient.get_model_parameters(self.model), len(self.X_train), {}

    def evaluate(self, parameters, config):  # type: ignore
        MnistClient.set_model_params(self.model, parameters)
        loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
        accuracy = self.model.score(self.X_test, self.y_test)
        return loss, len(self.X_test), {"accuracy": accuracy}

    @classmethod
    def get_model_parameters(cls, model: LogisticRegression) -> NDArrays:
        """Returns the parameters of a sklearn LogisticRegression model."""
        if model.fit_intercept:
            params = [model.coef_, model.intercept_]
        else:
            params = [model.coef_]
        return params

    @classmethod
    def set_model_params(cls, model: LogisticRegression, params: NDArrays) -> LogisticRegression:
        """Sets the parameters of a sklean LogisticRegression model."""
        model.coef_ = params[0]
        if model.fit_intercept:
            model.intercept_ = params[1]
        return model

    @classmethod
    def set_initial_params(cls, model: LogisticRegression):
        """Sets initial parameters as zeros Required since model params are uninitialized
        until model.fit is called.

        But server asks for initial parameters from clients at launch. Refer to
        sklearn.linear_model.LogisticRegression documentation for more information.
        """
        n_classes = 10  # MNIST has 10 classes
        n_features = 784  # Number of features in dataset
        model.classes_ = np.array([i for i in range(n_classes)])

        model.coef_ = np.zeros((n_classes, n_features))
        if model.fit_intercept:
            model.intercept_ = np.zeros((n_classes,))


def fit_round(server_round: int) -> Dict:
    """Send round number to client."""
    return {"server_round": server_round}


def get_evaluate_fn(model: LogisticRegression):
    """Return an evaluation function for server-side evaluation."""
    # Load test data here to avoid the overhead of doing it in `evaluate` itself
    fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})
    dataset = fds.load_split("test").with_format("numpy")
    X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

    # The `evaluate` function will be called after every round
    def evaluate(server_round, parameters: fl.common.NDArrays, config):
        # Update model with the latest parameters
        MnistClient.set_model_params(model, parameters)

        loss = log_loss(y_test, model.predict_proba(X_test))
        loss_list.append(loss)

        accuracy = model.score(X_test, y_test)
        accuracy_list.append(accuracy)

        return loss, {"accuracy": accuracy}

    return evaluate


def get_client_fn(cid: str):
    return MnistClient(cid).to_client()


def server_fn(context: Context):
    server_config = ServerConfig(num_rounds=50)
    model = LogisticRegression()
    MnistClient.set_initial_params(model)

    strategy = fl.server.strategy.FedAvg(
        min_available_clients=N_CLIENTS,
        min_fit_clients=N_CLIENTS,
        min_evaluate_clients=N_CLIENTS,
        evaluate_fn=get_evaluate_fn(model),
        on_fit_config_fn=fit_round,
    )
    return ServerAppComponents(
        strategy=strategy,
        config=server_config,
    )

    # context = LegacyContext(
    #     context=context,
    #     config=ServerConfig(num_rounds=50),
    #     strategy=strategy,
    # )

    # workflow = DefaultWorkflow(
    #     fit_workflow=SecAggPlusWorkflow(
    #         num_shares=6,
    #         reconstruction_threshold=4,
    #     ),
    # )


# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
    accuracy_list = []
    loss_list = []

    run_simulation(
        server_app=ServerApp(server_fn=server_fn),
        client_app=ClientApp(
            client_fn=get_client_fn,
            mods=[
                secaggplus_mod,
            ],
        ),
        num_supernodes=N_CLIENTS,
        backend_name="ray",
        backend_config={
            "num_cpus": N_CLIENTS,
            "num_gpus": N_GPUS,
            "client_resources": {"num_cpus": 1, "num_gpus": N_GPUS},
        },
        verbose_logging=True,
    )
    # Change logging detail level from default to info and log some values for administration
    logging.getLogger().setLevel(logging.INFO)
    logging.info(
        f"{len(accuracy_list)} accuracy scores aggregated. The highest score is {max(accuracy_list)}, the score of the last round is {accuracy_list[-1]}",
    )
    logging.info(
        f"{len(loss_list)} loss values aggregated. The lowest recorded loss {min(loss_list)}, the loss of the last round is {loss_list[-1]}",
    )

2 Likes

Hello @chielflower , there are two examples showing how to use SecAgg+ with Flower: