-
Notifications
You must be signed in to change notification settings - Fork 0
/
segmentation.py
81 lines (68 loc) · 2.86 KB
/
segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
device = 'cuda'
result_dir = 'output_images'
checkpoint_path = 'cloth_segm_u2net_latest.pth'
def process_photo(pic_path):
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
result = {}
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.eval()
img = Image.open(pic_path).convert('RGB')
img_size = img.size
img = img.resize((768, 768), Image.BICUBIC)
image_tensor = transform_rgb(img)
image_tensor = torch.unsqueeze(image_tensor, 0)
output_tensor = net(image_tensor)
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
cl_1 = [[[0 for _ in range(3)] for _ in range(len(output_arr))] for _ in range(len(output_arr))]
cl_2 = [[[0 for _ in range(3)] for _ in range(len(output_arr))] for _ in range(len(output_arr))]
cl_3 = [[[0 for _ in range(3)] for _ in range(len(output_arr))] for _ in range(len(output_arr))]
res = []
for i in range(len(output_arr)):
for j in range(len(output_arr[i])):
if output_arr[i][j] == 1:
cl_1[i][j] = [1, 1, 1]
res += [1]
elif output_arr[i][j] == 2:
cl_2[i][j] = [1, 1, 1]
res += [2]
elif output_arr[i][j] == 3:
cl_3[i][j] = [1, 1, 1]
res += [3]
if 1 in res:
mid_img = np.asanyarray(img)
res_arr = np.asarray(cl_1)
mid_img = np.multiply(mid_img, res_arr)
output_img = Image.fromarray(mid_img.astype('uint8'))
output_img = output_img.resize(img_size, Image.BICUBIC)
result['top'] = output_img
if 2 in res:
mid_img = np.asanyarray(img)
res_arr = np.asarray(cl_2)
mid_img = np.multiply(mid_img, res_arr)
output_img = Image.fromarray(mid_img.astype('uint8'))
output_img = output_img.resize(img_size, Image.BICUBIC)
result['bot'] = output_img
if 3 in res:
mid_img = np.asanyarray(img)
res_arr = np.asarray(cl_3)
mid_img = np.multiply(mid_img, res_arr)
output_img = Image.fromarray(mid_img.astype('uint8'))
output_img = output_img.resize(img_size, Image.BICUBIC)
result['bot'] = output_img
return result