###############################################################
# ANUGA_hydro.py
###############################################################
from datetime import datetime
import numpy as np
import xarray as xr
from pathlib import Path
import json
from dycove.sim.base import HydroSimulationBase, HydroEngineBase
from dycove.sim.engines.ANUGA_baptist import Baptist_operator
from dycove.utils.simulation_reporting import Reporter
from dycove.constants import H_LIM_VELOCITY
r = Reporter()
def _import_anuga():
""" Lazy loading of ANUGA parallel methods to avoid import errors when ANUGA will not be tested/used. """
try:
from anuga import myid, numprocs, finalize, barrier
return myid, numprocs, finalize, barrier
except ImportError:
msg = ("The `anuga` package is not installed. "
"Refer to the documentation for installation instructions.")
r.report(msg, level="ERROR")
raise ImportError(msg)
[docs]
class ANUGA(HydroSimulationBase):
"""
Hydrodynamic simulation wrapper for the ANUGA model.
This class connects the generic :class:`~dycove.sim.base.HydroSimulationBase`
interface to the ANUGA hydrodynamic engine
:class:`~dycove.sim.base.engines.ANUGA_hydro.AnugaEngine`.
Notes
-----
- The ANUGA ``domain`` object is expected to be pre-constructed by the user.
This preserves the typical ANUGA workflow where domains are created
directly in Python scripts.
- All higher-level logic that can be abstracted from the engine classes is
handled in :class:`~dycove.sim.base.HydroSimulationBase`; all low-level
model interactions are delegated to
:class:`~dycove.sim.base.engines.ANUGA_hydro.AnugaEngine`.
"""
[docs]
def __init__(self, anuga_domain, vegetation=None):
# Build ANUGA engine
engine = AnugaEngine(anuga_domain, vegetation)
# Pass ANUGA engine to the base class
super().__init__(engine)
[docs]
class AnugaEngine(HydroEngineBase):
"""
Engine interface for ANUGA hydrodynamic model.
This engine:
- Holds the ANUGA domain object.
- Reads/writes flow and vegetation state directly through Python objects
(unlike BMI-based engine used for Delft3D FM).
- Contains all methods that are specific to the ANUGA model.
Parameters
----------
anuga_domain : anuga.domain
The pre-constructed computational domain.
vegetation : VegetationSpecies or MultipleVegetationSpecies, optional
Vegetation object passed down from the base simulation.
Notes
-----
- ANUGA runs in many short `domain.evolve()` loops (see
:meth:`~dycove.sim.engines.ANUGA_hydro.AnugaEngine.step`)
rather than one long simulation call; the ``skip_step`` mechanism
prevents duplicate first timesteps across loops.
- Parallel execution (if enabled) requires merging local vegetation states
after simulation; see
:meth:`~dycove.sim.engines.ANUGA_hydro.AnugaEngine.merge_parallel_veg`.
"""
[docs]
def __init__(self, anuga_domain, vegetation=None):
# Lazy anuga loading
self.myid, self.numprocs, self.finalize, self.barrier = _import_anuga()
self.domain = anuga_domain
self.model_dir = self.domain.get_datadir()
# Passing vegetation as attribute of the engine
self.veg = vegetation
# With DYCOVE-ANUGA, we run many consecutive "domain.evolve" loops, rather than just one big loop.
# We don't want to skip the first step for the first of these loops, but for every other loop, we do.
# Otherwise, we get repeated steps.
self.skip_step = False
# Interval (seconds) for saving ANUGA output, this is just a placeholder.
# Actual value is set via run_simulation() and used as argument to domain.evolve()
self.save_interval = 3600
[docs]
def initialize(self):
# ANUGA doesn't have an "initialize" method like DFM.
# But we can include some required steps here rather than just having an empty method.
if self.veg is not None:
self.Baptist = Baptist_operator(self.domain,
veg_diameter=0,
veg_density=0,
veg_height=0,
drag=self.veg.get_drag())
self.morphology = False
[docs]
def step(self, seconds):
# Normally, all processes in ANUGA would be performed within the domain.evolve() loop at each
# yieldstep, but for consistency across all potential models, we wrap it up here under the
# "step" method.
# If performing a "big" step, reduce yieldstep so it equals save_interval
yieldstep = min(seconds, self.save_interval)
self.barrier()
for t in self.domain.evolve(yieldstep=yieldstep,
outputstep=self.save_interval,
duration=seconds,
skip_initial_step=self.skip_step):
if self.myid == 0:
r.report(self.domain.timestepping_statistics())
# Skip initial yieldstep for all future loops, to avoid rerunning yieldsteps at "restart"
self.skip_step = True
# Enforces wait time for all cores so they catch up to each other when this is called (ignored if not parallel)
self.barrier()
[docs]
def cleanup(self):
if self.is_parallel():
self.domain.sww_merge(delete_old=True)
self.finalize()
[docs]
def get_refdate(self):
# Hardcoded for now, because DFM requires this in input file. TODO: fix or remove
return datetime(2001, 1, 1)
[docs]
def get_cell_count(self):
return len(self.get_elevation())
[docs]
def get_elevation(self):
return self.domain.quantities["elevation"].centroid_values
[docs]
def get_velocity_and_depth(self):
stage = self.domain.quantities["stage"].centroid_values
depth = stage - self.get_elevation()
xmom = self.domain.quantities["xmomentum"].centroid_values
ymom = self.domain.quantities["ymomentum"].centroid_values
# Convert depth-averaged momentum to velocity
with np.errstate(divide="ignore", invalid="ignore"):
# Ignore velocities where depth is insufficient
xvel = np.where(depth < H_LIM_VELOCITY, 0., xmom/depth)
yvel = np.where(depth < H_LIM_VELOCITY, 0., ymom/depth)
velocity = np.sqrt(xvel**2 + yvel**2)
return velocity, depth
[docs]
def get_vegetation(self):
# Pull directly from Baptist operator
stemdensity = self.Baptist.veg_density.centroid_values
stemdiameter = self.Baptist.veg_diameter.centroid_values
stemheight = self.Baptist.veg_height.centroid_values
return stemdensity, stemdiameter, stemheight
[docs]
def set_vegetation(self, stemdensity, stemdiameter, stemheight):
# Update Baptist operator with new quantities
self.Baptist.set_vegetation(veg_diameter=stemdiameter,
veg_density=stemdensity,
veg_height=stemheight,
)
# --------------------------------------------------------
# Parallel methods
# --------------------------------------------------------
[docs]
def get_rank(self):
return self.myid
[docs]
def is_parallel(self):
return True if self.numprocs > 1 else False
[docs]
def merge_parallel_veg(self, OutputManager):
outputdir = OutputManager.veg_dir
n_cohort_steps = OutputManager.n_cohort_steps
# Load in file that tracks all output files
with open(outputdir / "_cohort_files_ets_index.json", "r") as f:
file_index = json.load(f)
sww_name = self.domain.get_name()
base_name = sww_name[:-2] # works because my_id == 0 for call to this method
sww_dir = Path(self.domain.get_datadir())
sww_file_list = [f"{sww_dir}/{base_name}_{p}.sww" for p in range(self.numprocs)]
tri_l2g, tri_full_flag = [], []
for sww_file in sww_file_list:
with xr.open_dataset(sww_file) as sww:
tri_l2g.append(sww["tri_l2g"].values)
tri_full_flag.append(sww["tri_full_flag"].values)
n_global = int(max(tri.max() for tri in tri_l2g) + 1)
f_ids = [np.where(tri_full_flag[p] == 1)[0] for p in range(self.numprocs)]
f_gids = [tri_l2g[p][f_ids[p]] for p in range(self.numprocs)]
# for cohort_id in range(len(self.veg.cohorts)):
# for file_num in range(n_cohort_steps[cohort_id]):
for year in file_index:
for ets in file_index[year]:
for fname in file_index[year][ets]:
local_data = []
c_merged_attrs = None
for p in range(self.numprocs):
c_file_sub = outputdir / f"{fname}_proc{p}.nc"
if not c_file_sub.exists():
msg = (f"No individual processor (proc) output files found with name "
f"'{fname}_proc{p}.nc', but filename was found in the output "
"index file.")
r.report(msg, level="ERROR")
raise FileNotFoundError(msg)
with xr.open_dataset(c_file_sub) as c_sub:
if p == 0:
c_merged_attrs = dict(c_sub.attrs)
local_data.append(
{k: v.values for k, v in c_sub.data_vars.items()}
)
c_file_sub.unlink()
c_merged = self.merge_local_to_global(local_data=local_data,
f_gids=f_gids,
f_ids=f_ids,
n_global=n_global,
)
OutputManager.save_netcdf(
outputdir,
fname,
c_merged,
saved_attrs=c_merged_attrs
)
[docs]
@staticmethod
def merge_local_to_global(local_data: list[dict[str, np.ndarray]],
f_gids: list[np.ndarray],
f_ids: list[np.ndarray],
n_global: int,
) -> dict[str, np.ndarray]:
"""
Merge per-processor local arrays into global arrays using
local-to-global index mappings.
"""
merged: dict[str, np.ndarray] = {}
for p, local_vars in enumerate(local_data):
for key, values in local_vars.items():
if key not in merged:
merged[key] = np.zeros(n_global, dtype=float)
merged[key][f_gids[p]] = values[f_ids[p]]
return merged