Skip to content

Commit

Permalink
1. Add gpu to model config; 2. force event type tensor to be int32
Browse files Browse the repository at this point in the history
  • Loading branch information
alilevy committed Sep 10, 2023
1 parent de24c0c commit d218f64
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions easy_tpp/config_factory/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(self, **kwargs):
self.pad_token_id = kwargs.get('event_pad_index', None)
self.model_id = kwargs.get('model_id', None)
self.pretrained_model_dir = kwargs.get('pretrained_model_dir', None)
self.gpu = kwargs.get('gpu', -1)
self.model_specs = kwargs.get('model_specs', {})

def get_yaml_config(self):
Expand All @@ -238,6 +239,7 @@ def get_yaml_config(self):
'event_pad_index': self.pad_token_id,
'model_id': self.model_id,
'pretrained_model_dir': self.pretrained_model_dir,
'gpu': self.gpu,
'model_specs': self.model_specs}

@staticmethod
Expand Down Expand Up @@ -272,4 +274,5 @@ def copy(self):
num_event_types=self.num_event_types,
event_pad_index=self.pad_token_id,
pretrained_model_dir=self.pretrained_model_dir,
gpu=self.gpu,
model_specs=self.model_specs)
1 change: 1 addition & 0 deletions easy_tpp/config_factory/runner_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def update_config(self):
current_stage = get_stage(self.base_config.stage)
is_training = current_stage == RunnerPhase.TRAIN
self.model_config.is_training = is_training
self.model_config.gpu = self.trainer_config.gpu

# update the dataset config => model config
self.model_config.num_event_types_pad = self.data_config.data_specs.num_event_types_pad
Expand Down
2 changes: 1 addition & 1 deletion easy_tpp/preprocess/event_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def make_type_mask_for_pad_sequence(self, pad_seqs):
np.array: a 3-dim matrix, where the last dim (one-hot vector) indicates the type of event
"""
type_mask = np.zeros([*pad_seqs.shape, self.num_event_types])
type_mask = np.zeros([*pad_seqs.shape, self.num_event_types], dtype=np.int32)
for i in range(self.num_event_types):
type_mask[:, :, i] = pad_seqs == i

Expand Down

0 comments on commit d218f64

Please sign in to comment.