Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import Tree from dask-awkward if not in dask #164

Merged
merged 12 commits into from
Feb 8, 2025
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: |
python3 -m pip install pip wheel
python3 -m pip install -q --no-cache-dir -e .[complete]
python3 -m pip install git+https://github.com/dask-contrib/dask-awkward
python3 -m pip list
- name: test
env: {"DASK_DATAFRAME__QUERY_PLANNING": "False"}
Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.1
rev: v0.9.3
hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-format
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
dependencies = [
"boost-histogram>=1.3.2",
"dask>=2021.03.0",
"dask-awkward >=2025",
]
dynamic = ["version"]

Expand All @@ -39,7 +40,6 @@ complete = [
docs = [
"dask-sphinx-theme >=3.0.2",
"dask[array,dataframe]",
"dask-awkward >=2023.10.0",
# FIXME: `sphinxcontrib-*` pins are a workaround until we have sphinx>=5.
# See https://github.com/dask/dask-sphinx-theme/issues/68.
"sphinx >=4.0.0",
Expand All @@ -51,7 +51,6 @@ docs = [
]
test = [
"dask[array,dataframe]",
"dask-awkward >=2023.10.0",
"hist",
"pytest",
]
Expand Down
4 changes: 2 additions & 2 deletions src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,14 +920,14 @@ def _weight_sample_check(

def _is_dask_dataframe(obj):
return (
obj.__class__.__module__ == "dask.dataframe.core"
type(obj).__module__.startswith("dask.dataframe")
and obj.__class__.__name__ == "DataFrame"
)


def _is_dask_series(obj):
return (
obj.__class__.__module__ == "dask.dataframe.core"
type(obj).__module__.startswith("dask.dataframe")
and obj.__class__.__name__ == "Series"
)

Expand Down
16 changes: 15 additions & 1 deletion src/dask_histogram/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
from dask.layers import DataFrameTreeReduction
try:
from dask.layers import DataFrameTreeReduction
except ImportError:
try:
from dask_awkward.layers import (
AwkwardTreeReductionLayer as DataFrameTreeReduction,
)
except ImportError:
DataFrameTreeReduction = None

if DataFrameTreeReduction is None:
raise ImportError(
"DataFrameReduction is unimportable - either downgrade dask to <2025"
"or install dask-awkward >=2025."
)


class MockableDataFrameTreeReduction(DataFrameTreeReduction):
Expand Down
2 changes: 1 addition & 1 deletion src/dask_histogram/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Sequence, Tuple, Union

from dask.array.core import Array
from dask.dataframe.core import DataFrame, Series
from dask.dataframe import DataFrame, Series
from numpy.typing import ArrayLike

BinType = Union[int, ArrayLike]
Expand Down
14 changes: 14 additions & 0 deletions tests/test_boost.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import boost_histogram as bh
import boost_histogram.numpy as bhnp
import dask
import dask.array as da
import numpy as np
import pytest
from packaging.version import parse as parse_version

import dask_histogram.boost as dhb
import dask_histogram.core as dhc
Expand Down Expand Up @@ -247,6 +249,10 @@ def test_histogramdd_multicolumn_input():
np.testing.assert_array_almost_equal(h1.view(), h2.view())


@pytest.mark.xfail(
parse_version(dask.__version__) >= parse_version("2025"),
reason="to_dataframe is broken with dask 2025.1.0",
)
def test_histogramdd_series():
pytest.importorskip("pandas")

Expand Down Expand Up @@ -276,6 +282,10 @@ def test_histogramdd_series():
np.testing.assert_array_almost_equal(h1.view()["variance"], h2.view()["variance"])


@pytest.mark.xfail(
parse_version(dask.__version__) >= parse_version("2025"),
reason="to_dataframe is broken with dask 2025.1.0",
)
def test_histogramdd_arrays_and_series():
pytest.importorskip("pandas")

Expand Down Expand Up @@ -305,6 +315,10 @@ def test_histogramdd_arrays_and_series():
np.testing.assert_array_almost_equal(h1.view()["variance"], h2.view()["variance"])


@pytest.mark.xfail(
parse_version(dask.__version__) >= parse_version("2025"),
reason="to_dataframe is broken with dask 2025.1.0",
)
def test_histogramdd_dataframe():
pytest.importorskip("pandas")
x = da.random.standard_normal(size=(1000, 3), chunks=(200, 3))
Expand Down
6 changes: 6 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import boost_histogram as bh
import dask
import dask.array as da
import dask.array.utils as dau
import numpy as np
import pytest
from dask.delayed import delayed
from packaging.version import parse as parse_version

import dask_histogram.core as dhc

Expand Down Expand Up @@ -124,6 +126,10 @@ def test_nd_array(weights):
np.testing.assert_allclose(h.counts(flow=True), dh.compute().counts(flow=True))


@pytest.mark.xfail(
parse_version(dask.__version__) >= parse_version("2025"),
reason="dask dataframe changed substantially in 2025.1.0",
)
@pytest.mark.parametrize("weights", [True, None])
def test_df_input(weights):
pytest.importorskip("pandas")
Expand Down