Hi there,
I am a student doing research on different privacy-preserving method in FL. I am trying to implement SecAgg(+) in this simple example, for doing some testing/simulations. I found adding the clientMod pretty straight forward, but I cannot figure out how/where to implement the SecAggPlusWorkflow on the ServerApp in this example. I could not figure it out using the examples and documentation. Does anyone have a suggestion? You can find my code below:
import logging
import warnings
from typing import Dict
import flwr as fl
import matplotlib.pyplot as plt
import numpy as np
from flwr.client import ClientApp
from flwr.client.mod import secaggplus_mod
from flwr.common import Context
from flwr.common import NDArrays
from flwr.server import LegacyContext
from flwr.server import ServerApp
from flwr.server import ServerAppComponents
from flwr.server import ServerConfig
from flwr.server.workflow import DefaultWorkflow
from flwr.server.workflow import SecAggPlusWorkflow
from flwr.simulation.run_simulation import run_simulation
from flwr_datasets import FederatedDataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split
N_CLIENTS = 6
N_GPUS = 0
class MnistClient(fl.client.NumPyClient):
def __init__(self, cid: str):
partition_id = int(cid)
# Load the MNSIT dataset, partition in number of clients
fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})
# get the partition for the client
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
# split partition dataset in 80% train, 20% test
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
X,
y,
test_size=0.2,
random_state=None,
)
# Create LogisticRegression Model
self.model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
solver="lbfgs",
fit_intercept=True,
class_weight=None,
)
# Setting initial parameters, akin to model.compile for keras models
MnistClient.set_initial_params(self.model)
def get_parameters(self, config): # type: ignore
return MnistClient.get_model_parameters(self.model)
def fit(self, parameters, config): # type: ignore
MnistClient.set_model_params(self.model, parameters)
# Ignore convergence failure due to low local epochs
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.model.fit(self.X_train, self.y_train)
print(f"Training finished for round {config['server_round']}")
return MnistClient.get_model_parameters(self.model), len(self.X_train), {}
def evaluate(self, parameters, config): # type: ignore
MnistClient.set_model_params(self.model, parameters)
loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
accuracy = self.model.score(self.X_test, self.y_test)
return loss, len(self.X_test), {"accuracy": accuracy}
@classmethod
def get_model_parameters(cls, model: LogisticRegression) -> NDArrays:
"""Returns the parameters of a sklearn LogisticRegression model."""
if model.fit_intercept:
params = [model.coef_, model.intercept_]
else:
params = [model.coef_]
return params
@classmethod
def set_model_params(cls, model: LogisticRegression, params: NDArrays) -> LogisticRegression:
"""Sets the parameters of a sklean LogisticRegression model."""
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model
@classmethod
def set_initial_params(cls, model: LogisticRegression):
"""Sets initial parameters as zeros Required since model params are uninitialized
until model.fit is called.
But server asks for initial parameters from clients at launch. Refer to
sklearn.linear_model.LogisticRegression documentation for more information.
"""
n_classes = 10 # MNIST has 10 classes
n_features = 784 # Number of features in dataset
model.classes_ = np.array([i for i in range(n_classes)])
model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))
def fit_round(server_round: int) -> Dict:
"""Send round number to client."""
return {"server_round": server_round}
def get_evaluate_fn(model: LogisticRegression):
"""Return an evaluation function for server-side evaluation."""
# Load test data here to avoid the overhead of doing it in `evaluate` itself
fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})
dataset = fds.load_split("test").with_format("numpy")
X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
# The `evaluate` function will be called after every round
def evaluate(server_round, parameters: fl.common.NDArrays, config):
# Update model with the latest parameters
MnistClient.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
loss_list.append(loss)
accuracy = model.score(X_test, y_test)
accuracy_list.append(accuracy)
return loss, {"accuracy": accuracy}
return evaluate
def get_client_fn(cid: str):
return MnistClient(cid).to_client()
def server_fn(context: Context):
server_config = ServerConfig(num_rounds=50)
model = LogisticRegression()
MnistClient.set_initial_params(model)
strategy = fl.server.strategy.FedAvg(
min_available_clients=N_CLIENTS,
min_fit_clients=N_CLIENTS,
min_evaluate_clients=N_CLIENTS,
evaluate_fn=get_evaluate_fn(model),
on_fit_config_fn=fit_round,
)
return ServerAppComponents(
strategy=strategy,
config=server_config,
)
# context = LegacyContext(
# context=context,
# config=ServerConfig(num_rounds=50),
# strategy=strategy,
# )
# workflow = DefaultWorkflow(
# fit_workflow=SecAggPlusWorkflow(
# num_shares=6,
# reconstruction_threshold=4,
# ),
# )
# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
accuracy_list = []
loss_list = []
run_simulation(
server_app=ServerApp(server_fn=server_fn),
client_app=ClientApp(
client_fn=get_client_fn,
mods=[
secaggplus_mod,
],
),
num_supernodes=N_CLIENTS,
backend_name="ray",
backend_config={
"num_cpus": N_CLIENTS,
"num_gpus": N_GPUS,
"client_resources": {"num_cpus": 1, "num_gpus": N_GPUS},
},
verbose_logging=True,
)
# Change logging detail level from default to info and log some values for administration
logging.getLogger().setLevel(logging.INFO)
logging.info(
f"{len(accuracy_list)} accuracy scores aggregated. The highest score is {max(accuracy_list)}, the score of the last round is {accuracy_list[-1]}",
)
logging.info(
f"{len(loss_list)} loss values aggregated. The lowest recorded loss {min(loss_list)}, the loss of the last round is {loss_list[-1]}",
)