-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
db8d4b0
commit cae787d
Showing
18 changed files
with
2,185 additions
and
1,015 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# DeepSpeed CIFAR Example | ||
This example is adapted from the | ||
[DCGAN example in the DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/training/gan) | ||
repository. It is intended to demonstrate a simple usecase of DeepSpeed with Determined. | ||
|
||
## Files | ||
* **model.py**: The DCGANTrial definition. | ||
* **gan_model.py**: Network definitions for generator and discriminator. | ||
* **data.py**: Dataset loading/downloading code. | ||
|
||
### Configuration Files | ||
* **ds_config.json**: The DeepSpeed config file. | ||
* **mnist.yaml**: Determined config to train the model on mnist on a cluster. | ||
|
||
## Data | ||
This repo supports the same datasets as the original example: `["imagenet", "lfw", "lsun", "cifar10", "mnist", "fake", "celeba"]`. The `cifar10` and `mnist` datasets will be downloaded as needed, whereas the rest must be mounted on the agent. For `lsun`, the `data_config.classes` setting must be set. The `folder` dataset can be used to load an arbitrary torchvision `ImageFolder` that is mounted on the agent. | ||
|
||
## To Run Locally | ||
|
||
It is recommended to run this from within one of our agent docker images, found at | ||
https://hub.docker.com/r/determinedai/pytorch-ngc/tags | ||
|
||
After installing docker and pulling an image, users can launch a container via | ||
`docker run --gpus=all -v ~path/to/repo:/src/proj -it <container name>` | ||
|
||
Install necessary dependencies via `pip install determined mpi4py` | ||
|
||
Then, run the following command: | ||
``` | ||
python trainer.py | ||
``` | ||
|
||
Any additional configs can be specified in `mnist.yaml` and `ds_config.json` accordingly. | ||
|
||
## To Run on Cluster | ||
If you have not yet installed Determined, installation instructions can be found | ||
under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html | ||
|
||
Run the following command: | ||
``` | ||
det experiment create mnist.yaml . | ||
``` | ||
The other configurations can be run by specifying the appropriate configuration file in place | ||
of `mnist.yaml`. | ||
|
||
## Results | ||
Training `mnist` should yield reasonable looking fake digit images on the images tab in TensorBoard after ~5k steps. | ||
|
||
Training `cifar10` does not converge as convincingly, but should look image-like after ~10k steps. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import contextlib | ||
import os | ||
from typing import cast | ||
|
||
import filelock | ||
import torch | ||
import torchvision.datasets as dset | ||
import torchvision.transforms as transforms | ||
|
||
CHANNELS_BY_DATASET = { | ||
"imagenet": 3, | ||
"folder": 3, | ||
"lfw": 3, | ||
"lsun": 3, | ||
"cifar10": 3, | ||
"mnist": 1, | ||
"fake": 3, | ||
"celeba": 3, | ||
} | ||
|
||
|
||
def get_dataset(data_config: dict) -> torch.utils.data.Dataset: | ||
if data_config.get("dataroot", None) is None: | ||
if str(data_config.get("dataset"),"").lower() != "fake": | ||
raise ValueError('`dataroot` parameter is required for dataset "%s"' | ||
% data_config.get("dataset", "")) | ||
else: | ||
context = contextlib.nullcontext() | ||
else: | ||
# Ensure that only one local process attempts to download/validate datasets at once. | ||
context = filelock.FileLock(os.path.join(data_config["dataroot"], ".lock")) | ||
with context: | ||
if data_config["dataset"] in ["imagenet", "folder", "lfw"]: | ||
# folder dataset | ||
dataset = dset.ImageFolder( | ||
root=data_config["dataroot"], | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.CenterCrop(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "lsun": | ||
classes = [c + "_train" for c in data_config["classes"].split(",")] | ||
dataset = dset.LSUN( | ||
root=data_config["dataroot"], | ||
classes=classes, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.CenterCrop(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "cifar10": | ||
dataset = dset.CIFAR10( | ||
root=data_config["dataroot"], | ||
download=True, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "mnist": | ||
dataset = dset.MNIST( | ||
root=data_config["dataroot"], | ||
download=True, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (0.5,)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "fake": | ||
dataset = dset.FakeData( | ||
image_size=(3, data_config["image_size"], data_config["image_size"]), | ||
transform=transforms.ToTensor(), | ||
) | ||
elif data_config["dataset"] == "celeba": | ||
dataset = dset.ImageFolder( | ||
root=data_config["dataroot"], | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.CenterCrop(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
else: | ||
unknown_dataset_name = data_config["dataset"] | ||
raise Exception(f"Unknown dataset {unknown_dataset_name}") | ||
return cast(torch.utils.data.Dataset, dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"train_batch_size": 64, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0002, | ||
"betas": [ | ||
0.5, | ||
0.999 | ||
], | ||
"eps": 1e-8 | ||
} | ||
}, | ||
"steps_per_print": 10 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import cast | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def weights_init(m: nn.Module) -> None: | ||
classname = m.__class__.__name__ | ||
if classname.find("Conv") != -1: | ||
nn.init.normal_(cast(torch.Tensor, m.weight.data), 0.0, 0.02) | ||
elif classname.find("BatchNorm") != -1: | ||
nn.init.normal_(cast(torch.Tensor, m.weight.data), 1.0, 0.02) | ||
nn.init.constant_(cast(torch.Tensor, m.bias.data), 0) | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, ngf: int, nc: int, nz: int) -> None: | ||
super(Generator, self).__init__() # type: ignore | ||
self.main = nn.Sequential( | ||
# input is Z, going into a convolution | ||
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), | ||
nn.BatchNorm2d(ngf * 8), # type: ignore | ||
nn.ReLU(True), | ||
# state size. (ngf*8) x 4 x 4 | ||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ngf * 4), # type: ignore | ||
nn.ReLU(True), | ||
# state size. (ngf*4) x 8 x 8 | ||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ngf * 2), # type: ignore | ||
nn.ReLU(True), | ||
# state size. (ngf*2) x 16 x 16 | ||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ngf), # type: ignore | ||
nn.ReLU(True), | ||
# state size. (ngf) x 32 x 32 | ||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), | ||
nn.Tanh() # type: ignore | ||
# state size. (nc) x 64 x 64 | ||
) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
output = self.main(input) | ||
return cast(torch.Tensor, output) | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, ndf: int, nc: int) -> None: | ||
super(Discriminator, self).__init__() # type: ignore | ||
self.main = nn.Sequential( | ||
# input is (nc) x 64 x 64 | ||
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf) x 32 x 32 | ||
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ndf * 2), # type: ignore | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf*2) x 16 x 16 | ||
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ndf * 4), # type: ignore | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf*4) x 8 x 8 | ||
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ndf * 8), # type: ignore | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf*8) x 4 x 4 | ||
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), | ||
nn.Sigmoid(), # type: ignore | ||
) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
output = self.main(input) | ||
return cast(torch.Tensor, output.view(-1, 1).squeeze(1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
name: dcgan_deepspeed_mnist | ||
data: | ||
dataroot: /data | ||
dataset: mnist | ||
image_size: 64 | ||
hyperparameters: | ||
deepspeed_config: ds_config.json | ||
noise_length: 100 | ||
generator_width_base: 64 | ||
discriminator_width_base: 64 | ||
data_workers: 16 | ||
environment: | ||
environment_variables: | ||
- NCCL_DEBUG=INFO | ||
- NCCL_SOCKET_IFNAME=ens,eth,ib | ||
image: determinedai/pytorch-ngc-dev:0736b6d | ||
bind_mounts: | ||
- host_path: /tmp | ||
container_path: /data | ||
resources: | ||
slots_per_trial: 2 | ||
searcher: | ||
name: single | ||
metric: no_validation_metric | ||
max_length: | ||
batches: 200 | ||
min_validation_period: | ||
batches: 0 | ||
entrypoint: | ||
- python3 | ||
- -m | ||
- determined.launch.deepspeed | ||
- python3 | ||
- trainer.py | ||
max_restarts: 0 |
Oops, something went wrong.