Skip to content

Commit

Permalink
Merge pull request #267 from neuronets/ohinds_train_mods
Browse files Browse the repository at this point in the history
Small changes to support long, preemptable training runs
  • Loading branch information
satra authored Oct 7, 2023
2 parents f926fba + 19a75b3 commit 59ec482
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ["3.10", "3.9", "3.8"]
python-version: ["3.11", "3.10", "3.9"]

steps:
- uses: actions/checkout@v3
Expand All @@ -34,7 +34,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ["3.10", "3.9", "3.8"]
python-version: ["3.11", "3.10", "3.9"]
steps:
- uses: actions/checkout@v3
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/guide-notebooks-ec2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
export LD_LIBRARY_PATH=opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib
echo $LD_LIBRARY_PATH
pip install matplotlib nilearn
pip install -U tensorflow
pip install -e .
nobrainer info
- name: run
Expand Down
2 changes: 1 addition & 1 deletion docker/cpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM tensorflow/tensorflow:2.13.0-jupyter
FROM tensorflow/tensorflow:2.14.0-jupyter
RUN curl -sSL http://neuro.debian.net/lists/focal.us-nh.full | tee /etc/apt/sources.list.d/neurodebian.sources.list \
&& export GNUPGHOME="$(mktemp -d)" \
&& echo "disable-ipv6" >> ${GNUPGHOME}/dirmngr.conf \
Expand Down
2 changes: 1 addition & 1 deletion docker/gpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM tensorflow/tensorflow:2.13.0-gpu-jupyter
FROM tensorflow/tensorflow:2.14.0-gpu-jupyter
RUN curl -sSL http://neuro.debian.net/lists/focal.us-nh.full | tee /etc/apt/sources.list.d/neurodebian.sources.list \
&& export GNUPGHOME="$(mktemp -d)" \
&& echo "disable-ipv6" >> ${GNUPGHOME}/dirmngr.conf \
Expand Down
7 changes: 7 additions & 0 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def from_tfrecords(
block_shape=None,
scalar_labels=False,
n_classes=1,
tf_dataset_options=None,
num_parallel_calls=1,
):
"""Function to retrieve a saved tf record as a nobrainer Dataset
Expand All @@ -97,6 +98,9 @@ def from_tfrecords(
compressed = _is_gzipped(files[0], filesys=fs)
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)

if tf_dataset_options:
dataset = dataset.with_options(tf_dataset_options)

# Read each of these files as a TFRecordDataset.
# Assume all files have same compression type as the first file.
compression_type = "GZIP" if compressed else None
Expand Down Expand Up @@ -150,6 +154,7 @@ def from_files(
eval_size=0.1,
n_classes=1,
block_shape=None,
tf_dataset_options=None,
):
"""Create Nobrainer datasets from data
filepaths: List(str), list of paths to individual input data files.
Expand Down Expand Up @@ -211,6 +216,7 @@ def from_files(
scalar_labels=scalar_labels,
n_classes=n_classes,
block_shape=block_shape,
tf_dataset_options=tf_dataset_options,
num_parallel_calls=num_parallel_calls,
)
ds_eval = None
Expand All @@ -223,6 +229,7 @@ def from_files(
scalar_labels=scalar_labels,
n_classes=n_classes,
block_shape=block_shape,
tf_dataset_options=tf_dataset_options,
num_parallel_calls=num_parallel_calls,
)
return ds_train, ds_eval
Expand Down
5 changes: 5 additions & 0 deletions nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ def load(
"""
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/"))
if not checkpoints:
self.last_epoch = 0
return None

# TODO, we should probably exclude non-checkpoint files here,
# and maybe parse the filename for the epoch number
self.last_epoch = len(checkpoints)

latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(
latest,
Expand Down
10 changes: 5 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ classifiers =
Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Software Development
Topic :: Software Development :: Libraries :: Python Modules
project_urls =
Source Code = https://github.com/neuronets/nobrainer

[options]
python_requires = >= 3.8
python_requires = >= 3.9
install_requires =
click
fsspec
joblib
nibabel
numpy
scikit-image
tensorflow-probability >= 0.11.0
tensorflow >= 2.12.0
tensorflow-addons >= 0.12.0
tensorflow-probability ~= 0.22.0
tensorflow ~= 2.13
tensorflow-addons ~= 0.21.0
psutil
zip_safe = False
packages = find:
Expand Down

0 comments on commit 59ec482

Please sign in to comment.