Kvikio demo¶
Requires
[ ] https://github.com/pydata/xarray/pull/10078
[ ] https://github.com/rapidsai/kvikio/pull/646
%load_ext watermark
%xmode minimal
import cupy_xarray # registers cupy accessor
import kvikio.zarr
import numpy as np
import xarray as xr
import zarr
%watermark -iv
Exception reporting mode: Minimal
numpy : 2.2.3
zarr : 3.0.5
cupy_xarray: 0.1.4+36.ge26ed24.dirty
kvikio : 25.4.0
xarray : 2025.1.3.dev22+g0184702f
xr.backends.list_engines()
{'netcdf4': <NetCDF4BackendEntrypoint>
Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray
Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html,
'kvikio': <KvikioBackendEntrypoint>
Open zarr files (.zarr) using Kvikio
Learn more at https://docs.rapids.ai/api/kvikio/stable/api/#zarr,
'store': <StoreBackendEntrypoint>
Open AbstractDataStore instances in Xarray
Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html,
'zarr': <ZarrBackendEntrypoint>
Open zarr files (.zarr) using zarr in Xarray
Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html}
Create example dataset¶
cannot be compressed
store = "/tmp/air-temperature.zarr"
airt = xr.tutorial.open_dataset("air_temperature", engine="netcdf4")
for var in airt.variables:
airt[var].encoding["compressors"] = None
airt["scalar"] = 12.0
airt.to_zarr(store, mode="w", zarr_format=3, consolidated=False)
/home/user/mambaforge/envs/cupy-xarray-doc/lib/python3.11/site-packages/xarray/core/dataset.py:2503: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
<xarray.backends.zarr.ZarrStore at 0x7f44a21e57e0>
Test opening¶
Standard usage¶
ds_cpu = xr.open_dataset(store, engine="zarr")
print(ds_cpu.air.data.__class__)
ds_cpu.air
<class 'numpy.ndarray'>
/tmp/ipykernel_72617/982297347.py:1: RuntimeWarning: Failed to open Zarr store with consolidated metadata, but successfully read with non-consolidated metadata. This is typically much slower for opening a dataset. To silence this warning, consider:
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
ds_cpu = xr.open_dataset(store, engine="zarr")
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB array([[[241.2 , 242.5 , ..., 235.5 , 238.6 ], [243.8 , 244.5 , ..., 235.3 , 239.3 ], ..., [295.9 , 296.2 , ..., 295.9 , 295.2 ], [296.29, 296.79, ..., 296.79, 296.6 ]], [[242.1 , 242.7 , ..., 233.6 , 235.8 ], [243.6 , 244.1 , ..., 232.5 , 235.7 ], ..., [296.2 , 296.7 , ..., 295.5 , 295.1 ], [296.29, 297.2 , ..., 296.4 , 296.6 ]], ..., [[245.79, 244.79, ..., 243.99, 244.79], [249.89, 249.29, ..., 242.49, 244.29], ..., [296.29, 297.19, ..., 295.09, 294.39], [297.79, 298.39, ..., 295.49, 295.19]], [[245.09, 244.29, ..., 241.49, 241.79], [249.89, 249.29, ..., 240.29, 241.69], ..., [296.09, 296.89, ..., 295.69, 295.19], [297.69, 298.09, ..., 296.19, 295.69]]], shape=(2920, 25, 53)) Coordinates: * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
Now with kvikio!¶
must read with
consolidated=False
(https://github.com/rapidsai/kvikio/issues/119)dask.from_zarr to GDSStore / open_mfdataset
# Consolidated must be False
ds = xr.open_dataset(store, engine="kvikio", consolidated=False)
print(ds.air._variable._data)
ds
MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=<xarray.backends.zarr.ZarrArrayWrapper object at 0x7f449f3ed980>, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(<function _scale_offset_decoding at 0x7f44a35896c0>, scale_factor=0.01, add_offset=None, dtype=<class 'numpy.float64'>), dtype=dtype('float64')), key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None))))))
<xarray.Dataset> Size: 31MB Dimensions: (time: 2920, lat: 25, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Data variables: scalar float64 8B ... air (time, lat, lon) float64 31MB ... Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
ds.scalar
<xarray.DataArray 'scalar' ()> Size: 8B [1 values with dtype=float64]
Lazy reading¶
ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB [3869000 values with dtype=float64] Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
Data load for repr¶
ds["air"].isel(time=0, lat=10).load()
<xarray.DataArray 'air' (lon: 53)> Size: 424B array([277.29, 277.4 , 277.79, 278.6 , 279.5 , 280.1 , 280.6 , 280.9 , 280.79, 280.7 , 280.79, 281. , 280.29, 277.7 , 273.5 , 269. , 265.5 , 264. , 265.2 , 268.1 , 269.79, 267.9 , 263. , 258.1 , 254.6 , 251.8 , 249.6 , 249.89, 252.3 , 254. , 254.3 , 255.89, 260. , 263. , 261.5 , 257.29, 255.5 , 258.29, 264. , 268.7 , 270.5 , 270.6 , 271.2 , 272.9 , 274.79, 276.4 , 278.2 , 280.5 , 282.9 , 284.7 , 286.1 , 286.9 , 286.6 ]) Coordinates: lat float32 4B 50.0 time datetime64[ns] 8B 2013-01-01 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
ds.scalar
<xarray.DataArray 'scalar' ()> Size: 8B [1 values with dtype=float64]
ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB [3869000 values with dtype=float64] Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
CuPy array on load¶
Configure Zarr to use GPU memory by setting zarr.config.enable_gpu()
.
See https://zarr.readthedocs.io/en/stable/user-guide/gpu.html#using-gpus-with-zarr
ds["air"].isel(time=0, lat=10).variable._data
MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=<xarray.backends.zarr.ZarrArrayWrapper object at 0x7f449f3ed980>, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(<function _scale_offset_decoding at 0x7f44a35896c0>, scale_factor=0.01, add_offset=None, dtype=<class 'numpy.float64'>), dtype=dtype('float64')), key=BasicIndexer((0, 10, slice(None, None, None))))))
with zarr.config.enable_gpu():
print(type(ds["air"].isel(time=0, lat=10).load().data))
<class 'cupy.ndarray'>
Load to host¶
zarr.config.enable_gpu()
<donfig.config_obj.ConfigSet at 0x7f449e250d50>
ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB [3869000 values with dtype=float64] Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
print(type(ds["air"].data))
<class 'cupy.ndarray'>
type(ds.air.as_numpy().data)
numpy.ndarray
type(ds.air.mean("time").load().data)
cupy.ndarray
Doesn’t work: Chunk with dask¶
meta
is wrong
ds.chunk(time=10).air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray> Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
dask.array.core.getter
calls np.asarray
on each chunk.
This calls ImplicitToExplicitIndexingAdapter.__array__
which calls np.asarray(cupy.array)
which raises.
Xarray uses .get_duck_array
internally to remove these adapters. We might need to add
# handle xarray internal classes that might wrap cupy
if hasattr(c, "get_duck_array"):
c = c.get_duck_array()
else:
c = np.asarray(c)
from dask.utils import is_arraylike
data = ds.air.variable._data
is_arraylike(data)
False
from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
ImplicitToExplicitIndexingAdapter(data).get_duck_array()
array([[[241.2 , 242.5 , 243.5 , ..., 232.8 , 235.5 , 238.6 ],
[243.8 , 244.5 , 244.7 , ..., 232.8 , 235.3 , 239.3 ],
[250. , 249.8 , 248.89, ..., 233.2 , 236.39, 241.7 ],
...,
[296.6 , 296.2 , 296.4 , ..., 295.4 , 295.1 , 294.7 ],
[295.9 , 296.2 , 296.79, ..., 295.9 , 295.9 , 295.2 ],
[296.29, 296.79, 297.1 , ..., 296.9 , 296.79, 296.6 ]],
[[242.1 , 242.7 , 243.1 , ..., 232. , 233.6 , 235.8 ],
[243.6 , 244.1 , 244.2 , ..., 231. , 232.5 , 235.7 ],
[253.2 , 252.89, 252.1 , ..., 230.8 , 233.39, 238.5 ],
...,
[296.4 , 295.9 , 296.2 , ..., 295.4 , 295.1 , 294.79],
[296.2 , 296.7 , 296.79, ..., 295.6 , 295.5 , 295.1 ],
[296.29, 297.2 , 297.4 , ..., 296.4 , 296.4 , 296.6 ]],
[[242.3 , 242.2 , 242.3 , ..., 234.3 , 236.1 , 238.7 ],
[244.6 , 244.39, 244. , ..., 230.3 , 232. , 235.7 ],
[256.2 , 255.5 , 254.2 , ..., 231.2 , 233.2 , 238.2 ],
...,
[295.6 , 295.4 , 295.4 , ..., 296.29, 295.29, 295. ],
[296.2 , 296.5 , 296.29, ..., 296.4 , 296. , 295.6 ],
[296.4 , 296.29, 296.4 , ..., 297. , 297. , 296.79]],
...,
[[243.49, 242.99, 242.09, ..., 244.19, 244.49, 244.89],
[249.09, 248.99, 248.59, ..., 240.59, 241.29, 242.69],
[262.69, 262.19, 261.69, ..., 239.39, 241.69, 245.19],
...,
[294.79, 295.29, 297.49, ..., 295.49, 295.39, 294.69],
[296.79, 297.89, 298.29, ..., 295.49, 295.49, 294.79],
[298.19, 299.19, 298.79, ..., 296.09, 295.79, 295.79]],
[[245.79, 244.79, 243.49, ..., 243.29, 243.99, 244.79],
[249.89, 249.29, 248.49, ..., 241.29, 242.49, 244.29],
[262.39, 261.79, 261.29, ..., 240.49, 243.09, 246.89],
...,
[293.69, 293.89, 295.39, ..., 295.09, 294.69, 294.29],
[296.29, 297.19, 297.59, ..., 295.29, 295.09, 294.39],
[297.79, 298.39, 298.49, ..., 295.69, 295.49, 295.19]],
[[245.09, 244.29, 243.29, ..., 241.69, 241.49, 241.79],
[249.89, 249.29, 248.39, ..., 239.59, 240.29, 241.69],
[262.99, 262.19, 261.39, ..., 239.89, 242.59, 246.29],
...,
[293.79, 293.69, 295.09, ..., 295.29, 295.09, 294.69],
[296.09, 296.89, 297.19, ..., 295.69, 295.69, 295.19],
[297.69, 298.09, 298.09, ..., 296.49, 296.19, 295.69]]],
shape=(2920, 25, 53))
ds.chunk(time=10).air.compute()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.
explicit meta¶
import cupy as cp
chunked = ds.chunk(time=10, from_array_kwargs={"meta": cp.array([])})
chunked.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray> Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16000366210938, 322.1000061035156]
chunked.compute()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.