How can I implement a YOLO model using the Flower framework?

How to Pass Weights as Parameters in Flower?
I’m trying to use the Flower framework to train a YOLO model in a federated learning setting. I’m having trouble figuring out how to properly pass the model weights as parameters between the server and clients.

Here’s what I’ve tried so far:

  • I’ve converted the YOLO model weights to a list of NumPy arrays.

However, I’m encountering errors during training, and I suspect it’s related to how the weights are being handled.

Could someone provide guidance or examples on how to correctly pass YOLO model weights as parameters in Flower? Any help would be greatly appreciated!

2 Likes

Hi @wkqco33 , welcome to the Flower forum!

Please take a look at the pytorch-quickstart example. Concretely see how:

  • First, the ServerApp initializes the global model from a randomly initialized model (it could be pretrained). This is done in this line. The get_weights function is simply extracting the state_dict of the PyTorch model and convert it into a list of NumPy arrays. That’s the preferred way of sharing model parameters between ServerApp and ClientApp objects.
  • Second, on the ClientApp side, that exact same list of NumPy arrays is delivered as the parameters input argument to the fit() and evaluate() methods. Note both methods call set_parameters. It is a utility function that constructs a valid state_dict out of a list NumPy arrays and applies it to the model local to the ClientApp.

That’s the basic way of communicating model paramters in Flower. Note that the specific implementation of those get_weights() and set_weights() functions is dependant on the ML framework you use as well as the type of model you use. More exotic models might need adjustments.

Let us know if you succeed!!

1 Like

It’s server code

class SaveModelStrategy(FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        aggregated_parameters, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )

        return aggregated_parameters, aggregated_metrics


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    return {"accuracy": sum(accuracies) / sum(examples)}  # type: ignore


def server_fn():
    model = load_model()
    parameters = ndarrays_to_parameters(get_weights(model))

    log(INFO, "Initial parameters")

    strategy = SaveModelStrategy(
        fraction_fit=0.5,
        fraction_evaluate=1.0,
        min_available_clients=2,
        initial_parameters=parameters,
        evaluate_metrics_aggregation_fn=weighted_average,
    )

    num_rounds = 3
    config = ServerConfig(num_rounds=num_rounds)

    return strategy, config


if __name__ == "__main__":
    strategy, config = server_fn()
    fl.server.start_server(
        server_address="0.0.0.0:8080",
        strategy=strategy,
        config=config,
    )

and it’s client code

class RobotClient(NumPyClient):
    def __init__(self, data, epochs):
        self.model: YOLO = load_model()
        self.data = data
        self.epochs = epochs

    def fit(self, parameters: NDArrays, config):
        set_weights(self.model, parameters)
        results: DetMetrics | None = self.model.train(
            data=self.data,
            epochs=self.epochs,
        )
        if results is not None:
            log(INFO, f"Results: {results.box.map}")

        return get_weights(self.model), 10, {}

    def evaluate(self, parameters: NDArrays, config):
        set_weights(self.model, parameters)
        matrics: DetMetrics = self.model.val()
        accuracy = matrics.box.map
        loss = matrics.fitness

        return loss, 10, {"accuracy": accuracy}


def client_fn():
    data = "coco8.yaml"
    epochs = 1

    return RobotClient(data, epochs).to_client()


if __name__ == "__main__":
    server_ip = os.getenv("SERVER_IP", "localhost")
    server_port = os.getenv("SERVER_PORT", "8080")

    log(INFO, f"Server IP: {server_ip}:{server_port}")

    fl.client.start_client(
        server_address=f"{server_ip}:{server_port}",
        client=client_fn(),
    )

I changed the get_weights and set_weights functions with your advice.

def load_model():
    model = YOLO("yolo11n.pt")
    return model

def get_weights(model: YOLO) -> NDArrays:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def set_weights(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=False)

But I can’t see the result because of these error messages.

