Skip to content

Commit

Permalink
[Improvement] Update NasMutator to build search_space in NAS (#426)
Browse files Browse the repository at this point in the history
* update space_mixin

* update NAS algorithms with SpaceMixin

* update pruning algorithms with SpaceMixin

* fix ut

* fix comments

* revert _load_fix_subnet_by_mutator

* fix dcff test

* add ut for registry

* update autoslim_greedy_search

* fix repeat-mutables bug

* fix slice_weight in export_fix_subnet

* Update NasMutator:
1. unify mutators for NAS algorithms as the NasMutator;
2. regard ChannelMutator as pruning-specified;
3. remove value_mutators & module_mutators;
4. set GroupMixin only for NAS;
5. revert all changes in ChannelMutator.

* update NAS algorithms using NasMutator

* update channel mutator

* update one_shot_channel_mutator

* fix comments

* update UT for NasMutator

* fix isort version

* fix comments

---------

Co-authored-by: gaoyang07 <[email protected]>
Co-authored-by: liukai <[email protected]>
  • Loading branch information
3 people authored Feb 1, 2023
1 parent b750375 commit a27952d
Show file tree
Hide file tree
Showing 71 changed files with 1,126 additions and 1,686 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
Expand Down
30 changes: 10 additions & 20 deletions configs/_base_/settings/cifar10_darts_supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,26 @@

# optimizer
optim_wrapper = dict(
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
mutator=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3),
clip_grad=dict(max_norm=5, norm_type=2))
optimizer=dict(
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
mutator=dict(
optimizer=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3)))

search_epochs = 50
# leanring policy
# TODO support different optim use different scheduler (wait mmengine)
param_scheduler = [
dict(
type='mmcls.CosineAnnealingLR',
T_max=50,
T_max=search_epochs,
eta_min=1e-3,
begin=0,
end=50),
end=search_epochs),
]
# param_scheduler = dict(
# architecture = dict(
# type='mmcls.CosineAnnealingLR',
# T_max=50,
# eta_min=1e-3,
# begin=0,
# end=50),
# mutator = dict(
# type='mmcls.ConstantLR',
# factor=1,
# begin=0,
# end=50))

# train, val, test setting
# TODO split cifar dataset
train_cfg = dict(
type='mmrazor.DartsEpochBasedTrainLoop',
mutator_dataloader=dict(
Expand All @@ -92,7 +82,7 @@
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
),
max_epochs=50)
max_epochs=search_epochs)

val_cfg = dict() # validate each epoch
test_cfg = dict() # dataset settings
12 changes: 1 addition & 11 deletions configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,7 @@
type='mmrazor.Autoformer',
architecture=supernet,
fix_subnet=None,
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
mutator=dict(type='mmrazor.NasMutator'))

# runtime setting
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,7 @@
loss_kl=dict(
preds_S=dict(recorder='fc', from_student=True),
preds_T=dict(recorder='fc', from_student=False)))),
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='DynamicValueMutator')))
mutators=dict(type='mmrazor.NasMutator'))

model_wrapper_cfg = dict(
type='mmrazor.BigNASDDP',
Expand Down
20 changes: 5 additions & 15 deletions configs/nas/mmcls/darts/darts_supernet_unroll_1xb96_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
'mmcls::_base_/default_runtime.py',
]

# model
mutator = dict(type='mmrazor.DiffModuleMutator')
custom_hooks = [
dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True)
]

# model
model = dict(
type='mmrazor.Darts',
architecture=dict(
Expand All @@ -20,24 +22,12 @@
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
cal_acc=True)),
mutator=dict(type='mmrazor.DiffModuleMutator'),
mutator=dict(type='mmrazor.NasMutator'),
unroll=True)

model_wrapper_cfg = dict(
type='mmrazor.DartsDDP',
broadcast_buffers=False,
find_unused_parameters=False)

# TRAINING
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
mutator=dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))

find_unused_parameter = False
2 changes: 1 addition & 1 deletion configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
mode='original',
loss_weight=1.0),
topk=(1, 5))),
mutator=dict(type='mmrazor.DiffModuleMutator'),
mutator=dict(type='mmrazor.NasMutator'),
pretrain_epochs=15,
finetune_epochs=_base_.search_epochs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,7 @@
loss_kl=dict(
preds_S=dict(recorder='fc', from_student=True),
preds_T=dict(recorder='fc', from_student=False)))),
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='DynamicValueMutator')))
mutators=dict(type='mmrazor.NasMutator'))

