-
Notifications
You must be signed in to change notification settings - Fork 24
/
inat2017.py
85 lines (71 loc) · 3.72 KB
/
inat2017.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
import os
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import check_integrity, extract_archive
from torchvision.datasets.utils import download_url, verify_str_arg
class INat2017(VisionDataset):
"""`iNaturalist 2017 <https://github.com/visipedia/inat_comp/blob/master/2017/README.md>`_ Dataset.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'train_val_images/'
file_list = {
'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
'train_val_images.tar.gz',
'7c784ea5e424efaec655bd392f87301f'),
'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
'train_val2017.zip',
'444c835f6459867ad69fcb36478786e7')
}
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = default_loader
self.split = verify_str_arg(split, "split", ("train", "val",))
if self._check_exists():
print('Files already downloaded and verified.')
elif download:
if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
print('Downloading...')
self._download()
print('Extracting...')
extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
else:
raise RuntimeError(
'Dataset not found. You can use download=True to download it.')
anno_filename = split + '2017.json'
with open(os.path.join(self.root, anno_filename), 'r') as fp:
all_annos = json.load(fp)
self.annos = all_annos['annotations']
self.images = all_annos['images']
def __getitem__(self, index):
path = os.path.join(self.root, self.images[index]['file_name'])
target = self.annos[index]['category_id']
image = self.loader(path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def __len__(self):
return len(self.images)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.base_folder))
def _download(self):
for url, filename, md5 in self.file_list.values():
download_url(url, root=self.root, filename=filename)
if not check_integrity(os.path.join(self.root, filename), md5):
raise RuntimeError("File not found or corrupted.")
if __name__ == '__main__':
train_dataset = INat2017('./inat2017', split='train', download=False)
test_dataset = INat2017('./inat2017', split='val', download=False)