diff --git a/intermediate/indexing/build_custom_index_1d.ipynb b/intermediate/indexing/build_custom_index_1d.ipynb new file mode 100644 index 00000000..fa6f50cd --- /dev/null +++ b/intermediate/indexing/build_custom_index_1d.ipynb @@ -0,0 +1,1172 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Creating a custom Xarray index \n", + "\n", + "This tutorial demonstrates [Xarray's](https://xarray.dev/) (relatively) new Flexible Index feature which allows users to modify traditional Xarray indexes to add custom functionality. \n", + "\n", + "Indexes are an important element of any Xarray object. They facilitate the label-based indexing that makes Xarray a great tool for n-dimensional array data. Most Xarray objects have `Pandas.Indexes`, which fit a wide range of use cases. However, these indexes also have limitations: \n", + "- All coordinate labels must be explicitly loaded in memory, \n", + "- It can be difficult to fit irregularly-sampled data within the `Pandas.Index` structure, \n", + "- There is no built-in support for dimensions that require additional metadata. \n", + "\n", + "Xarray's custom (wc: flexible?) index feature allows users to define their own Indexes and add them to Xarray objects. A few examples of situations where this is useful are: \n", + "- Periodic index, for datasets with periodic dimensions (such as longitude). \n", + "- Unit-aware index (see the [Pint] project) \n", + "- An index for coordinates described by a function rather than an array \n", + "- An index that can handle a 2D rotation \n", + "(add links to those that have references/examples out there) \n", + "\n", + "## Overview\n", + "We will focus on the following example: \n", + "- We have a 1-dimensional `Xarray.Dataset` indexed in a given coordinate system. However, we want to frequently query the dataset from a different coordinate reference system. \n", + "- Information describing the transformation between the two coordinate systems is stored as an attribute of the Xarray object. In this example, we use a simple, multiplicative transform.\n", + "- We want to define a custom index that will handle the coordinate transformation. This is a simplified analog of a common scenario: a geospatial dataset is in a given coordinate reference system and you would like to query it in another coordinate system. (maybe take out last sentence)\n", + "- link to existing documentation\n", + "\n", + "We start by defining a very simple index and then increase the complexity by adding more functionality.\n", + "\n", + "![coord transform](img2.png)\n", + "\n", + "## Learning goals\n", + "This notebook shows how to build a custom Xarray index and assign it to an Xarray object using [`xr.set_xindex()`](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.set_xindex.html). After working through this tutorial, users should expect to understand:\n", + "- How to define a custom Xarray index \n", + "- How to add a custom index index to an existing Xarray object using `xr.set_xindex()`\n", + "- The different components of an Xarray index and their function.\n", + "- How to implement methods to Xarray indexes such as `.sel()` and methods that handle alignment.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "from collections.abc import Sequence\n", + "from copy import deepcopy\n", + "\n", + "\n", + "from xarray import Index\n", + "from xarray.core.indexes import PandasIndex\n", + "from xarray.core.indexing import merge_sel_results\n", + "from xarray.core.indexes import Index, PandasIndex, get_indexer_nd\n", + "from xarray.core.indexing import merge_sel_results" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "### Define sample data\n", + "First, we define a sample dataset to work with. The functions below define parameters that are used to generate an Xarray dataset with a data variable that exists at arbitrary coordinates along an `x` dimension. It also has a scalar variable, `spatial_ref`, where metadata describing the coordinate transform is stored as an attribute. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "def make_kwargs(factor, range_ls, data_len):\n", + " \"\"\"\n", + " Create keyword arguments for a function.\n", + "\n", + " Parameters\n", + " ----------\n", + " factor : int or float\n", + " Multiplicative factor for coordinate transform\n", + " range_ls : list\n", + " Range describing the x-coordinate\n", + " data_len : int\n", + " Length of dataset along x-dim\n", + "\n", + " Returns\n", + " -------\n", + " dict\n", + "\n", + " \"\"\"\n", + " da_kwargs = {\n", + " 'factor': factor,\n", + " 'range': range_ls,\n", + " 'idx_name': 'x',\n", + " 'real_name': 'lon',\n", + " 'data_len': data_len,\n", + " }\n", + " return da_kwargs\n", + "\n", + "\n", + "def create_sample_data(kwargs: dict) -> xr.Dataset:\n", + " \"\"\"\n", + " Function to create sample data.\n", + "\n", + " Parameters\n", + " ----------\n", + " kwargs : dict\n", + " A dictionary generated from make_kwargs() containing the following key-value pairs:\n", + " - 'factor' (float): A multiplicative factor.\n", + " - 'range' (tuple): A tuple specifying the range of the x-coordinate.\n", + " - 'idx_name' (str): The name of the coordinate reference system A.\n", + " - 'real_name' (str): The name of the coordinate reference system B.\n", + "\n", + " Returns\n", + " -------\n", + " xr.Dataset\n", + " An Xarray dataset containing the sample data.\n", + "\n", + " Notes\n", + " -----\n", + " This function creates an Xarray dataset with random data. The dimensions and coordinates of the dataset are specified by the input arguments.\n", + "\n", + " Example\n", + " -------\n", + " >>> kwargs = {\n", + " ... 'factor': 2.0,\n", + " ... 'range': (0, 10, 1),\n", + " ... 'idx_name': 'coord_A',\n", + " ... 'real_name': 'coord_B'\n", + " ... }\n", + " >>> dataset = create_sample_data(kwargs)\n", + " \"\"\"\n", + " attrs = {\n", + " 'factor': kwargs['factor'],\n", + " 'range': kwargs['range'],\n", + " 'idx_name': kwargs['idx_name'],\n", + " 'real_name': kwargs['real_name'],\n", + " }\n", + "\n", + " da = xr.DataArray(\n", + " data=np.random.rand(kwargs['data_len']),\n", + " dims=(kwargs['idx_name']),\n", + " coords={'x': np.arange(kwargs['range'][0], kwargs['range'][1], kwargs['range'][2])},\n", + " )\n", + "\n", + " ds = xr.Dataset({'var1': da})\n", + "\n", + " spatial_ref = xr.DataArray()\n", + " spatial_ref.attrs = attrs\n", + "\n", + " ds['spatial_ref'] = spatial_ref\n", + " ds = ds.set_coords('spatial_ref')\n", + "\n", + " # ds = ds.expand_dims({'y': 1})\n", + "\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# create sample data\n", + "sample_ds1 = create_sample_data(make_kwargs(2, [0, 10, 1], 10))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ds1" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Defining a custom index\n", + "\n", + "### First, how will it be used?\n", + "Before we get into defining the custom index, it's helpful to see how it will be used. We have the object `sample_ds1`, which has a `PandasIndex`. \n", + "\n", + "Note, `PandasIndex` is a Xarray wrapper for `Pandas.Index` object <- maybe more detail than necessary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "type(sample_ds1.indexes['x'])" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "We want to replace the `PandasIndex` with our `CustomIndex`. To do this, we'll drop the `x` index from that dataset: \n", + "`sample_ds1 = sample_ds1.drop_indexes('x')`\n", + "\n", + "Once we define the new index, we'll attach it to the Xarray objects using the `xr.set_xindex()` method. This takes the coordinates from the Xarray object used to build the index, and the index class. It will look like this:\n", + "\n", + "`s1 = sample_ds1.set_xindex(['x','spatial_ref'], ToyIndex_scalar)`\n", + "\n", + "Now, let's define the custom index class." + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## The smallest `CustomIndex` \n", + "\n", + "This is an index that contains only the required component of an Xarray index, the `from_variables()` method. It can be successfully added to `ds` but it can't do much beyond that, and it doesn't contain any information about the transform between coordinate reference systems that we're interested in. Still, it's helpful to understand because `from_variables()` is how information gets from `ds` to our new index. `from_variables()` receives information about `ds` from `xr.set_xindex()` and uses it to construct an instance of `CustomIndex`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "class CustomIndex_tiny(xr.Index): # customindex inherits xarray Index\n", + " def __init__(self, x_indexes, variables=None):\n", + "\n", + " self.indexes = variables\n", + " self._xindexes = x_indexes\n", + "\n", + " self.spatial_ref = variables['spatial_ref']\n", + "\n", + " @classmethod\n", + " def from_variables(cls, variables, **kwargs):\n", + " '''this method creates a CustomIndex obj from a variables object.\n", + " variables is a dict created from ds1, keys are variable names,\n", + " values are associated xr.variables. created like this:\n", + " coord_vars = {name:ds._variables[name] for name in coord_names}\n", + " coord_names is passed to set_xindex\n", + " '''\n", + " # this index class expects to work with datasets with certain properties\n", + " # it must have exactly 2 variables: x and spatial_ref\n", + " assert len(variables) == 2\n", + " assert 'x' in variables\n", + " assert 'spatial_ref' in variables\n", + "\n", + " # separate dimensional, scalar variables into own dicts\n", + " dim_variables = {}\n", + " scalar_vars = {}\n", + " for k, i in variables.items():\n", + " if variables[k].ndim == 1:\n", + " dim_variables[k] = variables[k]\n", + " if variables[k].ndim == 0:\n", + " scalar_vars[k] = variables[k]\n", + "\n", + " options = {'dim': 'x', 'name': 'x'}\n", + "\n", + " # make dict of PandasIndexes for dim. variable\n", + " x_indexes = {\n", + " k: PandasIndex.from_variables({k: v}, options=options) for k, v in dim_variables.items()\n", + " }\n", + " # add scalar var to dict\n", + " x_indexes['spatial_ref'] = variables['spatial_ref']\n", + "\n", + " return cls(x_indexes, variables) # return an instance of CustomIndex class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ds1 = sample_ds1.drop_indexes('x')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "ds1 = sample_ds1.set_xindex(['x', 'spatial_ref'], CustomIndex_tiny)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "ds1" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "As mentioned above, `ds1` now has the CustomIndex, but it can't do much." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "%xmode Minimal\n", + "\n", + "ds1.sel(x=4)" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "### More detail on `from_variables()`\n", + "> - During `xr.set_xindex()`, a dict object called `variables` is created. For every coordinate in `ds`, `variables` has a key-value pair like follows: `name: ds._variables[name]`. \n", + "> - `variables` is passed to `from_variables()` and used to create another dict. The values in this dictionary hold a `PandasIndex` for each dimensional coordinate, and an `xr.Variable` for each scalar coordinate. \n", + "> - It's important to note that `from_variables()` is a **class method** (Add link). This means that it acts as a constructor, returning an instance of the `CustomIndex` class. \n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## Adding a coordinate transform and `.sel()` to `CustomIndex`\n", + "This section adds three new methods:\n", + "1. `create_variables()`: Returns a coordinate variable created from the new index.\n", + "2. `transform()`: Handles the coordinate transform between CRS A and CRS B. <- NOTE: remove this from class and pass to set_xindex?\n", + "3. `sel()`: Select points from `ds1` using `transform()`. This allows user to pass labels in coordinate reference system B, and `.sel()` will return appropriate elements from ds1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# create new sample data\n", + "sample_ds1 = create_sample_data(make_kwargs(2, [0, 10, 1], 10))\n", + "\n", + "# create a copy used for testing later\n", + "orig_ds1 = sample_ds1.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ds1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "class CustomIndex_sel(xr.Index): # customindex inherits xarray Index\n", + " def __init__(self, x_indexes, variables=None):\n", + "\n", + " self.indexes = variables\n", + " self._xindexes = x_indexes\n", + "\n", + " self.spatial_ref = variables['spatial_ref']\n", + "\n", + " @classmethod\n", + " def from_variables(cls, variables, **kwargs):\n", + " '''this method creates a CustomIndex obj from a variables object.\n", + " variables is a dict created from ds1, keys are variable names,\n", + " values are associated xr.variables. created like this:\n", + " coord_vars = {name:ds._variables[name] for name in coord_names}\n", + " coord_names is passed to set_xindex\n", + " '''\n", + " # this index class expects to work with datasets with certain properties\n", + " # it must have exactly 2 variables: x and spatial_ref\n", + " assert len(variables) == 2\n", + " assert 'x' in variables\n", + " assert 'spatial_ref' in variables\n", + "\n", + " dim_variables = {}\n", + " scalar_vars = {}\n", + " for k, i in variables.items():\n", + " if variables[k].ndim == 1:\n", + " dim_variables[k] = variables[k]\n", + " if variables[k].ndim == 0:\n", + " scalar_vars[k] = variables[k]\n", + "\n", + " options = {'dim': 'x', 'name': 'x'}\n", + "\n", + " x_indexes = {\n", + " k: PandasIndex.from_variables({k: v}, options=options) for k, v in dim_variables.items()\n", + " }\n", + "\n", + " x_indexes['spatial_ref'] = variables['spatial_ref']\n", + "\n", + " return cls(x_indexes, variables) # return an instance of CustomIndex class\n", + "\n", + " def create_variables(self, variables=None):\n", + " '''\n", + " Creates coord variable from index.\n", + "\n", + " Parameters:\n", + " -----------\n", + " variables : dict, optional\n", + " A dictionary of variables.\n", + "\n", + " Returns:\n", + " --------\n", + " dict\n", + " A dictionary containing the created variables.\n", + "\n", + " Notes:\n", + " ------\n", + " This method iterates over the `_xindexes` values and creates coord variables from the indexes.\n", + " It skips the spatial reference variable and updates the `idx_variables` dictionary with the created variables.\n", + " Finally, it adds the `spatial_ref` variable from the `variables` dictionary to the `idx_variables` dictionary.\n", + "\n", + " Example:\n", + " --------\n", + " >>> variables = {'spatial_ref': 123}\n", + " >>> result = create_variables(variables)\n", + " >>> print(result)\n", + " {'var1': ..., 'var2': ..., 'spatial_ref': 123}\n", + " '''\n", + " idx_variables = {}\n", + "\n", + " for index in self._xindexes.values():\n", + " if type(index) == xr.core.variable.Variable:\n", + " pass\n", + " else:\n", + " x = index.create_variables(variables)\n", + " idx_variables.update(x)\n", + "\n", + " idx_variables['spatial_ref'] = variables['spatial_ref']\n", + " return idx_variables\n", + "\n", + " idx_variables = {}\n", + "\n", + " for index in self._xindexes.values():\n", + " # want to skip spatial ref\n", + " if type(index) == xr.core.variable.Variable:\n", + " pass\n", + " else:\n", + "\n", + " x = index.create_variables(variables)\n", + " idx_variables.update(x)\n", + "\n", + " idx_variables['spatial_ref'] = variables['spatial_ref']\n", + " return idx_variables\n", + "\n", + " def transform(self, value):\n", + " \"\"\"\n", + " Transform the given value based on the spatial reference attributes. Currently, this only handles a very simple transform.\n", + " NOTE: this could be removed from the index class and passed to set_xindex()?\n", + "\n", + " Parameters:\n", + " -----------\n", + " value : int, float, slice, or list\n", + " The value to be transformed.\n", + "\n", + " Returns:\n", + " --------\n", + " transformed_labels : dict\n", + " A dictionary containing the transformed labels.\n", + "\n", + " Notes:\n", + " ------\n", + " - If `value` is a slice, it will be transformed based on the factor and index name attributes.\n", + " - If `value` is a single value or a list of values, each value will be transformed based on the factor attribute.\n", + "\n", + " Examples:\n", + " ---------\n", + " >>> spatial_ref = SpatialReference(factor=2, idx_name='index')\n", + " >>> transformed_labels = spatial_ref.transform(10)\n", + " >>> print(transformed_labels)\n", + " {'index': 5}\n", + "\n", + " >>> transformed_labels = spatial_ref.transform([10, 20, 30])\n", + " >>> print(transformed_labels)\n", + " {'index': [5, 10, 15]}\n", + "\n", + " >>> transformed_labels = spatial_ref.transform(slice(10, 20, 2))\n", + " >>> print(transformed_labels)\n", + " {'index': slice(5, 10, 2)}\n", + " \"\"\"\n", + " # extract attrs\n", + " fac = self.spatial_ref.attrs['factor']\n", + " key = self.spatial_ref.attrs['idx_name']\n", + "\n", + " # handle slice\n", + " if isinstance(value, slice):\n", + "\n", + " start, stop, step = value.start, value.stop, value.step\n", + " new_start, new_stop, new_step = start / fac, stop / fac, step\n", + " new_val = slice(new_start, new_stop, new_step)\n", + " transformed_labels = {key: new_val}\n", + " return transformed_labels\n", + "\n", + " # single or list of values\n", + " else:\n", + "\n", + " vals_to_transform = []\n", + "\n", + " if not isinstance(value, Sequence):\n", + " value = [value]\n", + "\n", + " for k in range(len(value)):\n", + "\n", + " val = value[k]\n", + " vals_to_transform.append(val)\n", + "\n", + " # logic for parsing attrs\n", + " transformed_x = [int(v / fac) for v in vals_to_transform]\n", + "\n", + " transformed_labels = {key: transformed_x}\n", + " return transformed_labels\n", + "\n", + " def sel(self, labels):\n", + " \"\"\"\n", + " Selects data from the index based on the provided labels.\n", + "\n", + " Parameters:\n", + " -----------\n", + " labels : dict\n", + " A dictionary containing the labels for each dimension.\n", + "\n", + " Returns:\n", + " --------\n", + " matches : PandasIndex\n", + " A PandasIndex object containing the selected data.\n", + "\n", + " Raises:\n", + " -------\n", + " AssertionError:\n", + " If the type of `labels` is not a dictionary.\n", + "\n", + " Notes:\n", + " ------\n", + " - The `labels` dictionary should have keys corresponding to the dimensions of the index.\n", + " - The values of the `labels` dictionary should be the labels to select from each dimension.\n", + " - The method uses the `transform` method to convert the labels to coordinate CRS.\n", + " - The selection is performed on the index created in the `.sel()` method.\n", + "\n", + " Example:\n", + " --------\n", + " >>> labels = {'x': 10}\n", + " >>> matches = obj.sel(labels)\n", + " >>> print(matches)\n", + " PandasIndex([10], dtype='int64', name='x')\n", + " \"\"\"\n", + "\n", + " assert type(labels) == dict\n", + "\n", + " # user passes to sel\n", + " label = next(iter(labels.values()))\n", + "\n", + " # materialize coord array to idx off of\n", + " params = self.spatial_ref.attrs['range']\n", + " full_arr = np.arange(params[0], params[1], params[2])\n", + " toy_index = PandasIndex(full_arr, dim='x')\n", + "\n", + " # transform user labesl to coord crs\n", + " idx = self.transform(label)\n", + "\n", + " # sel on index created in .sel()\n", + " matches = toy_index.sel(idx)\n", + "\n", + " return matches" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "Drop index:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ds1 = sample_ds1.drop_indexes('x')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "ds1 = sample_ds1.set_xindex(['x', 'spatial_ref'], CustomIndex_sel)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "ds1" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "Let's see if this works! Remember our coordinate transform (add desc. or illustration)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "ds1.sel(x=14)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": {}, + "outputs": [], + "source": [ + "assert ds1.sel(x=14) == orig_ds1.sel(x=7)" + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "`.sel()` can also handle passing lists and slices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [ + "ds1.sel(x=[8, 10, 14])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "# dim order switches? so need to specify data to assert\n", + "assert np.array_equal(ds1.sel(x=[8, 10, 14])['var1'].data, orig_ds1.sel(x=[4, 5, 7])['var1'].data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "ds1.sel(x=slice(4, 18))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "assert np.array_equal(\n", + " ds1.sel(x=slice(4, 18))['var1'].data, orig_ds1.sel(x=slice(2, 9))['var1'].data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "## Adding align\n", + "\n", + "NOTE: add illustration? \n", + "\n", + "Alignment is an important capability of Xarray indexes. It relies on three methods: `equals()`, `join()` and `reindex_like()`. \n", + "- `equals()`: Checks if the index is equal to the other index passed in the signatures are equal.\n", + "- `join()`: Joins the two indexes.\n", + "- `reindex_like()`: Reindexes the current index to match the result of the join.\n", + "Let's add them to the index :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "class CustomIndex(xr.Index): # customindex inherits xarray Index\n", + " def __init__(self, x_indexes, variables=None):\n", + "\n", + " self.indexes = variables\n", + " self._xindexes = x_indexes\n", + " if variables is not None:\n", + "\n", + " self.spatial_ref = variables['spatial_ref']\n", + " else:\n", + " self.spatial_ref = None\n", + "\n", + " @classmethod\n", + " def from_variables(cls, variables, **kwargs):\n", + " '''this method creates a CustomIndex obj from a variables object.\n", + " variables is a dict created from ds1, keys are variable names,\n", + " values are associated xr.variables. created like this:\n", + " coord_vars = {name:ds._variables[name] for name in coord_names}\n", + " coord_names is passed to set_xindex\n", + " '''\n", + " # this index class expects to work with datasets with certain properties\n", + " # must have exactly 2 variables: x and spatial_ref\n", + " assert len(variables) == 2\n", + " assert 'x' in variables\n", + " assert 'spatial_ref' in variables\n", + "\n", + " dim_variables = {}\n", + " scalar_vars = {}\n", + " for k, i in variables.items():\n", + " if variables[k].ndim == 1:\n", + " dim_variables[k] = variables[k]\n", + " if variables[k].ndim == 0:\n", + " scalar_vars[k] = variables[k]\n", + "\n", + " options = {'dim': 'x', 'name': 'x'}\n", + "\n", + " x_indexes = {\n", + " k: PandasIndex.from_variables({k: v}, options=options) for k, v in dim_variables.items()\n", + " }\n", + "\n", + " x_indexes['spatial_ref'] = variables['spatial_ref']\n", + "\n", + " return cls(x_indexes, variables)\n", + "\n", + " def create_variables(self, variables=None):\n", + " '''creates coord variable from index'''\n", + " if not variables:\n", + " variables = self.joined_var\n", + "\n", + " idx_variables = {}\n", + "\n", + " for index in self._xindexes.values():\n", + " # want to skip spatial ref\n", + " if type(index) == xr.core.variable.Variable:\n", + " pass\n", + " else:\n", + "\n", + " x = index.create_variables(variables)\n", + " idx_variables.update(x)\n", + "\n", + " idx_variables['spatial_ref'] = variables['spatial_ref']\n", + " return idx_variables\n", + "\n", + " def transform(self, value):\n", + "\n", + " # extract attrs\n", + " fac = self.spatial_ref.attrs['factor']\n", + " key = self.spatial_ref.attrs['idx_name']\n", + "\n", + " # handle slice\n", + " if isinstance(value, slice):\n", + "\n", + " start, stop, step = value.start, value.stop, value.step\n", + " new_start, new_stop, new_step = start / fac, stop / fac, step\n", + " new_val = slice(new_start, new_stop, new_step)\n", + " transformed_labels = {key: new_val}\n", + " return transformed_labels\n", + "\n", + " # single or list of values\n", + " else:\n", + "\n", + " vals_to_transform = []\n", + "\n", + " if not isinstance(value, Sequence):\n", + " value = [value]\n", + "\n", + " for k in range(len(value)):\n", + "\n", + " val = value[k]\n", + " vals_to_transform.append(val)\n", + "\n", + " # logic for parsing attrs, todo: switch to actual transform\n", + " transformed_x = [int(v / fac) for v in vals_to_transform]\n", + "\n", + " transformed_labels = {key: transformed_x}\n", + " return transformed_labels\n", + "\n", + " def sel(self, labels):\n", + "\n", + " assert type(labels) == dict\n", + "\n", + " # user passes to sel\n", + " label = next(iter(labels.values()))\n", + "\n", + " # materialize coord array to idx off of\n", + " params = self.spatial_ref.attrs['range']\n", + " full_arr = np.arange(params[0], params[1], params[2])\n", + " toy_index = PandasIndex(full_arr, dim='x')\n", + "\n", + " # transform user labesl to coord crs\n", + " idx = self.transform(label)\n", + "\n", + " # sel on index created in .sel()\n", + " matches = toy_index.sel(idx)\n", + "\n", + " return matches\n", + "\n", + " def equals(self, other):\n", + " \"\"\"\n", + " Check if the current instance is equal to another instance.\n", + " Parameters\n", + " ----------\n", + " other : object\n", + " The other instance to compare with.\n", + "\n", + " Returns\n", + " -------\n", + " bool\n", + " True if the current instance is equal to the other instance, False otherwise.\n", + " \"\"\"\n", + "\n", + " result = self._xindexes['x'].equals(other._xindexes['x']) and self._xindexes[\n", + " 'spatial_ref'\n", + " ].equals(other._xindexes['spatial_ref'])\n", + "\n", + " return result\n", + "\n", + " def join(self, other, how='inner'):\n", + " \"\"\"\n", + " Join the current index with another index.\n", + "\n", + " Parameters:\n", + " -----------\n", + " other : PandasIndex\n", + " The index to join with.\n", + " how : str, optional\n", + " The type of join to perform. Default is 'inner'.\n", + "\n", + " Returns:\n", + " --------\n", + " new_obj : PandasIndex\n", + " A new PandasIndex object representing the joined index.\n", + "\n", + " Notes:\n", + " ------\n", + " This method joins the current index with another index based on a common dimension.\n", + "\n", + " The current index and the other index are first converted into PandasIndex objects.\n", + "\n", + " The spatial reference information of the joined index is updated based on the start, stop, and step values of the joined index.\n", + "\n", + " The joined index is then converted back into a PandasIndex object and returned as a new PandasIndex object.\n", + " \"\"\"\n", + " # make self index obj\n", + " params_self = self.spatial_ref.attrs['range']\n", + " full_arr_self = np.arange(params_self[0], params_self[1], params_self[2])\n", + " toy_index_self = PandasIndex(full_arr_self, dim='x')\n", + "\n", + " # make other index obj\n", + " other_start = other._xindexes['x'].index.array[0]\n", + " other_stop = other._xindexes['x'].index.array[-1]\n", + " other_step = np.abs(\n", + " int((other_start - other_stop) / (len(other._xindexes['x'].index.array) - 1))\n", + " )\n", + "\n", + " params_other = other.spatial_ref.attrs['range']\n", + " full_arr_other = np.arange(\n", + " other_start, other_stop, other_step\n", + " ) # prev elements of params_other\n", + " toy_index_other = PandasIndex(full_arr_other, dim='x')\n", + "\n", + " self._indexes = {'x': toy_index_self}\n", + " other._indexes = {'x': toy_index_other}\n", + "\n", + " new_indexes = {'x': toy_index_self.join(toy_index_other, how=how)}\n", + "\n", + " # need to return an index obj, but don't want to have to pass variables\n", + " # so need to add all of the things that index needs to new_indexes before passign it to return?\n", + "\n", + " # this will need to be generalized / tested more\n", + " new_indexes['spatial_ref'] = deepcopy(self.spatial_ref)\n", + " start = int(new_indexes['x'].index.array[0])\n", + " stop = int(new_indexes['x'].index.array[-1])\n", + " step = int((stop - start) / (len(new_indexes['x'].index.array) - 1))\n", + "\n", + " new_indexes['spatial_ref'].attrs['range'] = [start, stop, step]\n", + "\n", + " idx_var = xr.IndexVariable(\n", + " dims=new_indexes['x'].index.name, data=new_indexes['x'].index.array\n", + " )\n", + " attr_var = new_indexes['spatial_ref']\n", + "\n", + " idx_dict = {'x': idx_var, 'spatial_ref': attr_var}\n", + "\n", + " new_obj = type(self)(new_indexes)\n", + " new_obj.joined_var = idx_dict\n", + " return new_obj\n", + "\n", + " def reindex_like(self, other, method=None, tolerance=None):\n", + " \"\"\"\n", + " Reindexes the current object to match the index of another object.\n", + "\n", + " Parameters:\n", + " -----------\n", + " other : object\n", + " The object whose index will be used for reindexing.\n", + " method : str, optional\n", + " The method to use for reindexing. Default is None.\n", + " tolerance : float, optional\n", + " The tolerance value to use for reindexing. Default is None.\n", + "\n", + " Returns:\n", + " --------\n", + " dict\n", + " A dictionary containing the reindexed values.\n", + "\n", + " Raises:\n", + " -------\n", + " None\n", + "\n", + " Notes:\n", + " ------\n", + " This method reindexes the current object to match the index of the `other` object.\n", + " It uses the `method` and `tolerance` parameters to determine the reindexing behavior.\n", + " The reindexed values are returned as a dictionary.\n", + " \"\"\"\n", + "\n", + " params_self = self.spatial_ref.attrs['range']\n", + " full_arr_self = np.arange(params_self[0], params_self[1], params_self[2])\n", + " toy_index_self = PandasIndex(full_arr_self, dim='x')\n", + "\n", + " toy_index_other = other._xindexes['x']\n", + "\n", + " d = {'x': toy_index_self.index.get_indexer(other._xindexes['x'].index, method, tolerance)}\n", + "\n", + " return d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "# create new sample data\n", + "sample_ds1 = create_sample_data(make_kwargs(2, [0, 10, 1], 10))\n", + "sample_ds2 = create_sample_data(make_kwargs(5, [5, 15, 1], 10))\n", + "\n", + "\n", + "# create a copy used for testing later\n", + "orig_ds1 = sample_ds1.copy()\n", + "orig_ds2 = sample_ds2.copy()" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, + "source": [ + "*** reindex_like needs to return an object like variables to pass to create vars (?)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ds1 = sample_ds1.drop_indexes('x')\n", + "sample_ds2 = sample_ds2.drop_indexes('x')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "ds1 = sample_ds1.set_xindex(['x', 'spatial_ref'], CustomIndex)\n", + "ds2 = sample_ds2.set_xindex(['x', 'spatial_ref'], CustomIndex)" + ] + }, + { + "cell_type": "markdown", + "id": "39", + "metadata": {}, + "source": [ + "## Align" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [], + "source": [ + "# create sample data -- we define 2 for alignment\n", + "sample_ds1 = create_sample_data(make_kwargs(2, [0, 10, 1], 10))\n", + "sample_ds2 = create_sample_data(make_kwargs(5, [8, 18, 1], 10))\n", + "\n", + "# create copies used for testing later\n", + "orig_ds1 = sample_ds1.copy()\n", + "orig_ds2 = sample_ds2.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ds1 = sample_ds1.drop_indexes('x')\n", + "sample_ds2 = sample_ds2.drop_indexes('x')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [], + "source": [ + "ds1 = sample_ds1.set_xindex(['x', 'spatial_ref'], CustomIndex)\n", + "ds2 = sample_ds2.set_xindex(['x', 'spatial_ref'], CustomIndex)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [], + "source": [ + "inner_align, _ = xr.align(ds1, ds2, join='inner')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "outer_align, _ = xr.align(ds1, ds2, join='outer')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": {}, + "outputs": [], + "source": [ + "outer_align" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": {}, + "outputs": [], + "source": [ + "inner_align" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": {}, + "outputs": [], + "source": [ + "# reindex_like not implemented for PandasIndx\n", + "# but that defaults to inner, and these are successsfuly producing left and right so shouldn't be it\n", + "# left_align,_ = xr.align(ds1, ds2, join='left')\n", + "# right_align,_ = xr.align(ds1, ds2, join='right')\n", + "\n", + "# don't remember what above was about , is reindex like not implemented for left, right joins something like that ?" + ] + }, + { + "cell_type": "markdown", + "id": "48", + "metadata": {}, + "source": [ + "## Wrap up / summary\n", + "To do" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/intermediate/indexing/img2.png b/intermediate/indexing/img2.png new file mode 100644 index 00000000..6282ff3c Binary files /dev/null and b/intermediate/indexing/img2.png differ