Skip to content

Commit

Permalink
updated ring pdfs plots
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Apr 2, 2024
1 parent 50ae243 commit 38191a2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
8 changes: 1 addition & 7 deletions src/scripts/plots/ring/distgif.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ def to_rgb_image(x: np.ndarray) -> pillow.Image:
f"{args.path}/ring/MonotonicPC/RGran_R1_K16_D1_Lcp_OAdam_LR0.005_BS64_IU",
f"{args.path}/ring/BornPC/RGran_R1_K2_D1_Lcp_OAdam_LR0.001_BS64_IN"
]
labels = [
'Ground Truth',
'GMM-2',
'GMM-16',
'NGMM-2'
]
gt_array = np.load(os.path.join(checkpoint_paths[0], 'gt.npy'))
gt_array = np.broadcast_to(gt_array, (args.max_num_frames, gt_array.shape[0], gt_array.shape[1]))
arrays = map(lambda p: np.load(os.path.join(p, 'diststeps.npy')), checkpoint_paths)
Expand Down Expand Up @@ -80,7 +74,7 @@ def to_rgb_image(x: np.ndarray) -> pillow.Image:
(int(0.5 * (x[1].shape[2] - cv2.getTextSize(x[0], font, fontscale, thickness)[0][0])),
int(0.5 * (caption_height + cv2.getTextSize(x[0], font, fontscale, thickness)[0][1]))),
font, fontscale, (16, 16, 16), thickness, cv2.LINE_AA), reps=(num_frames, 1, 1, 1))
], axis=1), zip(labels, arrays)
], axis=1), zip(['Ground Truth'] + ['GMM (K=2)', 'GMM (K=16)', 'NGMM (K=2)'], arrays)
)
gif_images = np.concatenate(list(arrays), axis=2)

Expand Down
16 changes: 7 additions & 9 deletions src/scripts/plots/ring/pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@
from scipy import special
import torch
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler

from datasets.loaders import load_artificial_dataset
from graphics.distributions import kde_samples_hmap
from graphics.utils import setup_tueplots
from pcs.models import TensorizedPC, PC, MonotonicPC
from scripts.utils import setup_model, setup_data_loaders

parser = argparse.ArgumentParser(
description="PDFs and ellipses plotter"
)
parser.add_argument('--checkpoint-path', default='checkpoints', type=str, help="The checkpoints path")
parser.add_argument('path', default='checkpoints', type=str, help="The checkpoints path")
parser.add_argument('--show-ellipses', default=False, action='store_true',
help="Whether to show the Gaussian components as ellipses")
parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title")
Expand All @@ -42,7 +39,7 @@ def load_mixture(
metadata, _ = setup_data_loaders('ring', 'datasets', 1)
model: TensorizedPC = setup_model(model_name, metadata, num_components=num_components)
exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size)
filepath = os.path.join(args.checkpoint_path, 'ring', model_name, exp_id, 'model.pt')
filepath = os.path.join(args.path, 'ring', model_name, exp_id, 'model.pt')
state_dict = torch.load(filepath, map_location='cpu')
model.load_state_dict(state_dict['weights'])
return model
Expand All @@ -56,7 +53,7 @@ def load_pdf(
batch_size: int = 64
) -> np.ndarray:
exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size)
filepath = os.path.join(args.checkpoint_path, 'ring', model, exp_id, 'distbest.npy')
filepath = os.path.join(args.path, 'ring', model, exp_id, 'distbest.npy')
return np.load(filepath)


Expand Down Expand Up @@ -110,7 +107,7 @@ def plot_pdf(
Optional[float] = None,
vmax: Optional[float] = None
):
pdf = pdf[8:-8, 8:-8]
#pdf = pdf[8:-8, 8:-8]

x_lim = metadata['domains'][0]
y_lim = metadata['domains'][1]
Expand Down Expand Up @@ -147,7 +144,7 @@ def plot_pdf(
]

truth_pdf = np.load(
os.path.join(args.checkpoint_path, 'ring', models[0],
os.path.join(args.path, 'ring', models[0],
exp_id_formats[0].format(num_components[0], learning_rates[0], 64), 'gt.npy')
)
# truth_pdf = ring_kde()
Expand Down Expand Up @@ -192,7 +189,8 @@ def plot_pdf(
ax.set_yticks([])
ax.set_aspect(1.0)
if args.title:
ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center')
#ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center')
ax.set_title(title, y=-0.275)

filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png'
plt.savefig(os.path.join('figures', 'gaussian-ring', filename), dpi=1200)

0 comments on commit 38191a2

Please sign in to comment.