Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Associative Embedding inference align with master #1960

Draft
wants to merge 3 commits into
base: dev-1.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
decoder=dict(codec, heatmap_size=codec['input_size'])),
test_cfg=dict(
multiscale_test=False,
flip_test=True,
flip_test=False,
shift_heatmap=True,
restore_heatmap_size=True,
align_corners=False))
Expand All @@ -113,9 +113,14 @@
dict(
type='BottomupResize',
input_size=codec['input_size'],
size_factor=32,
size_factor=64,
resize_mode='expand'),
dict(type='PackPoseInputs')
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
'skeleton_links'))
]

# data loaders
Expand Down
84 changes: 63 additions & 21 deletions mmpose/codecs/associative_embedding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import namedtuple
# from copy import deepcopy
from itertools import product
from typing import Any, List, Optional, Tuple

import numpy as np
import torch
# from mmengine import dump
from munkres import Munkres
from torch import Tensor

Expand Down Expand Up @@ -75,7 +77,9 @@ def _init_group():
tag_list=[])
return _group

for i in keypoint_order:
# group_history = []

for idx, i in enumerate(keypoint_order):
# Get all valid candidate of the i-th keypoints
valid = vals[i] > val_thr
if not valid.any():
Expand All @@ -87,12 +91,22 @@ def _init_group():

if len(groups) == 0: # Initialize the group pool
for tag, val, loc in zip(tags_i, vals_i, locs_i):

# Check if the keypoint belongs to existing groups
if len(groups):
prev_tags = np.stack([g.tag_list[0] for g in groups])
dists = np.linalg.norm(prev_tags - tag, ord=2, axis=1)
if dists.min() < 1:
continue

group = _init_group()
group.kpts[i] = loc
group.scores[i] = val
group.tag_list.append(tag)

groups.append(group)
# costs_copy = None
matches = None

else: # Match keypoints to existing groups
groups = groups[:max_groups]
Expand All @@ -101,17 +115,18 @@ def _init_group():
# Calculate distance matrix between group tags and tag candidates
# of the i-th keypoint
# Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
diff = tags_i[:, None] - np.array(group_tags)[None]
diff = (tags_i[:, None] -
np.array(group_tags)[None]).astype(np.float64)
dists = np.linalg.norm(diff, ord=2, axis=2)
num_kpts, num_groups = dists.shape[:2]

# Experimental cost function for keypoint-group matching
# Experimental cost function for keypoint-group matching2
costs = np.round(dists) * 100 - vals_i[..., None]

if num_kpts > num_groups:
padding = np.full((num_kpts, num_kpts - num_groups),
1e10,
dtype=np.float32)
padding = np.full((num_kpts, num_kpts - num_groups), 1e10)
costs = np.concatenate((costs, padding), axis=1)
# costs_copy = costs.copy()

# Match keypoints and groups by Munkres algorithm
matches = munkres.compute(costs)
Expand All @@ -121,13 +136,30 @@ def _init_group():
# Add the keypoint to the matched group
group = groups[group_idx]
else:
# Initialize a new group with unmatched keypoint
group = _init_group()
groups.append(group)

group.kpts[i] = locs_i[kpt_idx]
group.scores[i] = vals_i[kpt_idx]
group.tag_list.append(tags_i[kpt_idx])
# if dists[kpt_idx].min() < 0.2:
if False:
group = None
else:
# Initialize a new group with unmatched keypoint
group = _init_group()
groups.append(group)
if group is not None:
group.kpts[i] = locs_i[kpt_idx]
group.scores[i] = vals_i[kpt_idx]
group.tag_list.append(tags_i[kpt_idx])

# out = {
# 'idx': idx,
# 'i': i,
# 'costs': costs_copy,
# 'matches': matches,
# 'kpts': np.array([g.kpts for g in groups]),
# 'scores': np.array([g.scores for g in groups]),
# 'tag_list': [np.array(g.tag_list) for g in groups],
# }
# group_history.append(deepcopy(out))

# dump(group_history, 'group_history.pkl')

groups = groups[:max_groups]
if groups:
Expand Down Expand Up @@ -210,7 +242,7 @@ def __init__(
decode_gaussian_kernel: int = 3,
decode_keypoint_thr: float = 0.1,
decode_tag_thr: float = 1.0,
decode_topk: int = 20,
decode_topk: int = 30,
decode_max_instances: Optional[int] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -336,6 +368,12 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
B, K, H, W = batch_heatmaps.shape
L = batch_tags.shape[1] // K

# Heatmap NMS
# dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl')
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
self.decode_nms_kernel)
# dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl')

