API Reference

This document describes the helper functions and classes provided by AccelHydra.

Models

class accel_hydra.models.common.LoadPretrainedMixin[source]

Bases: object

Base class for loading pretrained checkpoints with custom state dict processing function.

process_state_dict(model_dict: dict[str, Tensor], state_dict: dict[str, Tensor])[source]

Custom processing functions of each model that transforms state_dict loaded from checkpoints to the state that can be used in load_state_dict. Use merge_mathced_keys to update parameters with matched names and shapes by default.

Parameters:
  • model_dict – The state dict of the current model, which is going to load pretrained parameters

  • state_dict – A dictionary of parameters from a pre-trained model.

Returns:

The updated state dict, where parameters with matched keys and shape are updated with values in state_dict.

Return type:

dict[str, torch.Tensor]

load_pretrained(ckpt_path: str | Path)[source]
class accel_hydra.models.common.CountParamsMixin[source]

Bases: object

count_params()[source]
class accel_hydra.models.common.SaveTrainableParamsMixin[source]

Bases: object

property param_names_to_save
load_state_dict(state_dict, strict=True)[source]

Trainer Components

class accel_hydra.trainer.Trainer(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True)[source]

Bases: CheckpointMixin

Base trainer class providing training workflow management.

Usage:

This is an abstract base class. Subclasses must implement training_step() and validation_step() methods to define the specific training and validation logic.

config_dict

Configuration dictionary for storing training configuration information.

Type:

dict | None

project_dir

Project root directory path for saving training-related files.

Type:

str | pathlib.Path

checkpoint_dir

Checkpoint save directory. If None, uses project_dir/checkpoints.

Type:

str | pathlib.Path

logging_config

Logging configuration object for experiment logging (wandb/swanlab/tensorboard).

Type:

accel_hydra.trainer.base.LoggingConfig | None

train_dataloader

Training data loader.

Type:

torch.utils.data.dataloader.DataLoader

val_dataloader

Validation data loader. Can be None to skip validation.

Type:

torch.utils.data.dataloader.DataLoader | None

model

PyTorch model to be trained.

Type:

torch.nn.modules.module.Module

optimizer

Optimizer for training.

Type:

torch.optim.optimizer.Optimizer

lr_scheduler

Learning rate scheduler.

Type:

torch.optim.lr_scheduler.LRScheduler

loss_fn

Loss function.

Type:

torch.nn.modules.module.Module

epochs

Total number of training epochs.

Type:

int

epoch_length

Number of steps per epoch. If None, uses the length of train_dataloader.

Type:

int | None

lr_scheduler_interval

Learning rate scheduler update interval. STEP means update every step, EPOCH means update every epoch.

Type:

accel_hydra.trainer.base.LRSchedulerInterval

gradient_accumulation_steps

Number of gradient accumulation steps to simulate larger batch size.

Type:

int

max_grad_norm

Maximum gradient norm for gradient clipping. If None, no gradient clipping is performed.

Type:

float | None

resume_from_checkpoint

Path to checkpoint for resuming training. None by default, meaning training from scratch.

Type:

str | pathlib.Path | None

save_every_n_steps

Save checkpoint every N steps. If None, no step-based saving.

Type:

int | None

permanent_save_every_n_steps

Permanently save checkpoint to project_dir every N steps. If None, no permanent saving. Checkpoints saved by save_every_n_steps and save_every_n_epochs will be automatically deleted based on cleaning strategies but these checkpoints will not be deleted.

Type:

int | None

save_every_n_epochs

Save checkpoint every N epochs. Default is 1 (save every epoch).

Type:

int | None

save_last_k

Keep the last K checkpoints and delete older ones. Default is 1 (keep only the latest checkpoint).

Type:

int | None

metric_monitor

MetricMonitor instance for tracking validation metrics and saving best model.

Type:

accel_hydra.trainer.base.MetricMonitor | None

early_stop

Early stopping patience. Stop training if validation metric doesn’t improve for N consecutive epochs.

Type:

