Trainer

The accel_hydra.Trainer class is the core component of AccelHydra. It provides a structured training loop with customizable hooks, automatic checkpoint management, logging, and distributed training support.

Overview

The accel_hydra.Trainer class wraps the standard PyTorch training workflow and integrates with Accelerate for distributed training and Hydra for configuration management. To use it, you need to inherit from accel_hydra.Trainer and implement the required abstract methods.

Training Loop Flow

The training process follows this high-level flow:

+---------------------------------------------------------+
|                      train(seed)                        |
|  - Sets random seed                                     |
|  - Calls setup_accelerator()                            |
|  - on_train_start()                                     |
+---------------------------+-----------------------------+
                            |
                            v
+---------------------------------------------------------+
|              Inside each epoch:                         |
|                                                         |
|  1. on_train_epoch_start()                              |
|  2. Training Loop:                                      |
|     - For each batch:                                   |
|       * training_step(batch, batch_idx)                 |
|       * Backward pass                                   |
|       * Gradient clipping                               |
|       * Optimizer step                                  |
|       * LR scheduler step (if step-based)               |
|       * Checkpoint saving (if triggered)                |
|  3. Validation Loop (if val_dataloader provided):       |
|     - on_validation_start()                             |
|     - For each batch:                                   |
|       * validation_step(batch, batch_idx)               |
|     - on_validation_end()                               |
|  4. LR scheduler step (if epoch-based)                  |
|  5. Checkpoint saving (if epoch-based trigger)          |
|  6. Best model saving (if metric_monitor improved)      |
|  7. Early stopping check                                |
|  8. on_train_epoch_end()                                |
+---------------------------+-----------------------------+
                            |
                            v
+---------------------------------------------------------+
|  - on_train_end()                                       |
+---------------------------------------------------------+

Core Methods

Abstract Methods (Must Implement)

abstract Trainer.training_step(batch: Any, batch_idx: int) Tensor[source]

Performs a single training step, like training_step() in Pytorch-Lightning.

This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation.

Parameters:
  • batch – A batch of data from the training DataLoader.

  • batch_idx – The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position.

Returns:

The computed loss tensor. It will be used for loss.backward().

The tensor should be a 0-dimensional. The loss will be automatically logged as “train/loss” by the Trainer.

Return type:

torch.Tensor

Example

def training_step(self, batch, batch_idx):
    features, labels = batch
    preds = self.model(features)
    loss = self.loss_fn(preds, labels)

    # Optional: Log additional metrics
    lr = self.optimizer.param_groups[0]["lr"]
    self.accelerator.log({"train/lr": lr}, step=self.step)

    return loss

Note

  • You should NOT call loss.backward() manually - the Trainer handles this automatically.

abstract Trainer.validation_step(batch: Any, batch_idx: int) None[source]

Performs a single validation step, like validation_step() in Pytorch-Lightning.

This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later.

Parameters:
  • batch – A batch of data from the validation DataLoader.

  • batch_idx – The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position.

Returns:

This method should not return anything. Store validation results in instance

variables for later use in get_val_metrics().

Return type:

None

Example

def validation_step(self, batch, batch_idx):
    features, labels = batch
    preds = self.model(features)
    predictions = preds.argmax(dim=-1)

    # Gather predictions from all processes (important for distributed training)
    output = {"predictions": predictions, "labels": labels}
    output = self.accelerator.gather_for_metrics(output)

    # Accumulate metrics
    accurate_preds = (output["predictions"] == output["labels"])
    self.validation_stats["accurate"] += accurate_preds.long().sum()
    self.validation_stats["num_elems"] += accurate_preds.shape[0]

Note

  • Use self.accelerator.gather_for_metrics() to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks.

Hooks (Optional Override)

Hooks are optional methods you can override to customize behavior at specific points in the training process. All hooks have empty default implementations, so you only need to override the ones you need.

