Customizing Flower Strategies

I am designing a YOLO based architecture using ultralytics and flwr, in this implementation I am planning on keeping non trainable BN parameters local for each client (this means that non trainable parameters for client 1 will not change while aggregation). I am facing a challenge while implementing this.

Proposed Solution:

I am introducing a new arrays in message named untrain_arrays, and the type of this arrays should be Iterable[ArrayRecord], it represents each client untrainable parameters lets say untrain_array[1] represents parameters of first client and so on.

sample code:

    def configure_train(
            self,
            server_round:int, 
            arrays: ArrayRecord,
            untrain_arrays: Iterable[ArrayRecord],
            config: ConfigRecord,
            grid:Grid,
        ) -> Iterable[Message]:
        if self.fraction_train == 0.0:
            return []
        # Sample nodes
        num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
        sample_size = max(num_nodes, self.min_train_nodes)
        node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
        log(
            INFO,
            "configure_train: Sampled %s nodes (out of %s)",
            len(node_ids),
            len(num_total),
        )
        # Always inject current server round
        config["server-round"] = server_round

        # Construct messages
        record = RecordDict(
            {self.arrayrecord_key: arrays, self.configrecord_key: config, self.untrain_record_key: untrain_arrays}
        )
        return self._construct_messages(record, node_ids, MessageType.TRAIN)

