From 6e8c634437adfc5c20f32f304cc9a8bd84bc3e90 Mon Sep 17 00:00:00 2001 From: Kaleb-Wang <80630440+Kaleb-Wang@users.noreply.github.com> Date: Thu, 8 Aug 2024 20:12:05 -0700 Subject: [PATCH 1/2] Updated thinning algo according to NHP's Algorithm 2 --- easy_tpp/model/torch_model/torch_thinning.py | 75 ++++++++++---------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/easy_tpp/model/torch_model/torch_thinning.py b/easy_tpp/model/torch_model/torch_thinning.py index 3b73d80..8b690c9 100644 --- a/easy_tpp/model/torch_model/torch_thinning.py +++ b/easy_tpp/model/torch_model/torch_thinning.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from easy_tpp.utils import logger class EventSampler(nn.Module): @@ -34,6 +35,11 @@ def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary, def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn, compute_last_step_only): + # logger.critical(f'time_seq: {time_seq}') + # logger.critical(f'time_delta_seq: {time_delta_seq}') + # logger.critical(f'event_seq: {event_seq}') + # logger.critical(f'intensity_fn: {intensity_fn}') + # logger.critical(f'compute_last_step_only: {compute_last_step_only}') """Compute the upper bound of intensity at each event timestamp. Args: @@ -54,10 +60,10 @@ def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, int steps=self.num_samples_boundary, device=self.device)[None, None, :] - # [batch_size, seq_len, num_sample] + # [batch_size, seq_len, num_samples_boundary] dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled - # [batch_size, seq_len, num_sample, event_num] + # [batch_size, seq_len, num_samples_boundary, event_num] intensities_for_bound = intensity_fn(time_seq, time_delta_seq, event_seq, @@ -120,34 +126,46 @@ def sample_uniform_distribution(self, intensity_upper_bound): return unif_numbers - def sample_accept(self, unif_numbers, sample_rate, total_intensities): + def sample_accept(self, unif_numbers, sample_rate, total_intensities, exp_numbers): """Do the sample-accept process. - For each parallel draw, find its min criterion: if that < 1.0, the 1st (i.e. smallest) sampled time - with cri < 1.0 is accepted; if none is accepted, use boundary / maxsampletime for that draw + For the accumulated exp (delta) samples drawn for each event timestamp, find (from left to right) the first + that makes the criterion < 1 and accept it as the sampled next-event time. If all exp samples are rejected + (criterion >= 1), then we set the sampled next-event time dtime_max. Args: unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number. sample_rate (tensor): [batch_size, max_len], sample rate (intensity). total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp] + exp_numbers (tensor): [batch_size, seq_len, num_sample, num_exp]: sampled exp numbers (delta in Algorithm 2). Returns: - list: two tensors, - criterion, [batch_size, max_len, num_sample, num_exp] - who_has_accepted_times, [batch_size, max_len, num_sample] + result (tensor): [batch_size, seq_len, num_sample], sampled next-event times. """ # [batch_size, max_len, num_sample, num_exp] criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities - + + # [batch_size, max_len, num_sample, num_exp] + masked_crit_less_than_1 = torch.where(criterion<1,1,0) + # [batch_size, max_len, num_sample] - min_cri_each_draw, _ = criterion.min(dim=-1) - - # find out unif_numbers * sample_rate < intensity + non_accepted_filter = (1-masked_crit_less_than_1).all(dim=3) + # [batch_size, max_len, num_sample] - who_has_accepted_times = min_cri_each_draw < 1.0 - - return criterion, who_has_accepted_times + first_accepted_indexer = masked_crit_less_than_1.argmax(dim=3) + + # [batch_size, max_len, num_sample,1] + # indexer must be unsqueezed to 4D to match the number of dimensions of exp_numbers + result_non_accepted_unfiltered = torch.gather(exp_numbers, 3, first_accepted_indexer.unsqueeze(3)) + + # [batch_size, max_len, num_sample,1] + result = torch.where(non_accepted_filter.unsqueeze(3), torch.tensor(self.dtime_max), result_non_accepted_unfiltered) + + # [batch_size, max_len, num_sample] + result = result.squeeze(dim=-1) + + return result def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary, intensity_fn, compute_last_step_only=False): @@ -177,7 +195,8 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou # we apply fast approximation, i.e., re-use exp sample times for computation # [batch_size, seq_len, num_exp] exp_numbers = self.sample_exp_distribution(intensity_upper_bound) - + exp_numbers = torch.cumsum(exp_numbers, dim=-1) + # 3. compute intensity at sampled times from exp distribution # [batch_size, seq_len, num_exp, event_num] intensities_at_sampled_times = intensity_fn(time_seq, @@ -193,9 +212,10 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou # add one dim of num_sample: re-use the intensity for samples for prediction # [batch_size, seq_len, num_sample, num_exp] total_intensities = torch.tile(total_intensities[:, :, None, :], [1, 1, self.num_sample, 1]) + # [batch_size, seq_len, num_sample, num_exp] exp_numbers = torch.tile(exp_numbers[:, :, None, :], [1, 1, self.num_sample, 1]) - + # 4. draw uniform distribution # [batch_size, seq_len, num_sample, num_exp] unif_numbers = self.sample_uniform_distribution(intensity_upper_bound) @@ -203,16 +223,7 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou # 5. find out accepted intensities # criterion, [batch_size, max_len, num_sample, num_exp] # who_has_accepted_times, [batch_size, max_len, num_sample] - criterion, who_has_accepted_times = self.sample_accept(unif_numbers, intensity_upper_bound, - total_intensities) - - # 6. find out accepted dtimes - sampled_dtimes_accepted = exp_numbers.clone() - - # for unaccepted, use boundary/maxsampletime for that draw - sampled_dtimes_accepted[criterion >= 1.0] = exp_numbers.max() + 1.0 - - accepted_times_each_draw, accepted_id_each_draw = sampled_dtimes_accepted.min(dim=-1) + res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers) # 7. fill out result dtime_boundary_ = dtime_boundary[:, -1:] if compute_last_step_only else dtime_boundary @@ -220,19 +231,9 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou # [batch_size, seq_len, num_sample] dtime_boundary_ = torch.tile(dtime_boundary_[..., None], [1, 1, self.num_sample]) - # [batch_size, seq_len, num_sample] - res = torch.ones_like(dtime_boundary_) * dtime_boundary_ - # [batch_size, seq_len, num_sample] weights = torch.ones_like(dtime_boundary_) weights /= weights.sum(dim=-1, keepdim=True) - res[who_has_accepted_times] = accepted_times_each_draw[who_has_accepted_times] - who_not_accept = ~who_has_accepted_times - - who_reach_further = exp_numbers[..., -1] > dtime_boundary_ - - res[who_not_accept & who_reach_further] = exp_numbers[..., -1][who_not_accept & who_reach_further] - # add a upper bound here in case it explodes, e.g., in ODE models return res.clamp(max=1e5), weights From c67216fbd672c8e616bcb3c902a03e94bc24861e Mon Sep 17 00:00:00 2001 From: Kaleb-Wang <80630440+Kaleb-Wang@users.noreply.github.com> Date: Fri, 9 Aug 2024 09:06:23 -0700 Subject: [PATCH 2/2] deleted unnecessary calculations near the end of draw_next_time_one_step --- easy_tpp/model/torch_model/torch_thinning.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/easy_tpp/model/torch_model/torch_thinning.py b/easy_tpp/model/torch_model/torch_thinning.py index 8b690c9..7472195 100644 --- a/easy_tpp/model/torch_model/torch_thinning.py +++ b/easy_tpp/model/torch_model/torch_thinning.py @@ -221,19 +221,11 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou unif_numbers = self.sample_uniform_distribution(intensity_upper_bound) # 5. find out accepted intensities - # criterion, [batch_size, max_len, num_sample, num_exp] - # who_has_accepted_times, [batch_size, max_len, num_sample] - res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers) - - # 7. fill out result - dtime_boundary_ = dtime_boundary[:, -1:] if compute_last_step_only else dtime_boundary - # [batch_size, seq_len, num_sample] - dtime_boundary_ = torch.tile(dtime_boundary_[..., None], [1, 1, self.num_sample]) + res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers) # [batch_size, seq_len, num_sample] - weights = torch.ones_like(dtime_boundary_) - weights /= weights.sum(dim=-1, keepdim=True) - + weights = torch.ones_like(res)/res.shape[2] + # add a upper bound here in case it explodes, e.g., in ODE models return res.clamp(max=1e5), weights