How can I implement a YOLO model using the Flower framework?

How to Pass Weights as Parameters in Flower?
I’m trying to use the Flower framework to train a YOLO model in a federated learning setting. I’m having trouble figuring out how to properly pass the model weights as parameters between the server and clients.

Here’s what I’ve tried so far:

  • I’ve converted the YOLO model weights to a list of NumPy arrays.

However, I’m encountering errors during training, and I suspect it’s related to how the weights are being handled.

Could someone provide guidance or examples on how to correctly pass YOLO model weights as parameters in Flower? Any help would be greatly appreciated!

3 Likes

Hi @wkqco33 , welcome to the Flower forum!

Please take a look at the pytorch-quickstart example. Concretely see how:

  • First, the ServerApp initializes the global model from a randomly initialized model (it could be pretrained). This is done in this line. The get_weights function is simply extracting the state_dict of the PyTorch model and convert it into a list of NumPy arrays. That’s the preferred way of sharing model parameters between ServerApp and ClientApp objects.
  • Second, on the ClientApp side, that exact same list of NumPy arrays is delivered as the parameters input argument to the fit() and evaluate() methods. Note both methods call set_parameters. It is a utility function that constructs a valid state_dict out of a list NumPy arrays and applies it to the model local to the ClientApp.

That’s the basic way of communicating model paramters in Flower. Note that the specific implementation of those get_weights() and set_weights() functions is dependant on the ML framework you use as well as the type of model you use. More exotic models might need adjustments.

Let us know if you succeed!!

1 Like

It’s server code

class SaveModelStrategy(FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        aggregated_parameters, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )

        return aggregated_parameters, aggregated_metrics


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    return {"accuracy": sum(accuracies) / sum(examples)}  # type: ignore


def server_fn():
    model = load_model()
    parameters = ndarrays_to_parameters(get_weights(model))

    log(INFO, "Initial parameters")

    strategy = SaveModelStrategy(
        fraction_fit=0.5,
        fraction_evaluate=1.0,
        min_available_clients=2,
        initial_parameters=parameters,
        evaluate_metrics_aggregation_fn=weighted_average,
    )

    num_rounds = 3
    config = ServerConfig(num_rounds=num_rounds)

    return strategy, config


if __name__ == "__main__":
    strategy, config = server_fn()
    fl.server.start_server(
        server_address="0.0.0.0:8080",
        strategy=strategy,
        config=config,
    )

and it’s client code

class RobotClient(NumPyClient):
    def __init__(self, data, epochs):
        self.model: YOLO = load_model()
        self.data = data
        self.epochs = epochs

    def fit(self, parameters: NDArrays, config):
        set_weights(self.model, parameters)
        results: DetMetrics | None = self.model.train(
            data=self.data,
            epochs=self.epochs,
        )
        if results is not None:
            log(INFO, f"Results: {results.box.map}")

        return get_weights(self.model), 10, {}

    def evaluate(self, parameters: NDArrays, config):
        set_weights(self.model, parameters)
        matrics: DetMetrics = self.model.val()
        accuracy = matrics.box.map
        loss = matrics.fitness

        return loss, 10, {"accuracy": accuracy}


def client_fn():
    data = "coco8.yaml"
    epochs = 1

    return RobotClient(data, epochs).to_client()


if __name__ == "__main__":
    server_ip = os.getenv("SERVER_IP", "localhost")
    server_port = os.getenv("SERVER_PORT", "8080")

    log(INFO, f"Server IP: {server_ip}:{server_port}")

    fl.client.start_client(
        server_address=f"{server_ip}:{server_port}",
        client=client_fn(),
    )

I changed the get_weights and set_weights functions with your advice.

def load_model():
    model = YOLO("yolo11n.pt")
    return model

def get_weights(model: YOLO) -> NDArrays:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def set_weights(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=False)

But I can’t see the result because of these error messages.

