Skip to content

Commit

Permalink
Predict like lightning-flash (#62)
Browse files Browse the repository at this point in the history
* Add missing copyright heads [skip ci]

* Rename to detection_datamodule.py

* Add classes and utils from lightning-flash

* Load model with Task and add DataPipeline into DataModule as lightning-flash

* Move the Task in flash.py directly into pl_wrapper.py

* Fix DataPipeline and support innference in LightningModule

* Add missing Copyright

* Fix ImportError in torch._six

* Cleanup codes [skip ci]

* Rename pl_wrapper.py -> yolo_module.py and detection_datamodule.py -> datamodule.py

* Add unittest for lightning-flash mechanism

* Bug fixes for collate in ObjectDetectionDataPipeline

* Update README.md for updated inference interfaces

* Remove predict.py and typo fixes

* Fix typo [skip ci]
  • Loading branch information
zhiqwang authored Feb 16, 2021
1 parent b2c1c0d commit c526565
Show file tree
Hide file tree
Showing 17 changed files with 374 additions and 168 deletions.
34 changes: 7 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,19 @@ There are no extra compiled components in `yolort` and package dependencies are
pip install -e .
```

- To run batched inference with YOLOv5s
- To read a source of image(s) and detect its objects 🔥

```python
from torchvision.io import read_image
from yolort.models import yolov5s

# Model
# Load model
model = yolov5s(pretrained=True, score_thresh=0.45)
model.eval()

# Images
img1 = read_image('zidane.jpg') / 255.
img2 = read_image('bus.jpg') / 255.
images = [img1, img2] # batched list of images

# Inference
results = model(images)
# Perform inference on an image file
predictions = model.predict('bus.jpg')
# Perform inference on a list of image files
predictions = model.predict(['bus.jpg', 'zidane.jpg'])
```

### Loading via `torch.hub`
Expand Down Expand Up @@ -112,25 +108,9 @@ The module state of `yolort` has some differences comparing to `ultralytics/yolo

</details>

### Inference on `PyTorch` backend 🔥

To read a source image and detect its objects run:

```bash
python -m detect [--input_source ./test/assets/zidane.jpg]
[--labelmap ./notebooks/assets/coco.names]
[--output_dir ./data-bin/output]
[--min_size 640]
[--max_size 640]
[--save_img]
[--gpu] # GPU switch, Set False as default
```

You can also see the [inference-pytorch-export-libtorch](notebooks/inference-pytorch-export-libtorch.ipynb) notebook for more information.

### Inference on `LibTorch` backend 🚀

We provide an [example](./deployment) of getting `LibTorch` inference to work. For details see the [GitHub actions](.github/workflows/nightly.yml).
We provide a [notebook](notebooks/inference-pytorch-export-libtorch.ipynb) to demonstrate how the model is transformed into `torchscript`. And we provide an [C++ example](./deployment) of how to infer with the transformed `torchscript` model. For details see the [GitHub actions](.github/workflows/nightly.yml).

## 🎨 Model Graph Visualization

Expand Down
102 changes: 91 additions & 11 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,35 @@
import torch
import pytorch_lightning as pl

from yolort.models import YOLOLitWrapper
from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models import yolov5s

from yolort.datasets import DetectionDataModule

from .torch_utils import image_preprocess
from .dataset_utils import DummyCOCODetectionDataset

from typing import Dict

from torchvision.io import read_image


def default_loader(img_name, is_half=False):
"""
Read Image using TorchVision.io Here
"""
img = read_image(img_name)
img = img.half() if is_half else img.float() # uint8 to fp16/32
img /= 255. # 0 - 255 to 0.0 - 1.0

return img


class EngineTester(unittest.TestCase):
def test_train(self):
# Read Image using TorchVision.io Here
# Do forward over image
img_name = "test/assets/zidane.jpg"
img_tensor = image_preprocess(img_name)
img_tensor = default_loader(img_name)
self.assertEqual(img_tensor.ndim, 3)
# Add a dummy image to train
img_dummy = torch.rand((3, 416, 360), dtype=torch.float32)
Expand All @@ -39,9 +51,8 @@ def test_train(self):

def test_train_one_step(self):
# Load model
model = YOLOLitWrapper()
model = yolov5s()
model.train()

# Setup the DataModule
train_dataset = DummyCOCODetectionDataset(num_samples=128)
datamodule = DetectionDataModule(train_dataset, batch_size=16)
Expand All @@ -50,21 +61,90 @@ def test_train_one_step(self):
trainer.fit(model, datamodule)

def test_inference(self):
# Infer over an image
# Set image inputs
img_name = "test/assets/zidane.jpg"
img_input = image_preprocess(img_name)
img_input = default_loader(img_name)
self.assertEqual(img_input.ndim, 3)

model = YOLOLitWrapper(pretrained=True)
# Load model
model = yolov5s(pretrained=True)
model.eval()

# Perform inference on a list of tensors
out = model([img_input])
self.assertIsInstance(out, list)
self.assertEqual(len(out), 1)
self.assertIsInstance(out[0], Dict)
self.assertIsInstance(out[0]["boxes"], torch.Tensor)
self.assertIsInstance(out[0]["labels"], torch.Tensor)
self.assertIsInstance(out[0]["scores"], torch.Tensor)

def test_predict_tensor(self):
# Set image inputs
img_name = "test/assets/zidane.jpg"
img_tensor = default_loader(img_name)
self.assertEqual(img_tensor.ndim, 3)
# Load model
model = yolov5s(pretrained=True)
model.eval()
# Perform inference on a list of image files
predictions = model.predict(img_tensor)
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 1)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)

