diff --git a/daft_lance/lance_merge_column.py b/daft_lance/lance_merge_column.py index c0e0cb9..ac624b7 100644 --- a/daft_lance/lance_merge_column.py +++ b/daft_lance/lance_merge_column.py @@ -342,7 +342,11 @@ def _can_use_fast_path( return False if "fragment_id" not in df.column_names: return False - df_row_count = len(df.collect()) + # Use count_rows() rather than collect() to avoid caching one-shot Python + # objects (e.g. BlobFile) into df._result_cache. collect() caches its result + # so a subsequent groupby().map_groups() would receive the same exhausted + # Python objects instead of fresh ones. + df_row_count = df.count_rows() ds_row_count = lance_ds.count_rows() return df_row_count == ds_row_count diff --git a/tests/io/lancedb/test_fast_path_merge.py b/tests/io/lancedb/test_fast_path_merge.py index cde4235..99d768b 100644 --- a/tests/io/lancedb/test_fast_path_merge.py +++ b/tests/io/lancedb/test_fast_path_merge.py @@ -587,3 +587,31 @@ def test_scan_fragments_individually(self, ds_path): combined = sorted(zip(all_ids, all_doubled)) assert combined == [(1, 2), (2, 4), (3, 6), (4, 8)] + + +# --------------------------------------------------------------------------- +# 10. Regressions +# --------------------------------------------------------------------------- + + +class TestRegressions: + def test_fast_path_check_does_not_set_result_cache(self, ds_path): + """Bug: _can_use_fast_path called df.collect(), which sets df._result_cache. + + Daft caches collect() results in _result_cache. One-shot Python objects + in that cache (e.g. BlobFile from take_blobs()) are exhausted; when the + fast-path merge re-executes the pipeline via groupby().map_groups(), the + same stale objects are returned and downstream UDFs produce null/wrong output. + + The fix uses df.count_rows() which does NOT populate _result_cache. + """ + ds = create_dataset(ds_path, [{"id": [1, 2, 3]}]) + df = read_with_metadata(ds_path).with_column("x", daft.lit(1)) + + assert getattr(df, "_result_cache", None) is None, "cache should start empty" + result = _can_use_fast_path(df, ds, "_rowaddr") + assert result is True + # count_rows() must not populate _result_cache; collect() would have done so + assert getattr(df, "_result_cache", None) is None, ( + "_result_cache was set — fast-path check used collect() instead of count_rows()" + )