[Question] Classification for critical cases #1943
Replies: 3 comments 5 replies
-
Probably the best way to do so is to sub-class the Bernoulli likelihood and to re-adjust the weights of class zero as compared to class one. |
Beta Was this translation helpful? Give feedback.
-
something like this is going to be how you'd implement it. i just copied over the weighted bernoulli distribution from the bernoulli pytorch source. note that none of this is tested but it should be okay. from torch.nn.functional import binary_cross_entropy_with_logits
from torch.distributions import Normal
class WeightedBernoulli(torch.distributions.Bernoulli):
def __init__(self, weight, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight = weight
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
# return -binary_cross_entropy_with_logits(logits, value, reduction='none', weight=self.weight)
## edited for clarity, this is roughly binary_cross_entropy_with_logits
prob_one = torch.sigmoid(logits)
loss = self.weight * value * prob_one.log().clamp(min=-100.) + (1. - self.weight) * (1. - value) * (1. - prob_one).log().clamp(min=-100.)
return loss
class WeightedBernoulliLikelihood(gpytorch.likelihoods.BernoulliLikelihood):
def __init__(self, weight, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight = weight
def forward(self, function_samples, **kwargs):
output_probs = Normal(0, 1).cdf(function_samples)
return WeightedBernoulli(probs=output_probs, weight=self.weight)
def marginal(self, function_dist, **kwargs):
mean = function_dist.mean
var = function_dist.variance
link = mean.div(torch.sqrt(1 + var))
output_probs = Normal(0, 1).cdf(link)
return WeightedBernoulli(probs=output_probs, weight=self.weight)
likelihood = WeightedBernoulliLikelihood(weight=torch.ones(1))
y = likelihood(f) |
Beta Was this translation helpful? Give feedback.
-
Thank you, I'll try it out and come back to you. So the weight passed to the WeightedBernoulliLikelihood is the weight for class 0 or class 1? |
Beta Was this translation helpful? Give feedback.
-
Hi everyone,
I'm training an Approximate GP with Cholesky Variational Distribution and a Bernoulli Likelihood for binary classification. Can I somehow modify the loss function so the penalty is greater when predicting 1 when the value is 0 than when predicting 0 when the value is 1?
Thanks in advance
Beta Was this translation helpful? Give feedback.
All reactions