def aggregate_train(
        self,
        server_round: int,
        replies: Iterable[Message]
    ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint."""
        aggregated_parameters, aggregated_metrics = super().aggregate_train(
            server_round, replies,
        )

        valid_replies = super()._check_and_log_replies(replies, is_train=True)

        untrain_arrays = None
        if valid_replies:
            untrain_arrays = [msg.content["untrain_arrays"] for msg in valid_replies]

        if aggregated_parameters is not None:
            net = self.load_and_update_model(aggregated_parameters)            
            full_parameters = {k:val.detach() for k, val in net.model.state_dict().items()}
            return full_parameters, aggregated_metrics
            
        return aggregated_parameters, aggregated_metrics, untrain_arrays

Question:

  1. Here the main question is while constructing the evaluate_config message, how can I include a Iterable[ArrayRecord] in the content. Are there any ways that I can handle this?

Hi @raghuram ,

Yes!

The simplest approach is to override the configure_evaluate method and store each ArrayRecord separately in the Message.content dictionary. Since Message.content is a RecordDict, you can safely attach ArrayRecord instances to it using unique keys. For example:

def configure_evaluate(
    self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
) -> Iterable[Message]:
    """Configure the next round of federated evaluation."""
    node_ids = list(grid.get_node_ids())

    # Say, you somehow get the arrays from clients
    client_arrays = [arrays_0, arrays_1, arrays_2]

    # Construct message content
    record = RecordDict(
        {
            self.arrayrecord_key: arrays,
            self.configrecord_key: config,
            "client_arrays_0": client_arrays[0],
            "client_arrays_1": client_arrays[1],
            "client_arrays_2": client_arrays[2],
        }
    )
    return self._construct_messages(record, node_ids, MessageType.EVALUATE)

Here, each untrainable parameter set (as an ArrayRecord) is given its own entry in Message.content.
This allows you to later access them in ClientApp like:

untrain_arrays = [msg.content[f"client_arrays_{i}"] for i in range(3)]

You can treat the RecordDict as a dictionary of records, giving you flexibility to store and retrieve any serializable info.

Note:
Message.content is a RecordDict, which supports ArrayRecord, MetricRecord, and ConfigRecord objects. It’s designed to serialize and transmit such records safely across ServerApp and ClientApp.

Here is the modified Strategy file:

""" strategy for federated learning """
import io
import os
import time
from logging import INFO
from pathlib import Path
from typing import Optional, Callable, Iterable

from flwr.common import (
    Parameters, 
    Scalar, 
    log,
    RecordDict,
    ConfigRecord,
    ArrayRecord,
    Message,
    MessageType,
    MetricRecord
)
from flwr.server import Grid
from flwr.serverapp.strategy import FedAvg, FedAdam, Result
from flwr.serverapp.strategy.strategy_utils import (
    log_strategy_start_info, 
    sample_nodes,
)

from ultralytics import YOLO

class CustomFedAvg(FedAvg):
    """
    FedAvg class that works with YOLO architecture
    """
    def __init__(self, *, fraction_train = 1, fraction_evaluate = 1, min_train_nodes = 2, min_evaluate_nodes = 2, min_available_nodes = 2):
        super().__init__(fraction_train=fraction_train, fraction_evaluate=fraction_evaluate, min_train_nodes=min_train_nodes, min_evaluate_nodes=min_evaluate_nodes, min_available_nodes=min_available_nodes)
        BASE_LIB_PATH = os.path.abspath(os.path.dirname(__file__))
        BASE_DIR_PATH = os.path.dirname(BASE_LIB_PATH)
        self.model_path = Path(BASE_DIR_PATH) / "yolo_config" / "yolo11n.yaml"
        self.untrain_record_key = "untrain_arrays"
        self.untrain_arrays : dict[str, ArrayRecord] = {}
    
    def __repr__(self):
        rep = f"FedAveraging merged with ultralytics, accept failures = {self.accept_failures}"
        return rep

    def _construct_messages(
            self, 
            record:RecordDict,
            node_ids:list[int],
            message_type:str
    ) -> Iterable[Message]:
        messages = []
        for node_id in node_ids:
            record[self.untrain_record_key] = self.untrain_arrays.get(
                node_id,
                self.untrain_arrays.get(0, None)
            )
            message = Message(
                content=record,
                message_type=message_type,
                dst_node_id=node_id,
            )
            messages.append(message)
        return messages
    
    def configure_train(
            self,
            server_round:int, 
            arrays: ArrayRecord,
            config: ConfigRecord,
            grid:Grid,
        ) -> Iterable[Message]:
        if self.fraction_train == 0.0:
            return []
        # Sample nodes
        num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
        sample_size = max(num_nodes, self.min_train_nodes)
        node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
        log(
            INFO,
            "configure_train: Sampled %s nodes (out of %s)",
            len(node_ids),
            len(num_total),
        )
        # Always inject current server round
        config["server-round"] = server_round

        # Construct messages
        record = RecordDict(
            {self.arrayrecord_key: arrays, self.configrecord_key: config}
        )
        return self._construct_messages(record, node_ids, MessageType.TRAIN)


    def load_and_update_model(self, aggregated_state) -> YOLO:
        net = YOLO(self.model_path).load('yolov11n.pt')
        state_dict = net.model.state_dict().copy()
        state_dict.update(aggregated_state)
        net.model.load_state_dict(state_dict)
        return net

    def aggregate_train(
        self,
        server_round: int,
        replies: Iterable[Message]
    ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint."""
        aggregated_parameters, aggregated_metrics = super().aggregate_train(
            server_round, replies
        )

        valid_replies, _ = super()._check_and_log_replies(replies, is_train=True)

        for msg in valid_replies:
            nid = msg.node_id
            client_untrain = msg.content[self.untrain_record_key]
            self.untrain_arrays[nid] = client_untrain

        if aggregated_parameters is not None:
            net = self.load_and_update_model(aggregated_parameters)            
            full_parameters = {k:val.detach() for k, val in net.model.state_dict().items() \
                               if not k.endswith(('running_mean', 'running_var', 'num_batches_tracked'))}
            return full_parameters, aggregated_metrics
            
        return aggregated_parameters, aggregated_metrics

    def configure_evaluate(
            self,
            server_round:int,
            arrays:ArrayRecord,
            config:ConfigRecord,
            grid:Grid,
    ) -> Iterable[Message]:
        if self.fraction_evaluate == 0.0:
            return []

        # Sample nodes
        num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
        sample_size = max(num_nodes, self.min_evaluate_nodes)
        node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
        log(
            INFO,
            "configure_evaluate: Sampled %s nodes (out of %s)",
            len(node_ids),
            len(num_total),
        )

        # Always inject current server round
        config["server-round"] = server_round

        # Construct messages
        record = RecordDict(
            {self.arrayrecord_key: arrays, self.configrecord_key: config}
        )
        return self._construct_messages(record, node_ids, MessageType.EVALUATE)

    def start(
            self,
            grid:Grid,
            initial_arrays:ArrayRecord,
            untrainable_parameters:ArrayRecord,
            num_rounds:int=3,
            timeout:float=3600,
            train_config: Optional[ConfigRecord] = None,
            evaluate_config: Optional[ConfigRecord] = None,
            evaluate_fn: Optional[
                Callable[[int, ArrayRecord], Optional[MetricRecord]]
            ] = None,
    ) -> Result:
        log(INFO, "Starting %s strategy:", self.__class__.__name__)
        log_strategy_start_info(
            num_rounds, initial_arrays, train_config, evaluate_config
        )
        self.summary()
        log(INFO, "")

        # Initialize if None
        train_config = ConfigRecord() if train_config is None else train_config
        evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
        result = Result()

        t_start = time.time()
        # Evaluate starting global parameters
        if evaluate_fn:
            res = evaluate_fn(0, initial_arrays)
            log(INFO, "Initial global evaluation results: %s", res)
            if res is not None:
                result.evaluate_metrics_serverapp[0] = res
            
        arrays = initial_arrays
        self.untrain_arrays[0] = untrainable_parameters

        for current_round in range(1, num_rounds+1):
            log(INFO, "")
            log(INFO, "[ROUND %s %s]", current_round, num_rounds)

            train_replies = grid.send_and_receive(
                messages=self.configure_train(
                    current_round,
                    arrays,
                    train_config,
                    grid,
                ),
                timeout=timeout,
            )

            agg_arrays, agg_train_metrics = self.aggregate_train(
                current_round, train_replies
            )

            if agg_arrays is not None:
                result.arrays = agg_arrays
                arrays = agg_arrays
            if agg_train_metrics is not None:
                log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
                result.train_metrics_clientapp[current_round] = agg_train_metrics

            evaluate_replies = grid.send_and_receive(
                messages=self.configure_evaluate(
                    current_round,
                    arrays,
                    evaluate_config,    
                    grid,
                ),
                timeout=timeout,
            )

            agg_evaluate_metrics = self.aggregate_evaluate(
                current_round,
                evaluate_replies,
            )

            # Log training metrics and append to history
            if agg_evaluate_metrics is not None:
                log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
                result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics

            # Centralized evaluation
            if evaluate_fn:
                log(INFO, "Global evaluation")
                res = evaluate_fn(current_round, arrays)
                log(INFO, "\t└──> MetricRecord: %s", res)
                if res is not None:
                    result.evaluate_metrics_serverapp[current_round] = res

        log(INFO, "")
        log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
        log(INFO, "")
        log(INFO, "Final results:")
        log(INFO, "")
        for line in io.StringIO(str(result)):
            log(INFO, "\t%s", line.strip("\n"))
        log(INFO, "")

        return result

Error:

flwr.serverapp.exception.InconsistentMessageReplies: Expected exactly one ArrayRecord in replies. Skipping aggregation.

Can anyone give me idea on how to solve this problem?

It seems that you are sending multiple models in the training round, not the evaluation. But anyway, the main reason you’re seeing this error is that you are calling super()._check_and_log_replies, which enforces consistency checks for built-in FedAvg. That helper verifies that each reply contains exactly one ArrayRecord, which is a requirement for the default FedAvg aggregation logic.

In your case, you’re intentionally sending additional records (untrain_arrays), so this check will always fail. The fix is simply to skip that check. Just remove the call to super()._check_and_log_replies(...). After that, aggregation should proceed normally.

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.