Skip to content

Commit

Permalink
BUG: Fix models to log train loss on step, fixes #720 (#722)
Browse files Browse the repository at this point in the history
* BUG: Fix models so we log train loss on step, fixes #720

* CLN: Apply linting changes: Black, iSort

* CLN: Make flake8 fixes
  • Loading branch information
NickleDave authored Oct 17, 2023
1 parent f85132b commit fefd576
Show file tree
Hide file tree
Showing 41 changed files with 241 additions and 229 deletions.
6 changes: 4 additions & 2 deletions src/vak/cli/prep.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Function called by command-line interface for prep command"""
from __future__ import annotations

import pathlib
import shutil
import warnings
import pathlib

import toml

Expand All @@ -13,7 +13,9 @@
from ..config.validators import are_sections_valid


def purpose_from_toml(config_toml: dict, toml_path: str | pathlib.Path | None = None) -> str:
def purpose_from_toml(
config_toml: dict, toml_path: str | pathlib.Path | None = None
) -> str:
"""determine "purpose" from toml config,
i.e., the command that will be run after we ``prep`` the data.
Expand Down
4 changes: 3 additions & 1 deletion src/vak/common/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,9 @@ def map_annotated_to_annot(
reference section of the documentation:
https://vak.readthedocs.io/en/latest/reference/filenames.html
"""
if isinstance(annotated_files, np.ndarray): # e.g., vak DataFrame['spect_path'].values
if isinstance(
annotated_files, np.ndarray
): # e.g., vak DataFrame['spect_path'].values
annotated_files = annotated_files.tolist()

if annot_format in (
Expand Down
1 change: 0 additions & 1 deletion src/vak/common/files/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from . import spect
from .files import find_fname, from_dir


__all__ = [
"find_fname",
"from_dir",
Expand Down
8 changes: 6 additions & 2 deletions src/vak/common/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from ..common.typing import PathLike


def get_summary_writer(log_dir: PathLike, filename_suffix: str) -> SummaryWriter:
def get_summary_writer(
log_dir: PathLike, filename_suffix: str
) -> SummaryWriter:
"""Get an instance of ``tensorboard.SummaryWriter``,
to use with a vak.Model during training.
Expand Down Expand Up @@ -45,7 +47,9 @@ def get_summary_writer(log_dir: PathLike, filename_suffix: str) -> SummaryWriter


def events2df(
events_path: PathLike, size_guidance: dict | None = None, drop_wall_time: bool = True
events_path: PathLike,
size_guidance: dict | None = None,
drop_wall_time: bool = True,
) -> pd.DataFrame:
"""Convert :mod:`tensorboard` events file to pandas.DataFrame
Expand Down
1 change: 0 additions & 1 deletion src/vak/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
validators,
)


__all__ = [
"config",
"eval",
Expand Down
6 changes: 1 addition & 5 deletions src/vak/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from . import frame_classification, parametric_umap


__all__ = [
"frame_classification",
"parametric_umap"
]
__all__ = ["frame_classification", "parametric_umap"]
22 changes: 16 additions & 6 deletions src/vak/datasets/frame_classification/frames_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from . import constants, helper
from .metadata import Metadata
from ... import common


class FramesDataset:
Expand Down Expand Up @@ -119,7 +118,10 @@ def __init__(
Transform applied to each item :math:`(x, y)`
returned by :meth:`FramesDataset.__getitem__`.
"""
from ... import prep # avoid circular import, use for constants.INPUT_TYPES
from ... import (
prep,
) # avoid circular import, use for constants.INPUT_TYPES

if input_type not in prep.constants.INPUT_TYPES:
raise ValueError(
f"``input_type`` must be one of: {prep.constants.INPUT_TYPES}\n"
Expand Down Expand Up @@ -165,7 +167,7 @@ def _load_frames(self, frames_path):
the input to the frame classification model.
Loads audio or spectrogram, depending on
:attr:`self.input_type`.
This function assumes that audio is in wav format
This function assumes that audio is in wav format
and spectrograms are in npz files.
"""
return helper.load_frames(frames_path, self.input_type)
Expand Down Expand Up @@ -233,15 +235,23 @@ def from_dataset_path(

split_path = dataset_path / split
if subset:
sample_ids_path = split_path / helper.sample_ids_array_filename_for_subset(subset)
sample_ids_path = (
split_path
/ helper.sample_ids_array_filename_for_subset(subset)
)
else:
sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME
sample_ids = np.load(sample_ids_path)

if subset:
inds_in_sample_path = split_path / helper.inds_in_sample_array_filename_for_subset(subset)
inds_in_sample_path = (
split_path
/ helper.inds_in_sample_array_filename_for_subset(subset)
)
else:
inds_in_sample_path = split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
inds_in_sample_path = (
split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
)
inds_in_sample = np.load(inds_in_sample_path)

return cls(
Expand Down
8 changes: 4 additions & 4 deletions src/vak/datasets/frame_classification/helper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""Helper functions used with frame classification datasets."""
from __future__ import annotations

from . import constants
from ... import common
from . import constants


def sample_ids_array_filename_for_subset(subset: str) -> str:
"""Returns name of sample IDs array file for a subset of the training data."""
return constants.SAMPLE_IDS_ARRAY_FILENAME.replace(
'.npy', f'-{subset}.npy'
)
".npy", f"-{subset}.npy"
)


def inds_in_sample_array_filename_for_subset(subset: str) -> str:
"""Returns name of inds in sample array file for a subset of the training data."""
return constants.INDS_IN_SAMPLE_ARRAY_FILENAME.replace(
'.npy', f'-{subset}.npy'
".npy", f"-{subset}.npy"
)


Expand Down
32 changes: 20 additions & 12 deletions src/vak/datasets/frame_classification/window_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from . import constants, helper
from .metadata import Metadata
from ... import common


def get_window_inds(n_frames: int, window_size: int, stride: int = 1):
Expand Down Expand Up @@ -231,7 +230,10 @@ def __init__(
The transform applied to the target for the output
of the neural network :math:`y`.
"""
from ... import prep # avoid circular import, use for constants.INPUT_TYPES
from ... import (
prep,
) # avoid circular import, use for constants.INPUT_TYPES

if input_type not in prep.constants.INPUT_TYPES:
raise ValueError(
f"``input_type`` must be one of: {prep.constants.INPUT_TYPES}\n"
Expand Down Expand Up @@ -284,15 +286,15 @@ def _load_frames(self, frames_path):
the input to the frame classification model.
Loads audio or spectrogram, depending on
:attr:`self.input_type`.
This function assumes that audio is in wav format
This function assumes that audio is in wav format
and spectrograms are in npz files.
"""
return helper.load_frames(frames_path, self.input_type)

def __getitem__(self, idx):
window_idx = self.window_inds[idx]
sample_ids = self.sample_ids[
window_idx: window_idx + self.window_size
window_idx : window_idx + self.window_size # noqa: E203
]
uniq_sample_ids = np.unique(sample_ids)
if len(uniq_sample_ids) == 1:
Expand All @@ -309,9 +311,7 @@ def __getitem__(self, idx):
frame_labels = []
for sample_id in sorted(uniq_sample_ids):
frames_path = self.dataset_path / self.frames_paths[sample_id]
frames.append(
self._load_frames(frames_path)
)
frames.append(self._load_frames(frames_path))
frame_labels.append(
np.load(
self.dataset_path / self.frame_labels_paths[sample_id]
Expand All @@ -331,10 +331,10 @@ def __getitem__(self, idx):

inds_in_sample = self.inds_in_sample[window_idx]
frames = frames[
..., inds_in_sample: inds_in_sample + self.window_size
..., inds_in_sample : inds_in_sample + self.window_size # noqa: E203
]
frame_labels = frame_labels[
inds_in_sample: inds_in_sample + self.window_size
inds_in_sample : inds_in_sample + self.window_size # noqa: E203
]
if self.transform:
frames = self.transform(frames)
Expand Down Expand Up @@ -405,15 +405,23 @@ def from_dataset_path(

split_path = dataset_path / split
if subset:
sample_ids_path = split_path / helper.sample_ids_array_filename_for_subset(subset)
sample_ids_path = (
split_path
/ helper.sample_ids_array_filename_for_subset(subset)
)
else:
sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME
sample_ids = np.load(sample_ids_path)

if subset:
inds_in_sample_path = split_path / helper.inds_in_sample_array_filename_for_subset(subset)
inds_in_sample_path = (
split_path
/ helper.inds_in_sample_array_filename_for_subset(subset)
)
else:
inds_in_sample_path = split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
inds_in_sample_path = (
split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
)
inds_in_sample = np.load(inds_in_sample_path)

window_inds_path = split_path / constants.WINDOW_INDS_ARRAY_FILENAME
Expand Down
3 changes: 2 additions & 1 deletion src/vak/datasets/parametric_umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from pynndescent import NNDescent
import scipy.sparse._coo
from pynndescent import NNDescent
from sklearn.utils import check_random_state
from torch.utils.data import Dataset

Expand All @@ -21,6 +21,7 @@

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
from umap.umap_ import fuzzy_simplicial_set # noqa: E402

# isort: on


Expand Down
1 change: 0 additions & 1 deletion src/vak/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from . import eval_, frame_classification, parametric_umap
from .eval_ import eval


__all__ = [
"eval",
"eval_",
Expand Down
5 changes: 1 addition & 4 deletions src/vak/eval/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytorch_lightning as lightning
import torch.utils.data

from .. import datasets, models, transforms
from .. import models, transforms
from ..common import validators
from ..datasets.parametric_umap import ParametricUMAPDataset

Expand Down Expand Up @@ -85,9 +85,6 @@ def eval_parametric_umap_model(
logger.info(
f"Loading metadata from dataset path: {dataset_path}",
)
metadata = datasets.parametric_umap.Metadata.from_dataset_path(
dataset_path
)

if not validators.is_a_directory(output_dir):
raise NotADirectoryError(
Expand Down
8 changes: 1 addition & 7 deletions src/vak/learncurve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from . import (
curvefit,
dirname,
frame_classification,
learncurve,
)
from . import curvefit, dirname, frame_classification, learncurve
from .learncurve import learning_curve


__all__ = [
"curvefit",
"dirname",
Expand Down
1 change: 0 additions & 1 deletion src/vak/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .classification import Accuracy


__all__ = [
"Accuracy",
]
2 changes: 1 addition & 1 deletion src/vak/models/frame_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def training_step(self, batch: tuple, batch_idx: int):
x, y = batch[0], batch[1]
out = self.network(x)
loss = self.loss(out, y)
self.log("train_loss", loss)
self.log("train_loss", loss, on_step=True)
return loss

def validation_step(self, batch: tuple, batch_idx: int):
Expand Down
8 changes: 5 additions & 3 deletions src/vak/models/parametric_umap_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def training_step(self, batch, batch_idx):
loss_umap, loss_reconstruction, loss = self.loss(
embedding_to, embedding_from, reconstruction, before_encoding
)
self.log("train_umap_loss", loss_umap)
self.log("train_umap_loss", loss_umap, on_step=True)
if loss_reconstruction:
self.log("train_reconstruction_loss", loss_reconstruction)
self.log(
"train_reconstruction_loss", loss_reconstruction, on_step=True
)
# note if there's no ``loss_reconstruction``, then ``loss`` == ``loss_umap``
self.log("train_loss", loss)
self.log("train_loss", loss, on_step=True)
return loss

def validation_step(self, batch, batch_idx):
Expand Down
4 changes: 1 addition & 3 deletions src/vak/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def __getattr__(name: str) -> Any:
model_name_family_name_map[model_name] = family_name
return model_name_family_name_map
elif name == "MODEL_NAMES":
return list(
MODEL_REGISTRY.keys()
)
return list(MODEL_REGISTRY.keys())
else:
raise AttributeError(
f"Not an attribute of `vak.models.registry`: {name}"
Expand Down
1 change: 0 additions & 1 deletion src/vak/nn/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .dice import DiceLoss, dice_loss
from .umap import UmapLoss, umap_loss


__all__ = [
"DiceLoss",
"dice_loss",
Expand Down
1 change: 1 addition & 0 deletions src/vak/nn/loss/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
from umap.umap_ import find_ab_params # noqa : E402

# isort: on


Expand Down
1 change: 0 additions & 1 deletion src/vak/predict/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from . import frame_classification, parametric_umap, predict_
from .predict_ import predict


__all__ = [
"frame_classification",
"parametric_umap",
Expand Down
1 change: 0 additions & 1 deletion src/vak/prep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from .prep_ import prep


__all__ = [
"audio_dataset",
"constants",
Expand Down
1 change: 0 additions & 1 deletion src/vak/prep/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ..common.typing import PathLike
from .spectrogram_dataset.audio_helper import files_from_dir


logger = logging.getLogger(__name__)


Expand Down
Loading

0 comments on commit fefd576

Please sign in to comment.