Trainer.on_train_start() None[source]
Trainer.on_train_end() None[source]
Trainer.on_train_epoch_start() None[source]
Trainer.on_train_epoch_end() None[source]
Trainer.on_validation_start() None[source]
Trainer.on_validation_end() None[source]
Trainer.get_val_metrics() dict[str, Any][source]

Example: Complete Trainer Implementation

For a complete example, see Step 3: Define Your Trainer in the Getting Started guide.

Full Trainer API

class accel_hydra.trainer.Trainer(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True)[source]

Bases: CheckpointMixin

Base trainer class providing training workflow management.

Usage:

This is an abstract base class. Subclasses must implement training_step() and validation_step() methods to define the specific training and validation logic.

config_dict

Configuration dictionary for storing training configuration information.

Type:

dict | None

project_dir

Project root directory path for saving training-related files.

Type:

str | pathlib.Path

checkpoint_dir

Checkpoint save directory. If None, uses project_dir/checkpoints.

Type:

str | pathlib.Path

logging_config

Logging configuration object for experiment logging (wandb/swanlab/tensorboard).

Type:

accel_hydra.trainer.base.LoggingConfig | None

train_dataloader

Training data loader.

Type:

torch.utils.data.dataloader.DataLoader

val_dataloader

Validation data loader. Can be None to skip validation.

Type:

torch.utils.data.dataloader.DataLoader | None

model

PyTorch model to be trained.

Type:

torch.nn.modules.module.Module

optimizer

Optimizer for training.

Type:

torch.optim.optimizer.Optimizer

lr_scheduler

Learning rate scheduler.

Type:

torch.optim.lr_scheduler.LRScheduler

loss_fn

Loss function.

Type:

torch.nn.modules.module.Module

epochs

Total number of training epochs.

Type:

int

epoch_length

Number of steps per epoch. If None, uses the length of train_dataloader.

Type:

int | None

lr_scheduler_interval

Learning rate scheduler update interval. STEP means update every step, EPOCH means update every epoch.

Type:

accel_hydra.trainer.base.LRSchedulerInterval

gradient_accumulation_steps

Number of gradient accumulation steps to simulate larger batch size.

Type:

int

max_grad_norm

Maximum gradient norm for gradient clipping. If None, no gradient clipping is performed.

Type:

float | None

resume_from_checkpoint

Path to checkpoint for resuming training. None by default, meaning training from scratch.

Type:

str | pathlib.Path | None

save_every_n_steps

Save checkpoint every N steps. If None, no step-based saving.

Type:

int | None

permanent_save_every_n_steps

Permanently save checkpoint to project_dir every N steps. If None, no permanent saving. Checkpoints saved by save_every_n_steps and save_every_n_epochs will be automatically deleted based on cleaning strategies but these checkpoints will not be deleted.

Type:

int | None

save_every_n_epochs

Save checkpoint every N epochs. Default is 1 (save every epoch).

Type:

int | None

save_last_k

Keep the last K checkpoints and delete older ones. Default is 1 (keep only the latest checkpoint).

Type:

int | None

metric_monitor

MetricMonitor instance for tracking validation metrics and saving best model.

Type:

accel_hydra.trainer.base.MetricMonitor | None

early_stop

Early stopping patience. Stop training if validation metric doesn’t improve for N consecutive epochs.

Type:

int | None

even_batches

Whether to use even batches for handling inconsistent data amounts across processes. Must set to False when batch_sampler does not have batch_size.

Type:

bool

config_dict: dict | None = None
project_dir: str | Path
checkpoint_dir: str | Path = None
logging_config: LoggingConfig | None = None
train_dataloader: DataLoader
val_dataloader: DataLoader | None
model: Module
optimizer: Optimizer
lr_scheduler: LRScheduler
loss_fn: Module
epochs: int
epoch_length: int | None = None
lr_scheduler_interval: LRSchedulerInterval = 'step'
gradient_accumulation_steps: int = 1
max_grad_norm: float | None = 2.0
resume_from_checkpoint: str | Path | None = None
log_every_n_steps: int | None = 50
save_every_n_steps: int | None = None
permanent_save_every_n_steps: int | None = None
save_every_n_epochs: int | None = 1
save_last_k: int | None = 1
metric_monitor: MetricMonitor | None = None
early_stop: int | None = None
even_batches: bool = True
wrap_and_broadcast_value(value: Any) Tensor[source]
setup_accelerator() None[source]
abstract training_step(batch: Any, batch_idx: int) Tensor[source]

Performs a single training step, like training_step() in Pytorch-Lightning.

This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation.

Parameters:
  • batch – A batch of data from the training DataLoader.

  • batch_idx – The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position.

Returns:

The computed loss tensor. It will be used for loss.backward().

The tensor should be a 0-dimensional. The loss will be automatically logged as “train/loss” by the Trainer.

Return type:

torch.Tensor

Example

def training_step(self, batch, batch_idx):
    features, labels = batch
    preds = self.model(features)
    loss = self.loss_fn(preds, labels)

    # Optional: Log additional metrics
    lr = self.optimizer.param_groups[0]["lr"]
    self.accelerator.log({"train/lr": lr}, step=self.step)

    return loss

Note

  • You should NOT call loss.backward() manually - the Trainer handles this automatically.

abstract validation_step(batch: Any, batch_idx: int) None[source]

Performs a single validation step, like validation_step() in Pytorch-Lightning.

This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later.

Parameters:
  • batch – A batch of data from the validation DataLoader.

  • batch_idx – The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position.

Returns:

This method should not return anything. Store validation results in instance

variables for later use in get_val_metrics().

Return type:

None

Example

def validation_step(self, batch, batch_idx):
    features, labels = batch
    preds = self.model(features)
    predictions = preds.argmax(dim=-1)

    # Gather predictions from all processes (important for distributed training)
    output = {"predictions": predictions, "labels": labels}
    output = self.accelerator.gather_for_metrics(output)

    # Accumulate metrics
    accurate_preds = (output["predictions"] == output["labels"])
    self.validation_stats["accurate"] += accurate_preds.long().sum()
    self.validation_stats["num_elems"] += accurate_preds.shape[0]

Note

  • Use self.accelerator.gather_for_metrics() to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks.

get_context() contextmanager[source]

FIXME: why does it not work?

gather_min_length(length: int) int[source]
val_loop() None[source]
on_validation_start() None[source]
on_validation_end() None[source]
get_val_metrics() dict[str, Any][source]
on_train_epoch_start() None[source]
on_train_epoch_end() None[source]
property checkpoint_objects: list[CheckpointMixin]

Returns a list of additional objects to be included in checkpoints.

This property allows subclasses to specify additional objects (beyond the Trainer itself) that should be saved and restored during checkpointing. All objects in the returned list must implement the CheckpointMixin interface (i.e., have state_dict() and load_state_dict() methods). The customized checkpointing is achieved by registering these objects with the Accelerate framework during setup_accelerator().

Returns:

A list of objects to include in checkpoints. Default is an

empty list. Subclasses can override this property to return custom objects.

Return type:

list[CheckpointMixin]

Example

import torch
from accel_hydra.trainer import CheckpointMixin

class VersionTracker(CheckpointMixin):
    def __init__(self):
        self.version = torch.__version__

    def state_dict(self) -> dict:
        return {"version": self.version}

    def load_state_dict(self, state_dict: dict) -> None:
        self.version = state_dict["version"]

class MyTrainer(Trainer):
    @property
    def checkpoint_objects(self) -> list[CheckpointMixin]:
        return [VersionTracker()]
state_dict() dict[source]
load_state_dict(state_dict: dict) None[source]
clean_checkpoints_to_k(checkpoints_dir: Path | str, k: int) None[source]
save_checkpoint(save_dir: Path | str, clean_old_checkpoints: bool = True) None[source]

Note: since wait_for_everyone is called, user must be responsible for making sure all processes call or not call this function at the same time!!!

train_loop() None[source]
on_train_start() None[source]
on_train_end() None[source]
train(seed: int) None[source]
__init__(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True) None