model_wrapper_cfg = dict(
type='mmrazor.BigNASDDP',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
model = dict(
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
model = dict(
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
_delete_=True,
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
_delete_=True,
type='mmrazor.SPOS',
architecture=supernet,
mutator=dict(type='mmrazor.OneShotModuleMutator'))
mutator=dict(type='mmrazor.NasMutator'))

find_unused_parameters = True
1 change: 0 additions & 1 deletion configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer')),
fix_subnet=None,
data_preprocessor=None,
target_pruning_ratio=target_pruning_ratio,
step_freq=1,
Expand Down
23 changes: 18 additions & 5 deletions mmrazor/engine/hooks/dump_subnet_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from pathlib import Path
from typing import Optional, Sequence, Union
Expand All @@ -8,6 +9,9 @@
from mmengine.hooks import Hook
from mmengine.registry import HOOKS

from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import convert_fix_subnet, export_fix_subnet

DATA_BATCH = Optional[Sequence[dict]]


Expand Down Expand Up @@ -103,16 +107,25 @@ def after_train_epoch(self, runner) -> None:

@master_only
def _save_subnet(self, runner) -> None:
"""Save the current subnet and delete outdated subnet.
"""Save the current best subnet.
Args:
runner (Runner): The runner of the training process.
"""
model = runner.model.module if runner.distributed else runner.model

if runner.distributed:
subnet_dict = runner.model.module.search_subnet()
else:
subnet_dict = runner.model.search_subnet()
# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
for module in model.architecture.modules():
if isinstance(module, BaseMutable):
if hasattr(module, 'arch_weights'):
delattr(module, 'arch_weights')

copied_model = copy.deepcopy(model)
copied_model.mutator.set_choices(copied_model.sample_choices())

subnet_dict = export_fix_subnet(copied_model)[0]
subnet_dict = convert_fix_subnet(subnet_dict)

if self.by_epoch:
subnet_filename = self.args.get(
Expand Down
8 changes: 5 additions & 3 deletions mmrazor/engine/hooks/estimate_resources_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def export_subnet(self, model) -> torch.nn.Module:
"""
# Avoid circular import
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import load_fix_subnet
from mmrazor.structures import export_fix_subnet, load_fix_subnet

# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
Expand All @@ -114,7 +114,9 @@ def export_subnet(self, model) -> torch.nn.Module:
delattr(module, 'arch_weights')

copied_model = copy.deepcopy(model)
fix_mutable = copied_model.search_subnet()
load_fix_subnet(copied_model, fix_mutable)
copied_model.mutator.set_choices(copied_model.mutator.sample_choices())

subnet_dict = export_fix_subnet(copied_model)[0]
load_fix_subnet(copied_model, subnet_dict)

return copied_model
30 changes: 10 additions & 20 deletions mmrazor/engine/runner/autoslim_greedy_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader

from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import export_fix_subnet
from mmrazor.structures import convert_fix_subnet, export_fix_subnet
from .utils import check_subnet_resources


Expand Down Expand Up @@ -68,14 +68,15 @@ def __init__(self,
self.model = runner.model

assert hasattr(self.model, 'mutator')
search_groups = self.model.mutator.search_groups
units = self.model.mutator.mutable_units

self.candidate_choices = {}
for group_id, modules in search_groups.items():
self.candidate_choices[group_id] = modules[0].candidate_choices
for unit in units:
self.candidate_choices[unit.alias] = unit.candidate_choices

self.max_subnet = {}
for group_id, candidate_choices in self.candidate_choices.items():
self.max_subnet[group_id] = len(candidate_choices)
for name, candidate_choices in self.candidate_choices.items():
self.max_subnet[name] = len(candidate_choices)
self.current_subnet = self.max_subnet

current_subnet_choices = self._channel_bins2choices(
Expand Down Expand Up @@ -117,7 +118,7 @@ def run(self) -> None:
pruned_subnet[unit_name] -= 1
pruned_subnet_choices = self._channel_bins2choices(
pruned_subnet)
self.model.set_subnet(pruned_subnet_choices)
self.model.mutator.set_choices(pruned_subnet_choices)
metrics = self._val_subnet()
score = metrics[self.score_key] \
if len(metrics) != 0 else 0.
Expand Down Expand Up @@ -195,27 +196,16 @@ def _save_searcher_ckpt(self) -> None:

def _save_searched_subnet(self):
"""Save the final searched subnet dict."""

def _convert_fix_subnet(fixed_subnet: Dict[str, Any]):
from mmrazor.utils.typing import DumpChosen

converted_fix_subnet = dict()
for key, val in fixed_subnet.items():
assert isinstance(val, DumpChosen)
converted_fix_subnet[key] = dict(val._asdict())

return converted_fix_subnet

if self.runner.rank != 0:
return
self.runner.logger.info('Search finished:')
for subnet, flops in zip(self.searched_subnet,
self.searched_subnet_flops):
subnet_choice = self._channel_bins2choices(subnet)
self.model.set_subnet(subnet_choice)
self.model.mutator.set_choices(subnet_choice)
fixed_subnet, _ = export_fix_subnet(self.model)
save_name = 'FLOPS_{:.2f}M.yaml'.format(flops)
fixed_subnet = _convert_fix_subnet(fixed_subnet)
fixed_subnet = convert_fix_subnet(fixed_subnet)
fileio.dump(fixed_subnet, osp.join(self.runner.work_dir,
save_name))
self.runner.logger.info(
Expand Down
Loading

0 comments on commit a27952d

Please sign in to comment.