Kvikio demo

Requires

  • [ ] https://github.com/pydata/xarray/pull/10078

  • [ ] https://github.com/rapidsai/kvikio/pull/646

%load_ext watermark
%xmode minimal

import cupy_xarray  # registers cupy accessor
import kvikio.zarr

import numpy as np
import xarray as xr
import zarr

%watermark -iv
Exception reporting mode: Minimal
numpy      : 2.2.3
zarr       : 3.0.5
cupy_xarray: 0.1.4+36.ge26ed24.dirty
kvikio     : 25.4.0
xarray     : 2025.1.3.dev22+g0184702f
xr.backends.list_engines()
{'netcdf4': <NetCDF4BackendEntrypoint>
   Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray
   Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html,
 'kvikio': <KvikioBackendEntrypoint>
   Open zarr files (.zarr) using Kvikio
   Learn more at https://docs.rapids.ai/api/kvikio/stable/api/#zarr,
 'store': <StoreBackendEntrypoint>
   Open AbstractDataStore instances in Xarray
   Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html,
 'zarr': <ZarrBackendEntrypoint>
   Open zarr files (.zarr) using zarr in Xarray
   Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html}

Create example dataset

  • cannot be compressed

store = "/tmp/air-temperature.zarr"
airt = xr.tutorial.open_dataset("air_temperature", engine="netcdf4")
for var in airt.variables:
    airt[var].encoding["compressors"] = None
airt["scalar"] = 12.0
airt.to_zarr(store, mode="w", zarr_format=3, consolidated=False)
/home/user/mambaforge/envs/cupy-xarray-doc/lib/python3.11/site-packages/xarray/core/dataset.py:2503: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
  return to_zarr(  # type: ignore[call-overload,misc]
<xarray.backends.zarr.ZarrStore at 0x7f44a21e57e0>

Test opening

Standard usage

ds_cpu = xr.open_dataset(store, engine="zarr")
print(ds_cpu.air.data.__class__)
ds_cpu.air
<class 'numpy.ndarray'>
/tmp/ipykernel_72617/982297347.py:1: RuntimeWarning: Failed to open Zarr store with consolidated metadata, but successfully read with non-consolidated metadata. This is typically much slower for opening a dataset. To silence this warning, consider:
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  ds_cpu = xr.open_dataset(store, engine="zarr")
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
array([[[241.2 , 242.5 , ..., 235.5 , 238.6 ],
        [243.8 , 244.5 , ..., 235.3 , 239.3 ],
        ...,
        [295.9 , 296.2 , ..., 295.9 , 295.2 ],
        [296.29, 296.79, ..., 296.79, 296.6 ]],

       [[242.1 , 242.7 , ..., 233.6 , 235.8 ],
        [243.6 , 244.1 , ..., 232.5 , 235.7 ],
        ...,
        [296.2 , 296.7 , ..., 295.5 , 295.1 ],
        [296.29, 297.2 , ..., 296.4 , 296.6 ]],

       ...,

       [[245.79, 244.79, ..., 243.99, 244.79],
        [249.89, 249.29, ..., 242.49, 244.29],
        ...,
        [296.29, 297.19, ..., 295.09, 294.39],
        [297.79, 298.39, ..., 295.49, 295.19]],

       [[245.09, 244.29, ..., 241.49, 241.79],
        [249.89, 249.29, ..., 240.29, 241.69],
        ...,
        [296.09, 296.89, ..., 295.69, 295.19],
        [297.69, 298.09, ..., 296.19, 295.69]]], shape=(2920, 25, 53))
Coordinates:
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]

Now with kvikio!

  • must read with consolidated=False (https://github.com/rapidsai/kvikio/issues/119)

  • dask.from_zarr to GDSStore / open_mfdataset

# Consolidated must be False
ds = xr.open_dataset(store, engine="kvikio", consolidated=False)
print(ds.air._variable._data)
ds
MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=<xarray.backends.zarr.ZarrArrayWrapper object at 0x7f449f3ed980>, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(<function _scale_offset_decoding at 0x7f44a35896c0>, scale_factor=0.01, add_offset=None, dtype=<class 'numpy.float64'>), dtype=dtype('float64')), key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None))))))
<xarray.Dataset> Size: 31MB
Dimensions:  (time: 2920, lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Data variables:
    scalar   float64 8B ...
    air      (time, lat, lon) float64 31MB ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
ds.scalar
<xarray.DataArray 'scalar' ()> Size: 8B
[1 values with dtype=float64]

Lazy reading

ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
[3869000 values with dtype=float64]
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]

Data load for repr

ds["air"].isel(time=0, lat=10).load()
<xarray.DataArray 'air' (lon: 53)> Size: 424B
array([277.29, 277.4 , 277.79, 278.6 , 279.5 , 280.1 , 280.6 , 280.9 ,
       280.79, 280.7 , 280.79, 281.  , 280.29, 277.7 , 273.5 , 269.  ,
       265.5 , 264.  , 265.2 , 268.1 , 269.79, 267.9 , 263.  , 258.1 ,
       254.6 , 251.8 , 249.6 , 249.89, 252.3 , 254.  , 254.3 , 255.89,
       260.  , 263.  , 261.5 , 257.29, 255.5 , 258.29, 264.  , 268.7 ,
       270.5 , 270.6 , 271.2 , 272.9 , 274.79, 276.4 , 278.2 , 280.5 ,
       282.9 , 284.7 , 286.1 , 286.9 , 286.6 ])
Coordinates:
    lat      float32 4B 50.0
    time     datetime64[ns] 8B 2013-01-01
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]
ds.scalar
<xarray.DataArray 'scalar' ()> Size: 8B
[1 values with dtype=float64]
ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
[3869000 values with dtype=float64]
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]

CuPy array on load

Configure Zarr to use GPU memory by setting zarr.config.enable_gpu().

See https://zarr.readthedocs.io/en/stable/user-guide/gpu.html#using-gpus-with-zarr

ds["air"].isel(time=0, lat=10).variable._data
MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=<xarray.backends.zarr.ZarrArrayWrapper object at 0x7f449f3ed980>, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(<function _scale_offset_decoding at 0x7f44a35896c0>, scale_factor=0.01, add_offset=None, dtype=<class 'numpy.float64'>), dtype=dtype('float64')), key=BasicIndexer((0, 10, slice(None, None, None))))))
with zarr.config.enable_gpu():
    print(type(ds["air"].isel(time=0, lat=10).load().data))
<class 'cupy.ndarray'>

Load to host

zarr.config.enable_gpu()
<donfig.config_obj.ConfigSet at 0x7f449e250d50>
ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
[3869000 values with dtype=float64]
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]
print(type(ds["air"].data))
<class 'cupy.ndarray'>
type(ds.air.as_numpy().data)
numpy.ndarray
type(ds.air.mean("time").load().data)
cupy.ndarray

Doesn’t work: Chunk with dask

meta is wrong

ds.chunk(time=10).air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]

dask.array.core.getter calls np.asarray on each chunk.

This calls ImplicitToExplicitIndexingAdapter.__array__ which calls np.asarray(cupy.array) which raises.

Xarray uses .get_duck_array internally to remove these adapters. We might need to add

# handle xarray internal classes that might wrap cupy
if hasattr(c, "get_duck_array"):
    c = c.get_duck_array()
else:
    c = np.asarray(c)
from dask.utils import is_arraylike

data = ds.air.variable._data
is_arraylike(data)
False
from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
ImplicitToExplicitIndexingAdapter(data).get_duck_array()
array([[[241.2 , 242.5 , 243.5 , ..., 232.8 , 235.5 , 238.6 ],
        [243.8 , 244.5 , 244.7 , ..., 232.8 , 235.3 , 239.3 ],
        [250.  , 249.8 , 248.89, ..., 233.2 , 236.39, 241.7 ],
        ...,
        [296.6 , 296.2 , 296.4 , ..., 295.4 , 295.1 , 294.7 ],
        [295.9 , 296.2 , 296.79, ..., 295.9 , 295.9 , 295.2 ],
        [296.29, 296.79, 297.1 , ..., 296.9 , 296.79, 296.6 ]],

       [[242.1 , 242.7 , 243.1 , ..., 232.  , 233.6 , 235.8 ],
        [243.6 , 244.1 , 244.2 , ..., 231.  , 232.5 , 235.7 ],
        [253.2 , 252.89, 252.1 , ..., 230.8 , 233.39, 238.5 ],
        ...,
        [296.4 , 295.9 , 296.2 , ..., 295.4 , 295.1 , 294.79],
        [296.2 , 296.7 , 296.79, ..., 295.6 , 295.5 , 295.1 ],
        [296.29, 297.2 , 297.4 , ..., 296.4 , 296.4 , 296.6 ]],

       [[242.3 , 242.2 , 242.3 , ..., 234.3 , 236.1 , 238.7 ],
        [244.6 , 244.39, 244.  , ..., 230.3 , 232.  , 235.7 ],
        [256.2 , 255.5 , 254.2 , ..., 231.2 , 233.2 , 238.2 ],
        ...,
        [295.6 , 295.4 , 295.4 , ..., 296.29, 295.29, 295.  ],
        [296.2 , 296.5 , 296.29, ..., 296.4 , 296.  , 295.6 ],
        [296.4 , 296.29, 296.4 , ..., 297.  , 297.  , 296.79]],

       ...,

       [[243.49, 242.99, 242.09, ..., 244.19, 244.49, 244.89],
        [249.09, 248.99, 248.59, ..., 240.59, 241.29, 242.69],
        [262.69, 262.19, 261.69, ..., 239.39, 241.69, 245.19],
        ...,
        [294.79, 295.29, 297.49, ..., 295.49, 295.39, 294.69],
        [296.79, 297.89, 298.29, ..., 295.49, 295.49, 294.79],
        [298.19, 299.19, 298.79, ..., 296.09, 295.79, 295.79]],

       [[245.79, 244.79, 243.49, ..., 243.29, 243.99, 244.79],
        [249.89, 249.29, 248.49, ..., 241.29, 242.49, 244.29],
        [262.39, 261.79, 261.29, ..., 240.49, 243.09, 246.89],
        ...,
        [293.69, 293.89, 295.39, ..., 295.09, 294.69, 294.29],
        [296.29, 297.19, 297.59, ..., 295.29, 295.09, 294.39],
        [297.79, 298.39, 298.49, ..., 295.69, 295.49, 295.19]],

       [[245.09, 244.29, 243.29, ..., 241.69, 241.49, 241.79],
        [249.89, 249.29, 248.39, ..., 239.59, 240.29, 241.69],
        [262.99, 262.19, 261.39, ..., 239.89, 242.59, 246.29],
        ...,
        [293.79, 293.69, 295.09, ..., 295.29, 295.09, 294.69],
        [296.09, 296.89, 297.19, ..., 295.69, 295.69, 295.19],
        [297.69, 298.09, 298.09, ..., 296.49, 296.19, 295.69]]],
      shape=(2920, 25, 53))
ds.chunk(time=10).air.compute()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.

explicit meta

import cupy as cp

chunked = ds.chunk(time=10, from_array_kwargs={"meta": cp.array([])})
chunked.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16000366210938, 322.1000061035156]
chunked.compute()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.