def test_predict_tensors(self):
# Set image inputs
img_tensor1 = default_loader("test/assets/zidane.jpg")
self.assertEqual(img_tensor1.ndim, 3)
img_tensor2 = default_loader("test/assets/bus.jpg")
self.assertEqual(img_tensor2.ndim, 3)
img_tensors = [img_tensor1, img_tensor2]
# Load model
model = yolov5s(pretrained=True)
model.eval()
# Perform inference on a list of image files
predictions = model.predict(img_tensors)
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 2)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)

def test_predict_image_file(self):
# Set image inputs
img_name = "test/assets/zidane.jpg"
# Load model
model = yolov5s(pretrained=True)
model.eval()
# Perform inference on an image file
predictions = model.predict(img_name)
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 1)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)

def test_predict_image_files(self):
# Set image inputs
img_name1 = "test/assets/zidane.jpg"
img_name2 = "test/assets/bus.jpg"
img_names = [img_name1, img_name2]
# Load model
model = yolov5s(pretrained=True)
model.eval()
# Perform inference on a list of image files
predictions = model.predict(img_names)
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 2)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)


if __name__ == '__main__':
unittest.main()
11 changes: 0 additions & 11 deletions test/torch_utils.py

This file was deleted.

3 changes: 2 additions & 1 deletion yolort/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from .pl_datamodule import DetectionDataModule, VOCDetectionDataModule, CocoDetectionDataModule
from .datapipeline import DataPipeline
from .datamodule import DetectionDataModule, VOCDetectionDataModule, CocoDetectionDataModule
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,77 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from torch import Tensor
from torchvision.io import read_image

from pytorch_lightning import LightningDataModule

from .transforms import collate_fn, default_train_transforms, default_val_transforms
from .voc import VOCDetection
from .coco import CocoDetection
from .datapipeline import DataPipeline

from typing import Callable, List, Any, Optional, Type
from collections.abc import Sequence


class ObjectDetectionDataPipeline(DataPipeline):
"""
Modified from:
<https://github.com/PyTorchLightning/lightning-flash/blob/24c5b66/flash/vision/detection/data.py#L133-L160>
"""
def __init__(self, loader: Optional[Callable] = None):
if loader is None:
loader = lambda x: read_image(x) / 255.
self._loader = loader

def before_collate(self, samples: Any) -> Any:
if _contains_any_tensor(samples, Tensor):
return samples

if isinstance(samples, str):
samples = [samples]

if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
outputs.append(output)
return outputs

raise NotImplementedError("The samples should either be a tensor or path, a list of paths or tensors.")

def collate(self, samples: Any) -> Any:
if not isinstance(samples, Tensor):
elem = samples[0]

if isinstance(elem, Sequence):
return collate_fn(samples)

from typing import Callable, List, Any, Optional
return list(samples)

return samples.unsqueeze(dim=0)

def after_collate(self, batch: Any) -> Any:
return (batch["x"], batch["target"]) if isinstance(batch, dict) else (batch, None)


def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
"""
TODO: we should refactor FlashDatasetFolder to better integrate
with DataPipeline. That way, we wouldn't need this check.
This is because we are running transforms in both places.
Ref:
<https://github.com/PyTorchLightning/lightning-flash/blob/24c5b66/flash/core/data/utils.py#L80-L90>
"""
if isinstance(value, dtype):
return True
if isinstance(value, (list, tuple)):
return any(_contains_any_tensor(v, dtype=dtype) for v in value)
elif isinstance(value, dict):
return any(_contains_any_tensor(v, dtype=dtype) for v in value.values())
return False


class DetectionDataModule(LightningDataModule):
Expand Down Expand Up @@ -78,6 +142,20 @@ def val_dataloader(self, batch_size: int = 16) -> None:

return loader

@property
def data_pipeline(self) -> DataPipeline:
if self._data_pipeline is None:
self._data_pipeline = self.default_pipeline()
return self._data_pipeline

@data_pipeline.setter
def data_pipeline(self, data_pipeline) -> None:
self._data_pipeline = data_pipeline

@staticmethod
def default_pipeline() -> DataPipeline:
return ObjectDetectionDataPipeline()


class VOCDetectionDataModule(DetectionDataModule):
def __init__(
Expand Down
Loading

0 comments on commit c526565

Please sign in to comment.