Fine-tuning LLMs with Flower

Hi! I’m a newbie working on fine-tuning LLMs for a chatbot in Italian. I have followed this guide. The training goes (apparently) well, but when I check the answers of the model, they are not similar to the training set. Even if I ask exactly the exact same questions from the training dataset. The following are the options I have tried (and that have not worked):

  1. Use different models from HF (i.e., google/gemma-2b-it, sapienzanlp/Minerva-7B-instruct-v1.0, swap-uniba/LLaMAntino-2-7b-hf-ITA, microsoft/phi-2, mistralai/Mistral-7B-Instruct-v0.3, and many others similarly)
  2. Increase the number of communications rounds (up to 15, since with the server I use, I can’t go further)
  3. Increase the number of local epochs (up to 3)
  4. Modify the learning rate

Has anybody faced the same issue? Or somebody could give me some hints on what I am doing wrong?

The following is the code I employed:


import os
import warnings
warnings.filterwarnings("ignore")

os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Path setup
root_path = '/my/path/'

%cd $root_path

code_path = root_path + 'code'
models_path = root_path + 'models/Daniel/federated_models_Flower_TF'
output_path = root_path + 'output/Daniel/federated_outputs_Flower_TF'
data_path = root_path + 'data'
index_path = root_path + 'index_storage'

import flwr as fl


from huggingface_hub import login
# HuggingFace login
os.environ["HF_KEY"] = "my_hf_key"
login(token=os.environ.get('HF_KEY'), add_to_git_credential=True)

def print_config(config: DictConfig):
    print(OmegaConf.to_yaml(config))

from omegaconf import DictConfig, OmegaConf
import yaml

def print_config(config: DictConfig):
    print(OmegaConf.to_yaml(config))

with open("config/federated_llama.yml", "r") as file:
    config_dict = yaml.safe_load(file)
    cfg = OmegaConf.create(config_dict)  # Convert directly to OmegaConf DictConfig

# Temporal
cfg.model.name = "google/gemma-2b-it" # RUNS VERY FAST ALGOUGHT DOESN-T PROVIDE VERY ACCURATE RESULTS, STILL IT IS GOOD

# cfg.model.name = "meta-llama/Llama-3.2-1B-Instruct"
# cfg.model.name = "swap-uniba/LLaMAntino-2-7b-hf-ITA" # RUNS OK, BUT TAKES TOO LONG
# cfg.model.name = "microsoft/phi-2" # RUNS FASTER BUT NOT GOOD RESULTS
# cfg.model.name = "meta-llama/Llama-3.2-1B" # RUNS VERY FAST ALGOUGHT DOESN-T PROVIDE VERY ACCURATE RESULTS
# cfg.model.name = "mistralai/Mistral-7B-Instruct-v0.3" # Takes long and produce not accurate results
# cfg.model.name = "swap-uniba/LLaMAntino-3-ANITA-8B-Inst-DPO-ITA" # NOT WORKING
# cfg.model.name = "osiria/llama2-13b-italian" # NOT WORKING
# cfg.model.name = "meta-llama/Llama-3.1-8B-Instruct"
# cfg.model.name = "sapienzanlp/Minerva-7B-instruct-v1.0"
# cfg.model.name = "mii-llm/maestrale-chat-v0.1-alpha-sft"
# cfg.model.name = "microsoft/phi-3-mini-4k-instruct"
# cfg.model.name = "google/gemma-2b-it"

cfg.flower.client_resources.num_cpus = 2
cfg.flower.fraction_fit = 1.0
cfg.flower.num_rounds = 5
cfg.train.training_arguments.report_to = "none"



cfg.model.quantization = 4
cfg.model.bnb_4bit_quant_type = "nf4"
cfg.model.use_flash_attention_2 = "true"
cfg.model.gradient_checkpointing = "true"
cfg.model.lora.peft_lora_r = 64
cfg.model.lora.peft_lora_alpha = 128    

cfg.train.training_arguments.num_train_epochs = 3
cfg.train.training_arguments.per_device_train_batch_size = 4
cfg.train.training_arguments.gradient_accumulation_steps = 4
cfg.train.training_arguments.learning_rate = 2e-5
cfg.train.training_arguments.warmup_ratio = 0.03
cfg.train.training_arguments.max_grad_norm = 0.3
cfg.train.training_arguments.optim = "paged_adamw_8bit"



print_config(cfg)

from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
import matplotlib.pyplot as plt

