diff --git a/compass/landice/mesh.py b/compass/landice/mesh.py index 9c2a125c5..d66ad3ee6 100644 --- a/compass/landice/mesh.py +++ b/compass/landice/mesh.py @@ -15,7 +15,8 @@ from mpas_tools.mesh.creation import build_planar_mesh from mpas_tools.mesh.creation.sort_mesh import sort_mesh from netCDF4 import Dataset -from scipy.interpolate import NearestNDInterpolator, interpn +from scipy.interpolate import interpn +from scipy.ndimage import distance_transform_edt def mpas_flood_fill(seed_mask, grow_mask, cellsOnCell, nEdgesOnCell, @@ -954,8 +955,7 @@ def add_bedmachine_thk_to_ais_gridded_data(self, source_gridded_dataset, return gridded_dataset_with_bm_thk -def preprocess_ais_data(self, source_gridded_dataset, - floodFillMask): +def preprocess_ais_data(self, source_gridded_dataset, floodFillMask): """ Perform adjustments to gridded AIS datasets needed for rest of compass workflow to utilize them @@ -973,13 +973,52 @@ def preprocess_ais_data(self, source_gridded_dataset, preprocessed_gridded_dataset : str name of NetCDF file with preprocessed version of gridded dataset """ - logger = self.logger + def _nearest_fill_from_valid(field2d, valid_mask): + """ + Fill invalid cells in a 2D regular raster using the value from the + nearest valid cell on the same grid. + + Parameters + ---------- + field2d : numpy.ndarray + 2D field to be filled + valid_mask : numpy.ndarray + Boolean mask where True marks valid cells + + Returns + ------- + filled : numpy.ndarray + Copy of field2d with invalid cells filled + """ + valid_mask = np.asarray(valid_mask, dtype=bool) + + if field2d.shape != valid_mask.shape: + raise ValueError('field2d and valid_mask must have the same shape') + + if not np.any(valid_mask): + raise ValueError('No valid cells available for nearest fill.') + + # For EDT, foreground=True cells get mapped to nearest background=False + # cell when return_indices=True. So we pass ~valid_mask. + nearest_inds = distance_transform_edt( + ~valid_mask, return_distances=False, return_indices=True + ) + + filled = np.array(field2d, copy=True) + invalid = ~valid_mask + filled[invalid] = field2d[ + nearest_inds[0, invalid], + nearest_inds[1, invalid] + ] + return filled + # Apply floodFillMask to thickness field to help with culling file_with_flood_fill = \ f"{source_gridded_dataset.split('.')[:-1][0]}_floodFillMask.nc" copyfile(source_gridded_dataset, file_with_flood_fill) + gg = Dataset(file_with_flood_fill, 'r+') gg.variables['thk'][0, :, :] *= floodFillMask gg.variables['vx'][0, :, :] *= floodFillMask @@ -989,15 +1028,19 @@ def preprocess_ais_data(self, source_gridded_dataset, # Now deal with the peculiarities of the AIS dataset. preprocessed_gridded_dataset = \ f"{file_with_flood_fill.split('.')[:-1][0]}_filledFields.nc" - copyfile(file_with_flood_fill, - preprocessed_gridded_dataset) + copyfile(file_with_flood_fill, preprocessed_gridded_dataset) + data = Dataset(preprocessed_gridded_dataset, 'r+') data.set_auto_mask(False) + x1 = data.variables["x1"][:] y1 = data.variables["y1"][:] - cellsWithIce = data.variables["thk"][:].ravel() > 0. + + thk = data.variables["thk"][0, :, :] + cellsWithIce = thk > 0.0 + data.createVariable('iceMask', 'f', ('time', 'y1', 'x1')) - data.variables['iceMask'][:] = data.variables["thk"][:] > 0. + data.variables['iceMask'][:] = data.variables["thk"][:] > 0.0 # Note: dhdt is only reported over grounded ice, so we will have to # either update the dataset to include ice shelves or give them values of @@ -1005,49 +1048,42 @@ def preprocess_ais_data(self, source_gridded_dataset, dHdt = data.variables["dhdt"][:] dHdtErr = 0.05 * dHdt # assign arbitrary uncertainty of 5% # Where dHdt data are missing, set large uncertainty - dHdtErr[dHdt > 1.e30] = 1. + dHdtErr[dHdt > 1.e30] = 1.0 # Extrapolate fields beyond region with ice to avoid interpolation - # artifacts of undefined values outside the ice domain - # Do this by creating a nearest neighbor interpolator of the valid data - # to recover the actual data within the ice domain and assign nearest - # neighbor values outside the ice domain - xGrid, yGrid = np.meshgrid(x1, y1) - xx = xGrid.ravel() - yy = yGrid.ravel() + # artifacts of undefined values outside the ice domain. + # + # The masks below are masks of valid cells. bigTic = time.perf_counter() for field in ['thk', 'bheatflx', 'vx', 'vy', 'ex', 'ey', 'thkerr', 'dhdt']: tic = time.perf_counter() - logger.info(f"Beginning building interpolator for {field}") + logger.info(f'Beginning nearest-fill preprocessing for {field}') + + field2d = data.variables[field][0, :, :] + if field in ['thk', 'thkerr']: - mask = cellsWithIce.ravel() + valid_mask = cellsWithIce elif field == 'bheatflx': - mask = np.logical_and( - data.variables[field][:].ravel() < 1.0e9, - data.variables[field][:].ravel() != 0.0) + valid_mask = np.logical_and(field2d < 1.0e9, field2d != 0.0) elif field in ['vx', 'vy', 'ex', 'ey', 'dhdt']: - mask = np.logical_and( - data.variables[field][:].ravel() < 1.0e9, - cellsWithIce.ravel() > 0) + valid_mask = np.logical_and(field2d < 1.0e9, cellsWithIce) else: - mask = cellsWithIce - interp = NearestNDInterpolator( - list(zip(xx[mask], yy[mask])), - data.variables[field][:].ravel()[mask]) - toc = time.perf_counter() - logger.info(f"Finished building interpolator in {toc - tic} seconds") + valid_mask = cellsWithIce + + logger.info(f'{field}: {valid_mask.sum()} valid cells, ' + f'{(~valid_mask).sum()} cells to fill') + + filled2d = _nearest_fill_from_valid(field2d, valid_mask) + data.variables[field][0, :, :] = filled2d - tic = time.perf_counter() - logger.info(f"Beginning interpolation for {field}") - # NOTE: Do not need to evaluate the extrapolator at all grid cells. - # Only needed for ice-free grid cells, since is NN extrapolation - data.variables[field][0, :] = interp(xGrid, yGrid) toc = time.perf_counter() - logger.info(f"Interpolation completed in {toc - tic} seconds") + logger.info(f'Nearest-fill preprocessing for {field} completed in ' + f'{toc - tic:.3f} seconds') bigToc = time.perf_counter() - logger.info(f"All interpolations completed in {bigToc - bigTic} seconds.") + logger.info(f'All nearest-fill preprocessing completed in ' + f'{bigToc - bigTic:.3f} seconds.') # Now perform some additional clean up adjustments to the dataset data.createVariable('dHdtErr', 'f', ('time', 'y1', 'x1')) @@ -1062,7 +1098,6 @@ def preprocess_ais_data(self, source_gridded_dataset, data.variables['subm'][:] *= -1.0 # correct basal melting sign data.variables['subm_ss'][:] *= -1.0 - data.renameVariable('dhdt', 'dHdt') data.renameVariable('thkerr', 'topgerr')