-
Notifications
You must be signed in to change notification settings - Fork 11
/
config.py
73 lines (62 loc) · 4.55 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
parser = argparse.ArgumentParser(description='Train prototypical networks')
# data args
parser.add_argument('--data.train', type=str, default='cu_birds', metavar='TRAINSETS', nargs='+', help="Datasets for training extractors")
parser.add_argument('--data.val', type=str, default='cu_birds', metavar='VALSETS', nargs='+',
help="Datasets used for validation")
parser.add_argument('--data.test', type=str, default='cu_birds', metavar='TESTSETS', nargs='+',
help="Datasets used for testing")
parser.add_argument('--data.num_workers', type=int, default=32, metavar='NEPOCHS',
help="Number of workers that pre-process images in parallel")
# model args
default_model_name = 'noname'
parser.add_argument('--model.name', type=str, default=default_model_name, metavar='MODELNAME',
help="A name you give to the extractor".format(default_model_name))
parser.add_argument('--model.backbone', default='resnet18', help="Use ResNet18 for experiments (default: False)")
parser.add_argument('--model.classifier', type=str, default='cosine', choices=['none', 'linear', 'cosine'], help="Do classification using cosine similatity between activations and weights")
parser.add_argument('--model.dropout', type=float, default=0, help="Adding dropout inside a basic block of widenet")
# train args
parser.add_argument('--train.batch_size', type=int, default=16, metavar='BS',
help='number of images in a batch')
parser.add_argument('--train.max_iter', type=int, default=500000, metavar='NEPOCHS',
help='number of epochs to train (default: 10000)')
parser.add_argument('--train.weight_decay', type=float, default=7e-4, metavar='WD',
help="weight decay coef")
parser.add_argument('--train.optimizer', type=str, default='momentum', metavar='OPTIM',
help='optimization method (default: momentum)')
parser.add_argument('--train.learning_rate', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.0001)')
parser.add_argument('--train.lr_policy', type=str, default='cosine', metavar='LR_policy',
help='learning rate decay policy')
parser.add_argument('--train.lr_decay_step_gamma', type=int, default=1e-1, metavar='DECAY_GAMMA',
help='the value to divide learning rate by when decayin lr')
parser.add_argument('--train.lr_decay_step_freq', type=int, default=10000, metavar='DECAY_FREQ',
help='the value to divide learning rate by when decayin lr')
parser.add_argument('--train.exp_decay_final_lr', type=float, default=8e-5, metavar='FINAL_LR',
help='the value to divide learning rate by when decayin lr')
parser.add_argument('--train.exp_decay_start_iter', type=int, default=30000, metavar='START_ITER',
help='the value to divide learning rate by when decayin lr')
parser.add_argument('--train.cosine_anneal_freq', type=int, default=4000, metavar='ANNEAL_FREQ',
help='the value to divide learning rate by when decayin lr')
parser.add_argument('--train.nesterov_momentum', action='store_true', help="If to augment query images in order to avearge the embeddings")
# evaluation during training
parser.add_argument('--train.eval_freq', type=int, default=5000, metavar='EVAL_FREQ',
help='How often to evaluate model during training')
parser.add_argument('--train.eval_size', type=int, default=300, metavar='EVAL_SIZE',
help='How many episodes to sample for validation')
parser.add_argument('--train.resume', type=int, default=1, metavar='RESUME_TRAIN',
help="Resume training starting from the last checkpoint (default: True)")
# creating a database of features
parser.add_argument('--dump.name', type=str, default='', metavar='DUMP_NAME',
help='Name for dumped dataset of features')
parser.add_argument('--dump.mode', type=str, default='test', metavar='DUMP_MODE',
help='What split of the original dataset to dump')
parser.add_argument('--dump.size', type=int, default=600, metavar='DUMP_SIZE',
help='Howe many episodes to dump')
# test args
parser.add_argument('--test.size', type=int, default=600, metavar='TEST_SIZE',
help='The number of test episodes sampled')
parser.add_argument('--test.distance', type=str, choices=['cos', 'l2'], default='cos', metavar='DISTANCE_FN',
help="If to augment support images in order to avearge the embeddings")
# log args
args = vars(parser.parse_args())