# shape of topk_val, top_indices: (B, K, TopK)
topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
k, dim=-1)
Expand Down Expand Up @@ -433,9 +471,8 @@ def _fill_missing_keypoints(self, keypoints: np.ndarray,
cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W
y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W))
keypoints[n, k] = [x, y]
keypoint_scores[n, k] = heatmaps[k, y, x]

return keypoints, keypoint_scores
return keypoints

def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
Expand All @@ -457,15 +494,12 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
batch, each is in shape (N, K). It usually represents the
confidience of the keypoint prediction
"""

B, _, H, W = batch_heatmaps.shape
assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), (
f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and '
f'tagging map ({batch_tags.shape})')

# Heatmap NMS
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
self.decode_nms_kernel)

# Get top-k in each heatmap and and convert to numpy
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
self._get_batch_topk(
Expand All @@ -489,7 +523,7 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor

if keypoints.size > 0:
# identify missing keypoints
keypoints, scores = self._fill_missing_keypoints(
keypoints = self._fill_missing_keypoints(
keypoints, scores, heatmaps, tags)

# refine keypoint coordinates according to heatmap distribution
Expand All @@ -500,6 +534,14 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
blur_kernel_size=self.decode_gaussian_kernel)
else:
keypoints = refine_keypoints(keypoints, heatmaps)
# The following 0.5-pixel shift is adapted from mmpose 0.x
# where the heatmap center is calculated by a biased
# rounding ``mu=[int(x), int(y)]``. We keep this shift
# operation for now to to compatible with 0.x checkpoints
# In mmpose 1.x, AE heatmap center is calculated by the
# unbiased rounding ``mu=[int(x+0.5), int(y+0.5)], so the
# following shift will be removed in the future.
keypoints += 0.5

batch_keypoints.append(keypoints)
batch_keypoint_scores.append(scores)
Expand Down
12 changes: 10 additions & 2 deletions mmpose/datasets/transforms/bottomup_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def transform(self, results: Dict) -> Optional[dict]:
output_size=actual_input_size)
else:
center = np.array([img_w / 2, img_h / 2], dtype=np.float32)
# center = np.round(center)
scale = np.array([
img_w * padded_input_size[0] / actual_input_size[0],
img_h * padded_input_size[1] / actual_input_size[1]
Expand All @@ -495,11 +496,18 @@ def transform(self, results: Dict) -> Optional[dict]:
rot=0,
output_size=padded_input_size)

_img = cv2.warpAffine(
img, warp_mat, padded_input_size, flags=cv2.INTER_LINEAR)
_img = cv2.warpAffine(img, warp_mat, padded_input_size)

imgs.append(_img)

# print('#' * 20)
# print('w,h: ', img_w, img_h, 'center: ', center, 'scale: ',
# scale,
# 'actual_input_size: ', actual_input_size,
# 'padded_input_size: ', padded_input_size)
# print(warp_mat)
# print('#' * 20)

# Store the transform information w.r.t. the main input size
if i == 0:
results['img_shape'] = padded_input_size[::-1]
Expand Down
12 changes: 11 additions & 1 deletion mmpose/models/heads/heatmap_heads/ae_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.structures import PixelData
from mmengine.utils import is_list_of
from torch import Tensor
Expand Down Expand Up @@ -110,7 +111,7 @@ def predict(self,
# TTA: multi-scale test
assert is_list_of(feats, list if flip_test else tuple)
else:
assert is_list_of(feats, tuple if flip_test else Tensor)
assert isinstance(feats, list if flip_test else tuple)
feats = [feats]

# resize heatmaps to align with with input size
Expand All @@ -129,6 +130,15 @@ def predict(self,
for scale_idx, _feats in enumerate(feats):
if not flip_test:
_heatmaps, _tags = self.forward(_feats)
if heatmap_size:
_heatmaps = F.interpolate(
_heatmaps, (img_h, img_w),
mode='bilinear',
align_corners=align_corners)
_tags = F.interpolate(
_tags, (img_h, img_w),
mode='bilinear',
align_corners=align_corners)

else:
# TTA: flip test
Expand Down