Traceback (most recent call last):
  File "/home/seoyc/Project/work/mlops/flower_study/src/flower_yolo/client.py", line 54, in <module>
    fl.client.start_client(
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/app.py", line 180, in start_client
    start_client_internal(
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/app.py", line 543, in start_client_internal
    raise ex
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/app.py", line 536, in start_client_internal
    reply_message = client_app(message=message, context=context)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/client_app.py", line 143, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/message_handler/message_handler.py", line 129, in handle_legacy_message_from_msgtype
    fit_res = maybe_call_fit(
              ^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/client.py", line 255, in maybe_call_fit
    return client.fit(fit_ins)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/flwr/client/numpy_client.py", line 259, in _fit
    results = self.numpy_client.fit(parameters, ins.config)  # type: ignore
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/seoyc/Project/work/mlops/flower_study/src/flower_yolo/client.py", line 22, in fit
    set_weights(self.model, parameters)
  File "/home/seoyc/Project/work/mlops/flower_study/src/flower_yolo/utils.py", line 25, in set_weights
    model.load_state_dict(state_dict, strict=False)
  File "/home/seoyc/Project/work/mlops/flower_study/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for YOLO:
        While copying the parameter named "model.model.0.conv.weight", whose dimensions in the model are torch.Size([16, 3, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 3, 3, 3]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).
        While copying the parameter named "model.model.0.conv.bias", whose dimensions in the model are torch.Size([16]) and whose dimensions in the checkpoint are torch.Size([16]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).
        size mismatch for model.model.1.conv.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32, 16, 3, 3]).
        size mismatch for model.model.1.conv.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for model.model.2.cv1.conv.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([32, 32, 1, 1]).
        size mismatch for model.model.2.cv1.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for model.model.2.cv2.conv.weight: copying a param with shape torch.Size([32, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 48, 1, 1]).
        size mismatch for model.model.2.cv2.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for model.model.2.m.0.cv1.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([8, 16, 3, 3]).
        size mismatch for model.model.2.m.0.cv1.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([8]).
        size mismatch for model.model.2.m.0.cv2.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16, 8, 3, 3]).
        size mismatch for model.model.2.m.0.cv2.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for model.model.3.conv.weight: copying a param with shape torch.Size([32, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
        size mismatch for model.model.3.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for model.model.4.cv1.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
        size mismatch for model.model.4.cv1.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for model.model.4.cv2.conv.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128, 96, 1, 1]).
        size mismatch for model.model.4.cv2.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.4.m.0.cv1.conv.weight: copying a param with shape torch.Size([64, 48, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 32, 3, 3]).
        size mismatch for model.model.4.m.0.cv1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for model.model.4.m.0.cv2.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32, 16, 3, 3]).
        size mismatch for model.model.4.m.0.cv2.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for model.model.5.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
        size mismatch for model.model.5.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.6.cv1.conv.weight: copying a param with shape torch.Size([8, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]).
        size mismatch for model.model.6.cv1.conv.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.6.cv2.conv.weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([128, 192, 1, 1]).
        size mismatch for model.model.6.cv2.conv.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([128]).
... Skip ....
size mismatch for model.model.23.cv3.0.0.0.conv.weight: copying a param with shape torch.Size([64, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3]).
        While copying the parameter named "model.model.23.cv3.0.0.0.conv.bias", whose dimensions in the model are torch.Size([64]) and whose dimensions in the checkpoint are torch.Size([64]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).
        size mismatch for model.model.23.cv3.0.0.1.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 64, 1, 1]).
        size mismatch for model.model.23.cv3.0.0.1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.0.1.0.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 1, 3, 3]).
        size mismatch for model.model.23.cv3.0.1.0.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.0.1.1.conv.weight: copying a param with shape torch.Size([64, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.0.1.1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.0.2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.0.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.0.0.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
        size mismatch for model.model.23.cv3.1.0.0.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for model.model.23.cv3.1.0.1.conv.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 128, 1, 1]).
        size mismatch for model.model.23.cv3.1.0.1.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.1.0.conv.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80, 1, 3, 3]).
        size mismatch for model.model.23.cv3.1.1.0.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.1.1.conv.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.1.1.1.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.1.2.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.1.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.0.0.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3]).
        size mismatch for model.model.23.cv3.2.0.0.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for model.model.23.cv3.2.0.1.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 256, 1, 1]).
        size mismatch for model.model.23.cv3.2.0.1.conv.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.1.0.conv.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 1, 3, 3]).
        size mismatch for model.model.23.cv3.2.1.0.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.1.1.conv.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.2.1.1.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.cv3.2.2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]).
        size mismatch for model.model.23.cv3.2.2.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([80]).
        size mismatch for model.model.23.dfl.conv.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 16, 1, 1]).

I really appreciate you and your support!!!