Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion daft_lance/lance_merge_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,11 @@ def _can_use_fast_path(
return False
if "fragment_id" not in df.column_names:
return False
df_row_count = len(df.collect())
# Use count_rows() rather than collect() to avoid caching one-shot Python
# objects (e.g. BlobFile) into df._result_cache. collect() caches its result
# so a subsequent groupby().map_groups() would receive the same exhausted
# Python objects instead of fresh ones.
df_row_count = df.count_rows()
ds_row_count = lance_ds.count_rows()
return df_row_count == ds_row_count

Expand Down
28 changes: 28 additions & 0 deletions tests/io/lancedb/test_fast_path_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,31 @@ def test_scan_fragments_individually(self, ds_path):

combined = sorted(zip(all_ids, all_doubled))
assert combined == [(1, 2), (2, 4), (3, 6), (4, 8)]


# ---------------------------------------------------------------------------
# 10. Regressions
# ---------------------------------------------------------------------------


class TestRegressions:
def test_fast_path_check_does_not_set_result_cache(self, ds_path):
"""Bug: _can_use_fast_path called df.collect(), which sets df._result_cache.

Daft caches collect() results in _result_cache. One-shot Python objects
in that cache (e.g. BlobFile from take_blobs()) are exhausted; when the
fast-path merge re-executes the pipeline via groupby().map_groups(), the
same stale objects are returned and downstream UDFs produce null/wrong output.

The fix uses df.count_rows() which does NOT populate _result_cache.
"""
ds = create_dataset(ds_path, [{"id": [1, 2, 3]}])
df = read_with_metadata(ds_path).with_column("x", daft.lit(1))

assert getattr(df, "_result_cache", None) is None, "cache should start empty"
result = _can_use_fast_path(df, ds, "_rowaddr")
assert result is True
# count_rows() must not populate _result_cache; collect() would have done so
assert getattr(df, "_result_cache", None) is None, (
"_result_cache was set — fast-path check used collect() instead of count_rows()"
)