int | None

even_batches

Whether to use even batches for handling inconsistent data amounts across processes. Must set to False when batch_sampler does not have batch_size.

Type:

bool

config_dict: dict | None = None
project_dir: str | Path
checkpoint_dir: str | Path = None
logging_config: LoggingConfig | None = None
train_dataloader: DataLoader
val_dataloader: DataLoader | None
model: Module
optimizer: Optimizer
lr_scheduler: LRScheduler
loss_fn: Module
epochs: int
epoch_length: int | None = None
lr_scheduler_interval: LRSchedulerInterval = 'step'
gradient_accumulation_steps: int = 1
max_grad_norm: float | None = 2.0
resume_from_checkpoint: str | Path | None = None
log_every_n_steps: int | None = 50
save_every_n_steps: int | None = None
permanent_save_every_n_steps: int | None = None
save_every_n_epochs: int | None = 1
save_last_k: int | None = 1
metric_monitor: MetricMonitor | None = None
early_stop: int | None = None
even_batches: bool = True
wrap_and_broadcast_value(value: Any) Tensor[source]
setup_accelerator() None[source]
abstract training_step(batch: Any, batch_idx: int) Tensor[source]

Performs a single training step, like training_step() in Pytorch-Lightning.

This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation.

Parameters:
  • batch – A batch of data from the training DataLoader.

  • batch_idx – The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position.

Returns:

The computed loss tensor. It will be used for loss.backward().

The tensor should be a 0-dimensional. The loss will be automatically logged as “train/loss” by the Trainer.

Return type:

torch.Tensor

Example

def training_step(self, batch, batch_idx):
    features, labels = batch
    preds = self.model(features)
    loss = self.loss_fn(preds, labels)

    # Optional: Log additional metrics
    lr = self.optimizer.param_groups[0]["lr"]
    self.accelerator.log({"train/lr": lr}, step=self.step)

    return loss

Note

  • You should NOT call loss.backward() manually - the Trainer handles this automatically.

abstract validation_step(batch: Any, batch_idx: int) None[source]

Performs a single validation step, like validation_step() in Pytorch-Lightning.

This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later.

Parameters:
  • batch – A batch of data from the validation DataLoader.

  • batch_idx – The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position.

Returns:

This method should not return anything. Store validation results in instance

variables for later use in get_val_metrics().

Return type:

None

Example

def validation_step(self, batch, batch_idx):
    features, labels = batch
    preds = self.model(features)
    predictions = preds.argmax(dim=-1)

    # Gather predictions from all processes (important for distributed training)
    output = {"predictions": predictions, "labels": labels}
    output = self.accelerator.gather_for_metrics(output)

    # Accumulate metrics
    accurate_preds = (output["predictions"] == output["labels"])
    self.validation_stats["accurate"] += accurate_preds.long().sum()
    self.validation_stats["num_elems"] += accurate_preds.shape[0]

Note

  • Use self.accelerator.gather_for_metrics() to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks.

get_context() contextmanager[source]

FIXME: why does it not work?

gather_min_length(length: int) int[source]
val_loop() None[source]
on_validation_start() None[source]
on_validation_end() None[source]
get_val_metrics() dict[str, Any][source]
on_train_epoch_start() None[source]
on_train_epoch_end() None[source]
property checkpoint_objects: list[CheckpointMixin]

Returns a list of additional objects to be included in checkpoints.

This property allows subclasses to specify additional objects (beyond the Trainer itself) that should be saved and restored during checkpointing. All objects in the returned list must implement the CheckpointMixin interface (i.e., have state_dict() and load_state_dict() methods). The customized checkpointing is achieved by registering these objects with the Accelerate framework during setup_accelerator().

Returns:

A list of objects to include in checkpoints. Default is an

empty list. Subclasses can override this property to return custom objects.

Return type:

list[CheckpointMixin]

Example

import torch
from accel_hydra.trainer import CheckpointMixin

