How to prevent OOM error while training?

Hi everyone. I’m attempting to train a model using transfer learning in FL with Flower. I’m working with CIFAR10 and have 10 clients, each receiving the same model to train, with the package managing the training process. The machine I’m using has roughly 500GB of RAM, and the simulation initially consumes around 250GB. However, as training progresses, the RAM usage continues to increase until all memory is exhausted, causing the simulation to crash due to an OOM error. Do you have any suggestions on how to resolve this issue?

In the following you can see the components that I have used in my code:

def client_fn(cid: str) -> FlowerClient:

  """Create a Flower client representing a single organization."""

  # Constructing the model for the clients and initializing the weights
  model = Transfer_learning_model()

  # create our optimizer with the learning passed 
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  # pass the loss and the optimizer to the model
  model.compile(loss = 'sparse_categorical_crossentropy', optimizer = optimizer, metrics=['accuracy'], )

  # Load training data for this specific client
  x_train_cid =  np.array(list_x_train[int(cid)], dtype=float)
  y_train_cid =  np.array(list_y_train[int(cid)], dtype = int)

  # Create and return client
  return FlowerClient(model, x_train_cid, y_train_cid, Epochs_num)


# Defining the evaluation function that would run in the setting after each
# communication round in the federated fashion
def evaluation_function( server_round: int, parameters: fl.common.NDArrays, config: Dict[str, fl.common.Scalar],
                    ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:

  # Defining the architecture of the model's network
  net = Transfer_learning_model()

  # Set the weight of the network from the aggregated weights of the local models
  net.set_weights(parameters)

  # Evaluate the performance of the aggregated model
  loss, accuracy, precision, recall, f1score = test_model(net, X_test, y_test)

  # remove the model after our testing
  del net

  return loss, {"accuracy": accuracy,"precision": precision,"recall": recall,"f1score": f1score}

# we are going to use the FedAvg strategy
Aggregation_strategy = fl.server.strategy.FedAvg(
    fraction_fit = 1, 
    fraction_evaluate = 0 ,
    evaluate_fn = evaluation_function)


# set the number of the local epochs 
Epochs_num = 2

# set the number of the rounds 
Rounds_num = 10

# the number of the clients 
Clients_num = 30 

commun_metrics_history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=Clients_num,
    config=fl.server.ServerConfig(num_rounds=Rounds_num),
    strategy=Aggregation_strategy)
2 Likes

Hi @mehrdadhz ,

Could you provide some information on how big your model is? as well as the data you use? 250GB sound like too much.

Have you tried running the examples/simulation-tensorflow example? does your project follow a similar structure?

With Tensroflow/keras it seems that creating/compiling models multiple times over and over makes RAM consumtion increase steadily. I’d advice asking in TF/Keras forums on what are the best practices to avoid that.

2 Likes

Hi Javier,

Thank you very much for your swift reply. So the model that I’m employing is EfficientNetB0 with the pre-trained weights on Imagenet dataset from Keras on CIFAR10 image dataset.

So this is not the first time that I’m using Flower to do the Federated learning and I have been able to run CIFAR10 on simpler CNN models without any problem, but as the model`s structure and the dataset is a little bit big I can accept that memory load more or less.

As you correctly mentioned, I have the same hypothesis why I see this behavior and I’m looking for a solution to delete the memory allocated to the clients after each round.

I tried to include deleting the model as well as the data passed to the clients in the FlowerClient class after the fitting process, but I face the same problem.

I was wondering if there is a way to release the whole memory allocated to the clients in Flower.

Thanks you very much again.

1 Like

Hi @mehrdadhz, I’m not familiar enough with TF to make good suggestions on what to change. The data pipelines from TF are very powerful but my understanding is that they are not intended to be running in parallel alongside others (as it happens when you run Flower simulation – the idea is to parallelise the execution of clients in the same round). My advice would be to switch to PyTorch.