# Instructions for developers:
#
# This is the main object of OceanSpy.
# All attributes are stored as global attributes (strings!) of the xr.Dataset.
# When users request an attribute, it is decoded from the global attributes.
# Thus, there are custom attribute setters (class setters are inhibited).
#
# There are private and public objects.
# Private objects use OceanSpy's reference aliases (_ds, _grid),
# while public objects are mirrors of the private objects using custom aliases.
#
# All functions in other modules that operate on od,
# must be added here in the shortcuts section.
#
# Add new attributes/methods in docs/api.rst
##############################################################################
# TODO: create list of OceanSpy name and add link under aliases.
# TODO: create a dictionary with parameters description and add under aliases.
# TODO: add more xgcm options. E.g., default boundary method.
# TODO: implement xgcm autogenerate in _set_coords,
# set_grid_coords, set_coords when released
# TODO: Use the coords parameter to create xgcm grid instead of
# _crate_grid.
# We will pass dictionary in xgcm.Grid,
# and we can have the option of usining comodo attributes
# (currently cleaned up so switched off)
##############################################################################
import copy as _copy
import sys as _sys
import warnings as _warnings
from collections import OrderedDict as _OrderedDict
import numpy as _np
# Required dependencies (private)
import xarray as _xr
# From OceanSpy (private)
from . import utils as _utils
from ._ospy_utils import (
_check_instance,
_check_list_of_string,
_check_oceanspy_axes,
_create_grid,
_rename_coord_attrs,
_setter_error_message,
)
from .animate import _animateMethods
from .compute import _computeMethods
from .plot import _plotMethods
from .subsample import _subsampleMethods
# Recommended dependencies (private)
try:
import cartopy.crs as _ccrs
except ImportError: # pragma: no cover
pass
try:
from scipy import spatial as _spatial
except ImportError: # pragma: no cover
pass
try:
from dask.diagnostics import ProgressBar as _ProgressBar
except ImportError: # pragma: no cover
pass
[docs]
class OceanDataset:
"""
OceanDataset combines a :py:obj:`xarray.Dataset`
with other objects used by OceanSpy (e.g., xgcm.Grid).
Additional objects are attached to the
:py:obj:`xarray.Dataset` as global attributes.
"""
[docs]
def __init__(self, dataset):
"""
Parameters
----------
dataset: xarray.Dataset
The multi-dimensional, in memory, array database.
References
----------
http://xarray.pydata.org/en/stable/generated/xarray.Dataset.html
"""
# Check parameters
_check_instance({"dataset": dataset}, "xarray.Dataset")
# Initialize dataset
self._ds = dataset.copy()
# Apply aliases
self = self._apply_aliases()
def __copy__(self):
"""
Shallow copy
"""
return OceanDataset(dataset=self.dataset.copy())
def __repr__(self):
main_info = ["<oceanspy.OceanDataset>"]
main_info.append("\nMain attributes:")
main_info.append(
" .dataset: %s"
% self.dataset.__repr__()[
self.dataset.__repr__().find("<") : self.dataset.__repr__().find(">")
+ 1
]
)
if self.grid is not None:
main_info.append(
" .grid: %s"
% self.grid.__repr__()[
self.grid.__repr__().find("<") : self.grid.__repr__().find(">") + 1
]
)
if self.projection is not None:
main_info.append(
" .projection: %s"
% self.projection.__repr__()[
self.projection.__repr__()
.find("<") : self.projection.__repr__()
.find(">")
+ 1
]
)
more_info = ["\n\nMore attributes:"]
if self.name:
more_info.append(" .name: %s" % self.name)
if self.description:
more_info.append(" .description: %s" % self.description)
more_info.append(" .parameters: %s" % type(self.parameters))
if self.aliases:
more_info.append(" .aliases: %s" % type(self.aliases))
if self.grid_coords:
more_info.append(" .grid_coords: %s" % type(self.grid_coords))
if self.grid_periodic:
more_info.append(" .grid_periodic: %s" % type(self.grid_periodic))
if self.face_connections:
more_info.append(" .face_connections: %s" % type(self.face_connections))
info = "\n".join(main_info)
info = info + "\n".join(more_info)
return info
# ===========
# ATTRIBUTES
# ===========
# -------------------
# name
# -------------------
@property
def name(self):
"""
Name of the OceanDataset.
"""
name = self._read_from_global_attr("name")
return name
@name.setter
def name(self, name):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("name"))
[docs]
def set_name(self, name, overwrite=None):
"""
Set name of the OceanDataset.
Parameters
----------
name: str
Name of the OceanDataset.
overwrite: bool or None
If None, raises error if name has been previously set.
If True, overwrite previous name.
If False, combine with previous name.
"""
# Check parameters
_check_instance({"name": name}, "str")
# Set name
self = self._store_as_global_attr(name="name", attr=name, overwrite=overwrite)
return self
# -------------------
# description
# -------------------
@property
def description(self):
"""
Description of the OceanDataset.
"""
description = self._read_from_global_attr("description")
return description
@description.setter
def description(self, description):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("description"))
[docs]
def set_description(self, description, overwrite=None):
"""
Set description of the OceanDataset.
Parameters
----------
description: str
Desription of the OceanDataset
overwrite: bool or None
If None, raises error if description has been previously set.
If True, overwrite previous description.
If False, combine with previous description.
"""
# Check parameters
_check_instance({"description": description}, "str")
# Set description
self = self._store_as_global_attr(
name="description", attr=description, overwrite=overwrite
)
return self
def __getitem__(self, key):
return self._ds[key]
# -------------------
# aliases
# -------------------
@property
def aliases(self):
"""
A dictionary to connect custom variable names
to OceanSpy reference names.
Keys are OceanSpy reference names, values are custom names:
{'ospy_name': 'custom_name'}
"""
aliases = self._read_from_global_attr("aliases")
return aliases
@property
def _aliases_flipped(self):
"""
Flip aliases:
From {'ospy_name': 'custom_name'}
to {'custom_name': 'ospy_name'}
"""
if self.aliases:
aliases_flipped = {custom: ospy for ospy, custom in self.aliases.items()}
else:
return self.aliases
return aliases_flipped
@aliases.setter
def aliases(self, aliases):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("aliases"))
[docs]
def set_aliases(self, aliases, overwrite=None):
"""
Set aliases to connect custom variables names
to OceanSpy reference names.
Parameters
----------
aliases: dict
Keys are OceanSpy names, values are custom names:
{'ospy_name': 'custom_name'}
overwrite: bool or None
If None, raises error if aliases has been previously set.
If True, overwrite previous aliases.
If False, combine with previous aliases.
"""
# Check parameters
_check_instance({"aliases": aliases}, "dict")
# Set aliases
self = self._store_as_global_attr(
name="aliases", attr=aliases, overwrite=overwrite
)
# Apply aliases
self = self._apply_aliases()
return self
def _apply_aliases(self):
"""
Check if there are variables with custom name in _ds,
and rename to OceanSpy reference name
"""
if self._aliases_flipped:
aliases = {
custom: ospy
for custom, ospy in self._aliases_flipped.items()
if custom in self._ds.variables or custom in self._ds.dims
}
self._ds = self._ds.rename(aliases)
return self
# -------------------
# dataset
# -------------------
@property
def dataset(self):
"""
xarray.Dataset: A multi-dimensional, in memory, array database.
References
----------
http://xarray.pydata.org/en/stable/generated/xarray.Dataset.html
"""
# Show _ds with renamed variables.
dataset = self._ds.copy()
if self.aliases:
aliases = {
ospy: custom
for ospy, custom in self.aliases.items()
if ospy in self._ds or ospy in self._ds.dims
}
dataset = dataset.rename(aliases)
return dataset
@dataset.setter
def dataset(self, dataset):
"""
Inhibit setter.
"""
raise AttributeError(
"Set a new dataset using " "`oceanspy.OceanDataset(dataset)`"
)
# -------------------
# parameters
# -------------------
@property
def parameters(self):
"""
A dictionary defining model parameters that are used by OceanSpy.
Default values are used for parameters that have not been set
(see :py:const:`oceanspy.DEFAULT_PARAMETERS`).
"""
from oceanspy import DEFAULT_PARAMETERS
parameters = self._read_from_global_attr("parameters")
if parameters is None:
parameters = DEFAULT_PARAMETERS
else:
parameters = {**DEFAULT_PARAMETERS, **parameters}
return parameters
@parameters.setter
def parameters(self, parameters):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("parameters"))
[docs]
def set_parameters(self, parameters):
"""
Set model parameters used by OceanSpy.
See :py:const:`oceanspy.DEFAULT_PARAMETERS` for a list of parameters,
and :py:const:`oceanspy.PARAMETERS_PARAMETERS_DESCRIPTION`
for their description.
See :py:const:`oceanspy.AVAILABLE_PARAMETERS` for a list of parameters
with predefined options.
Parameters
----------
parameters: dict
{'name': value}
"""
from oceanspy import AVAILABLE_PARAMETERS, DEFAULT_PARAMETERS, TYPE_PARAMETERS
# Check parameters
_check_instance({"parameters": parameters}, "dict")
# Check parameters
warn_params = []
for key, value in parameters.items():
if key not in DEFAULT_PARAMETERS.keys():
warn_params = warn_params + [key]
else:
if not isinstance(value, TYPE_PARAMETERS[key]):
raise TypeError(
"Invalid [{}]. " "Check oceanspy.TYPE_PARAMETERS" "".format(key)
)
check1 = key in AVAILABLE_PARAMETERS.keys()
if check1 and (value not in AVAILABLE_PARAMETERS[key]):
raise ValueError(
"Requested [{}] not available. "
"Check oceanspy.AVAILABLE_PARAMETERS"
"".format(key)
)
if len(warn_params) != 0:
_warnings.warn(
"{} are not OceanSpy parameters" "".format(warn_params), stacklevel=2
)
# Set parameters
self = self._store_as_global_attr(
name="parameters", attr=parameters, overwrite=True
)
return self
# -------------------
# grid_coords
# -------------------
@property
def grid_coords(self):
"""
Grid coordinates used by :py:obj:`xgcm.Grid`.
References
----------
https://xgcm.readthedocs.io/en/stable/grids.html#Grid-Metadata
"""
grid_coords = self._read_from_global_attr("grid_coords")
return grid_coords
@grid_coords.setter
def grid_coords(self, grid_coords):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("grid_coords"))
[docs]
def set_grid_coords(self, grid_coords, add_midp=False, overwrite=None):
"""
Set grid coordinates used by :py:obj:`xgcm.Grid`.
Parameters
----------
grid_coords: str
Grid coordinates used by :py:obj:`xgcm.Grid`.
Keys are axes, and values are dict with
key=dim and value=c_grid_axis_shift.
Available c_grid_axis_shift are {0.5, None, -0.5}.
E.g., {'Y': {'Y': None, 'Yp1': 0.5}}
See :py:const:`oceanspy.OCEANSPY_AXES` for a list of axes
add_midp: bool
If true, add inner dimension (mid points)
to axes with outer dimension only.
The new dimension will be named
as the outer dimension + '_midp'
overwrite: bool or None
If None, raises error if grid_coords has been previously set.
If True, overwrite previous grid_coors.
If False, combine with previous grid_coors.
References
----------
https://xgcm.readthedocs.io/en/stable/grids.html#Grid-Metadata
"""
# Check parameters
_check_instance(
{"grid_coords": grid_coords, "add_midp": add_midp},
{"grid_coords": "dict", "add_midp": "bool"},
)
# Check axes
_check_oceanspy_axes(list(grid_coords.keys()))
# Set grid_coords
self = self._store_as_global_attr(
name="grid_coords", attr=grid_coords, overwrite=overwrite
)
if add_midp:
grid_coords = {}
for axis in self.grid_coords:
check1 = len(self.grid_coords[axis]) == 1
check2 = list(self.grid_coords[axis].values())[0] is not None
if check1 and check2:
# Deal with aliases
dim = list(self.grid_coords[axis].keys())[0]
if self._aliases_flipped and dim in self._aliases_flipped:
_dim = self._aliases_flipped[dim]
self = self.set_aliases(
{_dim + "_midp": dim + "_midp"}, overwrite=False
)
else:
_dim = dim
# Midpoints are averages of outpoints
midp = (
self._ds[_dim].values[:-1] + self._ds[_dim].diff(_dim) / 2
).rename({_dim: _dim + "_midp"})
self._ds[_dim + "_midp"] = _xr.DataArray(
midp, dims=(_dim + "_midp")
)
if "units" in self._ds[_dim].attrs:
units = self._ds[_dim].attrs["units"]
self._ds[_dim + "_midp"].attrs["units"] = units
if "long_name" in self._ds[_dim].attrs:
long_name = self._ds[_dim].attrs["long_name"]
long_name = "Mid-points of {}".format(long_name)
self._ds[_dim + "_midp"].attrs["long_name"] = long_name
if "description" in self._ds[_dim].attrs:
desc = self._ds[_dim].attrs["description"]
desc = "Mid-points of {}".format(desc)
self._ds[_dim + "_midp"].attrs["description"] = desc
grid_coords[axis] = {**self.grid_coords[axis], dim + "_midp": None}
self = self._store_as_global_attr(
name="grid_coords", attr=grid_coords, overwrite=False
)
return self
# -------------------
# grid_periodic
# -------------------
@property
def grid_periodic(self):
"""
List of :py:obj:`xgcm.Grid` axes that are periodic.
"""
grid_periodic = self._read_from_global_attr("grid_periodic")
if not grid_periodic:
grid_periodic = []
return grid_periodic
@grid_periodic.setter
def grid_periodic(self, grid_periodic):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("grid_periodic"))
[docs]
def set_grid_periodic(self, grid_periodic):
"""
Set grid axes that will be treated as periodic by :py:obj:`xgcm.Grid`.
Axes that are not set periodic are non-periodic by default.
Parameters
----------
grid_periodic: list
List of periodic axes.
See :py:const:`oceanspy.OCEANSPY_AXES` for a list of axes
"""
# Check parameters
_check_instance({"grid_periodic": grid_periodic}, "list")
# Check axes
_check_oceanspy_axes(grid_periodic)
# Set grid_periodic
# Use overwrite True by default because
# xgcm default is all grid_priodic True.
self = self._store_as_global_attr(
name="grid_periodic", attr=grid_periodic, overwrite=True
)
return self
# -----------------
# face_connections
# -----------------
@property
def face_connections(self):
"""
Defines the topology of the grid used by :py:obj:`xgcm.Grid`.
References
----------
"""
face_connections = self._read_from_global_attr("face_connections")
return face_connections
@face_connections.setter
def face_connections(self, face_connections):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("face_connections"))
def set_face_connections(self, face_connections):
"""
Set face conections that define the grid topology that gets read by
:py:obj:`xgcm.Grid`
Parameters
----------
face_connections: dict
Dictionary the connections of each face along each direction.
"""
# check parameters
_check_instance({"face_connections": face_connections}, "dict")
if list(face_connections)[0] == "face":
for k in face_connections["face"].keys():
for axis in face_connections["face"][k].keys():
if isinstance(face_connections["face"][k][axis], tuple):
face_connections["face"][k][axis] = face_connections["face"][k][
axis
]
else:
face_connections["face"][k][axis] = eval(
face_connections["face"][k][axis]
)
elif list(face_connections)[0] is None:
face_connections = None
self = self._store_as_global_attr(
name="face_connections", attr=face_connections, overwrite=True
)
return self
# -------------------
# grid
# -------------------
@property
def grid(self):
"""
:py:obj:`xgcm.Grid`: A collection of axes,
which is a group of coordinates that all lie
along the same physical dimension
but describe different positions relative to a grid cell.
References
----------
https://xgcm.readthedocs.io/en/stable/api.html#Grid
"""
dataset = self.dataset.copy()
coords = self.grid_coords
periodic = self.grid_periodic
face_connections = self.face_connections
grid = _create_grid(dataset, coords, periodic, face_connections)
return grid
@property
def _grid(self):
"""
:py:obj:`xgcm.Grid` with OceanSpy reference names.
"""
aliases = self.aliases
coords = self.grid_coords
if aliases and coords:
# Flip aliases
aliases = {custom: ospy for ospy, custom in aliases.items()}
# Rename coords
for axis in coords:
for dim in coords[axis].copy():
if dim in aliases:
coords[axis][aliases[dim]] = coords[axis].pop(dim)
dataset = self._ds.copy()
periodic = self.grid_periodic
face_connections = self.face_connections
grid = _create_grid(dataset, coords, periodic, face_connections)
return grid
@grid.setter
def grid(self, grid):
"""
Inhibit setter.
"""
raise AttributeError(
"Set a new grid using " ".set_grid_coords and .set_periodic"
)
@_grid.setter
def _grid(self, grid):
"""
Inhibit setter.
"""
raise AttributeError(
"Set a new _grid using " ".set_grid_coords and .set_periodic"
)
# -------------------
# projection
# -------------------
@property
def projection(self):
"""
Cartopy projection of the OceanDataset.
"""
projection = self._read_from_global_attr("projection")
if projection:
if projection == "None":
projection = eval(projection)
else:
if "cartopy" not in _sys.modules: # pragma: no cover
_warnings.warn(
"cartopy is not available," " so projection is None",
stacklevel=2,
)
projection = None
else:
projection = eval("_ccrs.{}".format(projection))
return projection
@projection.setter
def projection(self, projection):
"""
Inhibit setter.
"""
raise AttributeError(_setter_error_message("projection"))
[docs]
def set_projection(self, projection, **kwargs):
"""
Set Cartopy projection of the OceanDataset.
Parameters
----------
projection: str or None
Cartopy projection of the OceanDataset.
Use None to remove projection.
**kwargs:
Keyword arguments for the projection.
E.g., central_longitude=0.0 for PlateCarree
References
----------
https://scitools.org.uk/cartopy/docs/latest/crs/projections.html
"""
# Check parameters
if projection is not None:
# Check
_check_instance({"projection": projection}, "str")
if not hasattr(_ccrs, projection):
raise ValueError("{} is not a cartopy projection" "".format(projection))
projection = "{}(**{})".format(projection, kwargs)
else:
projection = str(projection)
# Set projection
self = self._store_as_global_attr(
name="projection", attr=projection, overwrite=True
)
return self
# ===========
# METHODS
# ===========
[docs]
def create_tree(self, grid_pos="C"):
"""
Create a scipy.spatial.cKDTree for quick nearest-neighbor lookup.
Parameters
----------
grid_pos: str
Grid position. Options: {'C', 'G', 'U', 'V'}
Returns
-------
tree: scipy.spatial.cKDTree
Return a xKDTree object that can be used to query a point.
References
----------
| cKDTree:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.html
| Grid:
https://mitgcm.readthedocs.io/en/latest/algorithm/horiz-grid.html
"""
# Check parameters
_check_instance({"grid_pos": grid_pos}, "str")
grid_pos_list = ["C", "G", "U", "V"]
if grid_pos not in grid_pos_list:
raise ValueError(
"`grid_pos` must be one of {}:"
"\nhttps://mitgcm.readthedocs.io"
"/en/latest/algorithm/horiz-grid.html"
"".format(grid_pos_list)
)
# Convert if it's not cartesian
Y = self._ds["Y" + grid_pos]
X = self._ds["X" + grid_pos]
R = self.parameters["rSphere"]
if R:
x, y, z = _utils.spherical2cartesian(Y=Y, X=X, R=R)
else:
x = X
y = Y
z = _xr.zeros_like(Y)
# Stack
rid_value = 777777
x_stack = x.stack(points=x.dims).fillna(rid_value).data
y_stack = y.stack(points=y.dims).fillna(rid_value).data
z_stack = z.stack(points=z.dims).fillna(rid_value).data
# Construct KD-tree
tree = _spatial.cKDTree(_np.column_stack((x_stack, y_stack, z_stack)))
return tree
[docs]
def merge_into_oceandataset(self, obj, overwrite=False):
"""
Merge a Dataset or DataArray into the OceanDataset.
Parameters
----------
obj: xarray.DataArray or xarray.Dataset
object to merge.
overwrite: bool or None
If True, overwrite existing DataArrays with same name.
If False, use xarray.merge.
"""
# Check and make dataset
if not isinstance(obj, (_xr.DataArray, _xr.Dataset)):
raise TypeError("`obj` must be xarray.DataArray or xarray.Dataset")
_check_instance({"overwrite": overwrite}, "bool")
# Check name
obj = obj.drop_vars(obj.coords)
if isinstance(obj, _xr.DataArray):
if obj.name is None:
raise ValueError(
"xarray.DataArray doesn't have a name." "Set it using da.rename()"
)
else:
obj = obj.to_dataset()
# Merge
dataset = self.dataset
var2drop = [var for var in obj.variables if var in dataset]
if overwrite is False:
obj = obj.drop_vars(var2drop)
if len(var2drop) != 0:
_warnings.warn(
"{} will not be merged."
"\nSet `overwrite=True` if you wish otherwise."
"".format(var2drop),
stacklevel=2,
)
else:
if len(var2drop) != 0:
_warnings.warn(
"{} will be overwritten." "".format(var2drop), stacklevel=2
)
for var in obj.data_vars:
# Store dimension attributes that get lost
attrs = {}
for dim in obj[var].dims:
if dim not in dataset.dims:
pass
elif all(
[
i == j
for i, j in zip(
obj[dim].attrs.items(), dataset[dim].attrs.items()
)
]
):
attrs[dim] = dataset[dim].attrs
# Merge
dataset[var] = obj[var]
# Add attributes
for dim, attr in attrs.items():
dataset[dim].attrs = attr
return OceanDataset(dataset)
[docs]
def to_netcdf(self, path, **kwargs):
"""
Write contents to a netCDF file.
Parameters
----------
path: str
Path to which to save.
**kwargs:
Keyword arguments for :py:func:`xarray.Dataset.to_netcdf()`
References
----------
http://xarray.pydata.org/en/stable/generated/xarray.Dataset.to_netcdf.html
"""
# Check parameters
_check_instance({"path": path}, "str")
# to_netcdf doesn't like coordinates attribute
dataset = _rename_coord_attrs(self.dataset)
# Compute
compute = kwargs.pop("compute", None)
print("Writing dataset to [{}].".format(path))
if compute is None or compute is False:
delayed_obj = dataset.to_netcdf(path, compute=False, **kwargs)
with _ProgressBar():
delayed_obj.compute()
else:
dataset.to_netcdf(path, compute=compute, **kwargs)
[docs]
def to_zarr(self, path, **kwargs):
"""
Write contents to a zarr group.
Parameters
----------
path: str
Path to which to save.
**kwargs:
Keyword arguments for :py:func:`xarray.Dataset.to_zarr()`
References
----------
http://xarray.pydata.org/en/stable/generated/xarray.Dataset.to_zarr.html
"""
# Check parameters
_check_instance({"path": path}, "str")
# to_zarr doesn't like coordinates attribute
dataset = _rename_coord_attrs(self.dataset)
# Compute
compute = kwargs.pop("compute", None)
print("Writing dataset to [{}].".format(path))
if compute is None or compute is False:
delayed_obj = dataset.to_zarr(path, compute=False, **kwargs)
with _ProgressBar():
delayed_obj.compute()
else:
dataset.to_zarr(path, compute=compute, **kwargs)
# ==================================
# IMPORT (used by open_oceandataset)
# ==================================
[docs]
def shift_averages(self, averageList=None):
"""
Shift average variables to time_midp.
Average variables are defined as
variables with attribute [original_output='average'],
or variables in averageList.
Parameters
----------
averageList: 1D array_like, str, or None
List of variables (strings).
"""
if averageList is not None:
averageList = _check_list_of_string(averageList, "averageList")
else:
averageList = []
for var in self._ds.data_vars:
original_output = self._ds[var].attrs.pop("original_output", None)
if original_output == "average" or var in averageList:
ds_tmp = self._ds[var].drop_vars("time").isel(time=slice(1, None))
self._ds[var] = ds_tmp.rename({"time": "time_midp"})
if original_output is not None:
self._ds[var].attrs["original_output"] = original_output
return self
[docs]
def manipulate_coords(
self,
fillna=False,
coords1Dfrom2D=False,
coords2Dfrom1D=False,
coordsUVfromG=False,
):
"""
Manipulate coordinates to make them compatible with OceanSpy.
Parameters
----------
fillna: bool
If True, fill NaNs in 2D coordinates
(e.g., NaNs are created by MITgcm exch2).
coords1Dfrom2D: bool
If True, infer 1D coordinates from 2D coordinates (mean of 2D).
Use with rectilinear grid only.
coords2Dfrom1D: bool
If True, infer 2D coordinates from 1D coordinates (brodacast 1D).
coordsUVfromCG: bool
If True, compute missing coords (U and V points) from G points.
References
----------
Grid:
https://mitgcm.readthedocs.io/en/latest/algorithm/horiz-grid.html
"""
# Copy because the dataset will change
self = _copy.copy(self)
# Coordinates are dimensions only
self._ds = self._ds.reset_coords()
# Fill nans (e.g., because of exch2)
if fillna:
coords = ["YC", "XC", "YG", "XG", "YU", "XU", "YV", "XV"]
dims = ["X", "Y", "Xp1", "Yp1", "Xp1", "Y", "X", "Yp1"]
for i, (coord, dim) in enumerate(zip(coords, dims)):
if coord in self._ds.variables:
ds_tmp = self._ds[coord].ffill(dim).bfill(dim).persist()
self._ds[coord] = ds_tmp
# Get U and V by rolling G
if coordsUVfromG:
for i, (point_pos, dim2roll) in enumerate(zip(["U", "V"], ["Yp1", "Xp1"])):
for dim in ["Y", "X"]:
coord = self._ds[dim + "G"].rolling(**{dim2roll: 2})
coord = coord.mean().dropna(dim2roll, "all")
coord = coord.drop_vars(coord.coords).rename(
{dim2roll: dim2roll[0]}
)
self._ds[dim + point_pos] = coord
if "units" in self._ds[dim + "G"].attrs:
units = self._ds[dim + "G"].attrs["units"]
self._ds[dim + point_pos].attrs["units"] = units
# For cartesian grid we can use 1D coordinates
if coords1Dfrom2D:
# Take mean
self._ds["Y"] = self._ds["YC"].mean("X", keep_attrs=True).persist()
self._ds["X"] = self._ds["XC"].mean("Y", keep_attrs=True).persist()
self._ds["Yp1"] = self._ds["YG"].mean("Xp1", keep_attrs=True).persist()
self._ds["Xp1"] = self._ds["XG"].mean("Yp1", keep_attrs=True).persist()
# Get 2D coordinates broadcasting 1D
if coords2Dfrom1D:
# Broadcast
self._ds["YC"], self._ds["XC"] = _xr.broadcast(self._ds["Y"], self._ds["X"])
self._ds["YG"], self._ds["XG"] = _xr.broadcast(
self._ds["Yp1"], self._ds["Xp1"]
)
self._ds["YU"], self._ds["XU"] = _xr.broadcast(
self._ds["Y"], self._ds["Xp1"]
)
self._ds["YV"], self._ds["XV"] = _xr.broadcast(
self._ds["Yp1"], self._ds["X"]
)
# Add units
dims2 = ["YC", "XC", "YG", "XG", "YU", "XU", "YV", "XV"]
dims1 = ["Y", "X", "Yp1", "Xp1", "Y", "Xp1", "Yp1", "X"]
for i, (D2, D1) in enumerate(zip(dims2, dims1)):
if "units" in self._ds[D1].attrs:
self._ds[D2].attrs["units"] = self._ds[D1].attrs["units"]
# Set 2D coordinates
self._ds = self._ds.set_coords(["YC", "XC", "YG", "XG", "YU", "XU", "YV", "XV"])
# Attributes (use xmitgcm)
try:
from xmitgcm import variables
if self.parameters["rSphere"] is None:
coords = variables.horizontal_coordinates_cartesian
add_coords = _OrderedDict(
XU=dict(
attrs=dict(
standard_name=("plane_x_coordinate" "_at_u_location"),
long_name="x coordinate",
units="m",
coordinate="YU XU",
)
),
YU=dict(
attrs=dict(
standard_name=("plane_y_coordinate" "_at_u_location"),
long_name="y coordinate",
units="m",
coordinate="YU XU",
)
),
XV=dict(
attrs=dict(
standard_name=("plane_x_coordinate" "_at_v_location"),
long_name="x coordinate",
units="m",
coordinate="YV XV",
)
),
YV=dict(
attrs=dict(
standard_name=("plane_y_coordinate" "_at_v_location"),
long_name="y coordinate",
units="m",
coordinate="YV XV",
)
),
)
else:
coords = variables.horizontal_coordinates_spherical
add_coords = _OrderedDict(
XC=dict(
attrs=dict(
standard_name="longitude_at_T_location",
long_name="longitude",
units="degrees_east",
coordinate="YC XC",
)
),
YC=dict(
attrs=dict(
standard_name="latitude_at_T_location",
long_name="latitude",
units="degrees_north",
coordinate="YC XC",
)
),
XU=dict(
attrs=dict(
standard_name="longitude_at_u_location",
long_name="longitude",
units="degrees_east",
coordinate="YU XU",
)
),
YU=dict(
attrs=dict(
standard_name="latitude_at_u_location",
long_name="latitude",
units="degrees_north",
coordinate="YU XU",
)
),
XV=dict(
attrs=dict(
standard_name="longitude_at_v_location",
long_name="longitude",
units="degrees_east",
coordinate="YV XV",
)
),
YV=dict(
attrs=dict(
standard_name="latitude_at_v_location",
long_name="latitude",
units="degrees_north",
coordinate="YV XV",
)
),
)
coords = _OrderedDict(list(coords.items()) + list(add_coords.items()))
for var in coords:
attrs = coords[var]["attrs"]
for attr in attrs:
if attr not in self._ds[var].attrs:
self._ds[var].attrs[attr] = attrs[attr]
except ImportError: # pragma: no cover
pass
return self
# =====
# UTILS
# =====
def _store_as_global_attr(self, name, attr, overwrite):
"""
Store an OceanSpy attribute as dataset global attribute.
Parameters
----------
name: str
Name of the attribute. Attribute is stored as OceanSpy_+name.
attr: str or dict
Attribute to store
overwrite: bool or None
If None, raises error if attr has been previously set.
If True, overwrite previous attributes.
If False, combine with previous attributes.
"""
# Attribute name
name = "OceanSpy_" + name
if overwrite is None and name in self._ds.attrs:
raise ValueError(
"[{}] has been previously set: "
"`overwrite` must be bool"
"".format(name.replace("OceanSpy_", ""))
)
# Copy because attributes are added to _ds
self = _copy.copy(self)
# Store
if not overwrite and name in self._ds.attrs:
prev_attr = self._ds.attrs[name]
if prev_attr[0] == "{" and prev_attr[-1] == "}":
attr = {**eval(prev_attr), **attr}
else:
attr = prev_attr + "_" + attr
self._ds.attrs[name] = str(attr)
return self
def _read_from_global_attr(self, name):
"""
Read an OceanSpy attribute stored as dataset global attribute.
Parameters
----------
name: str
Name of the attribute.
Attribute is decoded from 'OceanSpy_'+name.
Returns
-------
attr: str or dict
Attribute that has been decoded.
"""
# Attribute name
name = "OceanSpy_" + name
# Check if attributes exists
if name not in self._ds.attrs:
return None
# Read attribute
attr = self._ds.attrs[name]
check_dict = attr[0] == "{" and attr[-1] == "}"
check_list = attr[0] == "[" and attr[-1] == "]"
if check_dict or check_list:
attr = eval(attr)
return attr
# ===========
# SHORTCUTS
# ===========
@property
def subsample(self):
"""
Access :py:mod:`oceanspy.subsample` functions.
"""
return _subsampleMethods(self)
@property
def compute(self):
"""
Access :py:mod:`oceanspy.compute` functions,
and merge the computed Dataset into the OceanDataset.
Set overwrite=True
to overwrite DataArrays already existing in the OceanDataset.
"""
return _computeMethods(self)
@property
def plot(self):
"""
Access :py:mod:`oceanspy.plot` functions.
"""
return _plotMethods(self)
@property
def animate(self):
"""
Access :py:mod:`oceanspy.animate` functions.
"""
return _animateMethods(self)