Skip to content

Commit

Permalink
Remove print
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jan 3, 2024
1 parent ccf7fb7 commit 4842a63
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def forward(self, x: Tensor) -> Tensor:
class ClassificationMixin:
def eval_step(self, batch: Tensor, name: str):
x, y, idx = batch
print('x eval.shape', x.shape)
print("x eval.shape", x.shape)
y_hat = self(x)
print('y_hat.shape', y_hat.shape)
print('y.shape',y.shape)
print("y_hat.shape", y_hat.shape)
print("y.shape", y.shape)
loss = self.loss(y_hat, y.flatten())

diff = torch.argmax(y_hat, dim=1) - y.flatten()
Expand Down Expand Up @@ -160,9 +160,7 @@ def eval_step(self, batch: Tensor, name: str):

class PredictionNetMixin:
def forward(self, x):
print('x.shape',x.shape)
ans = self.model(x)
print('prediction ans', ans.shape)
return ans

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -759,7 +757,7 @@ def select_network(cfg: DictConfig, device: str = None):
Only the input layer is high order, the rest
of the layers are standard linear+relu and normalization.
"""
print('using low order mlp')
print("using low order mlp")
layer_list = []
input_layer = torch.nn.Embedding(
num_embeddings=128,
Expand All @@ -779,7 +777,7 @@ def select_network(cfg: DictConfig, device: str = None):
hidden_layers=cfg.net.hidden.layers - 1,
non_linearity=torch.nn.ReLU(),
normalization=normalization,
#device=cfg.accelerator,
# device=cfg.accelerator,
)
layer_list.append(lower_layers)

Expand Down Expand Up @@ -810,7 +808,7 @@ def select_network(cfg: DictConfig, device: str = None):
hidden_layers=cfg.net.hidden.layers - 1,
non_linearity=torch.nn.ReLU(),
normalization=normalization,
#device=cfg.accelerator,
# device=cfg.accelerator,
)
layer_list.append(lower_layers)

Expand Down

0 comments on commit 4842a63

Please sign in to comment.