-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
77 lines (62 loc) · 8.58 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import pandas as pd
import numpy as np
import torch
from nets import HandNet
def scale_single_coord(coordinates):
# Read the CSV file with header
x_coords = coordinates[::3]
y_coords = coordinates[1::3]
z_coords = coordinates[2::3]
# Scale x-coordinates between 0 and 1 based on the sample-specific max and min
x_min = x_coords.min()
x_max = x_coords.max()
if x_max != x_min: # Avoid division by zero
scaled_x_coords = (x_coords - x_min) / (x_max - x_min)
else:
scaled_x_coords = x_coords # If all x are the same, no scaling needed
# Scale y-coordinates between 0 and 1 based on the sample-specific max and min
y_min = y_coords.min()
y_max = y_coords.max()
if y_max != y_min: # Avoid division by zero
scaled_y_coords = (y_coords - y_min) / (y_max - y_min)
else:
scaled_y_coords = y_coords # If all y are the same, no scaling needed
# Scale z-coordinates between 0 and 1 based on the sample-specific max and min
z_min = z_coords.min()
z_max = z_coords.max()
if z_max != z_min: # Avoid division by zero
scaled_z_coords = (z_coords - z_min) / (z_max - z_min)
else:
scaled_z_coords = z_coords # If all z are the same, no scaling needed
new_row = np.empty(63)
new_row[::3] = scaled_x_coords
new_row[1::3] = scaled_y_coords
new_row[2::3] = scaled_z_coords
return new_row
def predict(model: HandNet, scaled_coords: list) -> str:
scaled_coords = np.array(scaled_coords)
scaled_coords = scale_single_coord(scaled_coords)
output = model(torch.tensor(scaled_coords, dtype=torch.float32))
labels = ['peace sign', 'euro footballer', 'thumbs up', 'kpop heart', 'what the sigma']
print(output)
if torch.max(output) > 2:
return labels[torch.argmax(output)]
else:
return None
model = HandNet()
model.setup("./models/hand_model_unscaled_epoch20_lr0.001_bs_32.pth")
coordinates = [0.7336673140525818,0.8495280742645264,1.0091436024595168e-06,0.6723695397377014,0.8020392656326294,-0.04144725203514099,0.6415610313415527,0.70493483543396,-0.06892213970422745,0.6862631440162659,0.617411196231842,-0.09715491533279419,0.7308752536773682,0.5455268621444702,-0.12282225489616394,0.6220554709434509,0.519255518913269,-0.0405026376247406,0.5809885263442993,0.3876289129257202,-0.07717211544513702,0.5570772886276245,0.3020611107349396,-0.10012024641036987,0.5395866632461548,0.2303922176361084,-0.11448702961206436,0.6799868941307068,0.5021679401397705,-0.048963211476802826,0.6822059750556946,0.3265255093574524,-0.08738874644041061,0.6921988725662231,0.21608835458755493,-0.11221276223659515,0.7080681920051575,0.12600713968276978,-0.12400253117084503,0.7333049178123474,0.5320874452590942,-0.06217460706830025,0.7282810211181641,0.4819796681404114,-0.1107586920261383,0.7128937840461731,0.5624113082885742,-0.11770589649677277,0.710083544254303,0.6140854358673096,-0.10863137990236282,0.7773230075836182,0.584037721157074,-0.07907985895872116,0.7580936551094055,0.556846022605896,-0.11327849328517914,0.7386305332183838,0.6197518110275269,-0.10778772830963135,0.7311854362487793,0.6622230410575867,-0.09531086683273315]
# coordinates = np.array(coordinates)
print("peace sign:", predict(model, coordinates))
coordinates = [0.8492257595062256,0.6839681267738342,-6.473388793892809e-07,0.7999305129051208,0.6120420098304749,0.034273676574230194,0.7431418895721436,0.5735766887664795,0.04264887049794197,0.6967589855194092,0.5446481704711914,0.04546050727367401,0.6702293157577515,0.5029789209365845,0.047389592975378036,0.7032467126846313,0.5726965665817261,-0.004215891007333994,0.6496597528457642,0.6631115674972534,0.0028754619415849447,0.6708903312683105,0.6921699047088623,0.015950465574860573,0.6872950792312622,0.6837653517723083,0.02290438301861286,0.7150073051452637,0.6128968596458435,-0.02335749939084053,0.666991114616394,0.7172957062721252,-0.016618480905890465,0.6919753551483154,0.734817385673523,-0.00146484246943146,0.7096695899963379,0.721472978591919,0.0034847590140998363,0.7306035757064819,0.6686764359474182,-0.03730066120624542,0.6863994002342224,0.7575149536132812,-0.026639185845851898,0.710110604763031,0.7680844664573669,-0.007837279699742794,0.728296160697937,0.7547785043716431,-0.0030633199494332075,0.7495203614234924,0.730252742767334,-0.04833385348320007,0.6997395753860474,0.7783752679824829,-0.046500787138938904,0.6724326014518738,0.8008431196212769,-0.04178245738148689,0.6452920436859131,0.8158155083656311,-0.04197729378938675]
# coordinates = np.array(coordinates)
print("Euro Footballer:", predict(model, coordinates))
coordinates = [0.28713059425354004,0.757659912109375,-2.1653801240972825e-07,0.29923102259635925,0.6824442744255066,0.00553923100233078,0.32900315523147583,0.6193813681602478,0.003293841378763318,0.3558714687824249,0.5716983079910278,-0.0004087153065484017,0.364929735660553,0.5258445143699646,-0.0027999829035252333,0.3518580496311188,0.5993287563323975,-0.008322825655341148,0.40404850244522095,0.6082196235656738,-0.015889639034867287,0.3938250243663788,0.6334624290466309,-0.01794915273785591,0.3757784068584442,0.6378895044326782,-0.01827279105782509,0.35802483558654785,0.6381528377532959,-0.016119593754410744,0.4066096842288971,0.6520358324050903,-0.015889467671513557,0.3941640853881836,0.6718127727508545,-0.011149170808494091,0.3771132826805115,0.6704304814338684,-0.010449591092765331,0.361183226108551,0.6835852265357971,-0.02326744608581066,0.4038568437099457,0.6956864595413208,-0.02148977480828762,0.39017847180366516,0.7104352116584778,-0.011251932010054588,0.3733932375907898,0.7096193432807922,-0.007443400099873543,0.36208924651145935,0.7283036112785339,-0.03011762723326683,0.3933548331260681,0.7310906648635864,-0.025354333221912384,0.379855215549469,0.7417626976966858,-0.0147931557148695,0.3650161027908325,0.7406395673751831,-0.008991260081529617]
# coordinates = np.array(coordinates)
print("Thumbs up:", predict(model, coordinates))
coordinates = [0.26195427775382996,0.9775311350822449,1.931276329969478e-07,0.23879195749759674,0.8511254191398621,-0.0030527226626873016,0.24311569333076477,0.6897066235542297,-0.026214037090539932,0.2664613127708435,0.5579319000244141,-0.04983844980597496,0.2780214846134186,0.46756091713905334,-0.07133600860834122,0.1821352243423462,0.7219663262367249,-0.07456377148628235,0.25464001297950745,0.5563022494316101,-0.10620272904634476,0.3017522096633911,0.5314424633979797,-0.112911157310009,0.3235512375831604,0.5352469086647034,-0.1135348528623581,0.21586103737354279,0.7852230668067932,-0.09167969226837158,0.28681913018226624,0.5940459370613098,-0.11906694620847702,0.3086465001106262,0.6304773092269897,-0.10556620359420776,0.30181536078453064,0.6829167008399963,-0.09557411819696426,0.26008743047714233,0.8389890193939209,-0.1066732183098793,0.3237495422363281,0.6533142924308777,-0.12455764412879944,0.33498379588127136,0.6918531656265259,-0.09585420042276382,0.3237071633338928,0.7426030039787292,-0.07641878724098206,0.3088681697845459,0.8787086009979248,-0.1220976710319519,0.35675048828125,0.7144303917884827,-0.12448173016309738,0.3569049835205078,0.7318533658981323,-0.10109706968069077,0.342883825302124,0.7740848660469055,-0.08557042479515076]
# coordinates = np.array(coordinates)
print("kpop heart:", predict(model, coordinates))
coordinates = [0.4718482494354248,0.8738421201705933,6.751648840008784e-08,0.42348185181617737,0.8828456401824951,-0.05818401649594307,0.37854012846946716,0.8523808121681213,-0.10732463747262955,0.3679359555244446,0.8207353353500366,-0.15359406173229218,0.40433746576309204,0.793335497379303,-0.20028264820575714,0.4180438220500946,0.6476503014564514,-0.0861424058675766,0.4025708734989166,0.49867022037506104,-0.1431209295988083,0.3891969323158264,0.40533551573753357,-0.1745455265045166,0.3823089003562927,0.3304564654827118,-0.19446299970149994,0.4739235043525696,0.6412985324859619,-0.08438004553318024,0.41139933466911316,0.6249727010726929,-0.1568947583436966,0.40212592482566833,0.7334544062614441,-0.16179953515529633,0.41602885723114014,0.7827587127685547,-0.14769133925437927,0.5205585956573486,0.6656454801559448,-0.0886324867606163,0.45953986048698425,0.6944772005081177,-0.153018981218338,0.44931983947753906,0.7905809879302979,-0.12805454432964325,0.46161457896232605,0.8248273134231567,-0.0942612886428833,0.5594648122787476,0.7064246535301208,-0.09878034144639969,0.5108219385147095,0.7346844673156738,-0.14033280313014984,0.4936932623386383,0.807847261428833,-0.12171010673046112,0.4971807599067688,0.8357378244400024,-0.09566106647253036]
# coordinates = np.array(coordinates)
print("What the sigma:", predict(model, coordinates))