class VersionTracker(CheckpointMixin):
    def __init__(self):
        self.version = torch.__version__

    def state_dict(self) -> dict:
        return {"version": self.version}

    def load_state_dict(self, state_dict: dict) -> None:
        self.version = state_dict["version"]

class MyTrainer(Trainer):
    @property
    def checkpoint_objects(self) -> list[CheckpointMixin]:
        return [VersionTracker()]
state_dict() dict[source]
load_state_dict(state_dict: dict) None[source]
clean_checkpoints_to_k(checkpoints_dir: Path | str, k: int) None[source]
save_checkpoint(save_dir: Path | str, clean_old_checkpoints: bool = True) None[source]

Note: since wait_for_everyone is called, user must be responsible for making sure all processes call or not call this function at the same time!!!

train_loop() None[source]
on_train_start() None[source]
on_train_end() None[source]
train(seed: int) None[source]
__init__(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True) None
class accel_hydra.trainer.MetricMonitor(*, metric_name: str = 'loss', mode: Literal['min', 'max'] = 'min')[source]

Bases: CheckpointMixin

metric_name: str = 'loss'
mode: Literal['min', 'max'] = 'min'
compare(x: float, best_x: float) bool[source]

Compares the current value with the best value based on mode.

__call__(metric_dict: dict[str, Any]) bool[source]

Checks if the new value is better and updates best_value if so.

state_dict() dict[source]

Returns the state of the object as a dictionary.

load_state_dict(state_dict: dict)[source]

Loads the state from a dictionary.

__init__(*, metric_name: str = 'loss', mode: Literal['min', 'max'] = 'min') None
class accel_hydra.trainer.LoggingConfig(*, report_to: str | None = 'wandb', project: str, save_dir: str | pathlib.Path, name: str, resume_id: str | None = None, workspace: str | None = None)[source]

Bases: object

report_to: str | None = 'wandb'
project: str
save_dir: str | Path
name: str
resume_id: str | None = None
workspace: str | None = None
__init__(*, report_to: str | None = 'wandb', project: str, save_dir: str | Path, name: str, resume_id: str | None = None, workspace: str | None = None) None
class accel_hydra.trainer.CheckpointMixin[source]

Bases: ABC

abstract state_dict() dict[source]
abstract load_state_dict(state_dict: dict) None[source]

Training Launcher

class accel_hydra.train_launcher.TrainLauncher[source]

Bases: object

Base class for training launchers that handle configuration and training setup.

This class provides a structured way to launch training with Hydra configuration. Subclasses can override specific methods to customize the training process.

static get_register_resolver_fn() Callable[source]

Get the function to register custom OmegaConf resolvers.

Subclasses can override this staticmethod to return a custom resolver registration function. The returned function should be callable without arguments.

Returns:

A function that registers custom OmegaConf resolvers

Return type:

Callable

get_steps_for_lr_scheduler(train_dataloader: DataLoader)[source]

Calculate steps for LR scheduler.

This method handles the complexity of step counting in distributed training with gradient accumulation.

Parameters:

train_dataloader – The training dataloader

Returns:

(num_training_updates, num_warmup_updates) where num_warmup_updates can be None

Return type:

tuple

get_dataloaders()[source]
run()[source]

Main entry point that orchestrates the training setup and launch.

This method follows the standard training setup flow: 1. Load configuration 2. Setup resume if needed 3. Create model, dataloaders, optimizer, LR scheduler, loss function 4. Create trainer and start training

Subclasses can override this method to customize the entire flow, or override individual methods to customize specific steps.

Utilities

Data Utilities

accel_hydra.utils.data.init_dataloader_from_config(config: dict)[source]

A helper function to initialize a dataloader from a config.

Parameters:

config – A dictionary or DictConfig containing the dataloader configuration.

Returns:

instantiated dataloader object.

Example

config = '''
train_dataloader:
_target_: torch.utils.data.DataLoader
dataset:
    _target_: data.train_dataset
    data_root: /path/to/data
# sampler:
#   _target_: torch.utils.data.Sampler
#   ...
# batch_sampler:
#   _target_: torch.utils.data.BatchSampler
#   ...
'''
config = OmegaConf.create(config)
train_dataloader = init_dataloader_from_config(config["train_dataloader"])