Traceback (most recent call last):
  File "/home/seoyc/Project/work/mlops/flower_study/src/flower_yolo/client.py", line 54, in <module>
    fl.client.start_client(
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/app.py", line 180, in start_client
    start_client_internal(
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/app.py", line 543, in start_client_internal
    raise ex
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/app.py", line 536, in start_client_internal
    reply_message = client_app(message=message, context=context)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/client_app.py", line 143, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/message_handler/message_handler.py", line 129, in handle_legacy_message_from_msgtype
    fit_res = maybe_call_fit(
              ^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/client.py", line 255, in maybe_call_fit
    return client.fit(fit_ins)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/numpy_client.py", line 259, in _fit
    results = self.numpy_client.fit(parameters, ins.config)  # type: ignore
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/src/flower_yolo/client.py", line 22, in fit
    set_weights(self.model, parameters)
  File "/home/seoyc/Project/work/mlops/flower_study/src/flower_yolo/utils.py", line 25, in set_weights
    model.load_state_dict(state_dict, strict=False)
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for YOLO:
        While copying the parameter named "model.model.0.conv.weight", whose dimensions in the model are torch.Size([16, 3, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 3, 3, 3]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).
        While copying the parameter named "model.model.0.conv.bias", whose dimensions in the model are torch.Size([16]) and whose dimensions in the checkpoint are torch.Size([16]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).
        size mismatch for model.model.1.conv.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32, 16, 3, 3]).
        size mismatch for model.model.1.conv.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for model.model.2.cv1.conv.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32, 32, 1, 1]).
        size mismatch for model.model.2.cv1.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for model.model.2.cv2.conv.weight: copying a param with shape torch.Size([32, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 48, 1, 1]).
        size mismatch for model.model.2.cv2.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for model.model.2.m.0.cv1.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([8, 16, 3, 3]).
        size mismatch for model.model.2.m.0.cv1.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([8]).
        size mismatch for model.model.2.m.0.cv2.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16, 8, 3, 3]).
        size mismatch for model.model.2.m.0.cv2.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for model.model.3.conv.weight: copying a param with shape torch.Size([32, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
        size mismatch for model.model.3.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for model.model.4.cv1.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
        size mismatch for model.model.4.cv1.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for model.model.4.cv2.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128, 96, 1, 1]).
        size mismatch for model.model.4.cv2.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.4.m.0.cv1.conv.weight: copying a param with shape torch.Size([64, 48, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 32, 3, 3]).
        size mismatch for model.model.4.m.0.cv1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for model.model.4.m.0.cv2.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32, 16, 3, 3]).
        size mismatch for model.model.4.m.0.cv2.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for model.model.5.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
        size mismatch for model.model.5.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.6.cv1.conv.weight: copying a param with shape torch.Size([8, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]).
        size mismatch for model.model.6.cv1.conv.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.6.cv2.conv.weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([128, 192, 1, 1]).
        size mismatch for model.model.6.cv2.conv.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([128]).
... Skip ....
size mismatch for model.model.23.cv3.0.0.0.conv.weight: copying a param with shape torch.Size([64, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3]).
        While copying the parameter named "model.model.23.cv3.0.0.0.conv.bias", whose dimensions in the model are torch.Size([64]) and whose dimensions in the checkpoint are torch.Size([64]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).
        size mismatch for model.model.23.cv3.0.0.1.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 64, 1, 1]).
        size mismatch for model.model.23.cv3.0.0.1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.0.1.0.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 1, 3, 3]).
        size mismatch for model.model.23.cv3.0.1.0.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.0.1.1.conv.weight: copying a param with shape torch.Size([64, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.0.1.1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.0.2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.0.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.0.0.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
        size mismatch for model.model.23.cv3.1.0.0.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.23.cv3.1.0.1.conv.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 128, 1, 1]).
        size mismatch for model.model.23.cv3.1.0.1.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.1.0.conv.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80, 1, 3, 3]).
        size mismatch for model.model.23.cv3.1.1.0.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.1.1.conv.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.1.1.1.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.2.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.1.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.0.0.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3]).
        size mismatch for model.model.23.cv3.2.0.0.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for model.model.23.cv3.2.0.1.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 256, 1, 1]).
        size mismatch for model.model.23.cv3.2.0.1.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.1.0.conv.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 1, 3, 3]).
        size mismatch for model.model.23.cv3.2.1.0.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.1.1.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.2.1.1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.2.2.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.dfl.conv.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 16, 1, 1]).

I really appreciate you and your support!!!

looks like your checkpoint doesn’t match the model architechture. If you set strict=True in: model.load_state_dict(state_dict, strict=True) you can debug which layers doesnt match. Or you could just print out layers from model and from checkpoint and check if they match. Maybe you need to clone the tensors aswell, i am not sure

