Skip to content

Commit

Permalink
Merge pull request #5 from sfcheung/faster_find_products
Browse files Browse the repository at this point in the history
Add parallel to find_all_products
  • Loading branch information
sfcheung authored Jun 26, 2024
2 parents 129dc1b + 30fb4a4 commit 9e74881
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 15 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: betaselectr
Title: Selective Standardization in Structural Equation Models
Version: 0.0.1.1
Version: 0.0.1.2
Authors@R:
c(person(given = "Shu Fai",
family = "Cheung",
Expand Down
10 changes: 9 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# betaselectr 0.0.1.1
# betaselectr 0.0.1.2

- Added `lm_betaselect()` and related
methods and helper functions.
(0.0.1.1)

- Added parallel processing support to
the internal function
`find_all_products()`. (0.0.1.2)

- For `lav_betaselect()`, added an
option to skip finding product
terms. (0.0.1.2)
23 changes: 21 additions & 2 deletions R/lav_betaselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@
#' `FALSE`, then the bootstrap estimates
#' will not be stored.
#'
#' @param find_product_terms String.
#' If it is certain that a model does
#' not have product terms, setting this
#' to `FALSE` will skip the search, which
#' is time consuming for a models with
#' many paths and/or many variables.
#' Default is `TRUE`, and the function
#' will automatically identify product
#' terms, if any.
#'
#'
#' @author Shu Fai Cheung <https://orcid.org/0000-0002-9871-9448>
#'
Expand Down Expand Up @@ -333,7 +343,8 @@ lav_betaselect <- function(object,
iseed = NULL,
...,
delta_method = c("lavaan", "numDeriv"),
vector_form = TRUE) {
vector_form = TRUE,
find_product_terms = TRUE) {
if (!isTRUE(requireNamespace("pbapply", quietly = TRUE)) ||
!interactive()) {
progress <- FALSE
Expand All @@ -350,7 +361,15 @@ lav_betaselect <- function(object,
ngroups <- lavaan::lavTech(object, what = "ngroups")

# Get the variables to be standardized
prods <- find_all_products(object)
if (find_product_terms) {
prods <- find_all_products(object,
parallel = (parallel != "none"),
ncpus = ncpus,
cl = cl,
progress = progress)
} else {
prods <- list()
}
to_standardize <- fix_to_standardize(object = object,
to_standardize = to_standardize,
not_to_standardize = not_to_standardize,
Expand Down
59 changes: 52 additions & 7 deletions R/lav_betaselect_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -841,16 +841,61 @@ to_standardize_for_i <- function(prods,

#' @noRd

find_all_products <- function(fit) {
find_all_products <- function(fit,
parallel = TRUE,
ncpus = parallel::detectCores(logical = FALSE) - 1,
cl = NULL,
progress = FALSE) {

ptable <- lavaan::parameterTable(fit)
reg_paths <- all_reg_paths(ptable)
if (length(reg_paths) == 0) return(list())
prods <- mapply(manymome::get_prod,
x = reg_paths[, "rhs"],
y = reg_paths[, "lhs"],
MoreArgs = list(fit = fit,
expand = TRUE),
SIMPLIFY = FALSE)
if (parallel) {
if (is.null(cl)) {
cl <- parallel::makeCluster(min(nrow(reg_paths), ncpus))
on.exit(parallel::stopCluster(cl), add = TRUE)
}
tmp <- split(reg_paths,
seq_len(nrow(reg_paths)),
drop = FALSE)
tmpfct <- function(xx) {
force(fit)
manymome::get_prod(x = xx[2],
y = xx[1],
fit = fit,
expand = TRUE)
}
if (progress) {
cat("\nFinding product terms in the model ...\n")
prods <- pbapply::pblapply(tmp,
FUN = tmpfct,
cl = cl)
cat("\nFinished finding product terms.\n")
} else {
prods <- parallel::parLapplyLB(cl = cl,
tmp,
fun = tmpfct,
chunk.size = 1)
}
} else {
if (progress) {
cat("\nFinding product terms in the model ...\n")
prods <- pbapply::pbmapply(manymome::get_prod,
x = reg_paths[, "rhs"],
y = reg_paths[, "lhs"],
MoreArgs = list(fit = fit,
expand = TRUE),
SIMPLIFY = FALSE)
cat("\nFinished finding product terms.\n")
} else {
prods <- mapply(manymome::get_prod,
x = reg_paths[, "rhs"],
y = reg_paths[, "lhs"],
MoreArgs = list(fit = fit,
expand = TRUE),
SIMPLIFY = FALSE)
}
}
prods_ok <- sapply(prods, function(x) is.list(x))
if (!any(prods_ok)) {
return(list())
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Not ready for use.

# betaselectr: Do selective standardization in structural equation models and regression models

(Version 0.0.1.1, updated on 2024-06-26, [release history](https://sfcheung.github.io/betaselectr/news/index.html))
(Version 0.0.1.2, updated on 2024-06-26, [release history](https://sfcheung.github.io/betaselectr/news/index.html))

It computes Beta_Select, standardization
in structural equation models with only
Expand Down
13 changes: 12 additions & 1 deletion man/lav_betaselect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions tests/testthat/test-lav_betaselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ test_that("All est", {
"Standardized")
})

# Check skipping the search for product terms

test_that("All est", {
out1 <- lav_betaselect(fit, progress = FALSE)
out2 <- lav_betaselect(fit, find_product_terms = FALSE, progress = FALSE)
expect_equal(out1,
out2,
ignore_attr = TRUE)
})


# (which(std_nox$est.std != std$est.std))
Expand Down
15 changes: 13 additions & 2 deletions tests/testthat/test_find_all_products_cats.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ fit <- sem(mod, dat, meanstructure = TRUE, fixed.x = FALSE)
fit2 <- sem(mod, dat, meanstructure = TRUE, fixed.x = FALSE, group = "city")


prods <- find_all_products(fit)
prods2 <- find_all_products(fit2)
prods <- find_all_products(fit, parallel = FALSE)
prods2 <- find_all_products(fit2, parallel = FALSE)

test_that("Find prods", {
expect_true(setequal(names(prods),
Expand All @@ -38,3 +38,14 @@ test_that("Find cat", {
expect_true(setequal(find_categorical(fit2),
c("gpgp3", "gpgp2")))
})


test_that("Find prods: parallel", {
skip_on_cran()
prodsb <- find_all_products(fit, parallel = TRUE, ncpus = 2)
prods2b <- find_all_products(fit2, parallel = TRUE, ncpus = 2)
expect_true(setequal(names(prods),
c("x:gpgp2", "x:gpgp3", "x:w4")))
expect_true(setequal(names(prods2),
c("x:gpgp2", "x:gpgp3", "x:w4")))
})

0 comments on commit 9e74881

Please sign in to comment.