diff --git a/mmpretrain/models/backbones/poolformer.py b/mmpretrain/models/backbones/poolformer.py index e2ad67043db..2ef27c44e3c 100644 --- a/mmpretrain/models/backbones/poolformer.py +++ b/mmpretrain/models/backbones/poolformer.py @@ -220,6 +220,7 @@ class PoolFormer(BaseBackbone): Defaults to ``dict(type='LN2d', eps=1e-6)``. act_cfg (dict): The config dict for activation between pointwise convolution. Defaults to ``dict(type='GELU')``. + in_chans (int): The num of channels of input image. in_patch_size (int): The patch size of input image patch embedding. Defaults to 7. in_stride (int): The stride of input image patch embedding. @@ -285,6 +286,7 @@ def __init__(self, pool_size=3, norm_cfg=dict(type='GN', num_groups=1), act_cfg=dict(type='GELU'), + in_chans=3, in_patch_size=7, in_stride=4, in_pad=2, @@ -320,7 +322,7 @@ def __init__(self, patch_size=in_patch_size, stride=in_stride, padding=in_pad, - in_chans=3, + in_chans=in_chans, embed_dim=embed_dims[0]) # set the main block in network