Skip to content

Commit

Permalink
Update for new versions of black and pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Mar 22, 2024
1 parent 2a4c293 commit e7d694d
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 20 deletions.
4 changes: 1 addition & 3 deletions direct/nn/cirim/cirim.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def __init__(
bias=bias,
)
self.hh = nn.Parameter(
nn.init.normal_(
torch.empty(1, hidden_channels, 1, 1), std=1.0 / (hidden_channels * (1 + kernel_size**2))
)
nn.init.normal_(torch.empty(1, hidden_channels, 1, 1), std=1.0 / (hidden_channels * (1 + kernel_size**2)))
)

self.reset_parameters()
Expand Down
4 changes: 1 addition & 3 deletions direct/nn/didn/didn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __init__(
Padding size. Default: 0.
"""
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels * upscale_factor**2, kernel_size=kernel_size, padding=padding
)
self.conv = nn.Conv2d(in_channels, out_channels * upscale_factor**2, kernel_size=kernel_size, padding=padding)
self.pixelshuffle = nn.PixelShuffle(upscale_factor)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
10 changes: 5 additions & 5 deletions direct/nn/get_nn_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def _get_model_config(
{
"hidden_channels": kwargs.get("conv_hidden_channels", 64),
"n_convs": kwargs.get("conv_n_convs", 15),
"activation": nn.PReLU()
if kwargs.get("conv_activation", "prelu") == ActivationType.prelu
else nn.ReLU()
if kwargs.get("conv_activation", "relu") == ActivationType.relu
else nn.LeakyReLU(),
"activation": (
nn.PReLU()
if kwargs.get("conv_activation", "prelu") == ActivationType.prelu
else nn.ReLU() if kwargs.get("conv_activation", "relu") == ActivationType.relu else nn.LeakyReLU()
),
"batchnorm": kwargs.get("conv_batchnorm", False),
}
)
Expand Down
12 changes: 8 additions & 4 deletions direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,14 @@ def reconstruct_volumes( # type: ignore
)
# Maybe not needed.
del data
yield (curr_volume, curr_target, reduce_list_of_dicts(loss_dict_list), filename) if add_target else (
curr_volume,
reduce_list_of_dicts(loss_dict_list),
filename,
yield (
(curr_volume, curr_target, reduce_list_of_dicts(loss_dict_list), filename)
if add_target
else (
curr_volume,
reduce_list_of_dicts(loss_dict_list),
filename,
)
)

@torch.no_grad()
Expand Down
6 changes: 3 additions & 3 deletions direct/nn/unet/unet_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def __init__(
def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]:
output_image = self.model(
masked_kspace=data["masked_kspace"],
sensitivity_map=data["sensitivity_map"]
if self.cfg.model.image_initialization == "sense" # type: ignore
else None,
sensitivity_map=(
data["sensitivity_map"] if self.cfg.model.image_initialization == "sense" else None # type: ignore
),
)
output_image = T.modulus(output_image)

Expand Down
4 changes: 2 additions & 2 deletions tests/tests_data/test_mri_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Tests for the direct.data.mri_transforms module."""

import functools
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -372,7 +372,7 @@ def test_EstimateSensitivityMap(shape, type_of_map, gaussian_sigma, espirit_iter
else:
transform = EstimateSensitivityMap(**args)
if shape[0] == 1 or sense_map_in_sample:
with pytest.warns(None):
with warnings.catch_warnings(record=True):
sample = transform(sample)
else:
sample = transform(sample)
Expand Down

0 comments on commit e7d694d

Please sign in to comment.