-
Notifications
You must be signed in to change notification settings - Fork 0
/
lid.py
85 lines (75 loc) · 2.74 KB
/
lid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import numpy as np
from scipy.spatial.distance import cdist
def gmean(input_x, dim=0):
log_x = torch.log(input_x)
return torch.exp(torch.mean(log_x, dim=dim))
def get_lid_r(data, reference):
b = data.shape[0]
data = torch.flatten(data, start_dim=1)
reference = torch.flatten(reference, start_dim=1)
r = torch.cdist(data, reference, p=2)
a, idx = torch.sort(r, dim=1)
return r, a, idx
def lid_mle(data, reference, k=20, get_idx=False, compute_mode='use_mm_for_euclid_dist_if_necessary'):
b = data.shape[0]
k = min(k, b-2)
data = torch.flatten(data, start_dim=1)
reference = torch.flatten(reference, start_dim=1)
r = torch.cdist(data, reference, p=2, compute_mode=compute_mode)
a, idx = torch.sort(r, dim=1)
lids = -k / torch.sum(torch.log(a[:, 1:k] / a[:, k].view(-1, 1) + 1.e-4), dim=1)
if get_idx:
return idx, lids
return lids
def lid_mom_est(data, reference, k, get_idx=False, compute_mode='use_mm_for_euclid_dist_if_necessary'):
b = data.shape[0]
k = min(k, b-2)
data = torch.flatten(data, start_dim=1)
reference = torch.flatten(reference, start_dim=1)
r = torch.cdist(data, reference, p=2, compute_mode=compute_mode)
a, idx = torch.sort(r, dim=1)
m = torch.mean(a[:, 1:k], dim=1)
lids = m / (a[:, k] - m)
if get_idx:
return idx, lids
return lids
def lid_mom_est_eps(data, reference, k, get_idx=False):
b = data.shape[0]
k = min(k, b-2)
data = torch.flatten(data, start_dim=1)
reference = torch.flatten(reference, start_dim=1)
r = torch.cdist(data, reference, p=2)
a, idx = torch.sort(r, dim=1)
m = torch.mean(a[:, 1:k], dim=1)
lids = m / ((a[:, k] - m) + 1.e-4)
if get_idx:
return idx, lids
return lids
# ======== (numpy version) =================
def mle_batch(data, batch, k):
data = np.asarray(data, dtype=np.float32)
batch = np.asarray(batch, dtype=np.float32)
k = min(k, len(data)-1)
def f(v): return - k / np.sum(np.log(v/v[-1]))
a = cdist(batch, data)
a = np.apply_along_axis(np.sort, axis=1, arr=a)[:, 1:k+1]
a = np.apply_along_axis(f, axis=1, arr=a)
return a
# lid of a single query point x
def mle_single(data, x, k=20, dist=True, metric='euclidean'):
data = np.asarray(data, dtype=np.float32)
x = np.asarray(x, dtype=np.float32)
if x.ndim == 1:
x = x.reshape((-1, x.shape[0]))
if dist:
k = min(k, len(data)-1)
def f(v): return - k / np.sum(np.log(v/v[-1]))
if dist:
a = cdist(x, data, metric=metric)
a = np.apply_along_axis(np.sort, axis=1, arr=a)[:, 1:k+1]
else:
a = data
a = np.apply_along_axis(np.sort, axis=1, arr=a)
a = np.apply_along_axis(f, axis=1, arr=a)
return a[0]