I am facing the same issue. Did you find any solution?

@johanrubak
I solved this problem by using the latest version of the Flower Framework. I updated the project using this guide, and then I ran Flower with the deployment engine. As a result, I solved my problem and achieved the desired outcome of the YOLO model.

hey @johanrubak I had the same issue as @wkqco33 and I solved it too!
The issue was in this line of code:- self.model.val(). I printed this len(self.model.model.state_dict().items()) before i did validation and got 499 but printing len(self.model.model.state_dict().items()) after the self.model.val() code line gave me 175. This means that self.model.val() is reducing the number of layers for inference purpose and as a result when we are loading parameters in self.model in next round inside fit(), it is giving this error because self.model has only 175 layers keys but the parameters incoming from server is 499.
Solution I implemented:- save the full model (.pt file) before doing validation and then load this .pt file in fit() because as soon as one round is complete, fit() will be the first one to be called in next round.

Here’s my whole client side code`import os
import cv2
import torch
import numpy as np
from ultralytics import YOLO
import flwr as fl
from collections import OrderedDict
import datetime
import copy

Load YOLOv11 model

model_path = ‘v11n_best_20241107.pt’

model_path = ‘yolo11n.pt’

Path to corrected data

corrected_data_yaml = ‘train2.yaml’
train_path = “./datasets/Dataset/train2”
val_path = “./datasets/Dataset/val2”

Federated Learning with Flower

class YOLOClient(fl.client.NumPyClient):
def init(self, model_path, corrected_data_yaml):
# Load YOLOv11 model
self.model = YOLO(model_path)
self.corrected_data_yaml = corrected_data_yaml
self.state_dict_items = self.model.model.state_dict().items()
print(len(self.state_dict_items))
self.model_saved_before_validation = False

def get_parameters(self,config):
    print("!!!!!Parameters SENT To server!!!!!")
    return [val.cpu().numpy() for _, val in self.model.model.state_dict().items()]



def set_parameters(self, parameters,str):

    # Create a new state dict to update
    new_state_dict = OrderedDict()
    # print("Incoming state_dict keys:", [k for k, _ in zip(self.model.model.state_dict().keys(), parameters)])
    
    filename2 = f"2_processed_parameters_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

    with open(filename2, "x") as f:
        f.write(f"Coming from {str}\n")
        f.write(f"Current state_dict length: {len(self.state_dict_items)}\n")
        f.write(f"Received parameters length: {len(parameters)}\n\n")
    
        # Zip current state dict keys with incoming parameters
        for (k, v), new_param in zip(self.state_dict_items, parameters):
            try:
                # Convert numpy parameter to tensor with matching dtype and device
                new_tensor = torch.tensor(new_param, 
                                      dtype=v.dtype, 
                                      device=v.device).detach().clone()  #.detach.clone is added
                # new_tensor = new_tensor.detach().clone() #This was added 
            
                # Log the key, original tensor (v), and new tensor
                f.write(f"Key: {k}\n")
                f.write(f"Original Tensor (v) Shape: {v.shape}, Device: {v.device}\n")
                f.write(f"New Tensor Shape: {new_tensor.shape}, Device: {new_tensor.device}\n")
                f.write(f"Original Tensor Values (first 10): {v.flatten()[:10]}\n")
                f.write(f"New Tensor Values (first 10): {new_tensor.flatten()[:10]}\n\n")
                
                # Only update if shapes match
                if v.shape == new_tensor.shape:
                    # new_state_dict[k] = new_tensor.clone()
                    new_state_dict[k] = new_tensor
                    
                else:
                    print(f"Shape mismatch for {k}: {v.shape} vs {new_tensor.shape}")
            except Exception as e:
                print(f"Error processing parameter {k}: {e}")
    
    # Partially load the state dict
    try:
        print(f"SET PARAMS: {str}")
        #Print keys of self.model.model keys
        #Print new_state_dict keys
        filename3 = f"2_keys_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

        with open(filename3, "x") as f:
            f.write(f"Coming from {str}\n")
            f.write(f"model.model.state_dict length: {len(self.model.model.state_dict().items())}\n")

            for key, _ in self.model.model.state_dict().items():
                 f.write(f"{key}\n")
                 # f.write(f"Coming from {str}\n")

        # Define filename with timestamp
        filename4 = f"2_newstate_dict_keys_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

        # Write keys to file
        with open(filename4, "w") as f:
            f.write(f"Coming from {str}\n")
            f.write(f"new_state_dict length: {len(new_state_dict)}\n")
            for key in new_state_dict.keys():
                f.write(f"{key}\n")

        print("!!!!!!Loading New state dict in model.model!!!!!!")
        # self.model.model.train()  # Ensure model is in training mode before loading parameters    
        self.model.model.load_state_dict(new_state_dict, strict=True)
        # print(f"CHECKING {str}", len(self.model.model.state_dict().items()))
    except Exception as e:
        print(f"Error loading state dict: {e}")



def fit(self, parameters, config):

    if self.model_saved_before_validation == True:
        self.model = YOLO("./yolo_model_before_val/model_weights_before_val.pt")
        print("NEXT ROUND STARTED\n")
        self.model_saved_before_validation == False


    print(f"Entered Fit, len of model.model keys: {len(self.model.model.state_dict().items())}")
    self.set_parameters(parameters,"Fit")

    if not os.path.isfile(self.corrected_data_yaml):
        print(f"Corrected data config {self.corrected_data_yaml} not found. Skipping training.")
        return

    print("Training on corrected data...")
    # Train the model
    self.model.train(
    data=self.corrected_data_yaml,  # Path to your .yaml config file
    epochs=1,  # Number of epochs
    batch=16,  # Batch size
    imgsz=640,  # Image size (640 for resizing and letterboxing)
    project='local_train_output4',  # Output directory for storing results
    name='yolov11_training',  # Folder name to store the results
    exist_ok=True,  # Allow overwriting of existing output
    # cache=True,  # Cache images for faster training
    device='cuda',  # Use GPU if available
    workers = 0,)
    
    print("Training complete. Model updated.")
    print(f"CHECKING FIT AFTER TRAIN", len(self.model.model.state_dict().items()))
    num_train_samples = int(len(os.listdir(train_path))/2)
    return self.get_parameters(config),num_train_samples, {}

def evaluate(self, parameters, config):
    # Update model with new parameters from the server
    self.set_parameters(parameters,"Evaluate") #Update the local YOLO model with the parameters sent from the central server.

    print("Evaluating model using YOLO...")
    print(f"BEFORE EVALUATE, len of model.model keys: {len(self.model.model.state_dict().items())}")
    
    # Define directory and file path
    save_dir = "yolo_model_before_val"
    save_path = os.path.join(save_dir, "model_weights_before_val.pt")

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Save model, overwriting if the file exists
    self.model.save(save_path)
    print(f"Full model saved to {save_path} before validation!")
    self.model_saved_before_validation = True

    # Perform evaluation using the validation data
    results = self.model.val(data=self.corrected_data_yaml, batch=16)
    print(f"AFTER EVALUATE, len of model.model keys: {len(self.model.model.state_dict().items())}")
    
    

    # Extract relevant metrics
    loss = results.box.loss.item() if hasattr(results.box, 'loss') else 0.0
    mAP50 = results.box.map50 if hasattr(results.box, 'map50') else 0.0

    if hasattr(results.box, 'seen'):
        num_samples = results.box.seen 
    else:
        num_samples = len(os.listdir(val_path))/2

    mean_recall = results.box.mr if hasattr(results.box, 'mr') else 0

    print(f"Evaluation Complete: Loss={loss}, mAP50={mAP50}, Mean recall: {mean_recall}, Samples={num_samples}")

    # Return the evaluation results
    return loss, int(num_samples), {'mAP50': mAP50, 'Mean recall:': mean_recall}

#START FLOWER CLIENT
fl.client.start_numpy_client(server_address=“0.0.0.0:8080”, client=YOLOClient(model_path=model_path,corrected_data_yaml=corrected_data_yaml))`.
@wkqco33 In your code it seems to be same issue caused by self.model.val() line. I dont know how upgrading flwr library helped you resolve this because the main issue is cause by ultralytics val(). Maybe you can enlighten me about this. Also you can add my solution in your previous approach code and see if it works for you as well.

I hope I have properly explained the solution:)
Thanks for reading. Happy Learning :slight_smile: