-
Notifications
You must be signed in to change notification settings - Fork 18
/
imagelevel_module.py
102 lines (81 loc) · 3.65 KB
/
imagelevel_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from abc import abstractmethod
from typing import Any
import pandas as pd
from tracklab.datastruct import EngineDatapipe
from tracklab.pipeline import Module
from torch.utils.data.dataloader import default_collate, DataLoader
class ImageLevelModule(Module):
"""Abstract class to implement a module that operates directly on images.
This can for example be a bounding box detector, or a bottom-up
pose estimator (which outputs keypoints directly).
The functions to implement are
- __init__, which can take any configuration needed
- preprocess
- process
- datapipe (optional) : returns an object which will be used to create the pipeline.
(Only modify this if you know what you're doing)
- dataloader (optional) : returns a dataloader for the datapipe
You should also provide the following class properties :
- input_columns : what info you need for the detections
- output_columns : what info you will provide when called
- collate_fn (optional) : the function that will be used for collating the inputs
in a batch. (Default : pytorch collate function)
A description of the expected behavior is provided below.
"""
collate_fn = default_collate
input_columns = None
output_columns = None
@abstractmethod
def __init__(self, batch_size: int):
"""Init function
The arguments to this function are completely free
and will be provided by a configuration file.
You should call the __init__ function from the super() class.
"""
self.batch_size = batch_size
self._datapipe = None
@abstractmethod
def preprocess(self, image, detections: pd.DataFrame, metadata: pd.Series) -> Any:
"""Adapts the default input to your specific case.
Args:
image: a numpy array of the current image
detections: a DataFrame containing all the detections pertaining to a single
image
metadata: additional information about the image
Returns:
preprocessed_sample: input for the process function
"""
pass
@abstractmethod
def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame):
"""The main processing function. Runs on GPU.
Args:
batch: The batched outputs of `preprocess`
detections: The previous detections.
metadatas: The previous image metadatas
Returns:
output : Either a DataFrame containing the new/updated detections
or a tuple containing detections and metadatas (in that order)
The DataFrames can be either a list of Series, a list of DataFrames
or a single DataFrame. The returned objects will be aggregated
automatically according to the `name` of the Series/`index` of
the DataFrame. **It is thus mandatory here to name correctly
your series or index your dataframes.**
The output will override the previous detections
with the same name/index.
"""
pass
@property
def datapipe(self):
if self._datapipe is None:
self._datapipe = EngineDatapipe(self)
return self._datapipe
def dataloader(self, engine: "TrackingEngine"):
datapipe = self.datapipe
return DataLoader(
dataset=datapipe,
batch_size=self.batch_size,
collate_fn=type(self).collate_fn,
num_workers=engine.num_workers,
persistent_workers=False,
)