Source code for VectorPostProcessing.save_mesh_outputs_as_png

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import netCDF4 as nc
from datetime import timedelta
import xarray as xr
import matplotlib

[docs] def save_mesh_outputs_as_png( shapefile_path, netcdf_dir, ddb_path, varnames, filenames, cbar_labels, outdir, indices_to_save, mode='monthly', domain_name='Basin', cmap='gnuplot2_r', comid_field='COMID', # <-- New optional argument form shapefile subbasin_var='subbasin' # <-- New optional argument for subbasin variable in drainage database ): """ Generate and save static PNG plots of MESH model output variables for specific time slices. This function overlays MESH model output (e.g., discharge, snow, temperature) from NetCDF files onto a shapefile representing subbasin polygons. For each variable and selected time index, a PNG figure is generated with consistent colorbar scales and custom titles based on domain, variable, and time. Parameters ---------- shapefile_path : str Path to the shapefile (.shp) representing the subbasins with a 'COMID' field. netcdf_dir : str Directory containing the NetCDF output files. ddb_path : str Path to the NetCDF drainage database containing the 'subbasin' variable used for merging. varnames : list of str List of variable names to extract from each NetCDF file (e.g., ['QO', 'SNO']). filenames : list of str List of NetCDF filenames corresponding to each variable in `varnames`. cbar_labels : list of str List of labels for the colorbars corresponding to each variable. outdir : str Directory where output PNG figures will be saved. indices_to_save : list of int List of time indices to extract and plot (e.g., [0, 1, 5]). mode : {'daily', 'monthly', 'yearly', 'hourly'}, optional Time resolution of the data for labeling the figures. Default is 'monthly'. domain_name : str, optional Name of the domain used in the figure title. Default is 'Basin'. cmap : str, optional Matplotlib colormap to use for the plots. Default is 'gnuplot2_r'. subbasin_var : str, optional Name of the subbasin variable in the NetCDF drainage database (default: 'subbasin'). Returns ------- None Saves PNG image files to the specified output directory. Raises ------ ValueError If an unsupported mode is provided for the `mode` parameter. Notes ----- - Assumes that each NetCDF file contains a 'time' dimension and that values begin from index 1 (skipping index 0). - Automatically adjusts date labeling based on the selected time `mode`. - Supports automatic detection of layer (e.g., 'Layer1', 'Layer2') from filenames using 'IG1', 'IG2', etc. Example ------- >>> save_mesh_outputs_as_png( ... shapefile_path='shapes/sras_subbasins.shp', ... netcdf_dir='outputs/monthly', ... ddb_path='outputs/MESH_drainage_database.nc', ... varnames=['QO', 'SNO'], ... filenames=['QO_Y_GRD.nc', 'SNO_Y_GRD.nc'], ... cbar_labels=['Discharge [m³/s]', 'Snow [mm]'], ... outdir='outputs/pngs', ... indices_to_save=[0, 3, 6], ... mode='monthly', ... domain_name='SrAs', ... cmap='viridis', ... subbasin_var='subbasin' ... ) """ # Configure global font settings for plots font = {'family': 'DejaVu Serif', 'weight': 'bold', 'size': 24} matplotlib.rc('font', **font) # Ensure the output directory exists; if not, create it os.makedirs(outdir, exist_ok=True) # Open the drainage database (ddb) to extract subbasin IDs db = xr.open_dataset(ddb_path) segid = db[subbasin_var].values #print(f"Subbasin IDs (segid): {segid[:5]}...") # Print first 5 IDs for debugging db.close() # Create a pandas DataFrame for merging values later df = pd.DataFrame({'ID': segid}) #print(f"Head of DataFrame:\n{df.head()}") # Read shapefile, sort by COMID for consistent ordering, reset index shp = gpd.read_file(shapefile_path).sort_values(by=comid_field).reset_index(drop=True) #print(f"Head of Shapefile DataFrame:\n{shp.head()}") # Use the first NetCDF file to extract time information example_file = os.path.join(netcdf_dir, filenames[0]) # filenames[0] = first file with nc.Dataset(example_file) as ds: time_var = ds.variables['time'] # time dimension variable time_units = time_var.units.replace('years', 'days') # Replace 'years' unit with 'days' fallback if needed # Read raw time values; if in years convert to days (365 days/year) times = time_var[:] * 365 if 'years' in time_var.units else time_var[:] # Calendar attribute or default to 'standard' calendar = getattr(time_var, 'calendar', 'standard') # Convert numeric time to datetime objects dates = nc.num2date(times, units=time_units, calendar=calendar) starting_date = dates[0] # Precompute global min/max for each variable to use consistent color scale global_min_max = {} for i, fname in enumerate(filenames): # i is index into filenames/varnames with nc.Dataset(os.path.join(netcdf_dir, fname)) as ds: data = ds.variables[varnames[i]][:] # all data for this variable # nanmin and nanmax ignore NaNs global_min_max[varnames[i]] = (np.nanmin(data), np.nanmax(data)) # Loop over each variable by index for i in range(len(varnames)): # Loop over each requested time index for idx in indices_to_save: # idx refers to time slice, pay attention to indexing here fig, ax = plt.subplots(figsize=(20, 20)) # create new figure and axes # Open the file for this variable and extract the slice at idx with nc.Dataset(os.path.join(netcdf_dir, filenames[i])) as ds: # ds.variables[varnames[i]][idx, :] selects time index idx and all spatial values values = ds.variables[varnames[i]][idx, :] # Copy shapefile GeoDataFrame for merging merge_df = shp.copy() # Assign extracted values into df, aligning by array order df['value'] = values # Merge on COMID (left) and ID (right) to attach values to geoms merged = merge_df.merge(df, left_on=comid_field, right_on='ID', how='left') #print(f"Head of Merged DataFrame:\n{merged.head()}") # ────────── SPOT-CHECK DIAGNOSTICS ────────── # pick three example row positions: first, middle, last sample_positions = [0, len(merge_df)//2, len(merge_df)-1] sample_ids = merge_df[comid_field].iloc[sample_positions].tolist() print(f"\nSpot-check for '{varnames[i]}' at time index {idx}:") for sid in sample_ids: # original value in df (aligned by ID from the DDB array) orig_val = df.loc[df['ID'] == sid, 'value'].values[0] # merged value on the GeoDataFrame # merged_val = merged.loc[merged[comid_field] == sid, 'value'].values[0] match = df.loc[df['ID'] == sid, 'value'] if match.empty: print(f" COMID {sid} not found in df['ID']") orig_val = np.nan else: orig_val = match.values[0] print(f" COMID {sid}: original={orig_val} ⟶ merged={merged_val}") print("──────── end spot-check ────────\n") # ──────────────────────────────────────── # Retrieve precomputed global min and max for color scaling mn, mx = global_min_max[varnames[i]] # Detect layer from filename, e.g., 'IG1' -> 'Layer1' if 'IG1' in filenames[i]: layer = 'Layer1' elif 'IG2' in filenames[i]: layer = 'Layer2' elif 'IG3' in filenames[i]: layer = 'Layer3' else: layer = None # Determine the date label based on mode and index if mode == 'yearly': date = starting_date + timedelta(days=365.25 * idx) label = date.strftime('%Y') elif mode == 'monthly': date = starting_date + timedelta(days=30 * idx) label = date.strftime('%Y-%m') elif mode == 'daily': date = starting_date + timedelta(days=idx) label = date.strftime('%Y-%m-%d') elif mode == 'hourly': date = starting_date + timedelta(hours=idx) label = date.strftime('%Y-%m-%d_%H:%M') else: # Raise error if mode is unsupported raise ValueError("Invalid mode. Choose from 'daily', 'monthly', 'yearly', or 'hourly'.") # Construct plot title including domain, variable, optional layer, and date label title = f'{domain_name}_MESH_{varnames[i]}{"_" + layer if layer else ""}_{label}' ax.set_title(title) ax.set_xlabel('Longitude') ax.set_ylabel('Latitude') # Create normalization and colormap objects norm = colors.Normalize(vmin=mn, vmax=mx) if mn != mx else None cmap_used = plt.cm.get_cmap(cmap) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap_used) # Plot the merged GeoDataFrame, coloring by 'value' merged.plot(column='value', cmap=cmap, edgecolor='k', linewidth=0.1, ax=ax, vmin=mn, vmax=mx) # Create an colorbar axis on the right cbar_ax = fig.add_axes([0.82, 0.2, 0.02, 0.6]) fig.colorbar(sm, cax=cbar_ax, orientation='vertical', extend='max') cbar_ax.set_ylabel(cbar_labels[i]) # Adjust subplot layout to accommodate colorbar fig.subplots_adjust(left=0.05, right=0.80, top=0.95, bottom=0.05) # Filename includes variable, mode, zero-padded idx, and sanitized label filename = f"{varnames[i]}_{mode}_frame_{idx:03d}_{label.replace(':','-')}.png" fig.savefig(os.path.join(outdir, filename), dpi=300) plt.close(fig)