Skip to content

Commit

Permalink
Remove unused and broken rstan_seq_perm
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed May 20, 2024
1 parent b0d1b29 commit 1d17cb1
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 104 deletions.
12 changes: 0 additions & 12 deletions rstan/rstan/R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,3 @@ rstan_ess2_cpp <- function(sims) {
}
ess
}

rstan_seq_perm <- function(n, chains, seed, chain_id = 1) {
# Args:
# n: length of sequence to be generated
# chains: the number of chains, for which the permuations are applied
# seed: the seed for RNG
# chain_id: the chain id, for which the returned permuation is applied
#
conf <- list(n = n, chains = chains, seed = seed, chain_id = chain_id)
perm <- .Call(seq_permutation, conf)
perm + 1L # start from 1
}
123 changes: 43 additions & 80 deletions rstan/rstan/src/chains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,37 @@ class perm_args {
private:
int n, chains, chain_id;
unsigned int seed;

inline unsigned int sexp2seed(SEXP seed) {
if (TYPEOF(seed) == STRSXP)
return std::stoull(Rcpp::as<std::string>(seed));
return Rcpp::as<unsigned int>(seed);
}

public:
perm_args(Rcpp::List& lst) : chain_id(1) {
if (!lst.containsElementNamed("n"))
throw std::runtime_error("number of iterations kept (n) is not specified");
n = Rcpp::as<int>(lst["n"]);

if (!lst.containsElementNamed("chains"))
throw std::runtime_error("number of chains is not specified");
chains = Rcpp::as<int>(lst["chains"]);

if (lst.containsElementNamed("chain_id"))
chain_id = Rcpp::as<int>(lst["chain_id"]);

if (lst.containsElementNamed("seed"))
seed = sexp2seed(lst["seed"]);
else
seed = std::time(0);
}

inline int get_n() const { return n; }
inline int get_chain_id() const { return chain_id; }
inline unsigned int get_seed() const { return seed; }
inline int get_chains() const { return chains; }

inline SEXP perm_args_to_rlist() const {
Rcpp::List lst;
std::stringstream ss;
Expand Down Expand Up @@ -114,7 +114,7 @@ void validate_sim(SEXP sim) {
snames.push_back("permutation");
Rcpp::List lst(sim);
std::vector<std::string> names = lst.names();

for (std::vector<std::string>::const_iterator it = snames.begin();
it != snames.end();
++it) {
Expand All @@ -124,7 +124,7 @@ void validate_sim(SEXP sim) {
throw std::domain_error(msg.str());
}
}

unsigned int type = TYPEOF(lst["chains"]);
if (type != INTSXP && type != REALSXP) {
std::stringstream msg;
Expand All @@ -133,15 +133,15 @@ void validate_sim(SEXP sim) {
<< ", but INTSXP/REALSXP needed";
throw std::domain_error(msg.str());
}

SEXP sample_sexp = lst["samples"];

if (TYPEOF(lst["samples"]) != VECSXP) {
std::stringstream msg;
msg << "sim$samples is not a list";
throw std::domain_error(msg.str());
}

int nchains2 = Rcpp::List(sample_sexp).size();
if (nchains2 != Rcpp::as<int>(lst["chains"])) {
std::stringstream msg;
Expand Down Expand Up @@ -171,7 +171,7 @@ void get_kept_samples(SEXP sim, const size_t k, const size_t n,
Rcpp::List allsamples(static_cast<SEXP>(lst["samples"]));
Rcpp::IntegerVector n_save(static_cast<SEXP>(lst["n_save"]));
Rcpp::IntegerVector warmup2(static_cast<SEXP>(lst["warmup2"]));

Rcpp::List slst(static_cast<SEXP>(allsamples[k])); // chain k
Rcpp::NumericVector nv(static_cast<SEXP>(slst[n])); // parameter n
samples.assign(warmup2[k] + nv.begin(), nv.end());
Expand Down Expand Up @@ -205,7 +205,7 @@ void apply_kept_samples(SEXP sim, size_t k,
Rcpp::List allsamples(static_cast<SEXP>(lst["samples"]));
Rcpp::IntegerVector n_save(static_cast<SEXP>(lst["n_save"]));
Rcpp::IntegerVector warmup2(static_cast<SEXP>(lst["warmup2"]));

Rcpp::List slst(static_cast<SEXP>(allsamples[k])); // chain k
Rcpp::NumericVector nv(static_cast<SEXP>(slst[n])); // parameter n
// use int instead of size_t since these are R integers.
Expand Down Expand Up @@ -280,7 +280,6 @@ RcppExport SEXP effective_sample_size(SEXP sim, SEXP n_);
RcppExport SEXP effective_sample_size2(SEXP sims);
RcppExport SEXP split_potential_scale_reduction(SEXP sim, SEXP n_);
RcppExport SEXP split_potential_scale_reduction2(SEXP sims_);
RcppExport SEXP seq_permutation(SEXP conf);
RcppExport SEXP CPP_read_comments(SEXP file, SEXP n);

RcppExport SEXP stan_prob_autocovariance(SEXP v);
Expand Down Expand Up @@ -312,38 +311,38 @@ SEXP effective_sample_size(SEXP sim, SEXP n_) {
rstan::validate_param_idx(sim,n);
unsigned int m(rstan::num_chains(sim));
// need to generalize to each jagged samples per chain

std::vector<unsigned int> ns_save =
Rcpp::as<std::vector<unsigned int> >(lst["n_save"]);

std::vector<unsigned int> ns_warmup2 =
Rcpp::as<std::vector<unsigned int> >(lst["warmup2"]);

std::vector<unsigned int> ns_kept(ns_save);
for (size_t i = 0; i < ns_kept.size(); i++)
ns_kept[i] -= ns_warmup2[i];

unsigned int n_samples = ns_kept[0];
for (size_t chain = 1; chain < m; chain++) {
n_samples = std::min(n_samples, ns_kept[chain]);
}

using std::vector;
vector< vector<double> > acov;
for (size_t chain = 0; chain < m; chain++) {
vector<double> acov_chain;
rstan::autocovariance(sim, chain, n, acov_chain);
acov.push_back(acov_chain);
}

vector<double> chain_mean;
vector<double> chain_var;
for (size_t chain = 0; chain < m; chain++) {
unsigned int n_kept_samples = ns_kept[chain];
chain_mean.push_back(rstan::get_chain_mean(sim,chain,n));
chain_var.push_back(acov[chain][0]*n_kept_samples/(n_kept_samples-1));
}

double mean_var = stan::math::mean(chain_var);
double var_plus = mean_var*(n_samples-1)/n_samples;
if (m > 1) var_plus += stan::math::variance(chain_mean);
Expand All @@ -352,7 +351,7 @@ SEXP effective_sample_size(SEXP sim, SEXP n_) {
for (size_t chain = 0; chain < m; chain++) {
acov_t[chain] = acov[chain][1];
}

double rho_hat_even = 1;
double rho_hat_odd = 1 - (mean_var - stan::math::mean(acov_t)) / var_plus;
rho_hat_t[1] = rho_hat_odd;
Expand All @@ -373,15 +372,15 @@ SEXP effective_sample_size(SEXP sim, SEXP n_) {
}
max_t = t + 2;
}

// Geyer's initial monotone sequence
for (int t = 3; t <= max_t - 2; t += 2) {
if (rho_hat_t[t + 1] + rho_hat_t[t + 2] > rho_hat_t[t - 1] + rho_hat_t[t]) {
rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2;
rho_hat_t[t + 2] = rho_hat_t[t + 1];
}
}

double ess = m*n_samples;
ess /= (1 + 2 * stan::math::sum(rho_hat_t));
SEXP __sexp_result;
Expand Down Expand Up @@ -433,12 +432,12 @@ SEXP effective_sample_size2(SEXP sims) {
acov.push_back(acov_chain);
chain_mean.push_back(stan::math::mean(samples));
}

vector<double> chain_var;
for (size_t chain = 0; chain < m; chain++) {
chain_var.push_back(acov[chain][0]*n_samples/(n_samples-1));
}

double mean_var = stan::math::mean(chain_var);
double var_plus = mean_var*(n_samples-1)/n_samples;
if (m > 1) var_plus += stan::math::variance(chain_mean);
Expand All @@ -453,7 +452,7 @@ SEXP effective_sample_size2(SEXP sims) {
if (rho_hat >= 0)
rho_hat_t.push_back(rho_hat);
}

double ess = m*n_samples;
if (rho_hat_t.size() > 0) {
ess /= 1 + 2 * stan::math::sum(rho_hat_t);
Expand Down Expand Up @@ -482,10 +481,10 @@ SEXP split_potential_scale_reduction2(SEXP sims_) {
unsigned int n_samples = nm.nrow();
if (n_samples % 2 == 1)
n_samples--;

std::vector<double> split_chain_mean;
std::vector<double> split_chain_var;

for (size_t chain = 0; chain < n_chains; chain++) {
std::vector<double> split_chain(n_samples/2);
Rcpp::NumericMatrix::Column samples = nm(Rcpp::_, chain);
Expand All @@ -501,10 +500,10 @@ SEXP split_potential_scale_reduction2(SEXP sims_) {
// copied and pasted from split_potential_scale_reduction
double var_between = n_samples/2 * stan::math::variance(split_chain_mean);
double var_within = stan::math::mean(split_chain_var);

// rewrote [(n-1)*W/n + B/n]/W as (n-1+ B/W)/n
double srhat = sqrt((var_between/var_within + n_samples/2 -1)/(n_samples/2));

SEXP __sexp_result;
PROTECT(__sexp_result = Rcpp::wrap(srhat));
UNPROTECT(1);
Expand All @@ -524,56 +523,56 @@ SEXP split_potential_scale_reduction2(SEXP sims_) {
* @return split R hat.
*/
SEXP split_potential_scale_reduction(SEXP sim, SEXP n_) {

BEGIN_RCPP
rstan::validate_sim(sim);
Rcpp::List lst(sim);
unsigned int n = Rcpp::as<unsigned int>(n_);
// Rcpp::Rcout << "n=" << n << std::endl;
unsigned int n_chains(rstan::num_chains(sim));
// Rcpp::Rcout << "n_chains=" << n_chains << std::endl;

std::vector<unsigned int> ns_save =
Rcpp::as<std::vector<unsigned int> >(lst["n_save"]);

std::vector<unsigned int> ns_warmup2 =
Rcpp::as<std::vector<unsigned int> >(lst["warmup2"]);

std::vector<unsigned int> ns_kept(ns_save);
for (size_t i = 0; i < ns_kept.size(); i++)
ns_kept[i] -= ns_warmup2[i];

unsigned int n_samples = ns_kept[0];
for (size_t chain = 1; chain < n_chains; chain++) {
n_samples = std::min(n_samples, ns_kept[chain]);
}

if (n_samples % 2 == 1)
n_samples--;

std::vector<double> split_chain_mean;
std::vector<double> split_chain_var;

for (size_t chain = 0; chain < n_chains; chain++) {
std::vector<double> samples; // (n_samples);
rstan::get_kept_samples(sim, chain, n, samples);
// Rcpp::Rcout << samples[0] << ", " << samples.size() << std::endl;

std::vector<double> split_chain(n_samples/2);
split_chain.assign(samples.begin(),
samples.begin()+n_samples/2);
split_chain_mean.push_back(stan::math::mean(split_chain));
split_chain_var.push_back(stan::math::variance(split_chain));

split_chain.assign(samples.end()-n_samples/2,
samples.end());
split_chain_mean.push_back(stan::math::mean(split_chain));
split_chain_var.push_back(stan::math::variance(split_chain));
}

double var_between = n_samples/2 * stan::math::variance(split_chain_mean);
double var_within = stan::math::mean(split_chain_var);

// rewrote [(n-1)*W/n + B/n]/W as (n-1+ B/W)/n
double srhat = sqrt((var_between/var_within + n_samples/2 -1)/(n_samples/2));
SEXP __sexp_result;
Expand All @@ -583,42 +582,6 @@ SEXP split_potential_scale_reduction(SEXP sim, SEXP n_) {
END_RCPP
}

/*
* Obtain a permutation of size n.
* see <code>permutation</code> in <code>mcmc::chains.hpp</code>.
*
* @param conf an R named list contains elements: n, chains, seed, chain_id.
*
* @return A permutation of length 'n' starting from 0.
*/
SEXP seq_permutation(SEXP conf) {
BEGIN_RCPP
Rcpp::List lst(conf);
rstan::perm_args args(lst);
boost::uintmax_t DISCARD_STRIDE = static_cast<boost::uintmax_t>(1) << 50;
int n = args.get_n();
int cid = args.get_chain_id() + args.get_chains();
typedef boost::random::mixmax RNG;
RNG rng(args.get_seed());
rng.discard(DISCARD_STRIDE * (cid - 1));
Rcpp::IntegerVector x(n);
for (int i = 0; i < n; ++i)
x[i] = i;
if (n < 2) return x;
for (int i = n; --i != 0; ) {
boost::random::uniform_int_distribution<int> uid(0, i);
int j = uid(rng);
int temp = x[i];
x[i] = x[j];
x[j] = temp;
}
SEXP __sexp_result;
PROTECT(__sexp_result = Rcpp::wrap(x));
UNPROTECT(1);
return __sexp_result;
END_RCPP
}

/**
* Read comments
* @param file The filename (a character string in R)
Expand Down
2 changes: 0 additions & 2 deletions rstan/rstan/src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ SEXP effective_sample_size(SEXP sim, SEXP n_);
SEXP effective_sample_size2(SEXP sims);
SEXP split_potential_scale_reduction(SEXP sim, SEXP n_);
SEXP split_potential_scale_reduction2(SEXP sims_);
SEXP seq_permutation(SEXP conf);
SEXP CPP_read_comments(SEXP file, SEXP n);
SEXP stan_prob_autocovariance(SEXP v);
SEXP is_Null_NS(SEXP ns);
Expand All @@ -64,7 +63,6 @@ static const R_CallMethodDef CallEntries[] = {
CALLDEF(effective_sample_size2, 1),
CALLDEF(split_potential_scale_reduction, 2),
CALLDEF(split_potential_scale_reduction2, 1),
CALLDEF(seq_permutation, 1),
CALLDEF(CPP_read_comments, 2),
CALLDEF(stan_prob_autocovariance, 1),
CALLDEF(is_Null_NS, 1),
Expand Down
10 changes: 0 additions & 10 deletions rstan/rstan/tests/testthat/test-chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,3 @@ test_that("ess and rhat work", {
expect_equal(rhat2, 1.003782, tolerance = 0.001);
})

test_that("seq_perm works", {
# n, chains, seed, chain_id = 1
s <- rstan:::rstan_seq_perm(1, 4, 12345, 1)
expect_equal(s, 1L)
s2 <- rstan:::rstan_seq_perm(10, 4, 12345)
expect_equal(length(s2), 10L)
expect_equal(sort(s2), 1L:10L)
s3 <- rstan:::rstan_seq_perm(107, 4, 12345)
expect_equal(sort(s3), 1L:107L)
})

0 comments on commit 1d17cb1

Please sign in to comment.