Configuration Utilities

accel_hydra.utils.config.multiply(*args)[source]
accel_hydra.utils.config.register_omegaconf_resolvers(clear_resolvers: bool = True) None[source]

Register custom resolver for hydra configs, which can be used in YAML files for dynamically setting values.

Parameters:

clear_resolvers – If True, clear all existing resolvers before registering. Set to False if you want to extend existing resolvers.

accel_hydra.utils.config.load_config_with_overrides(config_file: str | ~pathlib.Path, overrides: list[str], register_resolver_fn: ~typing.Callable = <function register_omegaconf_resolvers>) DictConfig[source]
accel_hydra.utils.config.load_config_from_cli(return_config: bool = True, register_resolver_fn: ~typing.Callable = <function register_omegaconf_resolvers>)[source]
accel_hydra.utils.config.parse_launch_args()[source]

General Utilities

accel_hydra.utils.general.is_package_available(package_name: str) bool[source]
accel_hydra.utils.general.read_jsonl_to_mapping(jsonl_file: str | Path | list[str | Path], key_col: str, value_col: str) Dict[str, str][source]

Read two columns, indicated by key_col and value_col, from the given jsonl file to return the mapping dict TODO handle duplicate keys

accel_hydra.utils.general.setup_resume_cfg(config: dict, do_print: bool = True)[source]

Learning Rate Scheduler Utilities

accel_hydra.utils.lr_scheduler.get_warmup_steps(dataloader_one_pass_outside_steps: int, warmup_steps: int | None = None, warmup_epochs: float | None = None, epoch_length: int | None = None) int[source]

Derive warmup steps according to step number or epoch number. If warmup_steps is provided, then just return it. Otherwise, derive the warmup steps by epoch length and warmup epoch number.

accel_hydra.utils.lr_scheduler.get_dataloader_one_pass_outside_steps(train_dataloader: DataLoader, num_processes: int = 1)[source]

dataloader length after DDP, close to original_length / gpu_number

accel_hydra.utils.lr_scheduler.get_total_training_steps(train_dataloader: DataLoader, epochs: int, num_processes: int = 1, epoch_length: int | None = None)[source]

Calculate the total number of “visible” training steps.

If epoch_length is provided, it is used as the fixed length for each epoch. Otherwise, the function will determine the epoch length from train_dataloader.

Parameters:
  • train_dataloader – Training dataloader object.

  • epochs – The total number of epochs to run.

  • num_processes – The number of parallel processes used for distributed training.

  • epoch_length – A fixed number of training steps for each epoch. Defaults to None.

Returns:

The total number of training steps (i.e., epochs * epoch_length).

Return type:

int

accel_hydra.utils.lr_scheduler.get_dataloader_one_pass_steps_inside_accelerator(dataloader_one_pass_steps: int, gradient_accumulation_steps: int, num_processes: int)[source]

Calculate the number of “visible” training steps for a single pass over the dataloader inside an accelerator, accounting for gradient accumulation and distributed training.

Parameters:
  • dataloader_one_pass_steps – The number of steps (batches) in one pass over the dataset.

  • gradient_accumulation_steps – The number of steps to accumulate gradients before performing a parameter update.

  • num_processes – The number of parallel processes used for distributed training.

Returns:

The total number of “visible” training steps for one pass over the dataset,

multiplied by the number of processes.

Return type:

int

accel_hydra.utils.lr_scheduler.get_steps_inside_accelerator_from_outside_steps(outside_steps: int, dataloader_one_pass_outside_steps: int, dataloader_one_pass_steps_inside_accelerator: int, gradient_accumulation_steps: int, num_processes: int)[source]

Convert “outside” steps (as observed in wandb logger or similar context) to the corresponding number of “inside” steps (for accelerate lr scheduler).

Specifically, accelerate lr scheduler call step() num_processes times for every gradient_accumulation_steps outside steps.

Parameters:
  • outside_steps – The total number of steps counted outside accelerate context.

  • dataloader_one_pass_outside_steps – The number of steps (batches) to complete one pass of the dataloader outside accelerate.

  • dataloader_one_pass_steps_inside_accelerator – The number of lr_scheduler.step() calls inside accelerate, calculated via get_dataloader_one_pass_steps_inside_accelerator.

  • gradient_accumulation_steps – The number of steps to accumulate gradients.

  • num_processes – The number of parallel processes (GPUs) used in distributed training.

Returns:

The total number of lr_scheduler.step() calls inside accelerate that correspond to the given outside_steps.

Return type:

int

accel_hydra.utils.lr_scheduler.lr_scheduler_param_adapter(config_dict: dict[str, Any], num_training_steps: int, num_warmup_steps: int | None) dict[str, Any][source]

Adapter function to adapt the parameters of the LR scheduler to the number of training steps and warmup steps.

Parameters:
  • config_dict – The configuration dictionary of the LR scheduler.

  • num_training_steps – The number of training steps.

  • num_warmup_steps – The number of warmup steps.

Returns:

The adapted configuration dictionary of the LR scheduler.

Return type:

dict[str, Any]

PyTorch Utilities

accel_hydra.utils.torch.remove_key_prefix_factory(prefix: str = 'module.')[source]
accel_hydra.utils.torch.merge_matched_keys(model_dict: dict[str, Tensor], state_dict: dict[str, Tensor]) dict[str, Tensor][source]
Parameters:
  • model_dict – The state dict of the current model, which is going to load pretrained parameters

  • state_dict – A dictionary of parameters from a pre-trained model.

Returns:

The updated state dict, where parameters with matched keys and shape are updated with values in state_dict.

Return type:

dict[str, torch.Tensor]

accel_hydra.utils.torch.load_pretrained_model(model: ~torch.nn.modules.module.Module, ckpt_or_state_dict: str | ~pathlib.Path | dict[str, ~torch.Tensor], state_dict_process_fn: ~typing.Callable = <function merge_matched_keys>) None[source]
accel_hydra.utils.torch.create_mask_from_length(lengths: Tensor, max_length: int | None = None)[source]
accel_hydra.utils.torch.loss_with_mask(loss: Tensor, mask: Tensor, reduce: bool = True) Tensor[source]

Apply a mask to the loss tensor and optionally reduce it.

Parameters:
  • loss – Tensor of shape (b, t, …) representing the loss values.

  • mask – Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.

  • reduce – If True, return a single scalar value; otherwise, return a tensor of shape (b,).

Returns:

A scalar if reduce is True, otherwise a tensor of shape (b,).

Return type:

torch.Tensor

accel_hydra.utils.torch.trim_or_pad_length(x: Tensor, target_length: int, length_dim: int)[source]

Adjusts the size of the specified dimension of tensor x to match target_length.

Parameters:
  • x – Input tensor.

  • target_length – Desired size of the specified dimension.

  • length_dim – The dimension to modify.

Returns:

The adjusted tensor.

Return type:

torch.Tensor

accel_hydra.utils.torch.concat_non_padding(seq1: Tensor, mask1: BoolTensor, seq2: Tensor, mask2: BoolTensor) tuple[Tensor, BoolTensor, LongTensor][source]
Parameters:
  • seq1 – Tensor (B, L1, E) First sequence.

  • mask1 – BoolTensor (B, L1) True for valid tokens in seq1, False for padding.

  • seq2 – Tensor (B, L2, E) Second sequence.

  • mask2 – BoolTensor (B, L2) True for valid tokens in seq2, False for padding.

Returns:

Tensor (B, L1+L2, E)

Both sequences concatenated; valid tokens are left-aligned, padding on the right is 0.

concat_mask: BoolTensor (B, L1+L2)

Mask for the concatenated sequence.

permLongTensor (B, L1+L2)

