From 754c0744596d90882f541533e9fccebe81458640 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:16:40 +0200 Subject: [PATCH 1/5] perf: scatter groupby-sum terms directly instead of unstacking The fast path of LinearExpression.groupby(...).sum() used ds.unstack(group_dim, fill_value=...) followed by a stack, which materializes 2-3 intermediate copies of the padded result (n_groups x max_group_size x nterm) and goes through pandas MultiIndex machinery sized by the number of elements. Instead, factorize the groups and scatter coeffs/vars directly into the preallocated padded result arrays; constants are group-summed with np.add.at. Peak memory drops to input + result (the minimum for the padded layout) and the grouping itself gets considerably faster. The result is unchanged: same dims, coords, term ordering and padding. The unstack-based implementation is kept as _sum_by_unstack and still used for chunked (dask-backed) data, which cannot be scattered into numpy arrays. NaN group labels now raise an informative ValueError instead of failing inside unstack. Co-Authored-By: Claude Opus 4.8 (1M context) --- linopy/expressions.py | 140 +++++++++++++++++++++++++++++---- test/test_linear_expression.py | 124 +++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 14 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index ea8588d2..0e42af19 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -340,20 +340,13 @@ def sum( # At this point, group is always a pandas Series assert isinstance(group, pd.Series) - group_dim = group.index.name - - arrays = [group, group.groupby(group).cumcount()] - idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) - new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) - # collapsing group_dim invalidates every coordinate aligned to it - names_to_drop = [ - name - for name, coord in self.data.coords.items() - if group_dim in coord.dims - ] - ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) - ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) - ds = LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + + if self._can_sum_by_scatter(group): + ds = self._sum_by_scatter(group) + else: + # chunked (e.g. dask-backed) data or exotic coordinates on the + # grouped dimension: use xarray's unstack machinery + ds = self._sum_by_unstack(group) if int_map is not None: index = ds.indexes[GROUP_DIM].map({v: k for k, v in int_map.items()}) @@ -374,6 +367,125 @@ def func(ds: Dataset) -> Dataset: return self.map(func, **kwargs, shortcut=True) + def _can_sum_by_scatter(self, group: pd.Series) -> bool: + """ + Whether :meth:`_sum_by_scatter` covers the structure of the data. + + The scatter kernel requires numpy-backed arrays (chunked data cannot be + scattered into preallocated numpy arrays) and no coordinates tied to + the grouped dimension besides its own index. Everything else falls + back to :meth:`_sum_by_unstack`. + """ + data = self.data + group_dim = group.index.name + + numpy_backed = all( + isinstance(data[k].data, np.ndarray) for k in ("coeffs", "vars", "const") + ) + if not numpy_backed: + return False + + index = data.indexes.get(group_dim) + index_names = {group_dim, *(index.names if index is not None else ())} + return all( + coord.dims == (group_dim,) and name in index_names + for name, coord in data.coords.items() + if group_dim in coord.dims + ) + + def _sum_by_scatter(self, group: pd.Series) -> Dataset: + """ + Sum groups by scattering all terms directly into the final padded arrays. + + Every group member keeps its block of ``nterm`` terms, so the resulting + term dimension has size ``max_group_size * nterm`` and smaller groups are + padded with fill values. In contrast to :meth:`_sum_by_unstack` only the + result arrays are allocated, without intermediate copies of that size. + + Only the term and constant values are computed with numpy; the result + structure (dimensions, coordinates and their order) is assembled by + xarray. :meth:`_can_sum_by_scatter` decides whether the data is simple + enough for this kernel. + """ + data = self.data + group_dim = group.index.name + fill_value = LinearExpression._fill_value + + codes, unique_groups = pd.factorize(group, sort=True) + if (codes == -1).any(): + raise ValueError( + "Cannot group by a pandas object containing NaN values. " + "Drop or fill the corresponding entries before grouping." + ) + + n_groups = len(unique_groups) + sizes = np.bincount(codes, minlength=n_groups) + max_size = int(sizes.max()) if n_groups else 0 + + # position of each element within its group (order of appearance) + positions = pd.Series(codes).groupby(codes).cumcount().to_numpy() + + def scatter( + da: DataArray, fill: Any + ) -> tuple[tuple[Hashable, ...], np.ndarray]: + """Scatter one term-array into its padded (group x term) layout.""" + rest_dims = [d for d in da.dims if d not in (group_dim, TERM_DIM)] + values = da.transpose(group_dim, *rest_dims, TERM_DIM).values + rest_shape = values.shape[1:-1] + nterm = values.shape[-1] + + out = np.full( + (n_groups, *rest_shape, nterm, max_size), fill, dtype=values.dtype + ) + locs = (codes, *(slice(None),) * (len(rest_shape) + 1), positions) + out[locs] = values + # collapsing (nterm, max_size) into one axis keeps all terms of one + # group member together, with padding at the end of each block + out = out.reshape((n_groups, *rest_shape, nterm * max_size)) + return (GROUP_DIM, *rest_dims, TERM_DIM), out + + coeffs_dims, coeffs = scatter(data.coeffs, fill_value["coeffs"]) + vars_dims, vars = scatter(data.vars, fill_value["vars"]) + + # constants are summed up within each group, skipping NaN values + const_dims = [d for d in data.const.dims if d != group_dim] + const_values = data.const.transpose(group_dim, *const_dims).values + const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype) + np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values)) + + # only the values above are computed with numpy, the result structure + # (dimensions, coordinates and their order) is assembled by xarray + # itself and thereby matches a result of unstacking the group dimension + structure = data.drop_vars(["coeffs", "vars", "const"]) + structure = structure.drop_dims(group_dim) + structure = structure.expand_dims({GROUP_DIM: unique_groups}) + + return structure.assign( + coeffs=(coeffs_dims, coeffs), + vars=(vars_dims, vars), + const=((GROUP_DIM, *const_dims), const), + ) + + def _sum_by_unstack(self, group: pd.Series) -> Dataset: + """ + Sum groups by unstacking the group dimension into a padded helper + dimension and summing over it. + + Equivalent to :meth:`_sum_by_scatter` but goes through xarray's + unstack/stack machinery, which also supports chunked (dask) data. + """ + group_dim = group.index.name + arrays = [group, group.groupby(group).cumcount()] + idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) + new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) + # collapsing group_dim invalidates every coordinate aligned to it + names_to_drop = [ + name for name, coord in self.data.coords.items() if group_dim in coord.dims + ] + ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) + ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) + return LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + def roll(self, **kwargs: Any) -> LinearExpression: """ Roll the groupby object. diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 5ffd7de1..19850999 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1908,6 +1908,130 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None: assert grouped.nterm == 10 +def test_linear_expression_groupby_skewed_unsorted_groups(v: Variable) -> None: + """ + The scatter-based fast path must match the xarray fallback for groups that + are unsorted, non-contiguous and of very different sizes. + """ + expr = 2 * v + 5 + # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension + labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 + groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") + + grouped = expr.groupby(groups).sum() + fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) + + assert list(grouped.data.letter) == ["a", "b", "c"] + # padded to the largest group times the number of terms of the input + assert grouped.nterm == 14 * expr.nterm + assert_linequal(grouped, fallback) + + # every group must carry exactly the variables of its members, the rest is fill + for letter in ["a", "b", "c"]: + members = np.where(np.array(labels) == letter)[0] + vars_of_group = grouped.data.vars.sel(letter=letter).values + assert set(vars_of_group[vars_of_group >= 0]) == set(v.labels.values[members]) + assert (vars_of_group >= 0).sum() == len(members) * expr.nterm + assert grouped.const.sel(letter=letter).item() == 5 * len(members) + + +def test_linear_expression_groupby_chunked(v: Variable) -> None: + """Chunked (dask-backed) expressions group via xarray's unstack machinery.""" + pytest.importorskip("dask") + expr = 2 * v + 5 + groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") + + chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) + grouped_chunked = chunked.groupby(groups).sum() + grouped = expr.groupby(groups).sum() + + assert grouped_chunked.nterm == grouped.nterm + assert_linequal( + LinearExpression(grouped_chunked.data.compute(), expr.model), grouped + ) + + +def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: + expr = 1 * v + groups = pd.Series([1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans") + with pytest.raises(ValueError, match="NaN"): + expr.groupby(groups).sum() + + +@pytest.mark.parametrize( + "case", + [ + "skewed_int_groups", + "multidim_with_const", + "nan_const", + "masked_vars", + "quadratic", + "single_group", + "identity_groups", + ], +) +def test_linear_expression_groupby_scatter_equals_unstack(case: str) -> None: + """ + Lock the two groupby-sum kernels together. + + The fast path of groupby(...).sum() scatters terms into numpy arrays + (_sum_by_scatter); the xarray unstack implementation (_sum_by_unstack) is + kept for chunked data and exotic coordinates. Both must stay + interchangeable — if an xarray/pandas update changes the unstack output or + an edge case diverges, this fails. + """ + m = Model() + rng = np.random.default_rng(0) + idx = pd.RangeIndex(60, name="elem") + skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") + groups = skewed + + if case == "skewed_int_groups": + x = m.add_variables(coords=[idx], name="x") + expr: LinearExpression | QuadraticExpression = 3 * x - 2 * x + 7 + elif case == "multidim_with_const": + other = pd.Index(list("abc"), name="other") + y = m.add_variables(coords=[other, idx], name="y") + const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) + expr = 2 * y + 1 * y + const + elif case == "nan_const": + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + np.where(np.arange(60) % 3, np.nan, 5.0) + elif case == "masked_vars": + mask = xr.DataArray(np.arange(60) % 4 != 0, coords=[idx]) + x = m.add_variables(coords=[idx], name="x", mask=mask) + expr = 1 * x + elif case == "quadratic": + x = m.add_variables(coords=[idx], name="x") + expr = x * x + 2 * x + elif case == "single_group": + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + groups = pd.Series(1, index=idx, name="g") + else: # identity_groups + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + groups = pd.Series(np.arange(60), index=idx, name="g") + + gb = expr.groupby(groups) + assert gb._can_sum_by_scatter(groups) + scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) + unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) + + # identical structure: dims, dim order, coordinates + assert scatter.data.coeffs.dims == unstack.data.coeffs.dims + assert scatter.data.const.dims == unstack.data.const.dims + assert list(scatter.data.coords) == list(unstack.data.coords) + for name in scatter.data.coords: + assert_equal(scatter.data[name], unstack.data[name]) + + # identical values: vars and coeffs bit-exact, including padding positions + np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) + np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) + # constants may differ by floating-point summation order + np.testing.assert_allclose(scatter.const.values, unstack.const.values, rtol=1e-12) + + def test_linear_expression_rolling(v: Variable) -> None: expr = 1 * v rolled = expr.rolling(dim_2=2).sum() From 7f3edea01a5dbb2c4c9a78808a44458a23cec952 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 30 Jun 2026 13:18:06 +0200 Subject: [PATCH 2/5] test: cover empty group dim in scatter groupby-sum Add a test for grouping over an empty group dimension, which the scatter fast path handles cleanly but the unstack fallback cannot. Trim comments that duplicated the helper docstrings. --- linopy/expressions.py | 10 +++------- test/test_linear_expression.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 0e42af19..405ae7fc 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -344,8 +344,6 @@ def sum( if self._can_sum_by_scatter(group): ds = self._sum_by_scatter(group) else: - # chunked (e.g. dask-backed) data or exotic coordinates on the - # grouped dimension: use xarray's unstack machinery ds = self._sum_by_unstack(group) if int_map is not None: @@ -404,8 +402,9 @@ def _sum_by_scatter(self, group: pd.Series) -> Dataset: Only the term and constant values are computed with numpy; the result structure (dimensions, coordinates and their order) is assembled by - xarray. :meth:`_can_sum_by_scatter` decides whether the data is simple - enough for this kernel. + xarray itself and thereby matches the result of unstacking the group + dimension. :meth:`_can_sum_by_scatter` decides whether the data is + simple enough for this kernel. """ data = self.data group_dim = group.index.name @@ -453,9 +452,6 @@ def scatter( const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype) np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values)) - # only the values above are computed with numpy, the result structure - # (dimensions, coordinates and their order) is assembled by xarray - # itself and thereby matches a result of unstacking the group dimension structure = data.drop_vars(["coeffs", "vars", "const"]) structure = structure.drop_dims(group_dim) structure = structure.expand_dims({GROUP_DIM: unique_groups}) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 19850999..b710f0b1 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1958,6 +1958,18 @@ def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: expr.groupby(groups).sum() +def test_linear_expression_groupby_empty_groups() -> None: + """An empty group dimension scatters into an empty, well-formed result.""" + m = Model() + idx = pd.RangeIndex(0, name="elem") + x = m.add_variables(coords=[idx], name="x") + groups = pd.Series([], index=idx, name="g", dtype=int) + + grouped = (1 * x).groupby(groups).sum() + assert grouped.nterm == 0 + assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} + + @pytest.mark.parametrize( "case", [ From 7598180d8829f87741058936523d7fd2ba9976e5 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 30 Jun 2026 13:19:01 +0200 Subject: [PATCH 3/5] docs: add release note for scatter groupby-sum --- doc/release_notes.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 7e849a42..fb5cd3e2 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -21,6 +21,10 @@ Upcoming Version * ``add_variables(binary=True, ...)`` now accepts ``lower``/``upper`` bounds, as long as they are 0 or 1. Previously binary bounds could only be set via the ``.lower``/``.upper`` setters after creation. (https://github.com/PyPSA/linopy/issues/776) +**Performance** + +* ``LinearExpression.groupby(...).sum()`` now scatters terms directly into the padded result arrays instead of unstacking through pandas ``MultiIndex`` machinery, cutting peak memory to input + result and speeding up the grouping. + **Deprecations** * Mutation via assignment to ``Variable.lower`` / ``Variable.upper`` / ``Constraint.coeffs`` / ``Constraint.vars`` / ``Constraint.lhs`` / ``Constraint.sign`` / ``Constraint.rhs`` is deprecated and emits a ``DeprecationWarning``. Use ``Variable.update(...)`` / ``Constraint.update(...)`` instead — the canonical mutation API with one validation path and one place that flips the persistent-solver dirty flag. Read access to these properties is unchanged. The setters will be removed in a future release. From 81d7c23a513c484804aa176760e4d6c2d84cbf7a Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 30 Jun 2026 21:07:33 +0200 Subject: [PATCH 4/5] perf(groupby): widen scatter fast path to all numpy-backed data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Relax the groupby-sum scatter gate to a pure numpy/dask check: auxiliary coordinates on the grouped dimension no longer force the slow unstack path. Summing over groups collapses that dimension, so both kernels drop every coordinate tied to it — the scatter result is identical, just cheaper. The unstack kernel now serves only chunked (dask) data, and a debug log records when that fallback is taken. Inline the now-trivial predicate into the dispatch and consolidate the kernel tests into a TestGroupbySumScatterKernel class: a one-line case table over a shared fixture, with added coverage for combined structures, auxiliary coords, and a MultiIndex grouped dimension. Co-Authored-By: Claude Opus 4.8 (1M context) --- linopy/expressions.py | 45 ++--- test/test_linear_expression.py | 303 +++++++++++++++++++-------------- 2 files changed, 191 insertions(+), 157 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 405ae7fc..a3b10ba8 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -341,9 +341,17 @@ def sum( # At this point, group is always a pandas Series assert isinstance(group, pd.Series) - if self._can_sum_by_scatter(group): + numpy_backed = all( + isinstance(self.data[k].data, np.ndarray) + for k in ("coeffs", "vars", "const") + ) + if numpy_backed: ds = self._sum_by_scatter(group) else: + logger.debug( + "groupby-sum: non-numpy-backed (e.g. dask) data, " + "falling back to the unstack kernel." + ) ds = self._sum_by_unstack(group) if int_map is not None: @@ -365,32 +373,6 @@ def func(ds: Dataset) -> Dataset: return self.map(func, **kwargs, shortcut=True) - def _can_sum_by_scatter(self, group: pd.Series) -> bool: - """ - Whether :meth:`_sum_by_scatter` covers the structure of the data. - - The scatter kernel requires numpy-backed arrays (chunked data cannot be - scattered into preallocated numpy arrays) and no coordinates tied to - the grouped dimension besides its own index. Everything else falls - back to :meth:`_sum_by_unstack`. - """ - data = self.data - group_dim = group.index.name - - numpy_backed = all( - isinstance(data[k].data, np.ndarray) for k in ("coeffs", "vars", "const") - ) - if not numpy_backed: - return False - - index = data.indexes.get(group_dim) - index_names = {group_dim, *(index.names if index is not None else ())} - return all( - coord.dims == (group_dim,) and name in index_names - for name, coord in data.coords.items() - if group_dim in coord.dims - ) - def _sum_by_scatter(self, group: pd.Series) -> Dataset: """ Sum groups by scattering all terms directly into the final padded arrays. @@ -403,8 +385,8 @@ def _sum_by_scatter(self, group: pd.Series) -> Dataset: Only the term and constant values are computed with numpy; the result structure (dimensions, coordinates and their order) is assembled by xarray itself and thereby matches the result of unstacking the group - dimension. :meth:`_can_sum_by_scatter` decides whether the data is - simple enough for this kernel. + dimension. The caller dispatches here only for numpy-backed data + (chunked data uses :meth:`_sum_by_unstack`). """ data = self.data group_dim = group.index.name @@ -467,8 +449,9 @@ def _sum_by_unstack(self, group: pd.Series) -> Dataset: Sum groups by unstacking the group dimension into a padded helper dimension and summing over it. - Equivalent to :meth:`_sum_by_scatter` but goes through xarray's - unstack/stack machinery, which also supports chunked (dask) data. + Equivalent to :meth:`_sum_by_scatter`, but goes through xarray's + unstack/stack machinery. It is the fallback for chunked (dask) data, + which cannot be scattered into preallocated numpy buffers. """ group_dim = group.index.name arrays = [group, group.groupby(group).cumcount()] diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index b710f0b1..20032351 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,7 +7,10 @@ from __future__ import annotations +import logging import warnings +from collections.abc import Callable +from types import SimpleNamespace from typing import Any import numpy as np @@ -1908,140 +1911,188 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None: assert grouped.nterm == 10 -def test_linear_expression_groupby_skewed_unsorted_groups(v: Variable) -> None: - """ - The scatter-based fast path must match the xarray fallback for groups that - are unsorted, non-contiguous and of very different sizes. - """ - expr = 2 * v + 5 - # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension - labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 - groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") - - grouped = expr.groupby(groups).sum() - fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) - - assert list(grouped.data.letter) == ["a", "b", "c"] - # padded to the largest group times the number of terms of the input - assert grouped.nterm == 14 * expr.nterm - assert_linequal(grouped, fallback) - - # every group must carry exactly the variables of its members, the rest is fill - for letter in ["a", "b", "c"]: - members = np.where(np.array(labels) == letter)[0] - vars_of_group = grouped.data.vars.sel(letter=letter).values - assert set(vars_of_group[vars_of_group >= 0]) == set(v.labels.values[members]) - assert (vars_of_group >= 0).sum() == len(members) * expr.nterm - assert grouped.const.sel(letter=letter).item() == 5 * len(members) - - -def test_linear_expression_groupby_chunked(v: Variable) -> None: - """Chunked (dask-backed) expressions group via xarray's unstack machinery.""" - pytest.importorskip("dask") - expr = 2 * v + 5 - groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") - - chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) - grouped_chunked = chunked.groupby(groups).sum() - grouped = expr.groupby(groups).sum() - - assert grouped_chunked.nterm == grouped.nterm - assert_linequal( - LinearExpression(grouped_chunked.data.compute(), expr.model), grouped +@pytest.fixture +def scatter_ctx() -> SimpleNamespace: + """Shared 60-element building blocks for the scatter-vs-unstack case table.""" + m = Model() + rng = np.random.default_rng(0) + idx = pd.RangeIndex(60, name="elem") + other = pd.Index(list("abc"), name="other") + p, q = pd.Index(list("pq"), name="p"), pd.Index([10, 20, 30], name="q") + a, b = pd.Index(range(12), name="a"), pd.Index(list("vwxyz"), name="b") + + const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) + y = m.add_variables(coords=[other, idx], name="y") + yab = m.add_variables(coords=[a, b], name="yab") + stacked = LinearExpression((2 * yab + 1 * yab).data.stack(elem=["a", "b"]), m) + skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") + + return SimpleNamespace( + m=m, + x=m.add_variables(coords=[idx], name="x"), + y=y, + y3=m.add_variables(coords=[p, q, idx], name="y3"), + mx=m.add_variables( + coords=[idx], name="mx", mask=xr.DataArray(np.arange(60) % 4 != 0, [idx]) + ), + my=m.add_variables( + coords=[other, idx], + name="my", + mask=xr.DataArray(rng.random((3, 60)) > 0.25, [other, idx]), + ), + const=const, + nan_const=const.where(rng.random((3, 60)) > 0.3), + nan_vec=np.where(np.arange(60) % 3, np.nan, 5.0), + y_aux=(2 * y + 1 * y).assign_coords( + carrier=("elem", rng.choice(list("PQ"), 60)), + tag=(("other", "elem"), rng.integers(0, 9, (3, 60))), + ), + stacked=stacked, + skewed=skewed, + one_group=pd.Series(1, index=idx, name="g"), + identity=pd.Series(np.arange(60), index=idx, name="g"), + mi_groups=skewed.set_axis(stacked.data.indexes["elem"]), ) -def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: - expr = 1 * v - groups = pd.Series([1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans") - with pytest.raises(ValueError, match="NaN"): - expr.groupby(groups).sum() +# Each case maps a structure to (expr, groups) from `scatter_ctx`. The skewed +# group puts ~half the elements in group 0 and spreads 1..7 over the rest. +SCATTER_EQUALS_UNSTACK_CASES = { + "skewed_int_groups": lambda c: (3 * c.x - 2 * c.x + 7, c.skewed), + "multidim_with_const": lambda c: (2 * c.y + 1 * c.y + c.const, c.skewed), + "nan_const": lambda c: (1 * c.x + c.nan_vec, c.skewed), + "masked_vars": lambda c: (1 * c.mx, c.skewed), + "quadratic": lambda c: (c.x * c.x + 2 * c.x, c.skewed), + "single_group": lambda c: (1 * c.x, c.one_group), + "identity_groups": lambda c: (1 * c.x, c.identity), + # combined structures exercising several features at once + "multidim_const_nan": lambda c: ( + 2 * c.y - 3 * c.y + 1 * c.y + c.nan_const, + c.skewed, + ), + "three_dims": lambda c: (4 * c.y3 + 1 * c.y3, c.skewed), + "quadratic_multidim_const": lambda c: (c.y * c.y + 2 * c.y + c.const, c.skewed), + "masked_multidim": lambda c: (5 * c.my - 2 * c.my, c.skewed), + # both collapse the grouped dim, dropping every coordinate tied to it + "aux_coords_on_group_dim": lambda c: (c.y_aux, c.skewed), + "multiindex_dim": lambda c: (c.stacked, c.mi_groups), +} + + +class TestGroupbySumScatterKernel: + """ + ``groupby(...).sum()`` takes a scatter fast path (``_sum_by_scatter``) for + numpy-backed expressions and falls back to the xarray unstack machinery + (``_sum_by_unstack``) for chunked data and exotic coordinates. These tests + pin the two kernels together and cover the structural edge cases. + """ + @staticmethod + def _assert_kernels_identical(gb: Any, groups: pd.Series, m: Model) -> None: + """Force both kernels and assert they produce the same expression.""" + scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) + unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) + + assert scatter.data.coeffs.dims == unstack.data.coeffs.dims + assert scatter.data.const.dims == unstack.data.const.dims + assert list(scatter.data.coords) == list(unstack.data.coords) + for name in scatter.data.coords: + assert_equal(scatter.data[name], unstack.data[name]) + + np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) + np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) + # constants may differ only by floating-point summation order + np.testing.assert_allclose( + scatter.const.values, unstack.const.values, rtol=1e-12 + ) -def test_linear_expression_groupby_empty_groups() -> None: - """An empty group dimension scatters into an empty, well-formed result.""" - m = Model() - idx = pd.RangeIndex(0, name="elem") - x = m.add_variables(coords=[idx], name="x") - groups = pd.Series([], index=idx, name="g", dtype=int) - - grouped = (1 * x).groupby(groups).sum() - assert grouped.nterm == 0 - assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} - - -@pytest.mark.parametrize( - "case", - [ - "skewed_int_groups", - "multidim_with_const", - "nan_const", - "masked_vars", - "quadratic", - "single_group", - "identity_groups", - ], -) -def test_linear_expression_groupby_scatter_equals_unstack(case: str) -> None: - """ - Lock the two groupby-sum kernels together. + def test_skewed_unsorted_groups(self, v: Variable) -> None: + """ + The scatter-based fast path must match the xarray fallback for groups + that are unsorted, non-contiguous and of very different sizes. + """ + expr = 2 * v + 5 + # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension + labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 + groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") + + grouped = expr.groupby(groups).sum() + fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) + + assert list(grouped.data.letter) == ["a", "b", "c"] + # padded to the largest group times the number of terms of the input + assert grouped.nterm == 14 * expr.nterm + assert_linequal(grouped, fallback) + + # every group carries exactly the variables of its members, rest is fill + for letter in ["a", "b", "c"]: + members = np.where(np.array(labels) == letter)[0] + vars_of_group = grouped.data.vars.sel(letter=letter).values + present = set(vars_of_group[vars_of_group >= 0]) + assert present == set(v.labels.values[members]) + assert (vars_of_group >= 0).sum() == len(members) * expr.nterm + assert grouped.const.sel(letter=letter).item() == 5 * len(members) + + def test_chunked_uses_unstack( + self, v: Variable, caplog: pytest.LogCaptureFixture + ) -> None: + """Chunked (dask-backed) expressions group via xarray's unstack path.""" + pytest.importorskip("dask") + expr = 2 * v + 5 + groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") + + chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) + with caplog.at_level(logging.DEBUG, logger="linopy.expressions"): + grouped_chunked = chunked.groupby(groups).sum() + assert "falling back to the unstack kernel" in caplog.text + + grouped = expr.groupby(groups).sum() + assert grouped_chunked.nterm == grouped.nterm + assert_linequal( + LinearExpression(grouped_chunked.data.compute(), expr.model), grouped + ) - The fast path of groupby(...).sum() scatters terms into numpy arrays - (_sum_by_scatter); the xarray unstack implementation (_sum_by_unstack) is - kept for chunked data and exotic coordinates. Both must stay - interchangeable — if an xarray/pandas update changes the unstack output or - an edge case diverges, this fails. - """ - m = Model() - rng = np.random.default_rng(0) - idx = pd.RangeIndex(60, name="elem") - skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") - groups = skewed + def test_nan_groups_raise(self, v: Variable) -> None: + expr = 1 * v + groups = pd.Series( + [1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans" + ) + with pytest.raises(ValueError, match="NaN"): + expr.groupby(groups).sum() - if case == "skewed_int_groups": - x = m.add_variables(coords=[idx], name="x") - expr: LinearExpression | QuadraticExpression = 3 * x - 2 * x + 7 - elif case == "multidim_with_const": - other = pd.Index(list("abc"), name="other") - y = m.add_variables(coords=[other, idx], name="y") - const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) - expr = 2 * y + 1 * y + const - elif case == "nan_const": - x = m.add_variables(coords=[idx], name="x") - expr = 1 * x + np.where(np.arange(60) % 3, np.nan, 5.0) - elif case == "masked_vars": - mask = xr.DataArray(np.arange(60) % 4 != 0, coords=[idx]) - x = m.add_variables(coords=[idx], name="x", mask=mask) - expr = 1 * x - elif case == "quadratic": - x = m.add_variables(coords=[idx], name="x") - expr = x * x + 2 * x - elif case == "single_group": - x = m.add_variables(coords=[idx], name="x") - expr = 1 * x - groups = pd.Series(1, index=idx, name="g") - else: # identity_groups + def test_empty_groups(self) -> None: + """An empty group dimension scatters into an empty, well-formed result.""" + m = Model() + idx = pd.RangeIndex(0, name="elem") x = m.add_variables(coords=[idx], name="x") - expr = 1 * x - groups = pd.Series(np.arange(60), index=idx, name="g") - - gb = expr.groupby(groups) - assert gb._can_sum_by_scatter(groups) - scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) - unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) - - # identical structure: dims, dim order, coordinates - assert scatter.data.coeffs.dims == unstack.data.coeffs.dims - assert scatter.data.const.dims == unstack.data.const.dims - assert list(scatter.data.coords) == list(unstack.data.coords) - for name in scatter.data.coords: - assert_equal(scatter.data[name], unstack.data[name]) - - # identical values: vars and coeffs bit-exact, including padding positions - np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) - np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) - # constants may differ by floating-point summation order - np.testing.assert_allclose(scatter.const.values, unstack.const.values, rtol=1e-12) + groups = pd.Series([], index=idx, name="g", dtype=int) + + grouped = (1 * x).groupby(groups).sum() + assert grouped.nterm == 0 + assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} + + @pytest.mark.parametrize( + "build", + SCATTER_EQUALS_UNSTACK_CASES.values(), + ids=SCATTER_EQUALS_UNSTACK_CASES.keys(), + ) + def test_scatter_equals_unstack( + self, + build: Callable[[SimpleNamespace], tuple[LinearExpression, pd.Series]], + scatter_ctx: SimpleNamespace, + ) -> None: + """ + Lock the two groupby-sum kernels together. + + The fast path scatters terms into numpy arrays (``_sum_by_scatter``); + the unstack implementation (``_sum_by_unstack``) is kept for chunked + data. Both must stay interchangeable — if an xarray/pandas update + changes the unstack output or an edge case diverges, this fails. See + ``SCATTER_EQUALS_UNSTACK_CASES`` for the structures covered. + """ + expr, groups = build(scatter_ctx) + gb = expr.groupby(groups) + self._assert_kernels_identical(gb, groups, scatter_ctx.m) def test_linear_expression_rolling(v: Variable) -> None: From ac8ec47bb9999c409c90006fbe14270ff85f1615 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 30 Jun 2026 22:35:53 +0200 Subject: [PATCH 5/5] perf(groupby): unify scatter kernel over numpy and dask via apply_ufunc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the previous numpy-scatter / dask-unstack split with a single kernel (`_grouped_sum`) wrapped in `xarray.apply_ufunc`. It scatters terms into the padded result arrays for numpy-backed data and runs the same scatter lazily on chunked (dask) data via `dask="parallelized"`, after gathering the grouped and term dimensions (the scatter's core dims) into single chunks. This removes the last `pd.MultiIndex`/`unstack` usage in groupby-sum, drops the numpy-vs-dask branch in `sum()`, and keeps peak memory at input + result on both backends. Multi-key / DataFrame grouping and its `MultiIndex` result are unaffected — that logic sits above the kernel. Tests verify the kernel from first principles (each group's terms and constant must match its members) across every case shape on both numpy and dask, plus explicit anchors pinning the exact padded layout — member order, fill position, term interleaving and the factor axis — for the linear, multidim and quadratic cases. Co-Authored-By: Claude Opus 4.8 (1M context) --- doc/release_notes.rst | 2 +- linopy/expressions.py | 143 +++++++++---------- test/test_linear_expression.py | 248 ++++++++++++++++++++++++++------- 3 files changed, 258 insertions(+), 135 deletions(-) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index fb5cd3e2..a0a3fa99 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -23,7 +23,7 @@ Upcoming Version **Performance** -* ``LinearExpression.groupby(...).sum()`` now scatters terms directly into the padded result arrays instead of unstacking through pandas ``MultiIndex`` machinery, cutting peak memory to input + result and speeding up the grouping. +* ``LinearExpression.groupby(...).sum()`` now scatters terms directly into the padded result arrays via ``xarray.apply_ufunc``, avoiding intermediate copies and speeding up the grouping. A single kernel covers both numpy and chunked (dask) data, the latter staying lazy. On representative models this lowers build and export peak memory by up to ~3x. **Deprecations** diff --git a/linopy/expressions.py b/linopy/expressions.py index a3b10ba8..3e680c3a 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -74,7 +74,6 @@ FACTOR_DIM, GREATER_EQUAL, GROUP_DIM, - GROUPED_TERM_DIM, HELPER_DIMS, LESS_EQUAL, STACKED_TERM_DIM, @@ -341,18 +340,7 @@ def sum( # At this point, group is always a pandas Series assert isinstance(group, pd.Series) - numpy_backed = all( - isinstance(self.data[k].data, np.ndarray) - for k in ("coeffs", "vars", "const") - ) - if numpy_backed: - ds = self._sum_by_scatter(group) - else: - logger.debug( - "groupby-sum: non-numpy-backed (e.g. dask) data, " - "falling back to the unstack kernel." - ) - ds = self._sum_by_unstack(group) + ds = self._grouped_sum(group) if int_map is not None: index = ds.indexes[GROUP_DIM].map({v: k for k, v in int_map.items()}) @@ -373,20 +361,18 @@ def func(ds: Dataset) -> Dataset: return self.map(func, **kwargs, shortcut=True) - def _sum_by_scatter(self, group: pd.Series) -> Dataset: + def _grouped_sum(self, group: pd.Series) -> Dataset: """ Sum groups by scattering all terms directly into the final padded arrays. Every group member keeps its block of ``nterm`` terms, so the resulting term dimension has size ``max_group_size * nterm`` and smaller groups are - padded with fill values. In contrast to :meth:`_sum_by_unstack` only the - result arrays are allocated, without intermediate copies of that size. + padded with fill values. Only the result arrays are allocated, keeping + peak memory at input + result. - Only the term and constant values are computed with numpy; the result - structure (dimensions, coordinates and their order) is assembled by - xarray itself and thereby matches the result of unstacking the group - dimension. The caller dispatches here only for numpy-backed data - (chunked data uses :meth:`_sum_by_unstack`). + The scatter runs inside :func:`xarray.apply_ufunc`, so it covers numpy + and chunked (dask) data alike: for dask the grouped dimension is gathered + into a single chunk and the scatter is applied lazily. """ data = self.data group_dim = group.index.name @@ -400,70 +386,65 @@ def _sum_by_scatter(self, group: pd.Series) -> Dataset: ) n_groups = len(unique_groups) - sizes = np.bincount(codes, minlength=n_groups) - max_size = int(sizes.max()) if n_groups else 0 - + max_size = int(np.bincount(codes, minlength=n_groups).max()) if n_groups else 0 # position of each element within its group (order of appearance) positions = pd.Series(codes).groupby(codes).cumcount().to_numpy() - - def scatter( - da: DataArray, fill: Any - ) -> tuple[tuple[Hashable, ...], np.ndarray]: - """Scatter one term-array into its padded (group x term) layout.""" - rest_dims = [d for d in da.dims if d not in (group_dim, TERM_DIM)] - values = da.transpose(group_dim, *rest_dims, TERM_DIM).values - rest_shape = values.shape[1:-1] - nterm = values.shape[-1] - - out = np.full( - (n_groups, *rest_shape, nterm, max_size), fill, dtype=values.dtype + nterm = data.sizes[TERM_DIM] + + def scatter_terms(values: np.ndarray, fill: Any) -> np.ndarray: + # (..., n_elem, nterm) -> (..., n_groups, nterm * max_size); each + # member's nterm block is kept together, padding at the block's end + rest = values.shape[:-2] + out = np.full((*rest, n_groups, nterm, max_size), fill, dtype=values.dtype) + out[..., codes, :, positions] = np.moveaxis(values, -2, 0) + return out.reshape((*rest, n_groups, nterm * max_size)) + + def group_sum(values: np.ndarray) -> np.ndarray: + # (..., n_elem) -> (..., n_groups), summing within groups, skipping NaN + moved = np.moveaxis(values, -1, 0) + out = np.zeros((n_groups, *moved.shape[1:]), dtype=values.dtype) + np.add.at(out, codes, np.where(np.isnan(moved), 0, moved)) + return np.moveaxis(out, 0, -1) + + def single_chunk(da: DataArray) -> DataArray: + # the scatter's core dims must each sit in one chunk + if da.chunks is None: + return da + return da.chunk({d: -1 for d in (group_dim, TERM_DIM) if d in da.dims}) + + def scatter(da: DataArray, fill: Any) -> DataArray: + return xr.apply_ufunc( + scatter_terms, + single_chunk(da), + kwargs={"fill": fill}, + input_core_dims=[[group_dim, TERM_DIM]], + output_core_dims=[[GROUP_DIM, TERM_DIM]], + exclude_dims={group_dim, TERM_DIM}, + dask="parallelized", + dask_gufunc_kwargs={ + "output_sizes": {GROUP_DIM: n_groups, TERM_DIM: nterm * max_size} + }, + output_dtypes=[da.dtype], ) - locs = (codes, *(slice(None),) * (len(rest_shape) + 1), positions) - out[locs] = values - # collapsing (nterm, max_size) into one axis keeps all terms of one - # group member together, with padding at the end of each block - out = out.reshape((n_groups, *rest_shape, nterm * max_size)) - return (GROUP_DIM, *rest_dims, TERM_DIM), out - - coeffs_dims, coeffs = scatter(data.coeffs, fill_value["coeffs"]) - vars_dims, vars = scatter(data.vars, fill_value["vars"]) - - # constants are summed up within each group, skipping NaN values - const_dims = [d for d in data.const.dims if d != group_dim] - const_values = data.const.transpose(group_dim, *const_dims).values - const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype) - np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values)) - - structure = data.drop_vars(["coeffs", "vars", "const"]) - structure = structure.drop_dims(group_dim) - structure = structure.expand_dims({GROUP_DIM: unique_groups}) - - return structure.assign( - coeffs=(coeffs_dims, coeffs), - vars=(vars_dims, vars), - const=((GROUP_DIM, *const_dims), const), - ) - def _sum_by_unstack(self, group: pd.Series) -> Dataset: - """ - Sum groups by unstacking the group dimension into a padded helper - dimension and summing over it. - - Equivalent to :meth:`_sum_by_scatter`, but goes through xarray's - unstack/stack machinery. It is the fallback for chunked (dask) data, - which cannot be scattered into preallocated numpy buffers. - """ - group_dim = group.index.name - arrays = [group, group.groupby(group).cumcount()] - idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) - new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) - # collapsing group_dim invalidates every coordinate aligned to it - names_to_drop = [ - name for name, coord in self.data.coords.items() if group_dim in coord.dims - ] - ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) - ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) - return LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + const = xr.apply_ufunc( + group_sum, + single_chunk(data.const), + input_core_dims=[[group_dim]], + output_core_dims=[[GROUP_DIM]], + exclude_dims={group_dim}, + dask="parallelized", + dask_gufunc_kwargs={"output_sizes": {GROUP_DIM: n_groups}}, + output_dtypes=[data.const.dtype], + ) + ds = Dataset( + { + "coeffs": scatter(data.coeffs, fill_value["coeffs"]), + "vars": scatter(data.vars, fill_value["vars"]), + "const": const, + } + ) + return ds.assign_coords({GROUP_DIM: unique_groups}) def roll(self, **kwargs: Any) -> LinearExpression: """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 20032351..ab65f60d 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,8 +7,8 @@ from __future__ import annotations -import logging import warnings +from collections import Counter from collections.abc import Callable from types import SimpleNamespace from typing import Any @@ -29,7 +29,7 @@ Variable, merge, ) -from linopy.constants import HELPER_DIMS, TERM_DIM +from linopy.constants import FACTOR_DIM, HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression from linopy.testing import assert_linequal, assert_quadequal from linopy.variables import ScalarVariable @@ -1912,8 +1912,8 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None: @pytest.fixture -def scatter_ctx() -> SimpleNamespace: - """Shared 60-element building blocks for the scatter-vs-unstack case table.""" +def groupby_ctx() -> SimpleNamespace: + """Shared 60-element building blocks for the groupby-sum case table.""" m = Model() rng = np.random.default_rng(0) idx = pd.RangeIndex(60, name="elem") @@ -1955,9 +1955,9 @@ def scatter_ctx() -> SimpleNamespace: ) -# Each case maps a structure to (expr, groups) from `scatter_ctx`. The skewed +# Each case maps a structure to (expr, groups) from `groupby_ctx`. The skewed # group puts ~half the elements in group 0 and spreads 1..7 over the rest. -SCATTER_EQUALS_UNSTACK_CASES = { +GROUPBY_SUM_CASES = { "skewed_int_groups": lambda c: (3 * c.x - 2 * c.x + 7, c.skewed), "multidim_with_const": lambda c: (2 * c.y + 1 * c.y + c.const, c.skewed), "nan_const": lambda c: (1 * c.x + c.nan_vec, c.skewed), @@ -1979,37 +1979,84 @@ def scatter_ctx() -> SimpleNamespace: } -class TestGroupbySumScatterKernel: +def _term_multisets( + ds: xr.Dataset, row_dim: str, labels: list, loop_dims: list +) -> dict: """ - ``groupby(...).sum()`` takes a scatter fast path (``_sum_by_scatter``) for - numpy-backed expressions and falls back to the xarray unstack machinery - (``_sum_by_unstack``) for chunked data and exotic coordinates. These tests - pin the two kernels together and cover the structural edge cases. + Map ``(group label, slice over loop dims)`` to a multiset of live terms. + + A term is a coefficient paired with its sorted variable labels (one label for + a linear term, two for a quadratic factor pair); fill terms (all labels -1) + are skipped. Rows sharing a group label are merged, so the same helper + summarises the per-element input and the per-group result. + """ + coeffs = ds.coeffs.transpose(row_dim, *loop_dims, TERM_DIM).values + var_dims = [row_dim, *loop_dims, TERM_DIM] + if FACTOR_DIM in ds.vars.dims: + var_dims.append(FACTOR_DIM) + vars_ = ds.vars.transpose(*var_dims).values + if FACTOR_DIM not in ds.vars.dims: + vars_ = vars_[..., None] + + terms: dict = {} + for r in range(coeffs.shape[0]): + for loop in np.ndindex(*coeffs.shape[1:-1]): + bucket = terms.setdefault((labels[r], loop), Counter()) + cs, vs = coeffs[(r, *loop)], vars_[(r, *loop)] + for t in range(cs.shape[0]): + factors = tuple(sorted(int(f) for f in vs[t])) + if all(f == -1 for f in factors): + continue + bucket[(round(float(cs[t]), 9), factors)] += 1 + return terms + + +def _assert_grouped_sum_correct( + expr: LinearExpression, groups: pd.Series, *, chunked: bool = False +) -> None: """ + Verify ``groupby(...).sum()`` from first principles, without a reference + implementation. - @staticmethod - def _assert_kernels_identical(gb: Any, groups: pd.Series, m: Model) -> None: - """Force both kernels and assert they produce the same expression.""" - scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) - unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) - - assert scatter.data.coeffs.dims == unstack.data.coeffs.dims - assert scatter.data.const.dims == unstack.data.const.dims - assert list(scatter.data.coords) == list(unstack.data.coords) - for name in scatter.data.coords: - assert_equal(scatter.data[name], unstack.data[name]) - - np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) - np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) - # constants may differ only by floating-point summation order + For every group and every slice over the non-grouped dimensions, the result's + live terms must be exactly the multiset of its members' terms, and its + constant the members' NaN-skipping sum. With ``chunked=True`` the same check + runs against dask-backed input, exercising the lazy kernel. + """ + group_dim = groups.index.name + gname = groups.name + if chunked: + expr = LinearExpression(expr.data.chunk({group_dim: 17}), expr.model) + in_ds = expr.data + out_ds = expr.groupby(groups).sum().data + loop_dims = [d for d in in_ds.coeffs.dims if d not in (group_dim, TERM_DIM)] + + expected = _term_multisets(in_ds, group_dim, list(groups.values), loop_dims) + actual = _term_multisets(out_ds, gname, list(out_ds[gname].values), loop_dims) + assert actual == expected + + const_loop = [d for d in in_ds.const.dims if d != group_dim] + in_const = np.nan_to_num(in_ds.const.transpose(group_dim, *const_loop).values) + out_const = out_ds.const.transpose(gname, *const_loop).values + member_of = np.asarray(groups.values) + for i, g in enumerate(out_ds[gname].values): np.testing.assert_allclose( - scatter.const.values, unstack.const.values, rtol=1e-12 + in_const[member_of == g].sum(0), out_const[i], rtol=1e-9, atol=1e-12 ) + +class TestGroupbySumKernel: + """ + ``groupby(...).sum()`` builds the padded result with a single + :func:`xarray.apply_ufunc` kernel (``_grouped_sum``) over numpy and chunked + (dask) data. These tests verify that kernel from first principles and cover + the structural edge cases. + """ + def test_skewed_unsorted_groups(self, v: Variable) -> None: """ - The scatter-based fast path must match the xarray fallback for groups - that are unsorted, non-contiguous and of very different sizes. + The kernel must match the xarray fallback for groups that are unsorted, + non-contiguous and of very different sizes. """ expr = 2 * v + 5 # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension @@ -2033,18 +2080,21 @@ def test_skewed_unsorted_groups(self, v: Variable) -> None: assert (vars_of_group >= 0).sum() == len(members) * expr.nterm assert grouped.const.sel(letter=letter).item() == 5 * len(members) - def test_chunked_uses_unstack( - self, v: Variable, caplog: pytest.LogCaptureFixture - ) -> None: - """Chunked (dask-backed) expressions group via xarray's unstack path.""" + @pytest.mark.parametrize("chunks", [{"dim_2": 5}, {"dim_2": 5, "_term": 1}]) + def test_chunked_runs_lazily(self, v: Variable, chunks: dict) -> None: + """ + The kernel handles chunked (dask) data, staying lazy until computed. It + gathers both the grouped and term dimensions into single chunks, so a + split ``_term`` (a core dim of the scatter) is handled too. + """ pytest.importorskip("dask") - expr = 2 * v + 5 + expr = 2 * v + 3 * v + 5 # nterm 2, so `_term` can be split groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") - chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) - with caplog.at_level(logging.DEBUG, logger="linopy.expressions"): - grouped_chunked = chunked.groupby(groups).sum() - assert "falling back to the unstack kernel" in caplog.text + chunked = LinearExpression(expr.data.chunk(chunks), expr.model) + grouped_chunked = chunked.groupby(groups).sum() + # the result stays a lazy dask graph until explicitly computed + assert grouped_chunked.data.vars.chunks is not None grouped = expr.groupby(groups).sum() assert grouped_chunked.nterm == grouped.nterm @@ -2060,8 +2110,101 @@ def test_nan_groups_raise(self, v: Variable) -> None: with pytest.raises(ValueError, match="NaN"): expr.groupby(groups).sum() + def test_exact_padded_layout(self) -> None: + """ + Pin the concrete padded layout the property checks abstract away: member + order within a group, fill at the end of a short group's block, the + ``(g, _term)`` dim order, and the per-group constant sum. + """ + m = Model() + idx = pd.RangeIndex(4, name="elem") + x = m.add_variables(coords=[idx], name="x") + groups = pd.Series([0, 0, 1, 0], index=idx, name="g") + + grouped = (2 * x + 5).groupby(groups).sum() + lab = x.labels.values + + assert grouped.data.vars.dims == ("g", "_term") + assert list(grouped.data.g.values) == [0, 1] + assert grouped.nterm == 3 # max group size (3) * input nterm (1) + # group 0 holds members 0, 1, 3 in order; group 1 holds member 2 then fill + np.testing.assert_array_equal( + grouped.data.vars.values, [[lab[0], lab[1], lab[3]], [lab[2], -1, -1]] + ) + np.testing.assert_array_equal( + grouped.data.coeffs.values, [[2.0, 2.0, 2.0], [2.0, np.nan, np.nan]] + ) + np.testing.assert_array_equal(grouped.const.values, [15.0, 5.0]) + + def test_exact_padded_layout_multidim(self) -> None: + """ + Anchor the layout with a non-grouped dim and two input terms: the + ``(nterm, max_size)`` interleaving and per-slice padding must hold. + """ + m = Model() + other = pd.Index(list("ab"), name="other") + idx = pd.RangeIndex(4, name="elem") + y = m.add_variables(coords=[other, idx], name="y") + groups = pd.Series([0, 0, 1, 0], index=idx, name="g") + + grouped = (2 * y + 3 * y + 1).groupby(groups).sum() + lab = y.labels.values # (other, elem) + + assert set(grouped.data.vars.dims) == {"g", "other", "_term"} + # the non-grouped coord survives; the group coord holds the sorted keys + assert list(grouped.data.other.values) == ["a", "b"] + assert list(grouped.data.g.values) == [0, 1] + assert grouped.nterm == 6 # max group size (3) * input nterm (2) + # group 0 holds members 0, 1, 3 in order; group 1 holds member 2 then fill + m0 = [0, 1, 3] + exp_vars = [ + [list(lab[o, m0]) * 2 for o in (0, 1)], + [[lab[o, 2], -1, -1] * 2 for o in (0, 1)], + ] + np.testing.assert_array_equal( + grouped.data.vars.transpose("g", "other", "_term").values, exp_vars + ) + term, pad = [2.0, 2, 2, 3, 3, 3], [2.0, np.nan, np.nan, 3, np.nan, np.nan] + np.testing.assert_array_equal( + grouped.data.coeffs.transpose("g", "other", "_term").values, + [[term, term], [pad, pad]], + ) + np.testing.assert_array_equal( + grouped.data.const.transpose("g", "other").values, [[3.0, 3], [1, 1]] + ) + + def test_exact_padded_layout_quadratic(self) -> None: + """ + Anchor the layout for a quadratic: the ``_factor`` axis must scatter + through to the padded result alongside the term positions. + """ + m = Model() + idx = pd.RangeIndex(4, name="elem") + x = m.add_variables(coords=[idx], name="x") + y = m.add_variables(coords=[idx], name="y") + groups = pd.Series([0, 0, 1, 0], index=idx, name="g") + + grouped = (x * y).groupby(groups).sum() + lx, ly = x.labels.values, y.labels.values + + assert set(grouped.data.vars.dims) == {"g", "_factor", "_term"} + assert grouped.nterm == 3 + # factor 0 carries x, factor 1 carries y; members 0, 1, 3 then fill + m0 = [0, 1, 3] + exp_vars = [ + [list(lx[m0]), list(ly[m0])], + [[lx[2], -1, -1], [ly[2], -1, -1]], + ] + np.testing.assert_array_equal( + grouped.data.vars.transpose("g", "_factor", "_term").values, exp_vars + ) + np.testing.assert_array_equal( + grouped.data.coeffs.transpose("g", "_term").values, + [[1.0, 1, 1], [1, np.nan, np.nan]], + ) + def test_empty_groups(self) -> None: - """An empty group dimension scatters into an empty, well-formed result.""" + """An empty group dimension produces an empty, well-formed result.""" m = Model() idx = pd.RangeIndex(0, name="elem") x = m.add_variables(coords=[idx], name="x") @@ -2071,28 +2214,27 @@ def test_empty_groups(self) -> None: assert grouped.nterm == 0 assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} + @pytest.mark.parametrize("backend", ["numpy", "dask"]) @pytest.mark.parametrize( "build", - SCATTER_EQUALS_UNSTACK_CASES.values(), - ids=SCATTER_EQUALS_UNSTACK_CASES.keys(), + GROUPBY_SUM_CASES.values(), + ids=GROUPBY_SUM_CASES.keys(), ) - def test_scatter_equals_unstack( + def test_grouped_sum_correct( self, build: Callable[[SimpleNamespace], tuple[LinearExpression, pd.Series]], - scatter_ctx: SimpleNamespace, + groupby_ctx: SimpleNamespace, + backend: str, ) -> None: """ - Lock the two groupby-sum kernels together. - - The fast path scatters terms into numpy arrays (``_sum_by_scatter``); - the unstack implementation (``_sum_by_unstack``) is kept for chunked - data. Both must stay interchangeable — if an xarray/pandas update - changes the unstack output or an edge case diverges, this fails. See - ``SCATTER_EQUALS_UNSTACK_CASES`` for the structures covered. + Each group's terms and constant must match its members, from first + principles, on both numpy and dask backends. See ``GROUPBY_SUM_CASES`` + for the structures covered. """ - expr, groups = build(scatter_ctx) - gb = expr.groupby(groups) - self._assert_kernels_identical(gb, groups, scatter_ctx.m) + if backend == "dask": + pytest.importorskip("dask") + expr, groups = build(groupby_ctx) + _assert_grouped_sum_correct(expr, groups, chunked=backend == "dask") def test_linear_expression_rolling(v: Variable) -> None: