-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
149 lines (123 loc) · 6.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Here we implement some CP / training helper functions.
"""
import numpy as np
import torch.nn
import torch
import conformal_prediction as cp
from uncertainty_functions import smx_entropy
import evaluation
def pinball_loss_grad(y: float, yhat: np.ndarray, q: float) -> np.ndarray:
"""
Compute the gradient of the pinball loss function.
:param y: True values.
:param yhat: Predicted values.
:param q: Quantile level for the loss calculation.
:return: Gradient of the pinball loss.
"""
return -q * (y > yhat) + (1 - q) * (y < yhat)
def split_conformal(results: list[dict],
cal_path: str,
alpha: float,
cp_method: str) -> tuple[list[dict], float, float, float, np.ndarray, np.ndarray]:
"""
Perform split conformal prediction and obtain the conformal threshold.
:param results: List of dicts that will store metrics of interest.
:param cal_path: Path to saved softmax outputs / labels from the calibration dataset.
:param alpha: Target error level for conformal prediction; coverage is 1 - alpha.
:param cp_method: Which cp method to use
:param cp_method: The conformal prediction method to use ('thr' or 'raps').
:return: Updated results list, conformal threshold (tau_thr), upper and lower entropy quantiles (upper_q, lower_q),
and the calibration softmax scores and labels (cal_smx, cal_labels).
"""
# # # # # # # # CALIBRATION # # # # # # #
print('Calibrating conformal')
# start by loading and calibrating on imagenet1k validation set
data = np.load(cal_path)
smx = data['smx'] # get softmax scores
labels = data['labels'].astype(int)
# Split the softmax scores into calibration and validation sets
n = int(len(labels) * 0.5)
idx = np.array([1] * n + [0] * (smx.shape[0] - n)) > 0
np.random.shuffle(idx)
cal_smx, val_smx = smx[idx, :], smx[~idx, :]
cal_labels, val_labels = labels[idx], labels[~idx]
# find quantiles for the entropy of the prediction distribution
cal_ent, val_ent = smx_entropy(torch.Tensor(cal_smx)).numpy(), smx_entropy(torch.Tensor(val_smx)).numpy()
# In split conformal regression, the prediction interval is constructed such that it covers the true value with
# probability 1 - alpha. To achieve this, the interval bounds are set as quantiles of the calibration residuals,
# where the lower bound corresponds to alpha/2 and the upper bound to 1 - alpha/2. This ensures symmetric coverage
# around the true value, providing a balanced prediction interval.
# NOTE THAT our method does not use the lower bound, only the upper bound corresponding to \beta=1-\alpha
# See Eq. 3 in the paper. lower_q is not necessary for out methods; upper_q can serve as an initial starting point
lower_q = np.quantile(cal_ent, alpha / 2)
upper_q = np.quantile(cal_ent, 1 - alpha / 2)
# use this to form a prediction interval & check coverage
pred_int = ((lower_q <= val_ent) & (val_ent <= upper_q)).sum() # entropy should be within these quantiles
print(f'Entropy Coverage on validation set: {pred_int / len(val_ent)}')
# evaluate accuracy
acc_cal = evaluation.compute_accuracy(val_smx, val_labels)
# calibrate on imagenet calibration set
if cp_method == 'thr':
tau_thr = cp.calibrate_threshold(cal_smx, cal_labels, alpha) # get conformal quantile
elif cp_method == 'raps':
tau_thr = cp.calibrate_raps(cal_smx, cal_labels, alpha, k_reg=5, lambda_reg=0.01, rng=True)
else:
raise ValueError('CP method not supported choose from [thr, raps]')
# get confidence sets
if cp_method == 'thr':
conf_set_thr = cp.predict_threshold(val_smx, tau_thr)
elif cp_method == 'raps':
conf_set_thr = cp.predict_raps(val_smx, tau_thr, k_reg=5, lambda_reg=0.01, rng=True)
else:
raise ValueError('CP method not supported choose from [thr, raps]')
# evaluate coverage
cov_thr_in1k = float(evaluation.compute_coverage(conf_set_thr, val_labels))
# evaluate set size
size_thr_in1k, _ = evaluation.compute_size(conf_set_thr)
print(f'Accuracy on Calibration data: {acc_cal}')
print(f'Coverage on Calibration data: {cov_thr_in1k}')
print(f'Inefficiency on Calibration data: {size_thr_in1k}')
results_dict = {
'update': 'calibration',
'cal_acc': acc_cal,
'cal_cov': cov_thr_in1k,
'cal_size': size_thr_in1k
}
results.append(results_dict)
return results, tau_thr, upper_q, lower_q, cal_smx, cal_labels
def update_beta_online(output_ent: torch.Tensor, beta: float, alpha: float) -> float:
"""
Update the estimated \beta entropy quantile online for use in adapting the conformal prediction threshold, see
Eq. 3 of the paper.
:param output_ent: Entropy of the output predictions.
:param beta: Entropy quantile estimate.
:param alpha: Target error level (1 - alpha = coverage).
:return: Updated entropy quantile.
"""
# update the beta entropy quantile using pinball loss
loss = pinball_loss_grad(beta, output_ent.cpu().detach().numpy(), alpha).mean()
beta += loss
return beta
def update_beta_batch(output_ent: torch.Tensor, alpha: float) -> float:
"""
Instead of updating the \beta quantile online, we can simply use the entropy quantile on a particular batch of data
(or the entire dataset if available). On a large enough batch size, the difference with online estimate is
negligible.
:param output_ent: Entropy of the output predictions.
:param alpha: Target error level (1 - alpha = coverage).
:return: Entropy quantile of the batch / dataset.
"""
# Find the entropy quantile on the batch of data
upper_q = np.quantile(output_ent.cpu().detach().numpy(), 1 - alpha)
return upper_q
def t2sev(t, run_length=7, schedule=None):
"""
Time step to severity level, for continious shifts.
"""
t_base = t
if schedule == "gradual":
k = (t_base // run_length) % 10
return k if k <= 5 else 10 - k
else:
return 5 * ((t_base // run_length) % 2) # default: sudden schedule