Skip to content

Commit

Permalink
add tests - data loader test
Browse files Browse the repository at this point in the history
  • Loading branch information
iLampard committed Oct 20, 2024
1 parent ea238c3 commit 1f85043
Show file tree
Hide file tree
Showing 6 changed files with 1,045 additions and 3 deletions.
2 changes: 1 addition & 1 deletion easy_tpp/preprocess/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_dt_stats(self):
s_2_y = dts.var()
m = dts.shape[0]
n += m
# Formulat taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches
# Formula taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches
s_2_x = (((n - 1) * s_2_x + (m - 1) * s_2_y) / (n + m - 1)) + (
(n * m * ((x_bar - y_bar) ** 2)) / ((n + m) * (n + m - 1)))
x_bar = (n * x_bar + m * y_bar) / (n + m)
Expand Down
4 changes: 3 additions & 1 deletion easy_tpp/preprocess/event_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,9 @@ def _pad(
max_len=max_length,
dtype=np.int64)
else:
batch_output = encoded_inputs
batch_output[self.model_input_names[0]] = np.array(encoded_inputs[self.model_input_names[0]])
batch_output[self.model_input_names[1]] = np.array(encoded_inputs[self.model_input_names[1]])
batch_output[self.model_input_names[2]] = np.array(encoded_inputs[self.model_input_names[2]])

# non_pad_mask; replaced the use of event types by using the original sequence length
seq_pad_mask = np.full_like(batch_output[self.model_input_names[2]], fill_value=True, dtype=bool)
Expand Down
3 changes: 2 additions & 1 deletion examples/hf_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ def load_data_from_hf(hf_dir=None, local_dir=None):
if __name__ == '__main__':
# in case one fails to load from hf directly
# one can load the json data file locally
load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'})
# load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'})
load_data_from_hf(hf_dir='easytpp/taxi')
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit 1f85043

Please sign in to comment.