Skip to content

Commit

Permalink
Version 1.1.5 (#69)
Browse files Browse the repository at this point in the history
bugfix for LRFinder with a device function
  • Loading branch information
egillax authored Apr 24, 2023
1 parent f26382e commit cfb1bfa
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 1.1.4
Version: 1.1.5
Date: 18-04-2023
Authors@R: c(
person("Egill", "Fridgeirsson", email = "e.fridgeirsson@erasmusmc.nl", role = c("aut", "cre")),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
DeepPatientLevelPrediction 1.1.5
======================
- Fix bug where device function was not working for LRFinder

DeepPatientLevelPrediction 1.1.4
======================
- Remove torchopt dependancy since adamw is now in torch
Expand Down
2 changes: 1 addition & 1 deletion R/LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ lrFinder <- function(dataset, modelType, modelParams, estimatorSettings,
optimizer$zero_grad()

batch <- dataset[sample(batchIndex, estimatorSettings$batchSize)]
batch <- batchToDevice(batch, device=estimatorSettings$device)
batch <- batchToDevice(batch, device=device)

output <- model(batch$batch)

Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test-LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ test_that("LR finder works", {

expect_true(lr<=10.0)
expect_true(lr>=3e-4)
})

test_that("LR finder works with device specified by a function", {

deviceFun <- function(){
dev = "cpu"
}
lr <- lrFinder(dataset, modelType = ResNet, modelParams = list(catFeatures=dataset$numCatFeatures(),
numFeatures=dataset$numNumFeatures(),
sizeEmbedding=8,
sizeHidden=16,
numLayers=1,
hiddenFactor=1),
estimatorSettings = setEstimator(batchSize=32,
seed = 42,
device = deviceFun),
minLR = 3e-4,
maxLR = 10.0,
numLR = 20,
divergenceThreshold = 1.1)
expect_true(lr<=10.0)
expect_true(lr>=3e-4)


})

0 comments on commit cfb1bfa

Please sign in to comment.