Skip to content

Commit

Permalink
Allow customized loss functions for membership inference attack.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 430267951
  • Loading branch information
shs037 authored and tensorflower-gardener committed Feb 22, 2022
1 parent 39fa1d3 commit ec7d442
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import glob
import os
import pickle
from typing import Any, Iterable, MutableSequence, Optional, Union
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -165,6 +165,12 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value))


class LossFunction(enum.Enum):
"""An enum that defines loss function to use in `AttackInputData`."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'


@dataclasses.dataclass
class AttackInputData:
"""Input data for running an attack.
Expand Down Expand Up @@ -196,6 +202,17 @@ class AttackInputData:
entropy_train: Optional[np.ndarray] = None
entropy_test: Optional[np.ndarray] = None

# If loss is not explicitly specified, this function will be used to derive
# loss from logits and labels. It can be a pre-defined `LossFunction`.
# If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction] = LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is
# preferred when both are available.
loss_function_using_logits: Optional[bool] = None

@property
def num_classes(self):
if self.labels_train is None or self.labels_test is None:
Expand Down Expand Up @@ -248,37 +265,70 @@ def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
true_labels]
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)

@staticmethod
def _get_loss(
loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction],
loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')

predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = utils.squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss

def get_loss_train(self):
"""Calculates (if needed) cross-entropy losses for the training set.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if self.loss_train is None:
if self.labels_train is None:
return None
if self.logits_train is not None:
self.loss_train = utils.log_loss_from_logits(self.labels_train,
self.logits_train)
else:
self.loss_train = utils.log_loss(self.labels_train, self.probs_train)
return self.loss_train
if self.loss_function_using_logits is None:
self.loss_function_using_logits = (self.logits_train is not None)
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
self.probs_train, self.loss_function,
self.loss_function_using_logits)

def get_loss_test(self):
"""Calculates (if needed) cross-entropy losses for the test set.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if self.loss_test is None:
if self.labels_test is None:
return None
if self.logits_test is not None:
self.loss_test = utils.log_loss_from_logits(self.labels_test,
self.logits_test)
else:
self.loss_test = utils.log_loss(self.labels_test, self.probs_test)
return self.loss_test
if self.loss_function_using_logits is None:
self.loss_function_using_logits = bool(self.logits_test)
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
self.probs_test, self.loss_function,
self.loss_function_using_logits)

