API Reference
This document describes the helper functions and classes provided by AccelHydra.
Models
- class accel_hydra.models.common.LoadPretrainedMixin[source]
Bases:
objectBase class for loading pretrained checkpoints with custom state dict processing function.
- process_state_dict(model_dict: dict[str, Tensor], state_dict: dict[str, Tensor])[source]
Custom processing functions of each model that transforms state_dict loaded from checkpoints to the state that can be used in load_state_dict. Use merge_mathced_keys to update parameters with matched names and shapes by default.
- Parameters:
model_dict – The state dict of the current model, which is going to load pretrained parameters
state_dict – A dictionary of parameters from a pre-trained model.
- Returns:
The updated state dict, where parameters with matched keys and shape are updated with values in state_dict.
- Return type:
Trainer Components
- class accel_hydra.trainer.Trainer(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True)[source]
Bases:
CheckpointMixinBase trainer class providing training workflow management.
- Usage:
This is an abstract base class. Subclasses must implement training_step() and validation_step() methods to define the specific training and validation logic.
- config_dict
Configuration dictionary for storing training configuration information.
- Type:
dict | None
- project_dir
Project root directory path for saving training-related files.
- Type:
- checkpoint_dir
Checkpoint save directory. If None, uses project_dir/checkpoints.
- Type:
- logging_config
Logging configuration object for experiment logging (wandb/swanlab/tensorboard).
- Type:
accel_hydra.trainer.base.LoggingConfig | None
- train_dataloader
Training data loader.
- val_dataloader
Validation data loader. Can be None to skip validation.
- Type:
- model
PyTorch model to be trained.
- optimizer
Optimizer for training.
- lr_scheduler
Learning rate scheduler.
- loss_fn
Loss function.
- epochs
Total number of training epochs.
- Type:
- epoch_length
Number of steps per epoch. If None, uses the length of train_dataloader.
- Type:
int | None
- lr_scheduler_interval
Learning rate scheduler update interval. STEP means update every step, EPOCH means update every epoch.
- Type:
accel_hydra.trainer.base.LRSchedulerInterval
- gradient_accumulation_steps
Number of gradient accumulation steps to simulate larger batch size.
- Type:
- max_grad_norm
Maximum gradient norm for gradient clipping. If None, no gradient clipping is performed.
- Type:
float | None
- resume_from_checkpoint
Path to checkpoint for resuming training. None by default, meaning training from scratch.
- Type:
str | pathlib.Path | None
- save_every_n_steps
Save checkpoint every N steps. If None, no step-based saving.
- Type:
int | None
- permanent_save_every_n_steps
Permanently save checkpoint to project_dir every N steps. If None, no permanent saving. Checkpoints saved by save_every_n_steps and save_every_n_epochs will be automatically deleted based on cleaning strategies but these checkpoints will not be deleted.
- Type:
int | None
- save_every_n_epochs
Save checkpoint every N epochs. Default is 1 (save every epoch).
- Type:
int | None
- save_last_k
Keep the last K checkpoints and delete older ones. Default is 1 (keep only the latest checkpoint).
- Type:
int | None
- metric_monitor
MetricMonitor instance for tracking validation metrics and saving best model.
- Type:
accel_hydra.trainer.base.MetricMonitor | None
- early_stop
Early stopping patience. Stop training if validation metric doesn’t improve for N consecutive epochs.
- Type:
int | None
- even_batches
Whether to use even batches for handling inconsistent data amounts across processes. Must set to False when batch_sampler does not have batch_size.
- Type:
- logging_config: LoggingConfig | None = None
- train_dataloader: DataLoader
- val_dataloader: DataLoader | None
- model: Module
- optimizer: Optimizer
- lr_scheduler: LRScheduler
- loss_fn: Module
- epochs: int
- lr_scheduler_interval: LRSchedulerInterval = 'step'
- gradient_accumulation_steps: int = 1
- metric_monitor: MetricMonitor | None = None
- even_batches: bool = True
- abstract training_step(batch: Any, batch_idx: int) Tensor[source]
Performs a single training step, like training_step() in Pytorch-Lightning.
This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation.
- Parameters:
batch – A batch of data from the training DataLoader.
batch_idx – The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position.
- Returns:
- The computed loss tensor. It will be used for loss.backward().
The tensor should be a 0-dimensional. The loss will be automatically logged as “train/loss” by the Trainer.
- Return type:
Example
def training_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) loss = self.loss_fn(preds, labels) # Optional: Log additional metrics lr = self.optimizer.param_groups[0]["lr"] self.accelerator.log({"train/lr": lr}, step=self.step) return loss
Note
You should NOT call loss.backward() manually - the Trainer handles this automatically.
- abstract validation_step(batch: Any, batch_idx: int) None[source]
Performs a single validation step, like validation_step() in Pytorch-Lightning.
This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later.
- Parameters:
batch – A batch of data from the validation DataLoader.
batch_idx – The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position.
- Returns:
- This method should not return anything. Store validation results in instance
variables for later use in get_val_metrics().
- Return type:
None
Example
def validation_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) predictions = preds.argmax(dim=-1) # Gather predictions from all processes (important for distributed training) output = {"predictions": predictions, "labels": labels} output = self.accelerator.gather_for_metrics(output) # Accumulate metrics accurate_preds = (output["predictions"] == output["labels"]) self.validation_stats["accurate"] += accurate_preds.long().sum() self.validation_stats["num_elems"] += accurate_preds.shape[0]
Note
Use self.accelerator.gather_for_metrics() to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks.
- get_context() contextmanager[source]
FIXME: why does it not work?
- property checkpoint_objects: list[CheckpointMixin]
Returns a list of additional objects to be included in checkpoints.
This property allows subclasses to specify additional objects (beyond the Trainer itself) that should be saved and restored during checkpointing. All objects in the returned list must implement the CheckpointMixin interface (i.e., have state_dict() and load_state_dict() methods). The customized checkpointing is achieved by registering these objects with the Accelerate framework during setup_accelerator().
- Returns:
- A list of objects to include in checkpoints. Default is an
empty list. Subclasses can override this property to return custom objects.
- Return type:
list[CheckpointMixin]
Example
import torch from accel_hydra.trainer import CheckpointMixin class VersionTracker(CheckpointMixin): def __init__(self): self.version = torch.__version__ def state_dict(self) -> dict: return {"version": self.version} def load_state_dict(self, state_dict: dict) -> None: self.version = state_dict["version"] class MyTrainer(Trainer): @property def checkpoint_objects(self) -> list[CheckpointMixin]: return [VersionTracker()]
- save_checkpoint(save_dir: Path | str, clean_old_checkpoints: bool = True) None[source]
Note: since wait_for_everyone is called, user must be responsible for making sure all processes call or not call this function at the same time!!!
- __init__(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True) None
- class accel_hydra.trainer.MetricMonitor(*, metric_name: str = 'loss', mode: Literal['min', 'max'] = 'min')[source]
Bases:
CheckpointMixin- metric_name: str = 'loss'
- mode: Literal['min', 'max'] = 'min'
- compare(x: float, best_x: float) bool[source]
Compares the current value with the best value based on mode.
Training Launcher
- class accel_hydra.train_launcher.TrainLauncher[source]
Bases:
objectBase class for training launchers that handle configuration and training setup.
This class provides a structured way to launch training with Hydra configuration. Subclasses can override specific methods to customize the training process.
- static get_register_resolver_fn() Callable[source]
Get the function to register custom OmegaConf resolvers.
Subclasses can override this staticmethod to return a custom resolver registration function. The returned function should be callable without arguments.
- Returns:
A function that registers custom OmegaConf resolvers
- Return type:
Callable
- get_steps_for_lr_scheduler(train_dataloader: DataLoader)[source]
Calculate steps for LR scheduler.
This method handles the complexity of step counting in distributed training with gradient accumulation.
- Parameters:
train_dataloader – The training dataloader
- Returns:
(num_training_updates, num_warmup_updates) where num_warmup_updates can be None
- Return type:
- run()[source]
Main entry point that orchestrates the training setup and launch.
This method follows the standard training setup flow: 1. Load configuration 2. Setup resume if needed 3. Create model, dataloaders, optimizer, LR scheduler, loss function 4. Create trainer and start training
Subclasses can override this method to customize the entire flow, or override individual methods to customize specific steps.
Utilities
Data Utilities
- accel_hydra.utils.data.init_dataloader_from_config(config: dict)[source]
A helper function to initialize a dataloader from a config.
- Parameters:
config – A dictionary or DictConfig containing the dataloader configuration.
- Returns:
instantiated dataloader object.
Example
config = ''' train_dataloader: _target_: torch.utils.data.DataLoader dataset: _target_: data.train_dataset data_root: /path/to/data # sampler: # _target_: torch.utils.data.Sampler # ... # batch_sampler: # _target_: torch.utils.data.BatchSampler # ... ''' config = OmegaConf.create(config) train_dataloader = init_dataloader_from_config(config["train_dataloader"])
Configuration Utilities
- accel_hydra.utils.config.register_omegaconf_resolvers(clear_resolvers: bool = True) None[source]
Register custom resolver for hydra configs, which can be used in YAML files for dynamically setting values.
- Parameters:
clear_resolvers – If True, clear all existing resolvers before registering. Set to False if you want to extend existing resolvers.
- accel_hydra.utils.config.load_config_with_overrides(config_file: str | ~pathlib.Path, overrides: list[str], register_resolver_fn: ~typing.Callable = <function register_omegaconf_resolvers>) DictConfig[source]
General Utilities
Learning Rate Scheduler Utilities
- accel_hydra.utils.lr_scheduler.get_warmup_steps(dataloader_one_pass_outside_steps: int, warmup_steps: int | None = None, warmup_epochs: float | None = None, epoch_length: int | None = None) int[source]
Derive warmup steps according to step number or epoch number. If warmup_steps is provided, then just return it. Otherwise, derive the warmup steps by epoch length and warmup epoch number.
- accel_hydra.utils.lr_scheduler.get_dataloader_one_pass_outside_steps(train_dataloader: DataLoader, num_processes: int = 1)[source]
dataloader length after DDP, close to original_length / gpu_number
- accel_hydra.utils.lr_scheduler.get_total_training_steps(train_dataloader: DataLoader, epochs: int, num_processes: int = 1, epoch_length: int | None = None)[source]
Calculate the total number of “visible” training steps.
If epoch_length is provided, it is used as the fixed length for each epoch. Otherwise, the function will determine the epoch length from train_dataloader.
- Parameters:
train_dataloader – Training dataloader object.
epochs – The total number of epochs to run.
num_processes – The number of parallel processes used for distributed training.
epoch_length – A fixed number of training steps for each epoch. Defaults to None.
- Returns:
The total number of training steps (i.e., epochs * epoch_length).
- Return type:
- accel_hydra.utils.lr_scheduler.get_dataloader_one_pass_steps_inside_accelerator(dataloader_one_pass_steps: int, gradient_accumulation_steps: int, num_processes: int)[source]
Calculate the number of “visible” training steps for a single pass over the dataloader inside an accelerator, accounting for gradient accumulation and distributed training.
- Parameters:
dataloader_one_pass_steps – The number of steps (batches) in one pass over the dataset.
gradient_accumulation_steps – The number of steps to accumulate gradients before performing a parameter update.
num_processes – The number of parallel processes used for distributed training.
- Returns:
- The total number of “visible” training steps for one pass over the dataset,
multiplied by the number of processes.
- Return type:
- accel_hydra.utils.lr_scheduler.get_steps_inside_accelerator_from_outside_steps(outside_steps: int, dataloader_one_pass_outside_steps: int, dataloader_one_pass_steps_inside_accelerator: int, gradient_accumulation_steps: int, num_processes: int)[source]
Convert “outside” steps (as observed in wandb logger or similar context) to the corresponding number of “inside” steps (for accelerate lr scheduler).
Specifically, accelerate lr scheduler call step() num_processes times for every gradient_accumulation_steps outside steps.
- Parameters:
outside_steps – The total number of steps counted outside accelerate context.
dataloader_one_pass_outside_steps – The number of steps (batches) to complete one pass of the dataloader outside accelerate.
dataloader_one_pass_steps_inside_accelerator – The number of lr_scheduler.step() calls inside accelerate, calculated via get_dataloader_one_pass_steps_inside_accelerator.
gradient_accumulation_steps – The number of steps to accumulate gradients.
num_processes – The number of parallel processes (GPUs) used in distributed training.
- Returns:
The total number of lr_scheduler.step() calls inside accelerate that correspond to the given outside_steps.
- Return type:
PyTorch Utilities
- accel_hydra.utils.torch.merge_matched_keys(model_dict: dict[str, Tensor], state_dict: dict[str, Tensor]) dict[str, Tensor][source]
- Parameters:
model_dict – The state dict of the current model, which is going to load pretrained parameters
state_dict – A dictionary of parameters from a pre-trained model.
- Returns:
The updated state dict, where parameters with matched keys and shape are updated with values in state_dict.
- Return type:
- accel_hydra.utils.torch.load_pretrained_model(model: ~torch.nn.modules.module.Module, ckpt_or_state_dict: str | ~pathlib.Path | dict[str, ~torch.Tensor], state_dict_process_fn: ~typing.Callable = <function merge_matched_keys>) None[source]
- accel_hydra.utils.torch.create_mask_from_length(lengths: Tensor, max_length: int | None = None)[source]
- accel_hydra.utils.torch.loss_with_mask(loss: Tensor, mask: Tensor, reduce: bool = True) Tensor[source]
Apply a mask to the loss tensor and optionally reduce it.
- Parameters:
loss – Tensor of shape (b, t, …) representing the loss values.
mask – Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.
reduce – If True, return a single scalar value; otherwise, return a tensor of shape (b,).
- Returns:
A scalar if reduce is True, otherwise a tensor of shape (b,).
- Return type:
- accel_hydra.utils.torch.trim_or_pad_length(x: Tensor, target_length: int, length_dim: int)[source]
Adjusts the size of the specified dimension of tensor x to match target_length.
- Parameters:
x – Input tensor.
target_length – Desired size of the specified dimension.
length_dim – The dimension to modify.
- Returns:
The adjusted tensor.
- Return type:
- accel_hydra.utils.torch.concat_non_padding(seq1: Tensor, mask1: BoolTensor, seq2: Tensor, mask2: BoolTensor) tuple[Tensor, BoolTensor, LongTensor][source]
- Parameters:
seq1 – Tensor (B, L1, E) First sequence.
mask1 – BoolTensor (B, L1) True for valid tokens in seq1, False for padding.
seq2 – Tensor (B, L2, E) Second sequence.
mask2 – BoolTensor (B, L2) True for valid tokens in seq2, False for padding.
- Returns:
- Tensor (B, L1+L2, E)
Both sequences concatenated; valid tokens are left-aligned, padding on the right is 0.
- concat_mask: BoolTensor (B, L1+L2)
Mask for the concatenated sequence.
- permLongTensor (B, L1+L2)
Permutation that maps original indices → new indices. Needed for restoring the original sequences.
- Return type:
concat_seq
- accel_hydra.utils.torch.restore_from_concat(concat_seq: Tensor, mask1: BoolTensor, mask2: BoolTensor, perm: LongTensor) tuple[Tensor, Tensor][source]
Restore (seq1, seq2) from the concatenated sequence produced by concat_non_padding, using the returned permutation perm. Fully vectorised — no Python loops.
Accelerate Extensions
- class accel_hydra.utils.accelerate.AcceleratorSaveTrainableParams(device_placement: bool = True, split_batches: bool = <object object>, mixed_precision: ~accelerate.utils.dataclasses.PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, dataloader_config: ~accelerate.utils.dataclasses.DataLoaderConfiguration | None = None, deepspeed_plugin: ~accelerate.utils.dataclasses.DeepSpeedPlugin | dict[str, ~accelerate.utils.dataclasses.DeepSpeedPlugin] | None = None, fsdp_plugin: ~accelerate.utils.dataclasses.FullyShardedDataParallelPlugin | None = None, torch_tp_plugin: ~accelerate.utils.dataclasses.TorchTensorParallelPlugin | None = None, megatron_lm_plugin: ~accelerate.utils.dataclasses.MegatronLMPlugin | None = None, rng_types: list[str | ~accelerate.utils.dataclasses.RNGType] | None = None, log_with: str | ~accelerate.utils.dataclasses.LoggerType | ~accelerate.tracking.GeneralTracker | list[str | ~accelerate.utils.dataclasses.LoggerType | ~accelerate.tracking.GeneralTracker] | None = None, project_dir: str | ~os.PathLike | None = None, project_config: ~accelerate.utils.dataclasses.ProjectConfiguration | None = None, gradient_accumulation_plugin: ~accelerate.utils.dataclasses.GradientAccumulationPlugin | None = None, step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[~accelerate.utils.dataclasses.KwargsHandler] | None = None, dynamo_backend: ~accelerate.utils.dataclasses.DynamoBackend | str | None = None, dynamo_plugin: ~accelerate.utils.dataclasses.TorchDynamoPlugin | None = None, deepspeed_plugins: ~accelerate.utils.dataclasses.DeepSpeedPlugin | dict[str, ~accelerate.utils.dataclasses.DeepSpeedPlugin] | None = None, parallelism_config: ~accelerate.parallelism_config.ParallelismConfig | None = None)[source]
Bases:
AcceleratorExtended Accelerator that only saves trainable parameters and buffers.
This class extends the base
accelerate.Acceleratorclass to support selective state dict saving. When a model has the param_names_to_save attribute (typicallyaccel_hydra.models.common.SaveTrainableParamsBase), only the parameters and buffers specified in that attribute will be saved.This is particularly useful for models with frozen pre-trained components, where you only want to save trainable parameters to save space.
- Parameters:
*args – Positional arguments passed to the base
accelerate.Acceleratorclass.**kwargs – Keyword arguments passed to the base
accelerate.Acceleratorclass.
Example
from accel_hydra.utils.accelerate import AcceleratorSaveTrainableParams from accel_hydra.models.common import SaveTrainableParamsBase import torch.nn as nn class MyModel(SaveTrainableParamsBase): def __init__(self): super().__init__() self.frozen_layer = nn.Linear(10, 10) # Frozen pre-trained layer self.trainable_layer = nn.Linear(10, 5) # Trainable layer self.frozen_layer.requires_grad_(False) model = MyModel() accelerator = AcceleratorSaveTrainableParams() model = accelerator.prepare(model) # When saving, only trainable parameters and buffers are saved state_dict = accelerator.get_state_dict(model)
- get_state_dict(model, unwrap=True)[source]
Get the state dict of the model, filtering to only trainable parameters.
- Parameters:
model – The model to get the state dict from.
unwrap – Whether to unwrap the model before getting the state dict. Defaults to True.
- Returns:
- The trainable state dict of the model. The filtering works when
the model has the param_names_to_save attribute.
- Return type: