Trainer
The accel_hydra.Trainer class is the core component of AccelHydra. It provides a structured training loop with customizable hooks, automatic checkpoint management, logging, and distributed training support.
Overview
The accel_hydra.Trainer class wraps the standard PyTorch training workflow and integrates with Accelerate for distributed training and Hydra for configuration management. To use it, you need to inherit from accel_hydra.Trainer and implement the required abstract methods.
Training Loop Flow
The training process follows this high-level flow:
+---------------------------------------------------------+
| train(seed) |
| - Sets random seed |
| - Calls setup_accelerator() |
| - on_train_start() |
+---------------------------+-----------------------------+
|
v
+---------------------------------------------------------+
| Inside each epoch: |
| |
| 1. on_train_epoch_start() |
| 2. Training Loop: |
| - For each batch: |
| * training_step(batch, batch_idx) |
| * Backward pass |
| * Gradient clipping |
| * Optimizer step |
| * LR scheduler step (if step-based) |
| * Checkpoint saving (if triggered) |
| 3. Validation Loop (if val_dataloader provided): |
| - on_validation_start() |
| - For each batch: |
| * validation_step(batch, batch_idx) |
| - on_validation_end() |
| 4. LR scheduler step (if epoch-based) |
| 5. Checkpoint saving (if epoch-based trigger) |
| 6. Best model saving (if metric_monitor improved) |
| 7. Early stopping check |
| 8. on_train_epoch_end() |
+---------------------------+-----------------------------+
|
v
+---------------------------------------------------------+
| - on_train_end() |
+---------------------------------------------------------+
Core Methods
Abstract Methods (Must Implement)
- abstract Trainer.training_step(batch: Any, batch_idx: int) Tensor[source]
Performs a single training step, like training_step() in Pytorch-Lightning.
This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation.
- Parameters:
batch – A batch of data from the training DataLoader.
batch_idx – The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position.
- Returns:
- The computed loss tensor. It will be used for loss.backward().
The tensor should be a 0-dimensional. The loss will be automatically logged as “train/loss” by the Trainer.
- Return type:
Example
def training_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) loss = self.loss_fn(preds, labels) # Optional: Log additional metrics lr = self.optimizer.param_groups[0]["lr"] self.accelerator.log({"train/lr": lr}, step=self.step) return loss
Note
You should NOT call loss.backward() manually - the Trainer handles this automatically.
- abstract Trainer.validation_step(batch: Any, batch_idx: int) None[source]
Performs a single validation step, like validation_step() in Pytorch-Lightning.
This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later.
- Parameters:
batch – A batch of data from the validation DataLoader.
batch_idx – The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position.
- Returns:
- This method should not return anything. Store validation results in instance
variables for later use in get_val_metrics().
- Return type:
None
Example
def validation_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) predictions = preds.argmax(dim=-1) # Gather predictions from all processes (important for distributed training) output = {"predictions": predictions, "labels": labels} output = self.accelerator.gather_for_metrics(output) # Accumulate metrics accurate_preds = (output["predictions"] == output["labels"]) self.validation_stats["accurate"] += accurate_preds.long().sum() self.validation_stats["num_elems"] += accurate_preds.shape[0]
Note
Use self.accelerator.gather_for_metrics() to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks.
Hooks (Optional Override)
Hooks are optional methods you can override to customize behavior at specific points in the training process. All hooks have empty default implementations, so you only need to override the ones you need.
Example: Complete Trainer Implementation
For a complete example, see Step 3: Define Your Trainer in the Getting Started guide.
Full Trainer API
- class accel_hydra.trainer.Trainer(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True)[source]
Bases:
CheckpointMixinBase trainer class providing training workflow management.
- Usage:
This is an abstract base class. Subclasses must implement training_step() and validation_step() methods to define the specific training and validation logic.
- config_dict
Configuration dictionary for storing training configuration information.
- Type:
dict | None
- project_dir
Project root directory path for saving training-related files.
- Type:
- checkpoint_dir
Checkpoint save directory. If None, uses project_dir/checkpoints.
- Type:
- logging_config
Logging configuration object for experiment logging (wandb/swanlab/tensorboard).
- Type:
accel_hydra.trainer.base.LoggingConfig | None
- train_dataloader
Training data loader.
- val_dataloader
Validation data loader. Can be None to skip validation.
- Type:
- model
PyTorch model to be trained.
- optimizer
Optimizer for training.
- lr_scheduler
Learning rate scheduler.
- loss_fn
Loss function.
- epochs
Total number of training epochs.
- Type:
- epoch_length
Number of steps per epoch. If None, uses the length of train_dataloader.
- Type:
int | None
- lr_scheduler_interval
Learning rate scheduler update interval. STEP means update every step, EPOCH means update every epoch.
- Type:
accel_hydra.trainer.base.LRSchedulerInterval
- gradient_accumulation_steps
Number of gradient accumulation steps to simulate larger batch size.
- Type:
- max_grad_norm
Maximum gradient norm for gradient clipping. If None, no gradient clipping is performed.
- Type:
float | None
- resume_from_checkpoint
Path to checkpoint for resuming training. None by default, meaning training from scratch.
- Type:
str | pathlib.Path | None
- save_every_n_steps
Save checkpoint every N steps. If None, no step-based saving.
- Type:
int | None
- permanent_save_every_n_steps
Permanently save checkpoint to project_dir every N steps. If None, no permanent saving. Checkpoints saved by save_every_n_steps and save_every_n_epochs will be automatically deleted based on cleaning strategies but these checkpoints will not be deleted.
- Type:
int | None
- save_every_n_epochs
Save checkpoint every N epochs. Default is 1 (save every epoch).
- Type:
int | None
- save_last_k
Keep the last K checkpoints and delete older ones. Default is 1 (keep only the latest checkpoint).
- Type:
int | None
- metric_monitor
MetricMonitor instance for tracking validation metrics and saving best model.
- Type:
accel_hydra.trainer.base.MetricMonitor | None
- early_stop
Early stopping patience. Stop training if validation metric doesn’t improve for N consecutive epochs.
- Type:
int | None
- even_batches
Whether to use even batches for handling inconsistent data amounts across processes. Must set to False when batch_sampler does not have batch_size.
- Type:
- logging_config: LoggingConfig | None = None
- train_dataloader: DataLoader
- val_dataloader: DataLoader | None
- model: Module
- optimizer: Optimizer
- lr_scheduler: LRScheduler
- loss_fn: Module
- epochs: int
- lr_scheduler_interval: LRSchedulerInterval = 'step'
- gradient_accumulation_steps: int = 1
- metric_monitor: MetricMonitor | None = None
- even_batches: bool = True
- abstract training_step(batch: Any, batch_idx: int) Tensor[source]
Performs a single training step, like training_step() in Pytorch-Lightning.
This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation.
- Parameters:
batch – A batch of data from the training DataLoader.
batch_idx – The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position.
- Returns:
- The computed loss tensor. It will be used for loss.backward().
The tensor should be a 0-dimensional. The loss will be automatically logged as “train/loss” by the Trainer.
- Return type:
Example
def training_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) loss = self.loss_fn(preds, labels) # Optional: Log additional metrics lr = self.optimizer.param_groups[0]["lr"] self.accelerator.log({"train/lr": lr}, step=self.step) return loss
Note
You should NOT call loss.backward() manually - the Trainer handles this automatically.
- abstract validation_step(batch: Any, batch_idx: int) None[source]
Performs a single validation step, like validation_step() in Pytorch-Lightning.
This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later.
- Parameters:
batch – A batch of data from the validation DataLoader.
batch_idx – The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position.
- Returns:
- This method should not return anything. Store validation results in instance
variables for later use in get_val_metrics().
- Return type:
None
Example
def validation_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) predictions = preds.argmax(dim=-1) # Gather predictions from all processes (important for distributed training) output = {"predictions": predictions, "labels": labels} output = self.accelerator.gather_for_metrics(output) # Accumulate metrics accurate_preds = (output["predictions"] == output["labels"]) self.validation_stats["accurate"] += accurate_preds.long().sum() self.validation_stats["num_elems"] += accurate_preds.shape[0]
Note
Use self.accelerator.gather_for_metrics() to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks.
- get_context() contextmanager[source]
FIXME: why does it not work?
- property checkpoint_objects: list[CheckpointMixin]
Returns a list of additional objects to be included in checkpoints.
This property allows subclasses to specify additional objects (beyond the Trainer itself) that should be saved and restored during checkpointing. All objects in the returned list must implement the CheckpointMixin interface (i.e., have state_dict() and load_state_dict() methods). The customized checkpointing is achieved by registering these objects with the Accelerate framework during setup_accelerator().
- Returns:
- A list of objects to include in checkpoints. Default is an
empty list. Subclasses can override this property to return custom objects.
- Return type:
list[CheckpointMixin]
Example
import torch from accel_hydra.trainer import CheckpointMixin class VersionTracker(CheckpointMixin): def __init__(self): self.version = torch.__version__ def state_dict(self) -> dict: return {"version": self.version} def load_state_dict(self, state_dict: dict) -> None: self.version = state_dict["version"] class MyTrainer(Trainer): @property def checkpoint_objects(self) -> list[CheckpointMixin]: return [VersionTracker()]
- save_checkpoint(save_dir: Path | str, clean_old_checkpoints: bool = True) None[source]
Note: since wait_for_everyone is called, user must be responsible for making sure all processes call or not call this function at the same time!!!
- __init__(*, config_dict: dict | None = None, project_dir: str | Path, checkpoint_dir: str | Path | None = None, logging_config: LoggingConfig | None = None, train_dataloader: DataLoader, val_dataloader: DataLoader | None, model: Module, optimizer: Optimizer, lr_scheduler: LRScheduler, loss_fn: Module, epochs: int, epoch_length: int | None = None, lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP, gradient_accumulation_steps: int = 1, max_grad_norm: float | None = 2.0, resume_from_checkpoint: str | Path | None = None, log_every_n_steps: int | None = 50, save_every_n_steps: int | None = None, permanent_save_every_n_steps: int | None = None, save_every_n_epochs: int | None = 1, save_last_k: int | None = 1, metric_monitor: MetricMonitor | None = None, early_stop: int | None = None, even_batches: bool = True) None