Permutation that maps original indices → new indices. Needed for restoring the original sequences.

Return type:

concat_seq

accel_hydra.utils.torch.restore_from_concat(concat_seq: Tensor, mask1: BoolTensor, mask2: BoolTensor, perm: LongTensor) tuple[Tensor, Tensor][source]

Restore (seq1, seq2) from the concatenated sequence produced by concat_non_padding, using the returned permutation perm. Fully vectorised — no Python loops.

accel_hydra.utils.torch.contains_nan(data)[source]

check if data contains NaN

Accelerate Extensions

class accel_hydra.utils.accelerate.AcceleratorSaveTrainableParams(device_placement: bool = True, split_batches: bool = <object object>, mixed_precision: ~accelerate.utils.dataclasses.PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, dataloader_config: ~accelerate.utils.dataclasses.DataLoaderConfiguration | None = None, deepspeed_plugin: ~accelerate.utils.dataclasses.DeepSpeedPlugin | dict[str, ~accelerate.utils.dataclasses.DeepSpeedPlugin] | None = None, fsdp_plugin: ~accelerate.utils.dataclasses.FullyShardedDataParallelPlugin | None = None, torch_tp_plugin: ~accelerate.utils.dataclasses.TorchTensorParallelPlugin | None = None, megatron_lm_plugin: ~accelerate.utils.dataclasses.MegatronLMPlugin | None = None, rng_types: list[str | ~accelerate.utils.dataclasses.RNGType] | None = None, log_with: str | ~accelerate.utils.dataclasses.LoggerType | ~accelerate.tracking.GeneralTracker | list[str | ~accelerate.utils.dataclasses.LoggerType | ~accelerate.tracking.GeneralTracker] | None = None, project_dir: str | ~os.PathLike | None = None, project_config: ~accelerate.utils.dataclasses.ProjectConfiguration | None = None, gradient_accumulation_plugin: ~accelerate.utils.dataclasses.GradientAccumulationPlugin | None = None, step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[~accelerate.utils.dataclasses.KwargsHandler] | None = None, dynamo_backend: ~accelerate.utils.dataclasses.DynamoBackend | str | None = None, dynamo_plugin: ~accelerate.utils.dataclasses.TorchDynamoPlugin | None = None, deepspeed_plugins: ~accelerate.utils.dataclasses.DeepSpeedPlugin | dict[str, ~accelerate.utils.dataclasses.DeepSpeedPlugin] | None = None, parallelism_config: ~accelerate.parallelism_config.ParallelismConfig | None = None)[source]

Bases: Accelerator

Extended Accelerator that only saves trainable parameters and buffers.

This class extends the base accelerate.Accelerator class to support selective state dict saving. When a model has the param_names_to_save attribute (typically accel_hydra.models.common.SaveTrainableParamsBase), only the parameters and buffers specified in that attribute will be saved.

This is particularly useful for models with frozen pre-trained components, where you only want to save trainable parameters to save space.

Parameters:

Example

from accel_hydra.utils.accelerate import AcceleratorSaveTrainableParams
from accel_hydra.models.common import SaveTrainableParamsBase
import torch.nn as nn

class MyModel(SaveTrainableParamsBase):
    def __init__(self):
        super().__init__()
        self.frozen_layer = nn.Linear(10, 10)  # Frozen pre-trained layer
        self.trainable_layer = nn.Linear(10, 5)  # Trainable layer
        self.frozen_layer.requires_grad_(False)

model = MyModel()
accelerator = AcceleratorSaveTrainableParams()
model = accelerator.prepare(model)

# When saving, only trainable parameters and buffers are saved
state_dict = accelerator.get_state_dict(model)
get_state_dict(model, unwrap=True)[source]

Get the state dict of the model, filtering to only trainable parameters.

Parameters:
  • model – The model to get the state dict from.

  • unwrap – Whether to unwrap the model before getting the state dict. Defaults to True.

Returns:

The trainable state dict of the model. The filtering works when

the model has the param_names_to_save attribute.

Return type:

dict