Skip to content

Commit

Permalink
updated plotting and logging code
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Mar 12, 2024
1 parent 397fd40 commit 1ca2b52
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 78 deletions.
7 changes: 5 additions & 2 deletions src/scripts/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def _eval_step(
test_ppl = perplexity(test_avg_ll, self.metadata['num_variables'])
self.logger.info(f"[{self.args.dataset}] Epoch {epoch_idx}, Test ppl: {test_ppl:.03f}")
self.logger.log_scalar('Test/ppl', test_ppl, step=epoch_idx)
if self._log_distribution:
self.logger.log_best_distribution(
self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device)
metrics['best_valid_epoch'] = epoch_idx
metrics['best_valid_avg_ll'] = valid_avg_ll
metrics['best_valid_std_ll'] = valid_std_ll
Expand Down Expand Up @@ -357,7 +360,7 @@ def run(self):

if self._log_distribution:
self.logger.save_array(self.metadata['hmap'], 'gt.npy')
self.logger.log_distribution(
self.logger.log_step_distribution(
self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device)

# The train loop
Expand Down Expand Up @@ -391,7 +394,7 @@ def run(self):
else (max(1, int(2e-1 * self.args.log_frequency)) if epoch_idx == 2
else self.args.log_frequency)) == 0:
if self._log_distribution:
self.logger.log_distribution(
self.logger.log_step_distribution(
self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device)
opt_counter += 1
if diverged:
Expand Down
23 changes: 20 additions & 3 deletions src/scripts/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from PIL import Image as pillow

from graphics.distributions import bivariate_pmf_heatmap, bivariate_pdf_heatmap
from pcs.models import PC
from pcs.models import PC, TensorizedPC


class Logger:
Expand All @@ -33,6 +34,7 @@ def __init__(
wandb_kwargs = dict()
self._setup_wandb(wandb_path, **wandb_kwargs)

self._best_distribution = None
self._logged_distributions = list()
self._logged_wcoords = list()

Expand Down Expand Up @@ -105,7 +107,21 @@ def log_hparams(
if wandb.run:
wandb.run.summary.update(metric_dict)

def log_distribution(
def log_best_distribution(
self,
model: PC,
discretized: bool,
lim: Tuple[Tuple[Union[float, int], Union[float, int]], Tuple[Union[float, int], Union[float, int]]],
device: Optional[Union[str, torch.device]] = None
):
xlim, ylim = lim
if discretized:
dist_hmap = bivariate_pmf_heatmap(model, xlim, ylim, device=device)
else:
dist_hmap = bivariate_pdf_heatmap(model, xlim, ylim, device=device)
self._best_distribution = dist_hmap.astype(np.float32, copy=False)

def log_step_distribution(
self,
model: PC,
discretized: bool,
Expand All @@ -121,7 +137,8 @@ def log_distribution(

def close(self):
if self._logged_distributions:
self.save_array(np.stack(self._logged_distributions, axis=0), 'distribution.npy')
self.save_array(self._best_distribution, 'distbest.npy')
self.save_array(np.stack(self._logged_distributions, axis=0), 'diststeps.npy')
if self._logged_wcoords:
self.save_array(np.stack(self._logged_wcoords, axis=0), 'wcoords.npy')
if self._tboard_writer is not None:
Expand Down
10 changes: 4 additions & 6 deletions src/scripts/plots/gpt2dist/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,19 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame:
if rows_to_keep is not None:
for r, vs in rows_to_keep.items():
model_df = model_df[model_df[r].isin(vs)]
model_df.to_csv(f'{model_name}-gpt2commongen-results.csv', index=None)
model_df.to_csv(f'gpt2commongen-results-{model_name}.csv', index=None)
group_model_df = model_df.groupby(by=['init_method', 'learning_rate'])
should_label = True
metrics[model_name] = defaultdict(list)
for j, hparam_df in group_model_df:
ms, ps = hparam_df[metric].tolist(), hparam_df['num_components'].tolist()
if len(np.unique(ms)) < num_points or len(np.unique(ps)) < num_points:
if len(np.unique(ps)) < num_points:
continue
ms = np.array(ms, dtype=np.float64)
ps = np.array(ps, dtype=np.int64)
sort_indices = np.argsort(ps)
ps = ps[sort_indices]
ms = ms[sort_indices]
ps = ps[:num_points]
ms = ms[:num_points]
for p, m in zip(ps.tolist(), ms.tolist()):
metrics[model_name][p].append(m)
if not args.median:
Expand All @@ -130,7 +128,7 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame:
assert len(model_names) == 2
model_a, model_b = model_names
spvalues = defaultdict(lambda: defaultdict(dict))
for ts in ['mannwithneyu', 'ttest']:
for ts in ['mannwithneyu']:
for al in ['greater']:
for k in sorted(metrics[model_a].keys() & metrics[model_b].keys()):
lls_a = metrics[model_a][k]
Expand All @@ -141,7 +139,7 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame:
s, p = stats.ttest_ind(lls_b, lls_a, alternative=al)
else:
assert False, "Should not happen :("
spvalues[ts][al][k] = (round(s, 3), round(p, 4))
spvalues[ts][al][k] = (round(s, 3), round(p, 7))
print(spvalues)

#if args.train:
Expand Down
6 changes: 3 additions & 3 deletions src/scripts/plots/ring/distgif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
parser.add_argument('--drop-last-frames', type=int, default=0, help="The number of last frames to drop")

"""
python -m scripts.plots.ring.distgif checkpoints/loss-landscape --drop-last-frames 164
python -m scripts.plots.ring.distgif checkpoints/gaussian-ring --drop-last-frames 224
"""


if __name__ == '__main__':
def to_rgb(x: np.ndarray, cmap: cm.ScalarMappable, cmap_transform: Callable[[np.ndarray], np.ndarray]) -> np.ndarray:
#x = x[51:-50, 51:-50]
x = (cmap.to_rgba(cmap_transform(x)) * 255.0).astype(np.uint8)[..., :-1]
x = (cmap.to_rgba(cmap_transform(x.T)) * 255.0).astype(np.uint8)[..., :-1]
if x.shape[0] != args.gif_size or x.shape[1] != args.gif_size:
x = cv2.resize(x, dsize=(args.gif_size, args.gif_size), interpolation=cv2.INTER_CUBIC)
return x
Expand All @@ -49,7 +49,7 @@ def to_rgb_image(x: np.ndarray) -> pillow.Image:
]
gt_array = np.load(os.path.join(checkpoint_paths[0], 'gt.npy'))
gt_array = np.broadcast_to(gt_array, (args.max_num_frames, gt_array.shape[0], gt_array.shape[1]))
arrays = map(lambda p: np.load(os.path.join(p, 'distribution.npy')), checkpoint_paths)
arrays = map(lambda p: np.load(os.path.join(p, 'diststeps.npy')), checkpoint_paths)
if args.drop_last_frames > 0:
arrays = map(lambda a: a[:-args.drop_last_frames], arrays)
arrays = [gt_array] + list(arrays)
Expand Down
38 changes: 20 additions & 18 deletions src/scripts/plots/ring/ellipses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@


def ring_kde() -> np.ndarray:
splits = load_artificial_dataset('ring', num_samples=50000, dtype=np.dtype(np.float64))
splits = load_artificial_dataset('ring', num_samples=500, dtype=np.dtype(np.float64))
data = np.concatenate(splits, axis=0)
scaler = StandardScaler()
data = scaler.fit_transform(data)
data_min, data_max = np.min(data, axis=0), np.max(data, axis=0)
drange = np.abs(data_max - data_min)
data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05)
#drange = np.abs(data_max - data_min)
#data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05)
xlim, ylim = [(data_min[i], data_max[i]) for i in range(len(data_min))]
return kde_samples_hmap(data, xlim=xlim, ylim=ylim, bandwidth=0.16)

Expand All @@ -52,7 +52,7 @@ def load_mixture(
metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000)
model: TensorizedPC = setup_model(model_name, metadata, num_components=num_components)
exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size)
filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model_name, exp_id, 'model.pt')
filepath = os.path.join(args.checkpoint_path, 'ring', model_name, exp_id, 'model.pt')
state_dict = torch.load(filepath, map_location='cpu')
model.load_state_dict(state_dict['weights'])
return model
Expand All @@ -66,7 +66,7 @@ def load_pdf(
batch_size: int = 64
) -> np.ndarray:
exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size)
filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model, exp_id, 'pdf.npy')
filepath = os.path.join(args.checkpoint_path, 'ring', model, exp_id, 'distbest.npy')
return np.load(filepath)


Expand Down Expand Up @@ -155,41 +155,43 @@ def plot_pdf(
models = [
'MonotonicPC',
'MonotonicPC',
'BornPC',
'MAF',
'NSF'
'BornPC'
]

num_components = [2, 16, 2, 128, 128]
learning_rates = [5e-3, 5e-3, 4e-3, 1e-3, 1e-3]
num_components = [2, 16, 2]
learning_rates = [5e-3, 5e-3, 1e-3]

exp_id_formats = [
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU',
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU',
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN',
'K{}_OAdam_LR{}_BS{}',
'K{}_OAdam_LR{}_BS{}'
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN'
]

truth_pdf = ring_kde()

mixtures = [
load_mixture(m, eif, nc, lr)
for m, eif, nc, lr in zip(models[:3], exp_id_formats, num_components, learning_rates)
] + [None, None]
for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates)
]

pdfs = [
load_pdf(m, eif, nc, lr)
for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates)
]
vmax = np.max(pdfs)
vmax = np.max([truth_pdf] + pdfs)
vmin = 0.0

metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000)

os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True)
for idx, (p, pdf, m, nc) in enumerate(zip(mixtures, pdfs, models, num_components)):
data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components))
for idx, (p, pdf, m, nc) in enumerate(data_pdfs):
setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0)
fig, ax = plt.subplots(1, 1)
title = f"{format_model_name(m, nc)}" if args.title else None
if args.title:
title = f"{format_model_name(m, nc)}" if p is not None else m
else:
title = None

plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax)
if p is not None:
Expand Down
Loading

0 comments on commit 1ca2b52

Please sign in to comment.