def get_entropy_train(self):
"""Calculates prediction entropy for the training set."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from absl.testing import parameterized
import numpy as np
import pandas as pd

from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
Expand All @@ -48,9 +48,9 @@ def testStr(self, feature, value, expected_str):
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)


class AttackInputDataTest(absltest.TestCase):
class AttackInputDataTest(parameterized.TestCase):

def test_get_loss_from_logits(self):
def test_get_xe_loss_from_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
Expand All @@ -62,7 +62,7 @@ def test_get_loss_from_logits(self):
np.testing.assert_allclose(
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)

def test_get_loss_from_probs(self):
def test_get_xe_loss_from_probs(self):
attack_input = AttackInputData(
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]),
Expand All @@ -74,6 +74,130 @@ def test_get_loss_from_probs(self):
np.testing.assert_allclose(
attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7)

def test_get_binary_xe_loss_from_logits(self):
attack_input = AttackInputData(
logits_train=np.array([-10, -5, 0., 5, 10]),
logits_test=np.array([-10, -5, 0., 5, 10]),
labels_train=np.zeros((5,)),
labels_test=np.ones((5,)),
loss_function_using_logits=True)
expected_loss0 = np.array([0.000045398, 0.006715348, 0.6931471825, 5, 10])
np.testing.assert_allclose(
attack_input.get_loss_train(), expected_loss0, rtol=1e-2)
np.testing.assert_allclose(
attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2)

def test_get_binary_xe_loss_from_probs(self):
attack_input = AttackInputData(
probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
probs_test=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
labels_train=np.zeros((6,)),
labels_test=np.ones((6,)),
loss_function_using_logits=False)

expected_loss0 = np.array([
0.2231435513, 1.2039728043, 0.1053605157, 4.6051701860, 0.0020020027,
0.0080321717
])
expected_loss1 = np.array([
1.6094379124, 0.3566749439, 2.3025850930, 0.0100503359, 6.2146080984,
4.8283137373
])
np.testing.assert_allclose(
attack_input.get_loss_train(), expected_loss0, atol=1e-7)
np.testing.assert_allclose(
attack_input.get_loss_test(), expected_loss1, atol=1e-7)

@parameterized.named_parameters(
('use_logits', True, np.array([1, 0.]), np.array([0, 4.])),
('use_default', None, np.array([1, 0.]), np.array([0, 4.])),
('use_probs', False, np.array([0, 1.]), np.array([1, 1.])),
)
def test_get_squared_loss(self, loss_function_using_logits, expected_train,
expected_test):
attack_input = AttackInputData(
logits_train=np.array([0, 0.]),
logits_test=np.array([0, 0.]),
probs_train=np.array([1, 1.]),
probs_test=np.array([1, 1.]),
labels_train=np.array([1, 0.]),
labels_test=np.array([0, 2.]),
loss_function=LossFunction.SQUARED,
loss_function_using_logits=loss_function_using_logits,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
np.testing.assert_allclose(attack_input.get_loss_test(), expected_test)

@parameterized.named_parameters(
('use_logits', True, np.array([125.]), np.array([121.])),
('use_default', None, np.array([125.]), np.array([121.])),
('use_probs', False, np.array([458.]), np.array([454.])),
)
def test_get_customized_loss(self, loss_function_using_logits, expected_train,
expected_test):

def fake_loss(x, y):
return 2 * x + y

attack_input = AttackInputData(
logits_train=np.array([
123.,
]),
logits_test=np.array([
123.,
]),
probs_train=np.array([
456.,
]),
probs_test=np.array([
456.,
]),
labels_train=np.array([1.]),
labels_test=np.array([-1.]),
loss_function=fake_loss,
loss_function_using_logits=loss_function_using_logits,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
np.testing.assert_allclose(attack_input.get_loss_test(), expected_test)

@parameterized.named_parameters(
('both', np.array([0, 0.]), np.array([1, 1.]), np.array([1, 0.])),
('only_logits', np.array([0, 0.]), None, np.array([1, 0.])),
('only_probs', None, np.array([1, 1.]), np.array([0, 1.])),
)
def test_default_loss_function_using_logits(self, logits, probs, expected):
"""Tests for `loss_function_using_logits = None`. Should prefer logits."""
attack_input = AttackInputData(
logits_train=logits,
logits_test=logits,
probs_train=probs,
probs_test=probs,
labels_train=np.array([1, 0.]),
labels_test=np.array([1, 0.]),
loss_function=LossFunction.SQUARED,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
np.testing.assert_allclose(attack_input.get_loss_test(), expected)

@parameterized.parameters(
(None, np.array([1.]), True),
(np.array([1.]), None, False),
)
def test_loss_wrong_input(self, logits, probs, loss_function_using_logits):
attack_input = AttackInputData(
logits_train=logits,
logits_test=logits,
probs_train=probs,
probs_test=probs,
labels_train=np.array([
1.,
]),
labels_test=np.array([0.]),
loss_function_using_logits=loss_function_using_logits,
)
self.assertRaises(ValueError, attack_input.get_loss_train)
self.assertRaises(ValueError, attack_input.get_loss_test)

def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
loss_train=np.array([1.0, 3.0, 6.0]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,58 @@
from scipy import special


def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
"""Compute the cross entropy loss.
def log_loss(labels: np.ndarray,
pred: np.ndarray,
from_logits=False,
small_value=1e-8) -> np.ndarray:
"""Computes the per-example cross entropy loss.
Args:
labels: numpy array of shape (num_samples,) labels[i] is the true label
(scalar) of the i-th sample
pred: numpy array of shape(num_samples, num_classes) where pred[i] is the
probability vector of the i-th sample
labels: numpy array of shape (num_samples,). labels[i] is the true label
(scalar) of the i-th sample and is one of {0, 1, ..., num_classes-1}.
pred: numpy array of shape (num_samples, num_classes) or (num_samples,). For
categorical cross entropy loss, the shape should be (num_samples,
num_classes) and pred[i] is the logits or probability vector of the i-th
sample. For binary logistic loss, the shape should be (num_samples,) and
pred[i] is the probability of the positive class.
from_logits: whether `pred` is logits or probability vector.
small_value: a scalar. np.log can become -inf if the probability is too
close to 0, so the probability is clipped below by small_value.
Returns:
the cross-entropy loss of each sample
"""
classes = np.unique(labels)

# Binary logistic loss
if pred.ndim == 1:
if classes.min() < 0 or classes.max() > 1:
raise ValueError('Each value in pred is a scalar, but labels are not in',
'{0, 1}.')
if from_logits:
pred = special.expit(pred)

indices_class0 = (labels == 0)
prob_correct = np.copy(pred)
prob_correct[indices_class0] = 1 - prob_correct[indices_class0]
return -np.log(np.maximum(prob_correct, small_value))

# Multi-class categorical cross entropy loss
if classes.min() < 0 or classes.max() >= pred.shape[1]:
raise ValueError('labels should be in the range [0, num_classes-1].')
if from_logits:
pred = special.softmax(pred, axis=-1)
return -np.log(np.maximum(pred[range(labels.size), labels], small_value))


def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
"""Compute the cross entropy loss from logits."""
return log_loss(labels, special.softmax(logits, axis=-1))
def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
"""Computes the per-example squared loss.
Args:
y_true: numpy array of shape (num_samples,) representing the true labels.
y_pred: numpy array of shape (num_samples,) representing the predictions.
Returns:
the squared loss of each sample.
"""
return (y_true - y_pred)**2
Loading

0 comments on commit ec7d442

Please sign in to comment.