import argparse
from pathlib import Path
from typing import Callable
import hydra
import omegaconf
from omegaconf import OmegaConf
[docs]
def multiply(*args):
result = 1
for arg in args:
result *= arg
return result
[docs]
def register_omegaconf_resolvers(clear_resolvers: bool = True) -> None:
"""
Register custom resolver for hydra configs, which can be used in YAML
files for dynamically setting values.
Args:
clear_resolvers: If True, clear all existing resolvers before registering.
Set to False if you want to extend existing resolvers.
"""
if clear_resolvers:
OmegaConf.clear_resolvers()
OmegaConf.register_new_resolver("len", len, replace=True)
OmegaConf.register_new_resolver("multiply", multiply, replace=True)
[docs]
def load_config_with_overrides(
config_file: str | Path,
overrides: list[str],
register_resolver_fn: Callable = register_omegaconf_resolvers,
) -> omegaconf.DictConfig:
register_resolver_fn()
config_file = Path(config_file).resolve()
config_name = config_file.name.__str__()
config_dir = config_file.parent.resolve().__str__()
with hydra.initialize_config_dir(version_base=None, config_dir=config_dir):
config = hydra.compose(config_name=config_name, overrides=overrides)
config = OmegaConf.to_container(config, resolve=True)
return config
[docs]
def load_config_from_cli(
return_config: bool = True,
register_resolver_fn: Callable = register_omegaconf_resolvers,
):
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
"-c",
default="configs/train.yaml",
type=str,
help="Path to the config file",
)
parser.add_argument(
"--overrides",
"-o",
default=[],
nargs="*",
help="Overrides to the config",
)
args, _ = parser.parse_known_args()
if return_config:
config = load_config_with_overrides(
args.config_file, args.overrides, register_resolver_fn
)
return config
else:
return args.config_file, args.overrides
[docs]
def parse_launch_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--launcher",
"-l",
default="accel_hydra.train_launcher.TrainLauncher",
type=str,
help="The entrypoint of the training script to use"
)
args, _ = parser.parse_known_args()
return args.launcher