-
Notifications
You must be signed in to change notification settings - Fork 0
/
processors.py
45 lines (35 loc) · 1.32 KB
/
processors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
class OpticalFlowProcessor:
def preprocess(self, batch):
transforms = T.Compose(
[
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
]
)
batch = transforms(batch)
return batch
def get_optical_flow(self, frames, device):
model = raft_large(pretrained=True, progress=False)
model = model.to(device, torch.float32)
model = model.eval()
flow_maps = []
for frame_1, frame_2 in zip(frames, frames[1:]):
frame_1 = frame_1.to(device)
frame_2 = frame_2.to(device)
with torch.no_grad():
flow_map = model(frame_1, frame_2)
flow_maps.append(flow_map[-1].to("cpu"))
frame_1 = frame_1.to("cpu")
frame_2 = frame_2.to("cpu")
model.to("cpu")
del model
torch.cuda.empty_cache()
return flow_maps
def __call__(self, frames, device):
tensor_frames = [self.preprocess(frame)[None, :] for frame in frames]
optical_flow_maps = self.get_optical_flow(tensor_frames, device)
return optical_flow_maps