def format_dataset(dataset):
    """Helper function to format the dataset"""

    dataset = dataset.remove_columns(['instruction'])
    dataset = dataset.rename_column("output", "response")
    dataset = dataset.rename_column("input", "instruction")
    return dataset

def visualize_partitions(fed_dataset: FederatedDataset):
    """Helper function to visualize the partitions of the dataset"""

    _ = fed_dataset.load_partition(0)
    num_partitions = fed_dataset.partitioners['train'].num_partitions

    plt.bar(range(num_partitions), [len(fed_dataset.load_partition(i)) for i in range(num_partitions)])
    plt.xticks(range(num_partitions))
    plt.xlabel("Partition ID")
    plt.ylabel("Number of examples")
    plt.title(f"IID partitioning into {num_partitions} partitions")


from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import NaturalIdPartitioner

partitioner = NaturalIdPartitioner(partition_by="source")
fds = FederatedDataset(dataset="damjimenezgu/fl_data_fub",
                       partitioners={"train": partitioner})
partition = fds.load_partition(1)    
print("len(partition)",len(partition))
visualize_partitions(fds)

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
)
from peft.utils import prepare_model_for_kbit_training
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer

def get_model(model_cfg: DictConfig):
    """Load model with appropiate quantization config and
    other optimizations. Notice here that we are returning the
    LoRA model, not the full model."""

    use_cuda = torch.cuda.is_available()
    quantization_config = None
    model_name = model_cfg.name
    if use_cuda:
        if model_cfg.quantization == 4:
            quantization_config = BitsAndBytesConfig(load_in_4bit=True)
        elif model_cfg.quantization == 8:
            quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            raise ValueError(
                f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
            )

        model_name = model_cfg.name

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        # token = os.environ["HF_ACCESS_TOKEN"]
    )

    if use_cuda:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
        )

    target_modules = model_cfg.lora.target_modules
    if target_modules:
        target_modules = list(target_modules)
    peft_config = LoraConfig(
        r=model_cfg.lora.peft_lora_r,
        lora_alpha=model_cfg.lora.peft_lora_alpha,
        lora_dropout=0.075,
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )

    peft_model = get_peft_model(model, peft_config)
    if not (use_cuda):
        peft_model.enable_input_require_grads()

    if model_cfg.gradient_checkpointing:
        model.config.use_cache = False

    return peft_model

def formatting_prompts_func(example):
    mssg = "Rispondi SOLAMENTE con le informazioni fornite di seguito."
    return f"{mssg}\n### Istruzione:\n{example['instruction']}\n### Risposta: {example['response']}"
    
def get_tokenizer_and_data_collator_and_prompt_formatting(
    model_name: str, use_fast: bool, padding_side: str
):
    # From: https://huggingface.co/docs/trl/en/sft_trainer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, use_fast=use_fast, padding_side=padding_side
    )

    tokenizer.pad_token = (
        tokenizer.bos_token if padding_side == "left" else tokenizer.eos_token
    )
    response_template_with_context = "\n### Risposta:"  
    response_template_ids = tokenizer.encode(
        response_template_with_context, add_special_tokens=False
    )[2:]
    data_collator = DataCollatorForCompletionOnlyLM(
        response_template_ids, tokenizer=tokenizer
    )
    
    return tokenizer, data_collator, formatting_prompts_func    

tokenizer, data_collator, formatting_prompts_func = get_tokenizer_and_data_collator_and_prompt_formatting(
    cfg.model.name,
    cfg.model.use_fast_tokenizer,
    cfg.train.padding_side,
)    

import math

def cosine_annealing(
    current_round: int,
    total_round: int,
    lrate_max: float = 0.001,
    lrate_min: float = 0.0,
) -> float:
    """Implement cosine annealing learning rate schedule. Strictly speaking this
    is not necessary."""

    # Convert all inputs to the correct numeric types
    current_round = int(current_round)
    total_round = int(total_round)
    lrate_max = float(lrate_max)
    lrate_min = float(lrate_min)

    cos_inner = math.pi * current_round / total_round
    return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))

from flwr.common import Context
from flwr.common.typing import NDArrays, Scalar
from flwr.client import NumPyClient

from typing import Dict, Tuple, Callable
from collections import OrderedDict
from trl import SFTConfig

