hey @johanrubak I had the same issue as @wkqco33 and I solved it too!
The issue was in this line of code:- self.model.val(). I printed this len(self.model.model.state_dict().items()) before i did validation and got 499 but printing len(self.model.model.state_dict().items()) after the self.model.val() code line gave me 175. This means that self.model.val() is reducing the number of layers for inference purpose and as a result when we are loading parameters in self.model in next round inside fit(), it is giving this error because self.model has only 175 layers keys but the parameters incoming from server is 499.
Solution I implemented:- save the full model (.pt file) before doing validation and then load this .pt file in fit() because as soon as one round is complete, fit() will be the first one to be called in next round.
Here’s my whole client side code`import os
import cv2
import torch
import numpy as np
from ultralytics import YOLO
import flwr as fl
from collections import OrderedDict
import datetime
import copy
Load YOLOv11 model
model_path = ‘v11n_best_20241107.pt’
model_path = ‘yolo11n.pt’
Path to corrected data
corrected_data_yaml = ‘train2.yaml’
train_path = “./datasets/Dataset/train2”
val_path = “./datasets/Dataset/val2”
Federated Learning with Flower
class YOLOClient(fl.client.NumPyClient):
def init(self, model_path, corrected_data_yaml):
# Load YOLOv11 model
self.model = YOLO(model_path)
self.corrected_data_yaml = corrected_data_yaml
self.state_dict_items = self.model.model.state_dict().items()
print(len(self.state_dict_items))
self.model_saved_before_validation = False
def get_parameters(self,config):
print("!!!!!Parameters SENT To server!!!!!")
return [val.cpu().numpy() for _, val in self.model.model.state_dict().items()]
def set_parameters(self, parameters,str):
# Create a new state dict to update
new_state_dict = OrderedDict()
# print("Incoming state_dict keys:", [k for k, _ in zip(self.model.model.state_dict().keys(), parameters)])
filename2 = f"2_processed_parameters_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
with open(filename2, "x") as f:
f.write(f"Coming from {str}\n")
f.write(f"Current state_dict length: {len(self.state_dict_items)}\n")
f.write(f"Received parameters length: {len(parameters)}\n\n")
# Zip current state dict keys with incoming parameters
for (k, v), new_param in zip(self.state_dict_items, parameters):
try:
# Convert numpy parameter to tensor with matching dtype and device
new_tensor = torch.tensor(new_param,
dtype=v.dtype,
device=v.device).detach().clone() #.detach.clone is added
# new_tensor = new_tensor.detach().clone() #This was added
# Log the key, original tensor (v), and new tensor
f.write(f"Key: {k}\n")
f.write(f"Original Tensor (v) Shape: {v.shape}, Device: {v.device}\n")
f.write(f"New Tensor Shape: {new_tensor.shape}, Device: {new_tensor.device}\n")
f.write(f"Original Tensor Values (first 10): {v.flatten()[:10]}\n")
f.write(f"New Tensor Values (first 10): {new_tensor.flatten()[:10]}\n\n")
# Only update if shapes match
if v.shape == new_tensor.shape:
# new_state_dict[k] = new_tensor.clone()
new_state_dict[k] = new_tensor
else:
print(f"Shape mismatch for {k}: {v.shape} vs {new_tensor.shape}")
except Exception as e:
print(f"Error processing parameter {k}: {e}")
# Partially load the state dict
try:
print(f"SET PARAMS: {str}")
#Print keys of self.model.model keys
#Print new_state_dict keys
filename3 = f"2_keys_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
with open(filename3, "x") as f:
f.write(f"Coming from {str}\n")
f.write(f"model.model.state_dict length: {len(self.model.model.state_dict().items())}\n")
for key, _ in self.model.model.state_dict().items():
f.write(f"{key}\n")
# f.write(f"Coming from {str}\n")
# Define filename with timestamp
filename4 = f"2_newstate_dict_keys_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
# Write keys to file
with open(filename4, "w") as f:
f.write(f"Coming from {str}\n")
f.write(f"new_state_dict length: {len(new_state_dict)}\n")
for key in new_state_dict.keys():
f.write(f"{key}\n")
print("!!!!!!Loading New state dict in model.model!!!!!!")
# self.model.model.train() # Ensure model is in training mode before loading parameters
self.model.model.load_state_dict(new_state_dict, strict=True)
# print(f"CHECKING {str}", len(self.model.model.state_dict().items()))
except Exception as e:
print(f"Error loading state dict: {e}")
def fit(self, parameters, config):
if self.model_saved_before_validation == True:
self.model = YOLO("./yolo_model_before_val/model_weights_before_val.pt")
print("NEXT ROUND STARTED\n")
self.model_saved_before_validation == False
print(f"Entered Fit, len of model.model keys: {len(self.model.model.state_dict().items())}")
self.set_parameters(parameters,"Fit")
if not os.path.isfile(self.corrected_data_yaml):
print(f"Corrected data config {self.corrected_data_yaml} not found. Skipping training.")
return
print("Training on corrected data...")
# Train the model
self.model.train(
data=self.corrected_data_yaml, # Path to your .yaml config file
epochs=1, # Number of epochs
batch=16, # Batch size
imgsz=640, # Image size (640 for resizing and letterboxing)
project='local_train_output4', # Output directory for storing results
name='yolov11_training', # Folder name to store the results
exist_ok=True, # Allow overwriting of existing output
# cache=True, # Cache images for faster training
device='cuda', # Use GPU if available
workers = 0,)
print("Training complete. Model updated.")
print(f"CHECKING FIT AFTER TRAIN", len(self.model.model.state_dict().items()))
num_train_samples = int(len(os.listdir(train_path))/2)
return self.get_parameters(config),num_train_samples, {}
def evaluate(self, parameters, config):
# Update model with new parameters from the server
self.set_parameters(parameters,"Evaluate") #Update the local YOLO model with the parameters sent from the central server.
print("Evaluating model using YOLO...")
print(f"BEFORE EVALUATE, len of model.model keys: {len(self.model.model.state_dict().items())}")
# Define directory and file path
save_dir = "yolo_model_before_val"
save_path = os.path.join(save_dir, "model_weights_before_val.pt")
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Save model, overwriting if the file exists
self.model.save(save_path)
print(f"Full model saved to {save_path} before validation!")
self.model_saved_before_validation = True
# Perform evaluation using the validation data
results = self.model.val(data=self.corrected_data_yaml, batch=16)
print(f"AFTER EVALUATE, len of model.model keys: {len(self.model.model.state_dict().items())}")
# Extract relevant metrics
loss = results.box.loss.item() if hasattr(results.box, 'loss') else 0.0
mAP50 = results.box.map50 if hasattr(results.box, 'map50') else 0.0
if hasattr(results.box, 'seen'):
num_samples = results.box.seen
else:
num_samples = len(os.listdir(val_path))/2
mean_recall = results.box.mr if hasattr(results.box, 'mr') else 0
print(f"Evaluation Complete: Loss={loss}, mAP50={mAP50}, Mean recall: {mean_recall}, Samples={num_samples}")
# Return the evaluation results
return loss, int(num_samples), {'mAP50': mAP50, 'Mean recall:': mean_recall}
#START FLOWER CLIENT
fl.client.start_numpy_client(server_address=“0.0.0.0:8080”, client=YOLOClient(model_path=model_path,corrected_data_yaml=corrected_data_yaml))`.
@wkqco33 In your code it seems to be same issue caused by self.model.val() line. I dont know how upgrading flwr library helped you resolve this because the main issue is cause by ultralytics val(). Maybe you can enlighten me about this. Also you can add my solution in your previous approach code and see if it works for you as well.
I hope I have properly explained the solution:)
Thanks for reading. Happy Learning ![]()