Skip to content

Commit

Permalink
modelType as attribute and tests to cover database upload (#60)
Browse files Browse the repository at this point in the history
* modelType as attribute and tests to cover database upload
  • Loading branch information
egillax authored Mar 24, 2023
1 parent e8083d0 commit 207a860
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 34 deletions.
9 changes: 6 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 1.1.0
Version: 1.1.1
Date: 15-12-2022
Authors@R: c(
person("Egill", "Fridgeirsson", email = "e.fridgeirsson@erasmusmc.nl", role = c("aut", "cre")),
Expand Down Expand Up @@ -34,11 +34,14 @@ Suggests:
markdown,
plyr,
testthat,
PRROC
PRROC,
ResultModelManager (>= 0.2.0),
DatabaseConnector (>= 6.0.0)
Remotes:
ohdsi/PatientLevelPrediction,
ohdsi/FeatureExtraction,
ohdsi/Eunomia
ohdsi/Eunomia,
ohdsi/ResultModelManager
RoxygenNote: 7.2.3
Encoding: UTF-8
Config/testthat/edition: 3
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
DeepPatientLevelPrediction 1.1.1
- Fix bug introduced by removing modelType from attributes (#59)

DeepPatientLevelPrediction 1.1
======================
- Check for if number of heads is compatible with embedding dimension fixed (#55)
Expand Down
2 changes: 1 addition & 1 deletion R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ fitEstimator <- function(trainData,
isNumeric = cvResult$numericalIndex
)


attr(modelSettings$param, 'settings')$modelType <- modelSettings$modelType
comp <- start - Sys.time()
result <- list(
model = cvResult$estimator, # file.path(outLoc),
Expand Down
1 change: 0 additions & 1 deletion R/MLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
if (hyperParamSearch == "random") {
suppressWarnings(withr::with_seed(randomSampleSeed, {param <- param[sample(length(param), randomSample)]}))
}

results <- list(
fitFunction = "fitEstimator",
param = param,
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Requires R (version 4.0.0 or higher). Installation on Windows requires [RTools](
Getting Started
===============

- To install the package please read the [Package installation guide]()
- To install the package please read the [Package installation guide](https://ohdsi.github.io/DeepPatientLevelPrediction/articles/Installing.html)
- Please read the main vignette for the package:
[Building Deep Learning Models](https://ohdsi.github.io/DeepPatientLevelPrediction/articles/BuildingDeepModels.html)

Expand Down
33 changes: 32 additions & 1 deletion tests/testthat/test-MLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ results <- tryCatch(
plpData = plpData,
outcomeId = 3,
modelSettings = modelSettings,
analysisId = "MLP",
analysisId = "Analysis_MLP",
analysisName = "Testing Deep Learning",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
Expand Down Expand Up @@ -133,4 +133,35 @@ test_that("Errors are produced by settings function", {
hyperParamSearch = 'random',
randomSample = randomSample))

})

test_that("Can upload results to database", {
cohortDefinitions = data.frame(
cohortName = c('blank1'),
cohortId = c(1),
json = c('json')
)

sink(nullfile())
sqliteFile <- insertResultsToSqlite(resultLocation = file.path(testLoc, "MLP"),
cohortDefinitions = cohortDefinitions)
sink()

testthat::expect_true(file.exists(sqliteFile))

cdmDatabaseSchema <- 'main'
ohdsiDatabaseSchema <- 'main'
connectionDetails <- DatabaseConnector::createConnectionDetails(
dbms = 'sqlite',
server = sqliteFile
)
conn <- DatabaseConnector::connect(connectionDetails = connectionDetails)
targetDialect <- 'sqlite'

# check the results table is populated
sql <- 'select count(*) as N from main.performances;'
res <- DatabaseConnector::querySql(conn, sql)
testthat::expect_true(res$N[1]>0)


})
85 changes: 58 additions & 27 deletions tests/testthat/test-ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ resSet <- setResNet(

test_that("setResNet works", {
testthat::expect_s3_class(object = resSet, class = "modelSettings")

testthat::expect_equal(resSet$fitFunction, "fitEstimator")

testthat::expect_true(length(resSet$param) > 0)

expect_error(setResNet(numLayers = c(2),
sizeHidden = c(32),
hiddenFactor = c(2),
residualDropout = c(0.1),
hiddenDropout = c(0.1),
sizeEmbedding = c(32),
estimatorSettings = setEstimator(learningRate=c(3e-4),
weightDecay = c(1e-6),
seed=42,
batchSize = 128,
epochs=1),
hyperParamSearch = "random",
randomSample = 2))
sizeHidden = c(32),
hiddenFactor = c(2),
residualDropout = c(0.1),
hiddenDropout = c(0.1),
sizeEmbedding = c(32),
estimatorSettings = setEstimator(learningRate=c(3e-4),
weightDecay = c(1e-6),
seed=42,
batchSize = 128,
epochs=1),
hyperParamSearch = "random",
randomSample = 2))
})

sink(nullfile())
Expand All @@ -44,7 +44,7 @@ res2 <- tryCatch(
plpData = plpData,
outcomeId = 3,
modelSettings = resSet,
analysisId = "ResNet",
analysisId = "Analysis_ResNet",
analysisName = "Testing Deep Learning",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
Expand All @@ -59,7 +59,7 @@ res2 <- tryCatch(
runModelDevelopment = T,
runCovariateSummary = F
),
saveDirectory = file.path(testLoc, "Deep")
saveDirectory = file.path(testLoc, "ResNet")
)
},
error = function(e) {
Expand All @@ -71,17 +71,17 @@ sink()

test_that("ResNet with runPlp working checks", {
testthat::expect_false(is.null(res2))

# check structure
testthat::expect_true("prediction" %in% names(res2))
testthat::expect_true("model" %in% names(res2))
testthat::expect_true("covariateSummary" %in% names(res2))
testthat::expect_true("performanceEvaluation" %in% names(res2))

# check prediction same size as pop
testthat::expect_equal(nrow(res2$prediction %>%
dplyr::filter(evaluationType %in% c("Train", "Test"))), nrow(population))

dplyr::filter(evaluationType %in% c("Train", "Test"))), nrow(population))
# check prediction between 0 and 1
testthat::expect_gte(min(res2$prediction$value), 0)
testthat::expect_lte(max(res2$prediction$value), 1)
Expand All @@ -96,22 +96,22 @@ test_that("ResNet nn-module works ", {
normalization = torch::nn_batch_norm1d, hiddenDropout = 0.3,
residualDropout = 0.3, d_out = 1
)

pars <- sum(sapply(model$parameters, function(x) prod(x$shape)))

# expected number of parameters
expect_equal(pars, 1295)

input <- list()
input$cat <- torch::torch_randint(0, 5, c(10, 5), dtype = torch::torch_long())
input$num <- torch::torch_randn(10, 1, dtype = torch::torch_float32())


output <- model(input)

# output is correct shape
expect_equal(output$shape, 10)

input$num <- NULL
model <- ResNet(
catFeatures = 5, numFeatures = 0, sizeEmbedding = 5,
Expand Down Expand Up @@ -154,3 +154,34 @@ test_that("Errors are produced by settings function", {
hyperParamSearch = 'random',
randomSample = randomSample))
})


test_that("Can upload results to database", {
cohortDefinitions = data.frame(
cohortName = c('blank1'),
cohortId = c(1),
json = c('json')
)

sink(nullfile())
sqliteFile <- insertResultsToSqlite(resultLocation = file.path(testLoc, "ResNet"),
cohortDefinitions = cohortDefinitions)
sink()

testthat::expect_true(file.exists(sqliteFile))

cdmDatabaseSchema <- 'main'
ohdsiDatabaseSchema <- 'main'
connectionDetails <- DatabaseConnector::createConnectionDetails(
dbms = 'sqlite',
server = sqliteFile
)
conn <- DatabaseConnector::connect(connectionDetails = connectionDetails)
targetDialect <- 'sqlite'

# check the results table is populated
sql <- 'select count(*) as N from main.performances;'
res <- DatabaseConnector::querySql(conn, sql)
testthat::expect_true(res$N[1]>0)
})

0 comments on commit 207a860

Please sign in to comment.