def set_parameters(model, parameters: NDArrays) -> None:
    """Change the parameters of the model using the given ones."""

    peft_state_dict_keys = get_peft_model_state_dict(model).keys()
    params_dict = zip(peft_state_dict_keys, parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    set_peft_model_state_dict(model, state_dict)

class FlowerClient(NumPyClient):
    def __init__(
        self,
        model_cfg: DictConfig,
        train_cfg: DictConfig,
        trainset,
        tokenizer,
        formatting_prompts_func,
        data_collator,
        save_path,
    ):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.train_cfg = train_cfg
        self.training_arguments = SFTConfig(**train_cfg.training_arguments)
        self.training_arguments.report_to = "none"
        self.tokenizer = tokenizer
        self.formatting_prompts_func = formatting_prompts_func
        self.data_collator = data_collator
        self.save_path = save_path

        # instantiate model
        self.model = get_model(model_cfg)

        self.trainset = trainset

    def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
        """Return the parameters of the current net."""

        state_dict = get_peft_model_state_dict(self.model)
        return [val.cpu().numpy() for _, val in state_dict.items()]

    def fit(
        self, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Tuple[NDArrays, int, Dict]:
        """Implement distributed fit function for a given client."""
        set_parameters(self.model, parameters)
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX ENTERED TO FITTING XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        new_lr = cosine_annealing(
            int(config["current_round"]),  # Ensure this is an int
            int(self.train_cfg.num_rounds),  # Ensure these are numbers
            float(self.train_cfg.learning_rate_max),
            float(self.train_cfg.learning_rate_min),
        )
        self.training_arguments.learning_rate = new_lr
        self.training_arguments.output_dir = self.save_path

        evalset = True
        if self.train_cfg.evaluate_split:
            train_test = self.trainset.train_test_split(test_size=0.1, seed=1234)
            trainset = train_test['train']
            evalset = train_test['test']
        else:
            trainset = self.trainset

        trainer = SFTTrainer(
            model=self.model,
            processing_class=self.tokenizer,
            args=self.training_arguments,
            train_dataset=trainset,
            eval_dataset=evalset,
            formatting_func=self.formatting_prompts_func,
            data_collator=self.data_collator,
        )

        metrics = {}
        if self.train_cfg.evaluate_split:
            eval_res = trainer.evaluate()
            metrics['eval_loss'] = eval_res['eval_loss']
            print(eval_res)

        # Do local training
        results = trainer.train()

        metrics = {**metrics, "train_loss": results.training_loss}

        return (
            self.get_parameters({}),
            len(self.trainset),
            metrics,
        )

def gen_client_fn(
    fds,
    tokenizer,
    formatting_prompts_func,
    data_collator,
    model_cfg: DictConfig,
    train_cfg: DictConfig,
    save_path: str,
) -> Callable[[str], FlowerClient]:
    """Generate the client function that creates the Flower Clients."""

    def client_fn(context: Context) -> FlowerClient:
        """Create a Flower client representing a single organization."""
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX entering the client function XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        # Let's get the partition corresponding to the i-th client
        partition_id = int(context.node_config["partition-id"])
        client_trainset = fds.load_partition(partition_id, "train")
        client_trainset = client_trainset.remove_columns(["instruction","source"])
        client_trainset = client_trainset.rename_column("input", "instruction")
        client_trainset = client_trainset.rename_column("output", "response")

        
        return FlowerClient(
            model_cfg,
            train_cfg,
            client_trainset,
            tokenizer,
            formatting_prompts_func,
            data_collator,
            save_path,
        ).to_client()

    return client_fn

from flwr.client.mod import fixedclipping_mod

save_path = "./config/my_fl_llama_model"
client = fl.client.ClientApp(
    client_fn=gen_client_fn(
        fds,
        tokenizer,
        formatting_prompts_func,
        data_collator,
        cfg.model,
        cfg.train,
        save_path,
    ),
    # mods=[fixedclipping_mod]
)

from flwr.common import Context

def get_on_fit_config():
    """
    •	Purpose: This function provides the configuration dictionary sent to each
    client before training begins in a given round.
	•	Use Case: Clients can adapt behaviors (e.g., learning rate schedules) based
    on the current round number.
	•	Design Pattern: Returns a function (fit_config_fn) that Flower calls at each
    round.

    🧠 Analogy: This is like giving each chef a new cooking instruction every day
    based on how many days the kitchen has been operating.

    """
    def fit_config_fn(server_round: int):
        fit_config = {"current_round": server_round}
        return fit_config

    return fit_config_fn

def fit_weighted_average(metrics):
    """
    •	Purpose: Calculates a weighted average of training loss across all clients.
    •	Mechanism:
        •	Each client’s training loss is scaled by the number of examples it used.
        •	This ensures larger datasets have more influence on the final average.
    •	Why It Matters: Without weighting, small datasets could skew the overall metric.

    Example:
    # Client 1: 100 examples, loss = 0.5 → 100 * 0.5 = 50
    # Client 2: 200 examples, loss = 0.25 → 200 * 0.25 = 50
    # Weighted average = (50 + 50) / (100 + 200) = 0.333
    """
    # Multiply accuracy of each client by number of examples used
    losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"train_loss": sum(losses) / sum(examples)}

def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
    """
    •	Purpose: Save the global model periodically during training.
	•	Conditions:
	•	Skip round 0.
	•	Save if it’s the final round or every n rounds (save_every_round).
	•	How:
	•	Reconstruct the model from the config.
	•	Load the current global parameters.
	•	Save using the HuggingFace save_pretrained method.

    🧠 Why This Matters: In federated learning, there’s no single centralized
    training process. If something goes wrong, saved checkpoints are your recovery
    point.

    🔁 Return Format: Always returns 0.0, {} because evaluation is optional here
    — the function is used mainly for checkpointing.
    """

    def evaluate(server_round: int, parameters, config):
        # Save model
        if server_round != 0 and (
            server_round == total_round or server_round % save_every_round == 0
        ):
            # Init model
            model = get_model(model_cfg)
            set_parameters(model, parameters)

            model.save_pretrained(f"{save_path}/peft_{server_round}")

        return 0.0, {}

    return evaluate

from flwr.server.strategy import (
    DifferentialPrivacyClientSideFixedClipping
)

def server_fn(context: Context):
    """
    This function returns a configured Flower server app, which will be passed to
    start_simulation() later. It defines:
	•	The strategy for aggregation (FedAvg)
	•	How rounds behave (e.g., which clients to sample)
	•	Optional advanced features like Differential Privacy
    """

    # Define the Strategy
    ## 🔁 FedAvg is the classic federated strategy where model updates are averaged
    ## across clients.
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=cfg.flower.num_clients, # total clients
        fraction_fit=cfg.flower.fraction_fit, # ratio of clients to sample
        fraction_evaluate=0.0, # No federated evaluation
        # A (optional) function used to configure a "fit()" round
        on_fit_config_fn=get_on_fit_config(),
        # A (optional) function to aggregate metrics sent by clients
        fit_metrics_aggregation_fn=fit_weighted_average,
        # A (optional) function to execute on the server after each round.
        # In this example the function only saves the global model.
        evaluate_fn=get_evaluate_fn(
            cfg.model,
            cfg.train.save_every_round,
            cfg.flower.num_rounds,
            save_path
        ),
    )

    # Number of rounds to run the simulation
    num_rounds = cfg.flower.num_rounds
    config = fl.server.ServerConfig(num_rounds=num_rounds)

    return fl.server.ServerAppComponents(strategy=strategy, config=config)    

