diff --git a/src/vak/metrics/distance/functional.py b/src/vak/metrics/distance/functional.py index e429e3a82..8683afa47 100644 --- a/src/vak/metrics/distance/functional.py +++ b/src/vak/metrics/distance/functional.py @@ -1,4 +1,5 @@ import numpy as np +import torch def levenshtein(source, target): @@ -65,7 +66,7 @@ def levenshtein(source, target): d0, d1 = d1, d0 - return d0[-1] + return torch.tensor(d0[-1], dtype=torch.int32) def segment_error_rate(y_pred, y_true): @@ -95,4 +96,5 @@ def segment_error_rate(y_pred, y_true): "segment error rate is undefined when length of y_true is zero" ) - return levenshtein(y_pred, y_true) / len(y_true) + rate = levenshtein(y_pred, y_true) / len(y_true) + return torch.tensor(rate, dtype=torch.float32)