How can I implement a YOLO model using the Flower framework?

hey @johanrubak I had the same issue as @wkqco33 and I solved it too!
The issue was in this line of code:- self.model.val(). I printed this len(self.model.model.state_dict().items()) before i did validation and got 499 but printing len(self.model.model.state_dict().items()) after the self.model.val() code line gave me 175. This means that self.model.val() is reducing the number of layers for inference purpose and as a result when we are loading parameters in self.model in next round inside fit(), it is giving this error because self.model has only 175 layers keys but the parameters incoming from server is 499.
Solution I implemented:- save the full model (.pt file) before doing validation and then load this .pt file in fit() because as soon as one round is complete, fit() will be the first one to be called in next round.

Here’s my whole client side code`import os
import cv2
import torch
import numpy as np
from ultralytics import YOLO
import flwr as fl
from collections import OrderedDict
import datetime
import copy

Load YOLOv11 model

model_path = ‘v11n_best_20241107.pt’

model_path = ‘yolo11n.pt’

Path to corrected data

corrected_data_yaml = ‘train2.yaml’
train_path = “./datasets/Dataset/train2”
val_path = “./datasets/Dataset/val2”

Federated Learning with Flower

class YOLOClient(fl.client.NumPyClient):
def init(self, model_path, corrected_data_yaml):
# Load YOLOv11 model
self.model = YOLO(model_path)
self.corrected_data_yaml = corrected_data_yaml
self.state_dict_items = self.model.model.state_dict().items()
print(len(self.state_dict_items))
self.model_saved_before_validation = False

def get_parameters(self,config):
    print("!!!!!Parameters SENT To server!!!!!")
    return [val.cpu().numpy() for _, val in self.model.model.state_dict().items()]



def set_parameters(self, parameters,str):

    # Create a new state dict to update
    new_state_dict = OrderedDict()
    # print("Incoming state_dict keys:", [k for k, _ in zip(self.model.model.state_dict().keys(), parameters)])
    
    filename2 = f"2_processed_parameters_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

    with open(filename2, "x") as f:
        f.write(f"Coming from {str}\n")
        f.write(f"Current state_dict length: {len(self.state_dict_items)}\n")
        f.write(f"Received parameters length: {len(parameters)}\n\n")
    
        # Zip current state dict keys with incoming parameters
        for (k, v), new_param in zip(self.state_dict_items, parameters):
            try:
                # Convert numpy parameter to tensor with matching dtype and device
                new_tensor = torch.tensor(new_param, 
                                      dtype=v.dtype, 
                                      device=v.device).detach().clone()  #.detach.clone is added
                # new_tensor = new_tensor.detach().clone() #This was added 
            
                # Log the key, original tensor (v), and new tensor
                f.write(f"Key: {k}\n")
                f.write(f"Original Tensor (v) Shape: {v.shape}, Device: {v.device}\n")
                f.write(f"New Tensor Shape: {new_tensor.shape}, Device: {new_tensor.device}\n")
                f.write(f"Original Tensor Values (first 10): {v.flatten()[:10]}\n")
                f.write(f"New Tensor Values (first 10): {new_tensor.flatten()[:10]}\n\n")
                
                # Only update if shapes match
                if v.shape == new_tensor.shape:
                    # new_state_dict[k] = new_tensor.clone()
                    new_state_dict[k] = new_tensor
                    
                else:
                    print(f"Shape mismatch for {k}: {v.shape} vs {new_tensor.shape}")
            except Exception as e:
                print(f"Error processing parameter {k}: {e}")
    
    # Partially load the state dict
    try:
        print(f"SET PARAMS: {str}")
        #Print keys of self.model.model keys
        #Print new_state_dict keys
        filename3 = f"2_keys_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

        with open(filename3, "x") as f:
            f.write(f"Coming from {str}\n")
            f.write(f"model.model.state_dict length: {len(self.model.model.state_dict().items())}\n")

            for key, _ in self.model.model.state_dict().items():
                 f.write(f"{key}\n")
                 # f.write(f"Coming from {str}\n")

        # Define filename with timestamp
        filename4 = f"2_newstate_dict_keys_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

        # Write keys to file
        with open(filename4, "w") as f:
            f.write(f"Coming from {str}\n")
            f.write(f"new_state_dict length: {len(new_state_dict)}\n")
            for key in new_state_dict.keys():
                f.write(f"{key}\n")

        print("!!!!!!Loading New state dict in model.model!!!!!!")
        # self.model.model.train()  # Ensure model is in training mode before loading parameters    
        self.model.model.load_state_dict(new_state_dict, strict=True)
        # print(f"CHECKING {str}", len(self.model.model.state_dict().items()))
    except Exception as e:
        print(f"Error loading state dict: {e}")



def fit(self, parameters, config):

    if self.model_saved_before_validation == True:
        self.model = YOLO("./yolo_model_before_val/model_weights_before_val.pt")
        print("NEXT ROUND STARTED\n")
        self.model_saved_before_validation == False


    print(f"Entered Fit, len of model.model keys: {len(self.model.model.state_dict().items())}")
    self.set_parameters(parameters,"Fit")

    if not os.path.isfile(self.corrected_data_yaml):
        print(f"Corrected data config {self.corrected_data_yaml} not found. Skipping training.")
        return

    print("Training on corrected data...")
    # Train the model
    self.model.train(
    data=self.corrected_data_yaml,  # Path to your .yaml config file
    epochs=1,  # Number of epochs
    batch=16,  # Batch size
    imgsz=640,  # Image size (640 for resizing and letterboxing)
    project='local_train_output4',  # Output directory for storing results
    name='yolov11_training',  # Folder name to store the results
    exist_ok=True,  # Allow overwriting of existing output
    # cache=True,  # Cache images for faster training
    device='cuda',  # Use GPU if available
    workers = 0,)
    
    print("Training complete. Model updated.")
    print(f"CHECKING FIT AFTER TRAIN", len(self.model.model.state_dict().items()))
    num_train_samples = int(len(os.listdir(train_path))/2)
    return self.get_parameters(config),num_train_samples, {}

def evaluate(self, parameters, config):
    # Update model with new parameters from the server
    self.set_parameters(parameters,"Evaluate") #Update the local YOLO model with the parameters sent from the central server.

    print("Evaluating model using YOLO...")
    print(f"BEFORE EVALUATE, len of model.model keys: {len(self.model.model.state_dict().items())}")
    
    # Define directory and file path
    save_dir = "yolo_model_before_val"
    save_path = os.path.join(save_dir, "model_weights_before_val.pt")

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Save model, overwriting if the file exists
    self.model.save(save_path)
    print(f"Full model saved to {save_path} before validation!")
    self.model_saved_before_validation = True

    # Perform evaluation using the validation data
    results = self.model.val(data=self.corrected_data_yaml, batch=16)
    print(f"AFTER EVALUATE, len of model.model keys: {len(self.model.model.state_dict().items())}")
    
    

    # Extract relevant metrics
    loss = results.box.loss.item() if hasattr(results.box, 'loss') else 0.0
    mAP50 = results.box.map50 if hasattr(results.box, 'map50') else 0.0

    if hasattr(results.box, 'seen'):
        num_samples = results.box.seen 
    else:
        num_samples = len(os.listdir(val_path))/2

    mean_recall = results.box.mr if hasattr(results.box, 'mr') else 0

    print(f"Evaluation Complete: Loss={loss}, mAP50={mAP50}, Mean recall: {mean_recall}, Samples={num_samples}")

    # Return the evaluation results
    return loss, int(num_samples), {'mAP50': mAP50, 'Mean recall:': mean_recall}

#START FLOWER CLIENT
fl.client.start_numpy_client(server_address=“0.0.0.0:8080”, client=YOLOClient(model_path=model_path,corrected_data_yaml=corrected_data_yaml))`.
@wkqco33 In your code it seems to be same issue caused by self.model.val() line. I dont know how upgrading flwr library helped you resolve this because the main issue is cause by ultralytics val(). Maybe you can enlighten me about this. Also you can add my solution in your previous approach code and see if it works for you as well.

I hope I have properly explained the solution:)
Thanks for reading. Happy Learning :slight_smile: