{ "cells": [ { "cell_type": "markdown", "id": "5920bb97-1d76-4363-9aee-d1c5cd395409", "metadata": {}, "source": [ "# Kvikio demo\n", "\n", "Requires\n", "- [ ] https://github.com/pydata/xarray/pull/10078\n", "- [ ] https://github.com/rapidsai/kvikio/pull/646" ] }, { "cell_type": "code", "execution_count": 1, "id": "c9ee3a73-6f7b-4875-b5a6-2e6d48fade44", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Exception reporting mode: Minimal\n", "numpy : 2.2.3\n", "zarr : 3.0.5\n", "cupy_xarray: 0.1.4+36.ge26ed24.dirty\n", "kvikio : 25.4.0\n", "xarray : 2025.1.3.dev22+g0184702f\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%xmode minimal\n", "\n", "import cupy_xarray # registers cupy accessor\n", "import kvikio.zarr\n", "\n", "import numpy as np\n", "import xarray as xr\n", "import zarr\n", "\n", "%watermark -iv" ] }, { "cell_type": "code", "execution_count": 2, "id": "83b1b514-eeb8-4a81-a3e8-3a7dc82ffce4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'netcdf4': \n", " Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray\n", " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html,\n", " 'kvikio': \n", " Open zarr files (.zarr) using Kvikio\n", " Learn more at https://docs.rapids.ai/api/kvikio/stable/api/#zarr,\n", " 'store': \n", " Open AbstractDataStore instances in Xarray\n", " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html,\n", " 'zarr': \n", " Open zarr files (.zarr) using zarr in Xarray\n", " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xr.backends.list_engines()" ] }, { "cell_type": "markdown", "id": "5f12848d-a5ec-4cea-9a49-4f2bcefd9114", "metadata": { "tags": [] }, "source": [ "## Create example dataset\n", "\n", "- cannot be compressed" ] }, { "cell_type": "code", "execution_count": 3, "id": "d481cc3b-420e-4b7c-8c5e-77d874128b12", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/user/mambaforge/envs/cupy-xarray-doc/lib/python3.11/site-packages/xarray/core/dataset.py:2503: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs\n", " return to_zarr( # type: ignore[call-overload,misc]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "store = \"/tmp/air-temperature.zarr\"\n", "airt = xr.tutorial.open_dataset(\"air_temperature\", engine=\"netcdf4\")\n", "for var in airt.variables:\n", " airt[var].encoding[\"compressors\"] = None\n", "airt[\"scalar\"] = 12.0\n", "airt.to_zarr(store, mode=\"w\", zarr_format=3, consolidated=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "1a3d0ec7-22fb-4558-8e60-9627266e3111", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "883d5507-988f-453a-b576-87bb563b540f", "metadata": { "tags": [] }, "source": [ "## Test opening\n", "\n", "### Standard usage" ] }, { "cell_type": "code", "execution_count": 4, "id": "4a9ba63c-0b29-4eb8-9171-965b90071496", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_72617/982297347.py:1: RuntimeWarning: Failed to open Zarr store with consolidated metadata, but successfully read with non-consolidated metadata. This is typically much slower for opening a dataset. To silence this warning, consider:\n", "1. Consolidating metadata in this existing store with zarr.consolidate_metadata().\n", "2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or\n", "3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.\n", " ds_cpu = xr.open_dataset(store, engine=\"zarr\")\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
       "array([[[241.2 , 242.5 , ..., 235.5 , 238.6 ],\n",
       "        [243.8 , 244.5 , ..., 235.3 , 239.3 ],\n",
       "        ...,\n",
       "        [295.9 , 296.2 , ..., 295.9 , 295.2 ],\n",
       "        [296.29, 296.79, ..., 296.79, 296.6 ]],\n",
       "\n",
       "       [[242.1 , 242.7 , ..., 233.6 , 235.8 ],\n",
       "        [243.6 , 244.1 , ..., 232.5 , 235.7 ],\n",
       "        ...,\n",
       "        [296.2 , 296.7 , ..., 295.5 , 295.1 ],\n",
       "        [296.29, 297.2 , ..., 296.4 , 296.6 ]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[245.79, 244.79, ..., 243.99, 244.79],\n",
       "        [249.89, 249.29, ..., 242.49, 244.29],\n",
       "        ...,\n",
       "        [296.29, 297.19, ..., 295.09, 294.39],\n",
       "        [297.79, 298.39, ..., 295.49, 295.19]],\n",
       "\n",
       "       [[245.09, 244.29, ..., 241.49, 241.79],\n",
       "        [249.89, 249.29, ..., 240.29, 241.69],\n",
       "        ...,\n",
       "        [296.09, 296.89, ..., 295.69, 295.19],\n",
       "        [297.69, 298.09, ..., 296.19, 295.69]]], shape=(2920, 25, 53))\n",
       "Coordinates:\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 31MB\n", "array([[[241.2 , 242.5 , ..., 235.5 , 238.6 ],\n", " [243.8 , 244.5 , ..., 235.3 , 239.3 ],\n", " ...,\n", " [295.9 , 296.2 , ..., 295.9 , 295.2 ],\n", " [296.29, 296.79, ..., 296.79, 296.6 ]],\n", "\n", " [[242.1 , 242.7 , ..., 233.6 , 235.8 ],\n", " [243.6 , 244.1 , ..., 232.5 , 235.7 ],\n", " ...,\n", " [296.2 , 296.7 , ..., 295.5 , 295.1 ],\n", " [296.29, 297.2 , ..., 296.4 , 296.6 ]],\n", "\n", " ...,\n", "\n", " [[245.79, 244.79, ..., 243.99, 244.79],\n", " [249.89, 249.29, ..., 242.49, 244.29],\n", " ...,\n", " [296.29, 297.19, ..., 295.09, 294.39],\n", " [297.79, 298.39, ..., 295.49, 295.19]],\n", "\n", " [[245.09, 244.29, ..., 241.49, 241.79],\n", " [249.89, 249.29, ..., 240.29, 241.69],\n", " ...,\n", " [296.09, 296.89, ..., 295.69, 295.19],\n", " [297.69, 298.09, ..., 296.19, 295.69]]], shape=(2920, 25, 53))\n", "Coordinates:\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds_cpu = xr.open_dataset(store, engine=\"zarr\")\n", "print(ds_cpu.air.data.__class__)\n", "ds_cpu.air" ] }, { "cell_type": "markdown", "id": "95161182-6b58-4dbd-9752-9961c251be1a", "metadata": {}, "source": [ "### Now with kvikio!\n", "\n", " - must read with `consolidated=False` (https://github.com/rapidsai/kvikio/issues/119)\n", " - dask.from_zarr to GDSStore / open_mfdataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "8fd27bdf-e317-4de3-891e-41d38d06dcaf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float64')), key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None))))))\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 31MB\n",
       "Dimensions:  (time: 2920, lat: 25, lon: 53)\n",
       "Coordinates:\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Data variables:\n",
       "    scalar   float64 8B ...\n",
       "    air      (time, lat, lon) float64 31MB ...\n",
       "Attributes:\n",
       "    Conventions:  COARDS\n",
       "    title:        4x daily NMC reanalysis (1948)\n",
       "    description:  Data is from NMC initialized reanalysis\\n(4x/day).  These a...\n",
       "    platform:     Model\n",
       "    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
" ], "text/plain": [ " Size: 31MB\n", "Dimensions: (time: 2920, lat: 25, lon: 53)\n", "Coordinates:\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Data variables:\n", " scalar float64 8B ...\n", " air (time, lat, lon) float64 31MB ...\n", "Attributes:\n", " Conventions: COARDS\n", " title: 4x daily NMC reanalysis (1948)\n", " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", " platform: Model\n", " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly..." ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Consolidated must be False\n", "ds = xr.open_dataset(store, engine=\"kvikio\", consolidated=False)\n", "print(ds.air._variable._data)\n", "ds" ] }, { "cell_type": "code", "execution_count": 6, "id": "6c939a04-1588-4693-9483-c6ad7152951a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'scalar' ()> Size: 8B\n",
       "[1 values with dtype=float64]
" ], "text/plain": [ " Size: 8B\n", "[1 values with dtype=float64]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.scalar" ] }, { "cell_type": "markdown", "id": "bb84a7ad-84dc-4bb3-8636-3f9416953089", "metadata": { "tags": [] }, "source": [ "## Lazy reading" ] }, { "cell_type": "code", "execution_count": 7, "id": "1ecc39b1-b788-4831-9160-5b35afb83598", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
       "[3869000 values with dtype=float64]\n",
       "Coordinates:\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 31MB\n", "[3869000 values with dtype=float64]\n", "Coordinates:\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.air" ] }, { "cell_type": "markdown", "id": "7d366864-a2b3-4573-9bf7-41d1f6ee457c", "metadata": { "tags": [] }, "source": [ "## Data load for repr" ] }, { "cell_type": "code", "execution_count": 8, "id": "00205e73-9b43-4254-9cba-f75435251391", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (lon: 53)> Size: 424B\n",
       "array([277.29, 277.4 , 277.79, 278.6 , 279.5 , 280.1 , 280.6 , 280.9 ,\n",
       "       280.79, 280.7 , 280.79, 281.  , 280.29, 277.7 , 273.5 , 269.  ,\n",
       "       265.5 , 264.  , 265.2 , 268.1 , 269.79, 267.9 , 263.  , 258.1 ,\n",
       "       254.6 , 251.8 , 249.6 , 249.89, 252.3 , 254.  , 254.3 , 255.89,\n",
       "       260.  , 263.  , 261.5 , 257.29, 255.5 , 258.29, 264.  , 268.7 ,\n",
       "       270.5 , 270.6 , 271.2 , 272.9 , 274.79, 276.4 , 278.2 , 280.5 ,\n",
       "       282.9 , 284.7 , 286.1 , 286.9 , 286.6 ])\n",
       "Coordinates:\n",
       "    lat      float32 4B 50.0\n",
       "    time     datetime64[ns] 8B 2013-01-01\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 424B\n", "array([277.29, 277.4 , 277.79, 278.6 , 279.5 , 280.1 , 280.6 , 280.9 ,\n", " 280.79, 280.7 , 280.79, 281. , 280.29, 277.7 , 273.5 , 269. ,\n", " 265.5 , 264. , 265.2 , 268.1 , 269.79, 267.9 , 263. , 258.1 ,\n", " 254.6 , 251.8 , 249.6 , 249.89, 252.3 , 254. , 254.3 , 255.89,\n", " 260. , 263. , 261.5 , 257.29, 255.5 , 258.29, 264. , 268.7 ,\n", " 270.5 , 270.6 , 271.2 , 272.9 , 274.79, 276.4 , 278.2 , 280.5 ,\n", " 282.9 , 284.7 , 286.1 , 286.9 , 286.6 ])\n", "Coordinates:\n", " lat float32 4B 50.0\n", " time datetime64[ns] 8B 2013-01-01\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds[\"air\"].isel(time=0, lat=10).load()" ] }, { "cell_type": "code", "execution_count": 9, "id": "80aa6892-8c7f-44b3-bd52-9795ec4ea6f3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'scalar' ()> Size: 8B\n",
       "[1 values with dtype=float64]
" ], "text/plain": [ " Size: 8B\n", "[1 values with dtype=float64]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.scalar" ] }, { "cell_type": "code", "execution_count": 10, "id": "ba48a2c0-96e0-41d7-9e07-381e05e8dc33", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
       "[3869000 values with dtype=float64]\n",
       "Coordinates:\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 31MB\n", "[3869000 values with dtype=float64]\n", "Coordinates:\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.air" ] }, { "cell_type": "markdown", "id": "d0ea31d2-6c52-4346-b489-fc1e43200213", "metadata": { "tags": [] }, "source": [ "## CuPy array on load\n", "\n", "Configure Zarr to use GPU memory by setting `zarr.config.enable_gpu()`.\n", "\n", "See https://zarr.readthedocs.io/en/stable/user-guide/gpu.html#using-gpus-with-zarr" ] }, { "cell_type": "code", "execution_count": 11, "id": "1b34a68a-a6b3-4273-bf7c-28814ebfce11", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float64')), key=BasicIndexer((0, 10, slice(None, None, None))))))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds[\"air\"].isel(time=0, lat=10).variable._data" ] }, { "cell_type": "code", "execution_count": 12, "id": "db69559c-1fde-4b3b-914d-87d8437ec256", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "with zarr.config.enable_gpu():\n", " print(type(ds[\"air\"].isel(time=0, lat=10).load().data))" ] }, { "cell_type": "markdown", "id": "d34a5cce-7bbc-408f-b643-05da1e121c78", "metadata": { "tags": [] }, "source": [ "## Load to host" ] }, { "cell_type": "code", "execution_count": 13, "id": "84094bc6-7884-414a-89cf-4526c3a54aea", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "zarr.config.enable_gpu()" ] }, { "cell_type": "code", "execution_count": 14, "id": "09b40d7d-ed38-4a50-af11-c2e5f0242a97", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
       "[3869000 values with dtype=float64]\n",
       "Coordinates:\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 31MB\n", "[3869000 values with dtype=float64]\n", "Coordinates:\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.air" ] }, { "cell_type": "code", "execution_count": 15, "id": "615efd76-2194-4604-9ab8-61499e7d725d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "print(type(ds[\"air\"].data))" ] }, { "cell_type": "code", "execution_count": 16, "id": "eeb9ad78-1353-464f-8419-4c44ea499f17", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "numpy.ndarray" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(ds.air.as_numpy().data)" ] }, { "cell_type": "code", "execution_count": 17, "id": "140fe3e2-ea9b-445d-8401-5c624384c182", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "cupy.ndarray" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(ds.air.mean(\"time\").load().data)" ] }, { "cell_type": "markdown", "id": "cab539a7-d952-4b38-b515-712c52c62501", "metadata": { "tags": [] }, "source": [ "## Doesn't work: Chunk with dask" ] }, { "cell_type": "markdown", "id": "62c084eb-8df4-4b7f-a187-a736d68d430d", "metadata": {}, "source": [ "`meta` is wrong" ] }, { "cell_type": "code", "execution_count": 18, "id": "68f93bfe-fe56-488a-a10b-dc4f48029367", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
       "dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 31MB\n", "dask.array\n", "Coordinates:\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.chunk(time=10).air" ] }, { "cell_type": "markdown", "id": "3f4c72f6-22e7-4e99-9f4e-2524d6ab4226", "metadata": {}, "source": [ "`dask.array.core.getter` calls `np.asarray` on each chunk.\n", "\n", "This calls `ImplicitToExplicitIndexingAdapter.__array__` which calls `np.asarray(cupy.array)` which raises.\n", "\n", "Xarray uses `.get_duck_array` internally to remove these adapters. We might need to add\n", "```python\n", "# handle xarray internal classes that might wrap cupy\n", "if hasattr(c, \"get_duck_array\"):\n", " c = c.get_duck_array()\n", "else:\n", " c = np.asarray(c)\n", "```" ] }, { "cell_type": "code", "execution_count": 19, "id": "e1256d03-9701-433a-8291-80dc8dccffce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask.utils import is_arraylike\n", "\n", "data = ds.air.variable._data\n", "is_arraylike(data)" ] }, { "cell_type": "code", "execution_count": 20, "id": "308affa5-9fb9-4638-989b-97aac2604c16", "metadata": {}, "outputs": [], "source": [ "from xarray.core.indexing import ImplicitToExplicitIndexingAdapter" ] }, { "cell_type": "code", "execution_count": 21, "id": "985cd2f8-406e-4e9e-8017-42efb16aa40e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[[241.2 , 242.5 , 243.5 , ..., 232.8 , 235.5 , 238.6 ],\n", " [243.8 , 244.5 , 244.7 , ..., 232.8 , 235.3 , 239.3 ],\n", " [250. , 249.8 , 248.89, ..., 233.2 , 236.39, 241.7 ],\n", " ...,\n", " [296.6 , 296.2 , 296.4 , ..., 295.4 , 295.1 , 294.7 ],\n", " [295.9 , 296.2 , 296.79, ..., 295.9 , 295.9 , 295.2 ],\n", " [296.29, 296.79, 297.1 , ..., 296.9 , 296.79, 296.6 ]],\n", "\n", " [[242.1 , 242.7 , 243.1 , ..., 232. , 233.6 , 235.8 ],\n", " [243.6 , 244.1 , 244.2 , ..., 231. , 232.5 , 235.7 ],\n", " [253.2 , 252.89, 252.1 , ..., 230.8 , 233.39, 238.5 ],\n", " ...,\n", " [296.4 , 295.9 , 296.2 , ..., 295.4 , 295.1 , 294.79],\n", " [296.2 , 296.7 , 296.79, ..., 295.6 , 295.5 , 295.1 ],\n", " [296.29, 297.2 , 297.4 , ..., 296.4 , 296.4 , 296.6 ]],\n", "\n", " [[242.3 , 242.2 , 242.3 , ..., 234.3 , 236.1 , 238.7 ],\n", " [244.6 , 244.39, 244. , ..., 230.3 , 232. , 235.7 ],\n", " [256.2 , 255.5 , 254.2 , ..., 231.2 , 233.2 , 238.2 ],\n", " ...,\n", " [295.6 , 295.4 , 295.4 , ..., 296.29, 295.29, 295. ],\n", " [296.2 , 296.5 , 296.29, ..., 296.4 , 296. , 295.6 ],\n", " [296.4 , 296.29, 296.4 , ..., 297. , 297. , 296.79]],\n", "\n", " ...,\n", "\n", " [[243.49, 242.99, 242.09, ..., 244.19, 244.49, 244.89],\n", " [249.09, 248.99, 248.59, ..., 240.59, 241.29, 242.69],\n", " [262.69, 262.19, 261.69, ..., 239.39, 241.69, 245.19],\n", " ...,\n", " [294.79, 295.29, 297.49, ..., 295.49, 295.39, 294.69],\n", " [296.79, 297.89, 298.29, ..., 295.49, 295.49, 294.79],\n", " [298.19, 299.19, 298.79, ..., 296.09, 295.79, 295.79]],\n", "\n", " [[245.79, 244.79, 243.49, ..., 243.29, 243.99, 244.79],\n", " [249.89, 249.29, 248.49, ..., 241.29, 242.49, 244.29],\n", " [262.39, 261.79, 261.29, ..., 240.49, 243.09, 246.89],\n", " ...,\n", " [293.69, 293.89, 295.39, ..., 295.09, 294.69, 294.29],\n", " [296.29, 297.19, 297.59, ..., 295.29, 295.09, 294.39],\n", " [297.79, 298.39, 298.49, ..., 295.69, 295.49, 295.19]],\n", "\n", " [[245.09, 244.29, 243.29, ..., 241.69, 241.49, 241.79],\n", " [249.89, 249.29, 248.39, ..., 239.59, 240.29, 241.69],\n", " [262.99, 262.19, 261.39, ..., 239.89, 242.59, 246.29],\n", " ...,\n", " [293.79, 293.69, 295.09, ..., 295.29, 295.09, 294.69],\n", " [296.09, 296.89, 297.19, ..., 295.69, 295.69, 295.19],\n", " [297.69, 298.09, 298.09, ..., 296.49, 296.19, 295.69]]],\n", " shape=(2920, 25, 53))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ImplicitToExplicitIndexingAdapter(data).get_duck_array()" ] }, { "cell_type": "code", "execution_count": 22, "id": "fa8ef4f7-5014-476f-b4c0-ec2f9abdb6e2", "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.", "output_type": "error", "traceback": [ "\u001b[31mTypeError\u001b[39m\u001b[31m:\u001b[39m Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.\n" ] } ], "source": [ "ds.chunk(time=10).air.compute()" ] }, { "cell_type": "markdown", "id": "17dc1bf6-7548-4eee-a5f3-ebcc20d41567", "metadata": {}, "source": [ "### explicit meta" ] }, { "cell_type": "code", "execution_count": 23, "id": "cdd4b4e6-d69a-4898-964a-0e6096ca1942", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
       "dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
       "Attributes:\n",
       "    long_name:     4xDaily Air temperature at sigma level 995\n",
       "    units:         degK\n",
       "    precision:     2\n",
       "    GRIB_id:       11\n",
       "    GRIB_name:     TMP\n",
       "    var_desc:      Air temperature\n",
       "    dataset:       NMC Reanalysis\n",
       "    level_desc:    Surface\n",
       "    statistic:     Individual Obs\n",
       "    parent_stat:   Other\n",
       "    actual_range:  [185.16000366210938, 322.1000061035156]
" ], "text/plain": [ " Size: 31MB\n", "dask.array\n", "Coordinates:\n", " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", "Attributes:\n", " long_name: 4xDaily Air temperature at sigma level 995\n", " units: degK\n", " precision: 2\n", " GRIB_id: 11\n", " GRIB_name: TMP\n", " var_desc: Air temperature\n", " dataset: NMC Reanalysis\n", " level_desc: Surface\n", " statistic: Individual Obs\n", " parent_stat: Other\n", " actual_range: [185.16000366210938, 322.1000061035156]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import cupy as cp\n", "\n", "chunked = ds.chunk(time=10, from_array_kwargs={\"meta\": cp.array([])})\n", "chunked.air" ] }, { "cell_type": "code", "execution_count": 24, "id": "74f80d94-ebb6-43c3-9411-79e0442d894e", "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.", "output_type": "error", "traceback": [ "\u001b[31mTypeError\u001b[39m\u001b[31m:\u001b[39m Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.\n" ] } ], "source": [ "chunked.compute()" ] }, { "cell_type": "code", "execution_count": null, "id": "ac543634-80be-4e44-83e8-9e95a4955030", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cupy-xarray-doc", "language": "python", "name": "cupy-xarray-doc" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }