Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference verification plot #14

Merged
merged 14 commits into from
Apr 12, 2024
89 changes: 89 additions & 0 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Third-party
import os
import cartopy.feature as cf
import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -8,6 +9,7 @@
# First-party
from neural_lam import constants, utils
from neural_lam.rotate_grid import unrotate_latlon
from neural_lam.weather_dataset import WeatherDataModule


@matplotlib.rc_context(utils.fractional_plot_bundle(1))
Expand Down Expand Up @@ -179,3 +181,90 @@ def plot_spatial_error(error, title=None, vrange=None):
fig.suptitle(title, size=10)

return fig

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def verify_inference(file_path:str, feature_channel:int, vrange=None, save_path=None):
"""
Plot example prediction, forecast, and ground truth.
Each has shape (N_grid,)
"""

# Verify that feature channel is between 0 and 42
if not 0 <= feature_channel <= 42:
clechartre marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("feature_channel must be between 0 and 42, inclusive.")

# Load the inference dataset for plotting
predictions_data_module = WeatherDataModule(
"cosmo",
path_verif_file=file_path,
split="verification",
standardize=False,
subset=False,
batch_size=6,
num_workers=2
)
predictions_data_module.setup(stage='verification')
predictions_loader = predictions_data_module.verification_dataloader()
for predictions_batch in predictions_loader:
predictions = predictions_batch[0] # tensor
break

# get test data
data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0)
lon, lat = unrotate_latlon(data_latlon)

# Get common scale for values
total = predictions[0,:,:,feature_channel]
total_array = np.array(total)
if vrange is None:
vmin = total_array.min()
vmax = total_array.max()
else:
vmin, vmax = float(vrange[0].cpu().item()), float(vrange[1].cpu().item())

# Plot
for i in range(23):
clechartre marked this conversation as resolved.
Show resolved Hide resolved

feature_array = predictions[0,i,:,feature_channel].reshape(*constants.GRID_SHAPE[::-1]).cpu().numpy()
data_array = np.array(feature_array)

fig, axes = plt.subplots(
1,
1,
figsize=constants.FIG_SIZE,
subplot_kw={"projection": constants.SELECTED_PROJ},
)

contour_set = axes.contourf(
lon,
lat,
data_array,
transform=constants.SELECTED_PROJ,
cmap="plasma",
levels=np.linspace(vmin, vmax, num=100),
)
axes.add_feature(cf.BORDERS, linestyle="-", edgecolor="black")
axes.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black")
axes.gridlines(
crs=constants.SELECTED_PROJ,
draw_labels=False,
linewidth=0.5,
alpha=0.5,
)

# Ticks and labels
axes.set_title("Predictions from model inference", size = 15)
axes.text(0.5, 1.05, f"Feature channel {feature_channel}, time step {i}",
ha='center', va='bottom', transform=axes.transAxes, fontsize=12
)
cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20)
cbar.ax.tick_params(labelsize=10)

# Save the plot!
if save_path:
directory = os.path.dirname(save_path)
if not os.path.exists(directory):
os.makedirs(directory)
plt.savefig(save_path + f"_feature_channel_{feature_channel}_"+ f"{i}.png", bbox_inches='tight')

return fig
Loading
Loading