Source code for accel_hydra.trainer.base

import shutil
from abc import ABC, abstractmethod
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Literal

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from accelerate import DistributedDataParallelKwargs
from accelerate.utils import broadcast, set_seed

# from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DataLoader
from tqdm import trange

from ..utils import is_package_available
from ..utils.accelerate import AcceleratorSaveTrainableParams


[docs] @dataclass(kw_only=True) class LoggingConfig: report_to: str | None = "wandb" # "wandb" | "swanlab" | "tensorboard" project: str save_dir: str | Path name: str resume_id: str | None = None workspace: str | None = None # organization name in SwanLab def __post_init__(self): self.supported_loggers = ("wandb", "swanlab", "tensorboard") if self.report_to not in self.supported_loggers: raise ValueError( f"Unsupported logger: {self.report_to}. Supported loggers are {self.supported_loggers}." ) if not is_package_available(self.report_to): raise ValueError( f"{self.report_to} is not installed. Please install {self.report_to} using `pip install {self.report_to}`." )
class LRSchedulerInterval(str, Enum): EPOCH = "epoch" STEP = "step"
[docs] class CheckpointMixin(ABC):
[docs] @abstractmethod def state_dict(self) -> dict: ...
[docs] @abstractmethod def load_state_dict(self, state_dict: dict) -> None: ...
[docs] @dataclass(kw_only=True) class MetricMonitor(CheckpointMixin): metric_name: str = "loss" mode: Literal["min", "max"] = "min" def __post_init__(self): if self.mode not in ("min", "max"): raise ValueError("Mode must be 'min' or 'max'.") self.best_value = np.inf if self.mode == "min" else -np.inf self.worse_count = 0
[docs] def compare(self, x: float, best_x: float) -> bool: """Compares the current value with the best value based on mode.""" return x < best_x if self.mode == "min" else x > best_x
[docs] def __call__(self, metric_dict: dict[str, Any]) -> bool: """Checks if the new value is better and updates best_value if so.""" metric_value = metric_dict[self.metric_name] if isinstance(metric_value, torch.Tensor): metric_value = metric_value.item() if self.compare(metric_value, self.best_value): self.best_value = metric_value self.worse_count = 0 return True self.worse_count += 1 return False
[docs] def state_dict(self) -> dict: """Returns the state of the object as a dictionary.""" return { "mode": self.mode, "best_value": self.best_value, "worse_count": self.worse_count }
[docs] def load_state_dict(self, state_dict: dict): """Loads the state from a dictionary.""" self.mode = state_dict["mode"] self.best_value = state_dict["best_value"] self.worse_count = state_dict["worse_count"]
[docs] @dataclass(kw_only=True) class Trainer(CheckpointMixin): """Base trainer class providing training workflow management. Usage: This is an abstract base class. Subclasses must implement `training_step()` and `validation_step()` methods to define the specific training and validation logic. Attributes: config_dict: Configuration dictionary for storing training configuration information. project_dir: Project root directory path for saving training-related files. checkpoint_dir: Checkpoint save directory. If None, uses project_dir/checkpoints. logging_config: Logging configuration object for experiment logging (wandb/swanlab/tensorboard). train_dataloader: Training data loader. val_dataloader: Validation data loader. Can be None to skip validation. model: PyTorch model to be trained. optimizer: Optimizer for training. lr_scheduler: Learning rate scheduler. loss_fn: Loss function. epochs: Total number of training epochs. epoch_length: Number of steps per epoch. If None, uses the length of train_dataloader. lr_scheduler_interval: Learning rate scheduler update interval. STEP means update every step, EPOCH means update every epoch. gradient_accumulation_steps: Number of gradient accumulation steps to simulate larger batch size. max_grad_norm: Maximum gradient norm for gradient clipping. If None, no gradient clipping is performed. resume_from_checkpoint: Path to checkpoint for resuming training. None by default, meaning training from scratch. save_every_n_steps: Save checkpoint every N steps. If None, no step-based saving. permanent_save_every_n_steps: Permanently save checkpoint to project_dir every N steps. If None, no permanent saving. Checkpoints saved by `save_every_n_steps` and `save_every_n_epochs` will be automatically deleted based on cleaning strategies but these checkpoints will not be deleted. save_every_n_epochs: Save checkpoint every N epochs. Default is 1 (save every epoch). save_last_k: Keep the last K checkpoints and delete older ones. Default is 1 (keep only the latest checkpoint). metric_monitor: `MetricMonitor` instance for tracking validation metrics and saving best model. early_stop: Early stopping patience. Stop training if validation metric doesn't improve for N consecutive epochs. even_batches: Whether to use even batches for handling inconsistent data amounts across processes. Must set to False when batch_sampler does not have batch_size. """ config_dict: dict | None = None project_dir: str | Path checkpoint_dir: str | Path = None logging_config: LoggingConfig | None = None train_dataloader: DataLoader val_dataloader: DataLoader | None model: nn.Module optimizer: torch.optim.Optimizer lr_scheduler: torch.optim.lr_scheduler.LRScheduler loss_fn: nn.Module epochs: int epoch_length: int | None = None lr_scheduler_interval: LRSchedulerInterval = LRSchedulerInterval.STEP gradient_accumulation_steps: int = 1 max_grad_norm: float | None = 2.0 resume_from_checkpoint: str | Path | None = None log_every_n_steps: int | None = 50 save_every_n_steps: int | None = None permanent_save_every_n_steps: int | None = None save_every_n_epochs: int | None = 1 save_last_k: int | None = 1 metric_monitor: MetricMonitor | None = None early_stop: int | None = None # use_stateful_dataloader: bool = False even_batches: bool = True
[docs] def wrap_and_broadcast_value(self, value: Any) -> torch.Tensor: value = torch.tensor(value, device=self.accelerator.device) broadcast(value, from_process=0) return value
[docs] def setup_accelerator(self) -> None: ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) tracker = None if self.logging_config is not None: assert self.logging_config.report_to in ( "wandb", "swanlab", "tensorboard" ), ( f"Unsupported logger: {self.logging_config.report_to}. " "Supported loggers are 'wandb', 'swanlab', and 'tensorboard'." ) if self.logging_config.report_to == "swanlab": from swanlab.integration.accelerate import SwanLabTracker tracker = SwanLabTracker( run_name=self.logging_config.project, experiment_name=self.logging_config.name, logdir=self.logging_config.save_dir, workspace=self.logging_config.workspace, resume=True, id=self.logging_config.resume_id ) else: tracker = self.logging_config.report_to # dataloader_config = DataLoaderConfiguration( # use_stateful_dataloader=self.use_stateful_dataloader # even_batches=self.even_batches # ) self.accelerator = AcceleratorSaveTrainableParams( log_with=tracker, gradient_accumulation_steps=self.gradient_accumulation_steps, project_dir=self.project_dir, step_scheduler_with_optimizer=( self.lr_scheduler_interval == LRSchedulerInterval.STEP ), # dataloader_config=dataloader_config, kwargs_handlers=[ddp_kwargs] ) train_batch_sampler = self.train_dataloader.batch_sampler if not hasattr(train_batch_sampler, "batch_size"): assert self.even_batches is False, "even_batches must be False when batch_sampler does not have batch_size" # due to this line: https://github.com/huggingface/accelerate/blob/main/src/accelerate/data_loader.py#L246 assert getattr( train_batch_sampler, "drop_last", False ) is True, "drop_last must be True when batch_sampler does not have batch_size" if self.val_dataloader is not None: val_batch_sampler = self.val_dataloader.batch_sampler if not hasattr(val_batch_sampler, "batch_size"): assert self.even_batches is False, "even_batches must be False when batch_sampler does not have batch_size" assert getattr( val_batch_sampler, "drop_last", False ) is True, "drop_last must be True when batch_sampler does not have batch_size and even_batches is False" self.accelerator.even_batches = self.even_batches # TODO when `loss_fn` does not have named_parameters/buffers, loading will raise error ( self.train_dataloader, self.model, self.optimizer, self.lr_scheduler, ) = self.accelerator.prepare( self.train_dataloader, self.model, self.optimizer, self.lr_scheduler, ) if self.val_dataloader is not None: self.val_dataloader = self.accelerator.prepare(self.val_dataloader) self.accelerator.register_for_checkpointing(self) for checkpoint_object in self.checkpoint_objects: self.accelerator.register_for_checkpointing(checkpoint_object) if self.resume_from_checkpoint is not None: self.accelerator.print( f"resume from checkpoint: {self.resume_from_checkpoint}" ) self.accelerator.load_state( self.resume_from_checkpoint, strict=False )
[docs] @abstractmethod def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ Performs a single training step, like `training_step()` in Pytorch-Lightning. This method is called for each batch during training. Subclasses must implement this method to define the forward pass, loss computation, and other optional operations. The returned loss will be automatically used for backpropagation. Args: batch: A batch of data from the training DataLoader. batch_idx: The index of the current batch within the current epoch (0-indexed). This can be useful for logging or conditional logic based on batch position. Returns: torch.Tensor: The computed loss tensor. It will be used for `loss.backward()`. The tensor should be a 0-dimensional. The loss will be automatically logged as "train/loss" by the Trainer. Example: .. code-block:: python def training_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) loss = self.loss_fn(preds, labels) # Optional: Log additional metrics lr = self.optimizer.param_groups[0]["lr"] self.accelerator.log({"train/lr": lr}, step=self.step) return loss Note: - You should NOT call `loss.backward()` manually - the Trainer handles this automatically. """ raise NotImplementedError("Subclasses must implement this method")
[docs] @abstractmethod def validation_step(self, batch: Any, batch_idx: int) -> None: """ Performs a single validation step, like `validation_step()` in Pytorch-Lightning. This method is called for each batch during validation. Subclasses must implement this method to define the prediction operation and the potential metric calculation. You can specify the metric calculation logic to use the metric for learning rate scheduling or early stopping later. Args: batch: A batch of data from the validation DataLoader. batch_idx: The index of the current batch within the validation loop (0-indexed). This can be useful for logging or conditional logic based on batch position. Returns: None: This method should not return anything. Store validation results in instance variables for later use in `get_val_metrics()`. Example: .. code-block:: python def validation_step(self, batch, batch_idx): features, labels = batch preds = self.model(features) predictions = preds.argmax(dim=-1) # Gather predictions from all processes (important for distributed training) output = {"predictions": predictions, "labels": labels} output = self.accelerator.gather_for_metrics(output) # Accumulate metrics accurate_preds = (output["predictions"] == output["labels"]) self.validation_stats["accurate"] += accurate_preds.long().sum() self.validation_stats["num_elems"] += accurate_preds.shape[0] Note: - Use `self.accelerator.gather_for_metrics()` to collect predictions from all processes before computing metrics, otherwise discrepancies between processes may result in deadlocks. """ raise NotImplementedError("Subclasses must implement this method")
[docs] def get_context(self) -> contextmanager: """ FIXME: why does it not work? """ if self.even_batches: return nullcontext() else: return self.accelerator.join_uneven_inputs([self.model])
[docs] def gather_min_length(self, length: int) -> int: length_tensor = torch.tensor(length, device=self.accelerator.device) dist.all_reduce(length_tensor, op=dist.ReduceOp.MIN) return length_tensor.item()
[docs] def val_loop(self) -> None: self.model.eval() torch.set_grad_enabled(False) self.on_validation_start() if dist.is_initialized(): dataloader_len = self.gather_min_length(len(self.val_dataloader)) else: dataloader_len = len(self.val_dataloader) self.val_data_iterator = iter(self.val_dataloader) if self.accelerator.is_main_process: range_iterator = trange( dataloader_len, desc="Validation", ) else: range_iterator = range(dataloader_len) for batch_idx in range_iterator: batch = next(self.val_data_iterator) self.validation_step(batch, batch_idx) self.on_validation_end() self.model.train() torch.set_grad_enabled(True)
[docs] def on_validation_start(self) -> None: pass
[docs] def on_validation_end(self) -> None: pass
[docs] def get_val_metrics(self) -> dict[str, Any]: return {}
[docs] def on_train_epoch_start(self) -> None: pass
[docs] def on_train_epoch_end(self) -> None: pass
@property def checkpoint_objects(self) -> list[CheckpointMixin]: """Returns a list of additional objects to be included in checkpoints. This property allows subclasses to specify additional objects (beyond the Trainer itself) that should be saved and restored during checkpointing. All objects in the returned list must implement the `CheckpointMixin` interface (i.e., have `state_dict()` and `load_state_dict()` methods). The customized checkpointing is achieved by registering these objects with the Accelerate framework during `setup_accelerator()`. Returns: list[CheckpointMixin]: A list of objects to include in checkpoints. Default is an empty list. Subclasses can override this property to return custom objects. Example: .. code-block:: python import torch from accel_hydra.trainer import CheckpointMixin class VersionTracker(CheckpointMixin): def __init__(self): self.version = torch.__version__ def state_dict(self) -> dict: return {"version": self.version} def load_state_dict(self, state_dict: dict) -> None: self.version = state_dict["version"] class MyTrainer(Trainer): @property def checkpoint_objects(self) -> list[CheckpointMixin]: return [VersionTracker()] """ return []
[docs] def state_dict(self) -> dict: state_dict = {"epoch": self.epoch, "step": self.step} # state_dict = {"step": self.step} # if isinstance(self.train_dataloader, StatefulDataLoader): # # FIXME: after `accelerator.prepare`, how to determine if `train_dataloader` is `StatefulDataLoader`? # state_dict["train_dataloader"] = self.train_dataloader.state_dict() if self.metric_monitor is not None: state_dict["metric_monitor"] = self.metric_monitor.state_dict() return state_dict
[docs] def load_state_dict(self, state_dict: dict) -> None: self.epoch = state_dict["epoch"] self.step = state_dict["step"] # self.step = state_dict["step"] # self.epoch = self.step // self.epoch_length # if "train_dataloader" in state_dict: # self.train_dataloader.load_state_dict( # state_dict["train_dataloader"] # ) if "metric_monitor" in state_dict: self.metric_monitor.load_state_dict(state_dict["metric_monitor"])
[docs] def clean_checkpoints_to_k( self, checkpoints_dir: Path | str, k: int ) -> None: checkpoints_dir = Path(checkpoints_dir) checkpoints = ( list(checkpoints_dir.glob("epoch_*")) + list(checkpoints_dir.glob("step_*")) ) # sort `checkpoints` by their last modified timestamp (ascending order) checkpoints.sort(key=lambda x: x.stat().st_mtime) if k > 0: to_delete = checkpoints[:-k] if len(checkpoints) > k else [] elif k == 0: to_delete = checkpoints for checkpoint in to_delete: shutil.rmtree(checkpoint)
[docs] def save_checkpoint( self, save_dir: Path | str, clean_old_checkpoints: bool = True ) -> None: """ Note: since `wait_for_everyone` is called, user must be responsible for making sure all processes call or not call this function at the same time!!! """ self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: save_dir = Path(save_dir) if clean_old_checkpoints and self.save_last_k: self.clean_checkpoints_to_k( save_dir.parent, self.save_last_k - 1 ) self.accelerator.wait_for_everyone() self.accelerator.save_state(save_dir) self.accelerator.wait_for_everyone()
[docs] def train_loop(self) -> None: torch.set_grad_enabled(True) self.model.train() self.on_train_epoch_start() epoch_steps = (self.epoch + 1) * self.epoch_length - self.step if dist.is_initialized(): epoch_steps = self.gather_min_length(epoch_steps) else: epoch_steps = epoch_steps if self.accelerator.is_main_process: range_iterator = trange( epoch_steps, desc=f"Epoch {self.epoch + 1}/{self.epochs}" ) else: range_iterator = range(epoch_steps) for batch_idx in range_iterator: try: batch = next(self.train_data_iterator) except StopIteration: self.train_data_iterator = iter(self.train_dataloader) batch = next(self.train_data_iterator) with self.accelerator.accumulate(self.model): loss = self.training_step(batch, batch_idx) if self.step % self.log_every_n_steps == 0: self.accelerator.log({"train/loss": loss.item()}, step=self.step) self.accelerator.backward(loss) # gradient clipping and logging if self.accelerator.sync_gradients: if self.max_grad_norm: grad_norm = self.accelerator.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) else: grad_norm = nn.utils.clip_grad_norm_( self.model.parameters(), float('inf') ) if self.step % self.log_every_n_steps == 0: self.accelerator.log({"train/grad_norm": grad_norm}, step=self.step) self.optimizer.step() if self.lr_scheduler_interval == LRSchedulerInterval.STEP: self.lr_scheduler.step() self.optimizer.zero_grad() self.step += 1 if self.save_every_n_steps: should_save_checkpoint = self.step % self.save_every_n_steps == 0 if should_save_checkpoint: self.save_checkpoint( self.checkpoint_dir / f"step_{self.step}" ) # FIXME `self.epoch` may be not set properly at this step if self.permanent_save_every_n_steps: should_save_checkpoint = self.step % self.permanent_save_every_n_steps == 0 if should_save_checkpoint: # if self.step % self.epoch_length == 0: # self.epoch += 1 self.save_checkpoint( self.project_dir / f"ckpt_step_{self.step}", clean_old_checkpoints=False ) # if self.step % self.epoch_length == 0: # self.epoch -= 1 if self.val_dataloader is not None: self.val_loop() else: self.accelerator.print("No validation data, skipping validation") self.epoch += 1 if self.lr_scheduler_interval == LRSchedulerInterval.EPOCH: self.lr_scheduler.step() if self.save_every_n_epochs: should_save_checkpoint = self.wrap_and_broadcast_value( self.epoch % self.save_every_n_epochs == 0 ) if should_save_checkpoint: self.accelerator.print("\n Saving latest checkpoint...") self.save_checkpoint( self.checkpoint_dir / f"epoch_{self.epoch}" ) if self.val_dataloader is not None: metric_dict: dict = self.get_val_metrics() if self.metric_monitor is not None: # save checkpoint if the monitored metric improves should_save_checkpoint = self.wrap_and_broadcast_value( self.metric_monitor(metric_dict) ) if should_save_checkpoint: self.accelerator.print("\n Saving best checkpoint...") self.save_checkpoint(self.checkpoint_dir / "best") if self.early_stop is not None and self.metric_monitor.worse_count >= self.early_stop: self.should_stop_training = True # on start of train epoch end func self.on_train_epoch_end()
[docs] def on_train_start(self) -> None: self.project_dir = Path(self.project_dir) self.project_dir.mkdir(parents=True, exist_ok=True) if not self.checkpoint_dir: self.checkpoint_dir = self.project_dir / "checkpoints" else: self.checkpoint_dir = Path(self.checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.accelerator.print( f"{self.accelerator.state.num_processes} devices are used in training" ) # if load from previous checkpoint, `epoch` and `step` have been set if not hasattr(self, "epoch"): self.epoch = 0 if not hasattr(self, "step"): self.step = 0 self.should_stop_training = False # set up `epoch_length` and training data iterator if self.epoch_length is None: self.epoch_length = len(self.train_dataloader) self.train_data_iterator = iter(self.train_dataloader) self.accelerator.print("training start ............") if self.logging_config is not None: self.accelerator.init_trackers( self.logging_config.project, init_kwargs={ "wandb": { "name": self.logging_config.name, "dir": self.logging_config.save_dir, "id": self.logging_config.resume_id, "resume": "allow", } } ) if self.val_dataloader is not None and self.metric_monitor is None: assert self.early_stop is None, "early stop does not have metrics to monitor!"
[docs] def on_train_end(self) -> None: self.accelerator.print("training end ............") self.accelerator.end_training() # wandb sometimes stuck in finishing if is_package_available("wandb"): import wandb if wandb.run is not None: wandb.finish()
[docs] def train(self, seed: int) -> None: set_seed(seed) self.setup_accelerator() self.on_train_start() for _ in range(self.epoch, self.epochs): self.train_loop() if self.should_stop_training: break self.on_train_end()