-
Notifications
You must be signed in to change notification settings - Fork 9
/
au_data_loader.py
162 lines (132 loc) · 5.21 KB
/
au_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageDraw
import numpy as np
def get_reserved_set(label_path_dir):
au_label = []
returned_label = []
for names in os.listdir(label_path_dir):
name = os.path.join(label_path_dir, names)
for sequences in os.listdir(name):
sequence = os.path.join(name, sequences)
if os.listdir(sequence):
temp = np.loadtxt(os.path.join(sequence, os.listdir(sequence)[-1]))
if temp.ndim == 1:
au_label.append(temp.astype(np.int32))
returned_label.append(temp[0].astype(np.int32))
elif temp.ndim == 2:
au_label.extend(temp.astype(np.int32))
returned_label.append(temp[:, 0].astype(np.int32))
else:
print(temp)
raise Exception("label info error!")
ll = []
for au in au_label:
if au.ndim == 2:
ll.extend(au[:, 0].astype(np.int32))
elif au.ndim == 1:
ll.append(au[0].astype(np.int32))
else:
print(au)
raise Exception("label info error!")
label_set = set(ll)
reserved_set = set()
for label in label_set:
if ll.count(label) > 35:
reserved_set.add(label)
return reserved_set, returned_label # (1, 2, 4, 5, 6, 7, 9, 12, 14, 15, 17, 20, 23, 24, 25, 27)
def convert_label(init_label, reserved_set):
reserved_list = list(reserved_set)
converted_label = [0] * len(reserved_set)
if type(init_label) == list:
for l in init_label:
for i in range(len(reserved_list)):
if l == reserved_list[i]:
converted_label[i] = 1
elif type(init_label) == int:
for i in range(len(reserved_list)):
if init_label == reserved_list[i]:
converted_label[i] = 1
else:
print("init label is not a list! '", init_label, "'")
print("type is", type(init_label))
raise Exception
return converted_label
def draw_landmark_point(img, landmark):
draw = ImageDraw.Draw(img)
t = 1
for point in landmark:
draw.text(point.tolist(), str(t), fill=255)
t += 1
img.show()
def crop_au_img(img, landmark):
width, height = img.size
left = max(int(min(landmark[:, 0])) - 10, 0)
right = min(width, int(max(landmark[:, 0] + 10)))
top = max(int(min(landmark[:, 1])) - 20, 0)
bottom = min(height, int(max(landmark[:, 1])) + 10)
img = img.crop((left, top, right, bottom))
return img
def load_au_image_from_path(data_path_dir):
# prepare au image
au_image = []
for names in os.listdir(data_path_dir):
name = os.path.join(data_path_dir, names)
for sequences in os.listdir(name):
sequence = os.path.join(name, sequences)
if os.path.isdir(sequence):
if os.listdir(sequence):
au_image.append(
Image.open(os.path.join(sequence, os.listdir(sequence)[-1])).convert('RGB'))
return au_image
def load_au_label_from_path(label_path_dir, reserved_label, reserved_set):
# prepare au label
au_label = []
for l in reserved_label:
au_label.append(convert_label(l.tolist(), reserved_set))
return au_label
def load_au_landmark_from_path(landmark_path_dir):
au_landmark = []
for names in os.listdir(landmark_path_dir):
name = os.path.join(landmark_path_dir, names)
for sequences in os.listdir(name):
sequence = os.path.join(name, sequences)
if os.listdir(sequence):
au_landmark.append(np.loadtxt(os.path.join(sequence, os.listdir(sequence)[-1])))
return au_landmark
def load_au_emotion_from_path(emotion_path_dir):
au_emotion_landmark_path = []
for names in os.listdir(emotion_path_dir):
name = os.path.join(emotion_path_dir, names)
for sequences in os.listdir(name):
sequence = os.path.join(name, sequences)
if os.listdir(sequence):
au_emotion_landmark_path.append(os.path.join(sequence, os.listdir(sequence)[-1]))
return au_emotion_landmark_path
class au_data_loader(Dataset):
def __init__(self, au_image, au_label, transform=None, target_transform=None):
# prepare au image
self.au_image = au_image
# prepare au label
self.au_label = au_label
self.au_label = torch.from_numpy(np.array(self.au_label)).float()
# prepare au landmark
# self.au_landmark = au_landmark
# prepare au emotions
# self.au_emotion_landmark_path = au_emotion_landmark_path
self.transform = transform
self.target_transform = target_transform
self.train_data = self.au_image
self.train_label = self.au_label
def __getitem__(self, index):
img = self.train_data[index]
img = Image.fromarray(img)
target = self.train_label[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.train_data)