server = fl.server.ServerApp(server_fn=server_fn)    

from logging import ERROR, DEBUG
import time

client_resources = dict(cfg.flower.client_resources)
backend_setup = {"logging_level": DEBUG, "log_to_driver": False}


start_time = time.time()

fl.simulation.run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=cfg.flower.num_clients,
    backend_config={
        "client_resources": client_resources,
        "init_args": backend_setup
    }
)

print("--- %s seconds ---" % (time.time() - start_time))

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
import os

save_path = "./config/my_fl_llama_model"

# 1a. Load the base 4-bit LLM
base = AutoModelForCausalLM.from_pretrained( 
    cfg.model.name,        
    quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    torch_dtype=torch.bfloat16,
)

# 1b. Wrap it with your federated fine-tuned adapter
peft_model = PeftModel.from_pretrained(
    base,
    os.path.join(save_path, "peft_20"),    
)

merged_model = peft_model.merge_and_unload()

from transformers import AutoTokenizer

# os.makedirs("my_fl_llama_model")

# 3a. Create a local folder
output_dir = "my_fl_llama_model/merged_llama_32_1b"

# 3b. Save merged model weights & config
merged_model.save_pretrained(output_dir)

# 3c. Also save tokenizer files for complete deployability
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(cfg.model.name, use_fast=True)
tokenizer.save_pretrained(output_dir)

from llama_index.llms.huggingface import HuggingFaceLLM

# Initialize HuggingFaceLLM
llm = HuggingFaceLLM(
    model_name="my_fl_llama_model/merged_llama_32_1b",
    tokenizer_name="my_fl_llama_model/merged_llama_32_1b" 
)