Shortcuts

Source code for mmpose.engine.hooks.mode_switch_hooks

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Sequence

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner

from mmpose.registry import HOOKS
from mmpose.utils.hooks import rgetattr, rsetattr


[docs]@HOOKS.register_module() class YOLOXPoseModeSwitchHook(Hook): """Switch the mode of YOLOX-Pose during training. This hook: 1) Turns off mosaic and mixup data augmentation. 2) Uses instance mask to assist positive anchor selection. 3) Uses auxiliary L1 loss in the head. Args: num_last_epochs (int): The number of last epochs at the end of training to close the data augmentation and switch to L1 loss. Defaults to 20. new_train_dataset (dict): New training dataset configuration that will be used in place of the original training dataset. Defaults to None. new_train_pipeline (Sequence[dict]): New data augmentation pipeline configuration that will be used in place of the original pipeline during training. Defaults to None. """ def __init__(self, num_last_epochs: int = 20, new_train_dataset: dict = None, new_train_pipeline: Sequence[dict] = None): self.num_last_epochs = num_last_epochs self.new_train_dataset = new_train_dataset self.new_train_pipeline = new_train_pipeline def _modify_dataloader(self, runner: Runner): """Modify dataloader with new dataset and pipeline configurations.""" runner.logger.info(f'New Pipeline: {self.new_train_pipeline}') train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader) if self.new_train_dataset: train_dataloader_cfg.dataset = self.new_train_dataset if self.new_train_pipeline: train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline new_train_dataloader = Runner.build_dataloader(train_dataloader_cfg) runner.train_loop.dataloader = new_train_dataloader runner.logger.info('Recreated the dataloader!')
[docs] def before_train_epoch(self, runner: Runner): """Close mosaic and mixup augmentation, switch to use L1 loss.""" epoch = runner.epoch model = runner.model if is_model_wrapper(model): model = model.module if epoch + 1 == runner.max_epochs - self.num_last_epochs: self._modify_dataloader(runner) runner.logger.info('Added additional reg loss now!') model.head.use_aux_loss = True
[docs]@HOOKS.register_module() class RTMOModeSwitchHook(Hook): """A hook to switch the mode of RTMO during training. This hook allows for dynamic adjustments of model attributes at specified training epochs. It is designed to modify configurations such as turning off specific augmentations or changing loss functions at different stages of the training process. Args: epoch_attributes (Dict[str, Dict]): A dictionary where keys are epoch numbers and values are attribute modification dictionaries. Each dictionary specifies the attribute to modify and its new value. Example: epoch_attributes = { 5: [{"attr1.subattr": new_value1}, {"attr2.subattr": new_value2}], 10: [{"attr3.subattr": new_value3}] } """ def __init__(self, epoch_attributes: Dict[int, Dict]): self.epoch_attributes = epoch_attributes
[docs] def before_train_epoch(self, runner: Runner): """Method called before each training epoch. It checks if the current epoch is in the `epoch_attributes` mapping and applies the corresponding attribute changes to the model. """ epoch = runner.epoch model = runner.model if is_model_wrapper(model): model = model.module if epoch in self.epoch_attributes: for key, value in self.epoch_attributes[epoch].items(): rsetattr(model.head, key, value) runner.logger.info( f'Change model.head.{key} to {rgetattr(model.head, key)}')
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.