Skip to content

Commit

Permalink
updated plot scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Mar 15, 2024
1 parent 35e6856 commit e7f846f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 232 deletions.
6 changes: 5 additions & 1 deletion src/graphics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ def setup_tueplots(
rel_width: float = 1.0,
hw_ratio: Optional[float] = None,
inc_font_size: int = 0,
use_tex: bool = True,
**kwargs
):
font_config = fonts.iclr2023_tex(family='serif')
if use_tex:
font_config = fonts.iclr2023_tex(family='serif')
else:
font_config = fonts.iclr2023(family='serif')
if hw_ratio is not None:
kwargs['height_to_width_ratio'] = hw_ratio
size = figsizes.iclr2023(rel_width=rel_width, nrows=nrows, ncols=ncols, **kwargs)
Expand Down
216 changes: 0 additions & 216 deletions src/scripts/plots/ring/ellipses.py

This file was deleted.

21 changes: 11 additions & 10 deletions src/scripts/plots/ring/pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,28 @@ def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes):
cov = np.diag(covs[:, i])
v, w = np.linalg.eigh(cov)
v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.7, fill=False)
ell_dot = mpl.patches.Circle(mu, radius=0.03, fill=True)
ell.set_color('red')
ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.8, fill=False)
ell_dot = mpl.patches.Circle(mu, radius=0.02, fill=True)
ell.set_color('#E53935')
if isinstance(mixture, MonotonicPC):
#ell.set_alpha(mix_weights[i])
#ell_dot.set_alpha(0.5 * mix_weights[i])
ell_dot.set_color('red')
ell.set_alpha(0.775)
ell_dot.set_alpha(0.775)
ell_dot.set_color('#E53935')
# ell.set_alpha(0.775)
# ell_dot.set_alpha(0.775)
else:
if mix_weights[i] <= 0.0:
#ell.set_alpha(min(1.0, 3 * np.abs(mix_weights[i])))
ell.set_linestyle('dotted')
ell.set_linewidth(1.5)
#ell_dot.set_alpha(0.5 * np.abs(mix_weights[i]))
#%ell_dot.set_color('red')
ell_dot.set_color('#E53935')
else:
#ell.set_alpha(mix_weights[i])
#ell_dot.set_alpha(0.5 * mix_weights[i])
ell_dot.set_color('red')
ell.set_alpha(0.85)
ell_dot.set_alpha(0.85)
ell_dot.set_color('#E53935')
# ell.set_alpha(0.85)
# ell_dot.set_alpha(0.85)
ax.add_artist(ell)
ax.add_artist(ell_dot)

Expand Down
18 changes: 13 additions & 5 deletions src/scripts/plots/wdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,38 @@
path = setup_experiment_path(
args.path, args.dataset, args.model, args.exp_alias, trial_id=build_run_id(args))
sd = torch.load(os.path.join(path, 'model.pt'), map_location='cpu')['weights']
print(sd.keys())

# Concatenate weights in a large vector
ws = list()
for k in sd.keys():
if 'layer' in k and 'weight' in k:
ws.append(sd[k].flatten().numpy())
# Select the parameters of CP layers only
if 'layer' in k and 'weight' in k and 'input' not in k and 'mixture' not in k:
w = sd[k]
if 'Born' in args.model: # Perform squaring
if len(w.shape) == 3: # CP layer
w = torch.einsum('fki,fkj->fkij', w, w)
else:
assert False, "This should not happen :("
ws.append(w.flatten().numpy())
ws = np.concatenate(ws, axis=0)

# Preprocess the weights, and set some flags
if 'Mono' in args.model:
mb = np.quantile(ws, q=[0.9999])
mb = np.quantile(ws, q=[0.99], method='lower')
ws = ws[ws <= mb]
ws = np.exp(ws)
hcol = 'C0'
elif 'Born' in args.model:
ma, mb = np.quantile(ws, q=[0.0005, 0.9995])
ma, mb = np.quantile(ws, q=[0.005, 0.995], method='lower')
ws = ws[(ws >= ma) & (ws <= mb)]
hcol = 'C1'
print(ws.shape)

# Compute and plot the instogram
setup_tueplots(1, 1, rel_width=0.25, hw_ratio=1.0)
hlabel = f'{format_model(args.model)}'
plt.hist(ws, density=True, bins=64, color=hcol, label=hlabel)
plt.hist(ws, bins=64, color=hcol, label=hlabel)
plt.yscale('log')
if args.legend:
plt.legend()
Expand Down

0 comments on commit e7f846f

Please sign in to comment.