Source code for accel_hydra.models.common

from pathlib import Path

import torch

from ..utils.torch import load_pretrained_model, merge_matched_keys


[docs] class LoadPretrainedMixin: """Base class for loading pretrained checkpoints with custom state dict processing function."""
[docs] def process_state_dict( self, model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor] ): """ Custom processing functions of each model that transforms `state_dict` loaded from checkpoints to the state that can be used in `load_state_dict`. Use `merge_mathced_keys` to update parameters with matched names and shapes by default. Args: model_dict: The state dict of the current model, which is going to load pretrained parameters state_dict: A dictionary of parameters from a pre-trained model. Returns: dict[str, torch.Tensor]: The updated state dict, where parameters with matched keys and shape are updated with values in `state_dict`. """ state_dict = merge_matched_keys(model_dict, state_dict) return state_dict
[docs] def load_pretrained(self, ckpt_path: str | Path): load_pretrained_model( self, ckpt_path, state_dict_process_fn=self.process_state_dict )
[docs] class CountParamsMixin:
[docs] def count_params(self): num_params = 0 trainable_params = 0 for param in self.parameters(): num_params += param.numel() if param.requires_grad: trainable_params += param.numel() return num_params, trainable_params
[docs] class SaveTrainableParamsMixin: @property def param_names_to_save(self): state_dict_keys = set(self.state_dict().keys()) names = [] for name, param in self.named_parameters(): if param.requires_grad: names.append(name) all_buffer_names = {n for n, _ in self.named_buffers()} persistent_buffers = all_buffer_names.intersection(state_dict_keys) names.extend(list(persistent_buffers)) return names
[docs] def load_state_dict(self, state_dict, strict=True): for key in self.param_names_to_save: if key not in state_dict: raise Exception( f"{key} not found in either pre-trained models (e.g. BERT)" " or resumed checkpoints (e.g. epoch_40/model.pt)" ) return super().load_state_dict(state_dict, strict)