Hi everyone. I’m attempting to train a model using transfer learning in FL with Flower. I’m working with CIFAR10 and have 10 clients, each receiving the same model to train, with the package managing the training process. The machine I’m using has roughly 500GB of RAM, and the simulation initially consumes around 250GB. However, as training progresses, the RAM usage continues to increase until all memory is exhausted, causing the simulation to crash due to an OOM error. Do you have any suggestions on how to resolve this issue?
In the following you can see the components that I have used in my code:
def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""
# Constructing the model for the clients and initializing the weights
model = Transfer_learning_model()
# create our optimizer with the learning passed
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# pass the loss and the optimizer to the model
model.compile(loss = 'sparse_categorical_crossentropy', optimizer = optimizer, metrics=['accuracy'], )
# Load training data for this specific client
x_train_cid = np.array(list_x_train[int(cid)], dtype=float)
y_train_cid = np.array(list_y_train[int(cid)], dtype = int)
# Create and return client
return FlowerClient(model, x_train_cid, y_train_cid, Epochs_num)
# Defining the evaluation function that would run in the setting after each
# communication round in the federated fashion
def evaluation_function( server_round: int, parameters: fl.common.NDArrays, config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
# Defining the architecture of the model's network
net = Transfer_learning_model()
# Set the weight of the network from the aggregated weights of the local models
net.set_weights(parameters)
# Evaluate the performance of the aggregated model
loss, accuracy, precision, recall, f1score = test_model(net, X_test, y_test)
# remove the model after our testing
del net
return loss, {"accuracy": accuracy,"precision": precision,"recall": recall,"f1score": f1score}
# we are going to use the FedAvg strategy
Aggregation_strategy = fl.server.strategy.FedAvg(
fraction_fit = 1,
fraction_evaluate = 0 ,
evaluate_fn = evaluation_function)
# set the number of the local epochs
Epochs_num = 2
# set the number of the rounds
Rounds_num = 10
# the number of the clients
Clients_num = 30
commun_metrics_history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=Clients_num,
config=fl.server.ServerConfig(num_rounds=Rounds_num),
strategy=Aggregation_strategy)