diff --git a/api/src/main/java/org/apache/iceberg/expressions/PathUtil.java b/api/src/main/java/org/apache/iceberg/expressions/PathUtil.java index 2e943bfdbfbd..bb972c1805d0 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/PathUtil.java +++ b/api/src/main/java/org/apache/iceberg/expressions/PathUtil.java @@ -27,13 +27,27 @@ import java.util.stream.Collectors; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.base.Splitter; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Streams; public class PathUtil { private PathUtil() {} + /** + * One step in a variant JSONPath: an object member name or a zero-based array index (RFC 9535 + * {@code [n]} selector). + * + *

This is a copy of the canonical implementation from PR #15384. Once that PR merges, this + * class will be the single definition and this note can be removed. + */ + public sealed interface PathSegment permits PathSegment.Name, PathSegment.Index { + record Name(String name) implements PathSegment {} + + record Index(int index) implements PathSegment {} + } + private static final String RFC9535_NAME_FIRST = "[A-Za-z_\\x{0080}-\\x{D7FF}\\x{E000}-\\x{10FFFF}]"; private static final String RFC9535_NAME_CHARS = @@ -41,46 +55,263 @@ private PathUtil() {} private static final Predicate RFC9535_MEMBER_NAME_SHORTHAND = Pattern.compile(RFC9535_NAME_FIRST + RFC9535_NAME_CHARS).asMatchPredicate(); + /** Letters that follow {@code \} for control-character escapes in RFC 9535 quoted segments. */ + private static final String RFC9535_SIMPLE_ESCAPE_LETTERS = "btnfr"; + + private static final String RFC9535_SIMPLE_ESCAPE_CHARS = "\b\t\n\f\r"; + private static final Pattern RFC9535_REQUIRES_ESCAPE = Pattern.compile( "[^\\x{0020}-\\x{0026}\\x{0028}-\\x{005B}\\x{005D}-\\x{D7FF}\\x{E000}-\\x{10FFFF}]"); + /** + * Matches one bracket segment {@code ['...']} where inner text may contain RFC 9535 escapes + * (quote, backslash, control characters, and four-digit hex escapes). + */ + private static final Pattern BRACKET_SEGMENT = Pattern.compile("\\['((?:[^'\\\\]|\\\\.)*)'\\]"); + private static final Map RFC9535_ESCAPE_REPLACEMENTS = buildReplacementMap(); - private static final Splitter DOT = Splitter.on("."); private static final String ROOT = "$"; - static List parse(String path) { + /** + * Parses a path into segments. After the root {@code $}, each segment is either dot shorthand + * ({@code .name} per RFC 9535), a single-quoted bracket name ({@code ['...']}) with RFC 9535 + * escapes, or a numeric array index ({@code [n]}). Forms may be mixed (e.g. {@code $.a['b.c']}, + * {@code $.items[0].tags}, {@code $.matrix[0][1]}). Wildcards and recursive descent are not + * supported. + * + *

The root path {@code $} yields an empty segment list. + */ + public static List parse(String path) { Preconditions.checkArgument(path != null, "Invalid path: null"); + Preconditions.checkArgument(!path.isEmpty(), "Invalid path: empty"); + Preconditions.checkArgument( + path.startsWith(ROOT), "Invalid path, does not start with %s: %s", ROOT, path); + + if (path.equals(ROOT)) { + return Lists.newArrayList(); + } + + return parseAfterRoot(path); + } + + /** Normalizes object field names only (no array indices). */ + public static String toNormalizedPath(Iterable fields) { + return toNormalizedPath( + Streams.stream(fields).map(PathSegment.Name::new).collect(Collectors.toList())); + } + + static String toNormalizedPath(List segments) { + StringBuilder builder = new StringBuilder(ROOT); + for (PathSegment segment : segments) { + if (segment instanceof PathSegment.Name) { + String name = ((PathSegment.Name) segment).name(); + builder.append("['").append(rfc9535escape(name)).append("']"); + } else if (segment instanceof PathSegment.Index) { + int index = ((PathSegment.Index) segment).index(); + Preconditions.checkArgument(index >= 0, "Invalid path, negative array index: %s", index); + builder.append('[').append(index).append(']'); + } else { + throw new IllegalStateException("Unknown segment: " + segment); + } + } + return builder.toString(); + } + + private static List parseAfterRoot(String path) { + List segments = Lists.newArrayList(); + Matcher bracketMatcher = BRACKET_SEGMENT.matcher(path); + int len = path.length(); + int pos = ROOT.length(); + + while (pos < len) { + char ch = path.charAt(pos); + pos = + switch (ch) { + case '.' -> appendDotSegment(segments, path, pos); + case '[' -> appendBracketOrIndexSegment(segments, path, pos, bracketMatcher); + default -> + throw new IllegalArgumentException( + String.format( + "Invalid path, expected '.' or '[' at position %s: %s", pos, path)); + }; + } + + return segments; + } + + /** + * Appends a dot-style segment to {@code segments} by reading from {@code path[dotPos]}: a single + * leading {@code .} then an RFC 9535 shorthand name until the next {@code .} or {@code [}. + * + * @param segments output; segments parsed so far, updated in place + * @param path full path + * @param dotPos index of the {@code .} starting the segment + */ + private static int appendDotSegment(List segments, String path, int dotPos) { + int pos = dotPos + 1; + int pathLen = path.length(); + Preconditions.checkArgument(pos < pathLen, "Invalid path, trailing dot: %s", path); + int start = pos; + while (pos < pathLen) { + char ch = path.charAt(pos); + if (ch == '.' || ch == '[') { + break; + } + pos++; + } + + Preconditions.checkArgument(pos > start, "Invalid path, empty segment after '.': %s", path); + String name = path.substring(start, pos); Preconditions.checkArgument( - !path.contains("[") && !path.contains("]"), "Unsupported path, contains bracket: %s", path); + RFC9535_MEMBER_NAME_SHORTHAND.test(name), + "Invalid path: %s (%s has invalid characters)", + path, + name); + segments.add(new PathSegment.Name(name)); + return pos; + } + + /** + * Appends a bracket segment to {@code segments} starting at {@code path[bracketPos]}. If the next + * character is a digit, consumes a numeric array index {@code [n]}; otherwise consumes a quoted + * name {@code ['...']}. A lone {@code [} with no following quoted form (e.g. the path ends at + * {@code $[}) is rejected in {@link #appendQuotedBracketSegment} when the pattern does not match. + * + * @param segments output; segments parsed so far, updated in place + * @param path full path + * @param bracketPos index of the opening {@code [} + */ + private static int appendBracketOrIndexSegment( + List segments, String path, int bracketPos, Matcher bracketMatcher) { Preconditions.checkArgument( - !path.contains("*"), "Unsupported path, contains wildcard: %s", path); + bracketPos < path.length() && path.charAt(bracketPos) == '[', "Invalid path: %s", path); + if (bracketPos + 1 < path.length() && isAsciiDigit(path.charAt(bracketPos + 1))) { + return appendArrayIndexSegment(segments, path, bracketPos); + } + return appendQuotedBracketSegment(segments, path, bracketPos, bracketMatcher); + } + + private static boolean isAsciiDigit(char ch) { + return ch >= '0' && ch <= '9'; + } + + /** + * Appends a non-negative array index from {@code [n]} to {@code segments}, starting with {@code + * [} at {@code path[bracketPos]}. + * + * @param segments output; segments parsed so far, updated in place + * @param path full path + * @param bracketPos index of the opening {@code [} before the digits + */ + private static int appendArrayIndexSegment( + List segments, String path, int bracketPos) { + int pos = bracketPos + 1; + int len = path.length(); + int start = pos; + while (pos < len && isAsciiDigit(path.charAt(pos))) { + pos++; + } + Preconditions.checkArgument(pos > start, "Invalid path, empty array index in: %s", path); Preconditions.checkArgument( - !path.contains(".."), "Unsupported path, contains recursive descent: %s", path); + pos < len && path.charAt(pos) == ']', "Invalid path, unclosed array index in: %s", path); + int index; + String digits = path.substring(start, pos); + try { + index = Integer.parseInt(digits); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format("Invalid path, array index out of int range: %s", path), e); + } + Preconditions.checkArgument(index >= 0, "Invalid path, negative array index in: %s", path); + segments.add(new PathSegment.Index(index)); + return pos + 1; + } - List parts = DOT.splitToList(path); + /** + * Appends a name from a {@code ['...']} segment to {@code segments} using the bracket matcher + * (inner text may use RFC 9535 escapes). Expects a full quoted bracket token at {@code + * path[bracketPos]}; otherwise the matcher or alignment checks throw. + * + * @param segments output; segments parsed so far, updated in place + * @param path full path + * @param bracketPos index of the opening {@code [} that must begin {@code ['} + */ + private static int appendQuotedBracketSegment( + List segments, String path, int bracketPos, Matcher bracketMatcher) { Preconditions.checkArgument( - ROOT.equals(parts.get(0)), "Invalid path, does not start with %s: %s", ROOT, path); - - List names = parts.subList(1, parts.size()); - for (String name : names) { - Preconditions.checkArgument( - RFC9535_MEMBER_NAME_SHORTHAND.test(name), - "Invalid path: %s (%s has invalid characters)", - path, - name); + bracketMatcher.find(bracketPos), "Invalid path, malformed bracket segment: %s", path); + Preconditions.checkArgument( + bracketMatcher.start() == bracketPos, + "Invalid path, unexpected characters at position %s: %s", + bracketPos, + path); + segments.add(new PathSegment.Name(rfc9535unescape(bracketMatcher.group(1)))); + return bracketMatcher.end(); + } + + /** Unescapes the inner text of a {@code ['...']} segment (inverse of {@link #rfc9535escape}). */ + @VisibleForTesting + @SuppressWarnings("StatementSwitchToExpressionSwitch") + static String rfc9535unescape(String escaped) { + if (!escaped.contains("\\")) { + return escaped; + } + + StringBuilder builder = new StringBuilder(escaped.length()); + int cursor = 0; + while (cursor < escaped.length()) { + char ch = escaped.charAt(cursor); + if (ch != '\\') { + builder.append(ch); + cursor += 1; + } else { + Preconditions.checkArgument( + cursor + 1 < escaped.length(), "Invalid escape sequence at end of: %s", escaped); + char next = escaped.charAt(cursor + 1); + switch (next) { + case 'u': + Preconditions.checkArgument( + cursor + 5 < escaped.length(), + "Invalid \\uXXXX escape at position %s in: %s", + cursor, + escaped); + builder.append((char) Integer.parseInt(escaped.substring(cursor + 2, cursor + 6), 16)); + cursor += 6; + break; + case 'b': + case 't': + case 'f': + case 'n': + case 'r': + case '\'': + case '\\': + builder.append(rfc9535SimpleEscapedChar(next)); + cursor += 2; + break; + default: + throw new IllegalArgumentException( + "Invalid escape sequence \\" + next + " in: " + escaped); + } + } } - return names; + return builder.toString(); } - public static String toNormalizedPath(Iterable fields) { - return ROOT - + Streams.stream(fields) - .map(PathUtil::rfc9535escape) - .map(name -> "['" + name + "']") - .collect(Collectors.joining("")); + private static char rfc9535SimpleEscapedChar(char next) { + int idx = RFC9535_SIMPLE_ESCAPE_LETTERS.indexOf(next); + if (idx >= 0) { + return RFC9535_SIMPLE_ESCAPE_CHARS.charAt(idx); + } + if (next == '\'') { + return '\''; + } + if (next == '\\') { + return '\\'; + } + throw new IllegalArgumentException("Invalid simple escape: \\" + next); } @VisibleForTesting diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java index 24e58ad1e808..5be76bc0d7d4 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java @@ -361,7 +361,9 @@ public void testExtractExpressionNonVariant() { new String[] { "$", // root path "$.event_id", - "$.event.id" + "$.event.id", + "$['event_id']", // bracket notation + "$.events[0].event_id" // array index accessor }; @ParameterizedTest @@ -378,9 +380,7 @@ public void testExtractExpressionBindingPaths(String path) { null, "", "event_id", // missing root - "$['event_id']", // uses bracket notation "$..event_id", // uses recursive descent - "$.events[0].event_id", // uses position accessor "$.events.*" // uses wildcard }; diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestPathUtil.java b/api/src/test/java/org/apache/iceberg/expressions/TestPathUtil.java index 9115a8fcd2dd..2bb1801cda36 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestPathUtil.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestPathUtil.java @@ -18,11 +18,14 @@ */ package org.apache.iceberg.expressions; +import static org.apache.iceberg.expressions.PathUtil.PathSegment.Index; +import static org.apache.iceberg.expressions.PathUtil.PathSegment.Name; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.util.List; +import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.FieldSource; @@ -30,9 +33,13 @@ @SuppressWarnings({"AvoidEscapedUnicodeCharacters", "IllegalTokenText"}) public class TestPathUtil { + private static List names(String... names) { + return java.util.Arrays.stream(names).map(Name::new).collect(Collectors.toList()); + } + @Test public void testSimplePath() { - assertThat(PathUtil.parse("$.event.id")).isEqualTo(List.of("event", "id")); + assertThat(PathUtil.parse("$.event.id")).isEqualTo(names("event", "id")); } private static final String[] VALID_PATHS = @@ -40,8 +47,18 @@ public void testSimplePath() { "$", // root path "$.event_id", "$.event.id", + "$['event_id']", // bracket form + "$.event['x.y']", // mixed: dot then bracket + "$['event']['id']", // bracket then bracket + "$['a'].b", // bracket then dot "$.\u2603", // snowman "$.\uD834\uDD1E", // surrogate pair, U+1D11E + "$.matrix[0][1]", + "$.basket[0][2].a", + "$.items[0].tags[1]", + "$['matrix'][0][1]", + "$['items'][0]['tags'][1]", + "$['basket'][0][2]['a']", }; @ParameterizedTest @@ -55,9 +72,7 @@ public void testExtractExpressionBindingPaths(String path) { null, "", "event_id", // missing root - "$['event_id']", // uses bracket notation "$..event_id", // uses recursive descent - "$.events[0].event_id", // uses position accessor "$.events.*", // uses wildcard "$.0invalid", // starts with a digit "$._\uD834", // dangling high surrogate @@ -79,6 +94,10 @@ public void testExtractBindingWithInvalidPath(String path) { new String[] {"$.a.b.c", "$['a']['b']['c']"}, new String[] {"$.\u2603", "$['☃']"}, new String[] {"$.a\uD834\uDD1Eb.x", "$['a\uD834\uDD1Eb']['x']"}, + // Dot shorthand vs bracket form for names; array steps use unquoted [n] in both. + new String[] {"$.matrix[0][1]", "$['matrix'][0][1]"}, + new String[] {"$.items[0].tags[1]", "$['items'][0]['tags'][1]"}, + new String[] {"$.basket[0][2].a", "$['basket'][0][2]['a']"}, }; @ParameterizedTest @@ -123,4 +142,30 @@ public void testNormalizedFieldLists(List fields, String normalizedPath) public void testPathEscaping(String name, String escaped) { assertThat(PathUtil.rfc9535escape(name)).isEqualTo(escaped); } + + @Test + void testParseArrayPath() { + assertThat(PathUtil.parse("$.commits[0].author.name")) + .isEqualTo( + List.of(new Name("commits"), new Index(0), new Name("author"), new Name("name"))); + } + + @Test + void testParseArrayIndexOnly() { + assertThat(PathUtil.parse("$.a[1][2].b")) + .isEqualTo(List.of(new Name("a"), new Index(1), new Index(2), new Name("b"))); + } + + @Test + void testParseBracketMixed() { + assertThat(PathUtil.parse("$['issue']['labels'][0]['name']")) + .isEqualTo(List.of(new Name("issue"), new Name("labels"), new Index(0), new Name("name"))); + } + + @Test + void testNameAndIndexSegmentsAreDistinct() { + // $[0] is an array index; $['[0]'] is a field whose name is literally "[0]" — must not conflate + assertThat(PathUtil.parse("$[0]")).containsExactly(new Index(0)); + assertThat(PathUtil.parse("$['[0]']")).containsExactly(new Name("[0]")); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java index f4323f1c0350..119f0eaff34b 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java @@ -26,6 +26,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.SparkVariantExtractionUtil; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Type.TypeID; import org.apache.iceberg.types.TypeUtil; @@ -128,6 +129,16 @@ public Type field(Types.NestedField field, Supplier fieldResult) { "Cannot project an optional field as non-null: %s", field.name()); + // When the Iceberg field is VARIANT and Spark has rewritten it to an annotated struct + // (all sub-fields carry __VARIANT_METADATA_KEY via SupportsPushDownVariantExtractions), + // treat the field as a full variant projection. The struct is an optimizer artifact that + // maps ordinal slot i to a specific shredded path; Iceberg's projection must keep the full + // VariantType column so the reader can access both metadata/value and typed_value columns. + if (field.type().typeId() == TypeID.VARIANT + && SparkVariantExtractionUtil.isVariantExtractionStruct(requestedField.dataType())) { + return Types.VariantType.get(); + } + this.current = requestedField.dataType(); try { return fieldResult.get(); diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java index 8128babfa340..34b087b862a1 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -276,6 +276,14 @@ public boolean aggregatePushDownEnabled() { .parse(); } + public boolean variantExtractionPushDownEnabled() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED) + .defaultValue(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED_DEFAULT) + .parse(); + } + public boolean adaptiveSplitSizeEnabled() { return confParser .booleanConf() diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index ddedc36c7126..fcff086fdbf4 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -45,6 +45,14 @@ private SparkSQLProperties() {} "spark.sql.iceberg.aggregate-push-down.enabled"; public static final boolean AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT = true; + // Controls whether Iceberg accepts variant field extractions + // (SupportsPushDownVariantExtractions). + // Enabled by default. Set to false to keep the scan schema as VariantType when Spark pushes + // variant_get into the scan. + public static final String VARIANT_EXTRACTION_PUSH_DOWN_ENABLED = + "spark.sql.iceberg.variant-extraction-push-down.enabled"; + public static final boolean VARIANT_EXTRACTION_PUSH_DOWN_ENABLED_DEFAULT = true; + // Controls write distribution mode public static final String DISTRIBUTION_MODE = "spark.sql.iceberg.distribution-mode"; diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/data/SparkVariantExtractionUtil.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/data/SparkVariantExtractionUtil.java new file mode 100644 index 000000000000..f3055db33f3a --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/data/SparkVariantExtractionUtil.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import org.apache.iceberg.expressions.PathUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.VarcharType; +import org.apache.spark.sql.types.VariantType; + +/** + * Utilities for Spark variant extraction pushdown ({@code __VARIANT_METADATA_KEY}). + * + *

Path parsing ({@link #isSupportedExtractionPath}) delegates to {@link PathUtil}. + */ +public class SparkVariantExtractionUtil { + public static final String VARIANT_METADATA_KEY = "__VARIANT_METADATA_KEY"; + public static final String PLACEHOLDER_PATH = "$.__placeholder_field__"; + + private SparkVariantExtractionUtil() {} + + public static boolean isVariantExtractionStruct(DataType dataType) { + if (!(dataType instanceof StructType)) { + return false; + } + + StructType structType = (StructType) dataType; + if (structType.fields().length == 0) { + return false; + } + + for (StructField field : structType.fields()) { + if (!field.metadata().contains(VARIANT_METADATA_KEY)) { + return false; + } + } + + return true; + } + + public static String extractionPath(StructField field) { + Preconditions.checkArgument( + field.metadata().contains(VARIANT_METADATA_KEY), + "Missing %s on field %s", + VARIANT_METADATA_KEY, + field.name()); + Metadata variantMetadata = field.metadata().getMetadata(VARIANT_METADATA_KEY); + return variantMetadata.getString("path"); + } + + public static boolean isPlaceholderExtraction(StructField field) { + return PLACEHOLDER_PATH.equals(extractionPath(field)); + } + + /** + * Returns true when the selective Parquet reader can handle the extraction path. Delegates to + * {@link PathUtil#parse} for RFC 9535-compliant validation; any path that fails to parse is + * declined. + */ + public static boolean isSupportedExtractionPath(String jsonPath) { + try { + PathUtil.parse(jsonPath); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } + + /** + * Returns true when Spark pushdown can materialize the extraction target type from shredded + * variant columns. Must stay in sync with the SparkVariantExtractionReaders cast support added in + * the selective Parquet reader change (see https://github.com/apache/iceberg/pull/16714). + */ + public static boolean isSupportedPushdownTargetType(DataType targetType) { + if (targetType == null || targetType instanceof VariantType) { + return false; + } + + // @todo: add support for container types + // Spark treats char/varchar as string + if (isUnsupportedContainerType(targetType)) { + return false; + } + + return isSupportedPrimitiveType(targetType); + } + + private static boolean isUnsupportedContainerType(DataType targetType) { + return targetType instanceof StructType + || targetType instanceof ArrayType + || targetType instanceof MapType + || targetType instanceof CharType + || targetType instanceof VarcharType; + } + + private static boolean isSupportedPrimitiveType(DataType targetType) { + return isSupportedBasicType(targetType) || isSupportedTemporalOrDecimalType(targetType); + } + + private static boolean isSupportedBasicType(DataType targetType) { + return DataTypes.StringType.sameType(targetType) + || DataTypes.IntegerType.sameType(targetType) + || DataTypes.LongType.sameType(targetType) + || DataTypes.ByteType.sameType(targetType) + || DataTypes.ShortType.sameType(targetType) + || DataTypes.BooleanType.sameType(targetType) + || DataTypes.FloatType.sameType(targetType) + || DataTypes.DoubleType.sameType(targetType) + || DataTypes.DateType.sameType(targetType) + || DataTypes.BinaryType.sameType(targetType); + } + + private static boolean isSupportedTemporalOrDecimalType(DataType targetType) { + return targetType instanceof DecimalType + || targetType instanceof TimestampType + || targetType instanceof TimestampNTZType; + } + + public static String extractionPath( + org.apache.spark.sql.connector.read.VariantExtraction extraction) { + org.apache.spark.sql.types.Metadata metadata = extraction.metadata(); + if (metadata.contains(VARIANT_METADATA_KEY)) { + return metadata.getMetadata(VARIANT_METADATA_KEY).getString("path"); + } + + return metadata.getString("path"); + } +} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java index 814ca410d147..6573dea676b7 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -18,7 +18,10 @@ */ package org.apache.iceberg.spark.source; +import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.function.Supplier; import org.apache.iceberg.Scan; @@ -29,14 +32,28 @@ import org.apache.iceberg.Table; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.metrics.ScanReport; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.data.SparkVariantExtractionUtil; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.VariantExtraction; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VariantType; +import org.apache.spark.sql.types.VariantType$; class SparkBatchQueryScan extends SparkRuntimeFilterableScan { private final Snapshot snapshot; private final String branch; + // All extractions Spark passed to pushVariantExtractions (in original order). + private final VariantExtraction[] allVariantExtractions; + // Parallel mask: true iff the corresponding extraction was accepted by this datasource. + private final boolean[] variantExtractionAccepted; SparkBatchQueryScan( SparkSession spark, @@ -48,16 +65,127 @@ class SparkBatchQueryScan extends SparkRuntimeFilterableScan { SparkReadConf readConf, Schema projection, List filters, + VariantExtraction[] allVariantExtractions, + boolean[] variantExtractionAccepted, Supplier scanReportSupplier) { super(spark, table, schema, scan, readConf, projection, filters, scanReportSupplier); this.snapshot = snapshot; this.branch = branch; + this.allVariantExtractions = allVariantExtractions; + this.variantExtractionAccepted = variantExtractionAccepted; } Long snapshotId() { return snapshot != null ? snapshot.snapshotId() : null; } + @Override + public StructType readSchema() { + if (allVariantExtractions.length == 0) { + return super.readSchema(); + } + + // Check whether any extraction was accepted. + boolean hasAcceptedExtraction = false; + for (boolean b : variantExtractionAccepted) { + if (b) { + hasAcceptedExtraction = true; + break; + } + } + final boolean anyAccepted = hasAcceptedExtraction; + + // Group extractions by top-level variant column in Spark batch order. List index per column + // matches the ordinal Spark uses in GetStructField(ref, ordinal) after + // buildScanWithPushedVariants. + Map> pushedSlotsByColumn = new LinkedHashMap<>(); + for (int i = 0; i < allVariantExtractions.length; i++) { + String colName = allVariantExtractions[i].columnName()[0]; + List slots = + pushedSlotsByColumn.computeIfAbsent(colName, k -> Lists.newArrayList()); + slots.add(new PushedVariantSlot(allVariantExtractions[i], variantExtractionAccepted[i])); + } + + StructType base = super.readSchema(); + StructField[] newFields = + Arrays.stream(base.fields()) + .map( + field -> { + List slots = pushedSlotsByColumn.get(field.name()); + if (slots == null) { + return field; + } + + // Emit one struct field per Spark ordinal. Accepted extractions carry + // VariantMetadata so the Parquet reader maps them to typed_value.* shredded + // columns. + // Rejected extractions are VariantType placeholders so ordinals in the plan + // remain + // valid even when only a subset of slots are accepted. + StructField[] extracted = new StructField[slots.size()]; + int acceptedCount = 0; + for (int ordinal = 0; ordinal < slots.size(); ordinal++) { + PushedVariantSlot slot = slots.get(ordinal); + VariantExtraction extraction = slot.extraction(); + if (slot.accepted()) { + acceptedCount++; + extracted[ordinal] = + DataTypes.createStructField( + String.valueOf(ordinal), + extraction.expectedDataType(), + true, + extraction.metadata()); + } else { + // Rejected extraction: use VariantType so Spark can fall back to full + // variant evaluation for this ordinal slot. + extracted[ordinal] = + DataTypes.createStructField( + String.valueOf(ordinal), VariantType$.MODULE$, true); + } + } + + // Mirror ParquetScan: IsNotNull-only pushdown uses a boolean placeholder field + // instead of an all-VariantType struct (see Spark ParquetScan.scala). + if (acceptedCount == 0 + && extracted.length == 1 + && slots.get(0).extraction().expectedDataType() instanceof VariantType) { + Metadata placeholderMetadata = + new MetadataBuilder() + .putMetadata( + SparkVariantExtractionUtil.VARIANT_METADATA_KEY, + new MetadataBuilder() + .putString("path", SparkVariantExtractionUtil.PLACEHOLDER_PATH) + .putBoolean("failOnError", false) + .putString("timeZoneId", "UTC") + .build()) + .build(); + extracted[0] = + DataTypes.createStructField( + "0", DataTypes.BooleanType, true, placeholderMetadata); + } else if (!anyAccepted) { + // No typed extractions were accepted; keep the original variant column type. + return field; + } + + // Use empty metadata for the rewritten top-level variant field. The original + // Iceberg field metadata (fieldId) must not be inherited here because Spark's + // Alias(realAttr, name)(expectedExprId) in buildScanWithPushedVariants carries + // realAttr.metadata as its own metadata. RemoveRedundantAliases strips the alias + // when alias.metadata == attr.metadata — if both carry the fieldId metadata the + // alias is removed, leaving a stale exprId in the plan and causing + // PLAN_VALIDATION_FAILED. Using empty metadata breaks that equality so the alias + // is preserved and the correct exprId binding is maintained. + return DataTypes.createStructField( + field.name(), + DataTypes.createStructType(extracted), + field.nullable(), + org.apache.spark.sql.types.Metadata$.MODULE$.empty()); + }) + .toArray(StructField[]::new); + + return DataTypes.createStructType(newFields); + } + @Override public Statistics estimateStatistics() { return estimateStatistics(snapshot); @@ -107,4 +235,23 @@ public String description() { runtimeFiltersDesc(), groupingKeyDesc()); } + + /** One pushed variant extraction for a top-level variant column. */ + private static final class PushedVariantSlot { + private final VariantExtraction extraction; + private final boolean accepted; + + private PushedVariantSlot(VariantExtraction extraction, boolean accepted) { + this.extraction = extraction; + this.accepted = accepted; + } + + private VariantExtraction extraction() { + return extraction; + } + + private boolean accepted() { + return accepted; + } + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index 69b6314a7f2b..261ed4f1ff67 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -57,6 +57,7 @@ import org.apache.spark.sql.connector.read.SupportsPushDownLimit; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters; +import org.apache.spark.sql.connector.read.VariantExtraction; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.slf4j.Logger; @@ -76,6 +77,24 @@ public class SparkScanBuilder extends BaseSparkScanBuilder private final Long startSnapshotId; private final Long endSnapshotId; private Scan localScan; + // All extractions Spark passed to pushVariantExtractions, kept in original order so that + // readSchema() can reproduce the per-column ordinals Spark already embedded in the plan. + private VariantExtraction[] allVariantExtractions = new VariantExtraction[0]; + // Parallel boolean array: accepted[i] == true iff extractions[i] was accepted. + private boolean[] variantExtractionAccepted = new boolean[0]; + + protected void setVariantExtractions(VariantExtraction[] extractions, boolean[] accepted) { + this.allVariantExtractions = extractions; + this.variantExtractionAccepted = accepted; + } + + protected VariantExtraction[] variantExtractions() { + return allVariantExtractions; + } + + protected boolean[] variantExtractionAccepted() { + return variantExtractionAccepted; + } SparkScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { this( @@ -264,6 +283,8 @@ private Scan buildBatchScan() { readConf(), projection, filters(), + variantExtractions(), + variantExtractionAccepted(), metricsReporter()::scanReport); } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java index 80a40d72c8d1..ae385719f510 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java @@ -55,6 +55,7 @@ import org.apache.iceberg.spark.TimeTravel; import org.apache.iceberg.spark.TimeTravel.AsOfTimestamp; import org.apache.iceberg.spark.TimeTravel.AsOfVersion; +import org.apache.iceberg.types.Type; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.connector.catalog.SupportsDeleteV2; @@ -179,9 +180,19 @@ public Constraint[] constraints() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (hasAnyTopLevelVariantColumn(schema)) { + return new SparkVariantExtractionScanBuilder( + spark(), table(), schema, snapshot, branch, timeTravel, options); + } + return new SparkScanBuilder(spark(), table(), schema, snapshot, branch, timeTravel, options); } + private static boolean hasAnyTopLevelVariantColumn(Schema tableSchema) { + return tableSchema.columns().stream() + .anyMatch(field -> field.type().typeId() == Type.TypeID.VARIANT); + } + @Override public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { Preconditions.checkArgument(timeTravel == null, "Cannot write to table with time travel"); diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantExtractionScanBuilder.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantExtractionScanBuilder.java new file mode 100644 index 000000000000..f77fe176f0a5 --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantExtractionScanBuilder.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.TimeTravel; +import org.apache.iceberg.spark.data.SparkVariantExtractionUtil; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.SupportsPushDownVariantExtractions; +import org.apache.spark.sql.connector.read.VariantExtraction; +import org.apache.spark.sql.types.VariantType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link SparkScanBuilder} for tables with one or more top-level VARIANT columns. + * + *

Spark's {@code buildScanWithPushedVariants} is unsafe when a scan has multiple variant columns + * and only a subset of extractions are accepted ({@code rewriteExpr} would emit invalid {@code + * GetStructField} on columns whose type was not rewritten). This builder uses an all-or-nothing + * policy: decline the entire batch if any extraction is unsupported, or if any full-variant slot + * ({@code expectedDataType = VariantType}, path {@code $}) is present. + * + *

{@link #pushVariantExtractions} receives extractions collected from {@code Project} and {@code + * Filter} on each scan subtree ({@code PhysicalOperation}); it does not walk through {@code Join} + * or {@code Aggregate}. Typed paths usually come from filters. Plain {@code VARIANT} attributes in + * the visible project list also add a full-variant slot ({@code expectedDataType = VariantType}, + * Spark path {@code $}; for example when Spark defaults to {@code scan.output}, or passthroughs the + * raw column on a join branch). {@code variant_get} in {@code GROUP BY} or aggregate functions is + * not collected; extract in a subquery first if agg pushdown is needed. + * + *

Iceberg declines batches that include a full-variant slot because accepting only the typed + * extractions would rewrite the scan to a struct while {@code variant_get} above the join/aggregate + * barrier still references the pre-rewrite attribute, causing binding failures or {@code + * ClassCastException}. For example TPC-DS q42 on the item scan: + * + *

{@code
+ * HashAggregate GROUP BY variant_get(item_data, '$.category', ...)
+ *   Project [..., variant_get(item_data, '$.category', ...)]   // above Join — not collected
+ *     Join ON ss.ss_item_sk = i.i_item_sk
+ *       Project [i_item_sk, item_data]                         // plain passthrough → full variant
+ *         Filter variant_get(item_data, '$.manager_id', ...) = ...
+ *           Scan item                                          // PhysicalOperation stops here
+ * }
+ */ +class SparkVariantExtractionScanBuilder extends SparkScanBuilder + implements SupportsPushDownVariantExtractions { + + private static final Logger LOG = + LoggerFactory.getLogger(SparkVariantExtractionScanBuilder.class); + + SparkVariantExtractionScanBuilder( + SparkSession spark, Table table, CaseInsensitiveStringMap options) { + super(spark, table, options); + } + + SparkVariantExtractionScanBuilder( + SparkSession spark, + Table table, + Schema schema, + Snapshot snapshot, + String branch, + TimeTravel timeTravel, + CaseInsensitiveStringMap options) { + super(spark, table, schema, snapshot, branch, timeTravel, options); + } + + @Override + public boolean[] pushVariantExtractions(VariantExtraction[] extractions) { + // Default false: Spark treats each extraction as declined unless explicitly accepted. + boolean[] accepted = new boolean[extractions.length]; + + if (!readConf().variantExtractionPushDownEnabled()) { + return accepted; + } + + boolean[] candidates = new boolean[extractions.length]; + int typedExtractions = 0; + int typedCandidates = 0; + + for (int i = 0; i < extractions.length; i++) { + VariantExtraction extraction = extractions[i]; + if (extraction.expectedDataType() instanceof VariantType) { + // Full-variant slot (expectedDataType = VariantType): decline the batch. See class javadoc. + return accepted; + } + + typedExtractions += 1; + String[] colPath = extraction.columnName(); + boolean candidate = + colPath.length == 1 + && isVariantColumn(colPath[0]) + && SparkVariantExtractionUtil.isSupportedExtractionPath( + SparkVariantExtractionUtil.extractionPath(extraction)) + && SparkVariantExtractionUtil.isSupportedPushdownTargetType( + extraction.expectedDataType()); + candidates[i] = candidate; + if (candidate) { + typedCandidates += 1; + } + } + + // Accept all typed extractions in the batch, including the same path with different target + // types + // when both appear in the visible Project/Filter (e.g. SELECT variant_get(v,'$.size','double'), + // variant_get(v,'$.size','long')). variant_get inside Aggregate is not collected by Spark. + // The Parquet reader shares one physical column read per path and casts to each target type. + boolean acceptAllTyped = typedExtractions > 0 && typedCandidates == typedExtractions; + + if (acceptAllTyped) { + for (int i = 0; i < extractions.length; i++) { + accepted[i] = candidates[i]; + } + + LOG.info( + "Accepted {}/{} variant extraction(s) for scan of table {}", + typedCandidates, + extractions.length, + table().name()); + setVariantExtractions(extractions, accepted); + } else if (typedCandidates > 0) { + LOG.info( + "Declined variant extraction pushdown for table {} because {} of {} typed extraction(s)" + + " have unsupported paths, target types, or reference non-variant columns", + table().name(), + typedExtractions - typedCandidates, + typedExtractions); + } + + return accepted; + } + + private boolean isVariantColumn(String colName) { + org.apache.iceberg.types.Types.NestedField field = schema().findField(colName); + return field != null && field.type().typeId() == Type.TypeID.VARIANT; + } +} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java index 1760143d2cb1..f2294dc284fa 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestBaseWithCatalog.java @@ -47,6 +47,7 @@ import org.apache.iceberg.rest.RESTServerExtension; import org.apache.iceberg.util.PropertyUtil; import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.ExtendWith; @@ -153,6 +154,12 @@ private void configureValidationCatalog() { public void before() { configureValidationCatalog(); + // TODO: remove once PR #16714 is merged. + // Disable variant extraction pushdown until the Parquet selective reader (PR #16714) is merged + // and wired through engineProjection. Without it the reader returns VariantVal while the + // rewritten readSchema() declares a struct, causing ClassCastException at runtime. + spark.conf().set(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED, "false"); + spark.conf().set("spark.sql.catalog." + catalogName, implementation); catalogConfig.forEach( (key, value) -> spark.conf().set("spark.sql.catalog." + catalogName + "." + key, value)); @@ -167,6 +174,11 @@ public void before() { sql("CREATE NAMESPACE IF NOT EXISTS default"); } + @AfterEach + public void afterEach() { + spark.conf().unset(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED); + } + protected String tableName(String name) { return (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default." + name; } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkVariantExtractionUtil.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkVariantExtractionUtil.java new file mode 100644 index 000000000000..c99bd44d6de1 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkVariantExtractionUtil.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VariantType$; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.FieldSource; + +class TestSparkVariantExtractionUtil { + + private static final DataType[] SUPPORTED_TARGET_TYPES = + new DataType[] { + DataTypes.StringType, + DataTypes.IntegerType, + DataTypes.LongType, + DataTypes.ByteType, + DataTypes.ShortType, + DataTypes.BooleanType, + DataTypes.FloatType, + DataTypes.DoubleType, + DataTypes.createDecimalType(9, 2), + DataTypes.DateType, + DataTypes.TimestampType, + DataTypes.TimestampNTZType, + DataTypes.BinaryType + }; + + private static final DataType[] UNSUPPORTED_TARGET_TYPES = + new DataType[] { + VariantType$.MODULE$, + new StructType( + new StructField[] {DataTypes.createStructField("a", DataTypes.IntegerType, true)}), + DataTypes.createArrayType(DataTypes.IntegerType), + DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType), + DataTypes.createCharType(4), + DataTypes.createVarcharType(4) + }; + + @ParameterizedTest + @FieldSource("SUPPORTED_TARGET_TYPES") + void isSupportedPushdownTargetTypeAcceptsSupportedTypes(DataType targetType) { + assertThat(SparkVariantExtractionUtil.isSupportedPushdownTargetType(targetType)).isTrue(); + } + + @ParameterizedTest + @FieldSource("UNSUPPORTED_TARGET_TYPES") + void isSupportedPushdownTargetTypeRejectsUnsupportedTypes(DataType targetType) { + assertThat(SparkVariantExtractionUtil.isSupportedPushdownTargetType(targetType)).isFalse(); + } + + @Test + void isSupportedExtractionPathAcceptsAllValidPaths() { + assertThat(SparkVariantExtractionUtil.isSupportedExtractionPath("$.size")).isTrue(); + assertThat(SparkVariantExtractionUtil.isSupportedExtractionPath("$.pull_request.user.login")) + .isTrue(); + assertThat(SparkVariantExtractionUtil.isSupportedExtractionPath("$.commits[0].author.name")) + .isTrue(); + assertThat( + SparkVariantExtractionUtil.isSupportedExtractionPath("$['issue']['labels'][0]['name']")) + .isTrue(); + } +} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestVariantShreddingPushdown.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestVariantShreddingPushdown.java new file mode 100644 index 000000000000..1ff1298649a6 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestVariantShreddingPushdown.java @@ -0,0 +1,743 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VariantType; +import org.apache.spark.sql.types.VariantType$; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +/** + * Tests that {@code SparkScanBuilder} implements {@code SupportsPushDownVariantExtractions} so + * Spark's {@code V2ScanRelationPushDown} rule rewrites {@code variant_get} expressions into struct + * field accesses and prunes the scan output schema to an annotated struct rather than the full + * {@code VariantType}. + * + *

These tests cover the DSv2 contract (plan shape + query correctness). Actual Parquet I/O + * column pruning is a follow-on change. + * + *

Each test sets {@link SparkSQLProperties#VARIANT_EXTRACTION_PUSH_DOWN_ENABLED} explicitly so + * behavior does not depend on the product default. + */ +public class TestVariantShreddingPushdown extends CatalogTestBase { + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + }; + } + + @BeforeEach + public void createTable() { + // Use a schema with a non-variant column + a variant column, matching the GHA table shape + // (type STRING, payload VARIANT) to catch any plan-shape issue with multi-column tables. + sql( + "CREATE TABLE %s (id INT, type STRING, v VARIANT) USING iceberg TBLPROPERTIES ('%s' = '3')", + tableName, TableProperties.FORMAT_VERSION); + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + } + + private void withExtractionPushdown(boolean enabled, Runnable action) { + spark + .conf() + .set(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED, String.valueOf(enabled)); + try { + action.run(); + } finally { + spark.conf().unset(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED); + } + } + + private void withVariantPushIntoScan(Runnable action) { + spark.conf().set(SQLConf.PUSH_VARIANT_INTO_SCAN().key(), "true"); + try { + action.run(); + } finally { + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + } + } + + @AfterEach + public void dropTable() { + spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS); + spark.conf().unset(SparkSQLProperties.VARIANT_EXTRACTION_PUSH_DOWN_ENABLED); + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + sql("DROP TABLE IF EXISTS %s", tableName); + } + + private List scanReadSchemas(LogicalPlan optimized) { + return scala.collection.JavaConverters.seqAsJavaListConverter(optimized.collectLeaves()) + .asJava() + .stream() + .filter(node -> node instanceof DataSourceV2ScanRelation) + .map(node -> ((DataSourceV2ScanRelation) node).scan().readSchema()) + .collect(Collectors.toList()); + } + + private StructField variantField(StructType scanSchema, String colName) { + return scanSchema.apply(colName); + } + + private void assertVariantColumnIsAnnotatedStruct( + StructType scanSchema, String colName, int expectedOrdinalFields) { + StructField field = variantField(scanSchema, colName); + assertThat(field.dataType()) + .as("%s should be rewritten to an annotated struct, not VariantType", colName) + .isInstanceOf(StructType.class) + .isNotInstanceOf(VariantType.class); + + StructType vStruct = (StructType) field.dataType(); + assertThat(vStruct.fields()) + .as("struct for %s should have %s ordinal field(s)", colName, expectedOrdinalFields) + .hasSize(expectedOrdinalFields); + } + + @TestTemplate + public void testVariantExtractionPlanShape() { + sql( + "INSERT INTO %s VALUES (1, 'A', parse_json('{\"city\": \"Austin\", \"zip\": 78701}')), " + + "(2, 'B', parse_json('{\"city\": \"Boston\", \"zip\": 2108}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT variant_get(v, '$.city', 'string') FROM %s", tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + + assertThat(scanSchemas).hasSize(1); + assertVariantColumnIsAnnotatedStruct(scanSchemas.get(0), "v", 1); + + StructType vStruct = (StructType) scanSchemas.get(0).apply("v").dataType(); + StructField extracted = vStruct.fields()[0]; + assertThat(extracted.name()).as("ordinal name must be '0'").isEqualTo("0"); + assertThat(extracted.dataType()) + .as("extracted type must be StringType") + .isEqualTo(DataTypes.StringType); + })); + } + + @TestTemplate + public void testFilterPlanShape() { + sql( + "INSERT INTO %s VALUES (1, 'PushEvent', parse_json('{\"size\": 10}')), " + + "(2, 'PushEvent', parse_json('{\"size\": 3}')), " + + "(3, 'IssueEvent', parse_json('{\"size\": 8}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT count(*) AS n FROM %s WHERE type = 'PushEvent'" + + " AND variant_get(v, '$.size', 'long') > 5", + tableName)); + LogicalPlan optimized = df.queryExecution().optimizedPlan(); + + assertThat(optimized).isNotNull(); + assertThat(optimized.toString()) + .as("optimized plan should not retain variant_get after pushdown") + .doesNotContain("variant_get"); + + List scanSchemas = scanReadSchemas(optimized); + assertThat(scanSchemas).hasSize(1); + assertVariantColumnIsAnnotatedStruct(scanSchemas.get(0), "v", 1); + + StructType vStruct = (StructType) scanSchemas.get(0).apply("v").dataType(); + assertThat(vStruct.fields()[0].name()).isEqualTo("0"); + assertThat(vStruct.fields()[0].dataType()).isEqualTo(DataTypes.LongType); + })); + } + + @TestTemplate + public void testPushdownDisabledWhenConfigOff() { + withExtractionPushdown( + false, + () -> { + sql("INSERT INTO %s VALUES (1, 'A', parse_json('{\"city\": \"Austin\"}'))", tableName); + + Dataset df = + spark.sql( + String.format("SELECT variant_get(v, '$.city', 'string') FROM %s", tableName)); + List scanSchemas = scanReadSchemas(df.queryExecution().optimizedPlan()); + + assertThat(scanSchemas).hasSize(1); + assertThat(variantField(scanSchemas.get(0), "v").dataType()) + .as("v stays VariantType when iceberg extraction pushdown is disabled") + .isEqualTo(VariantType$.MODULE$); + }); + } + + @TestTemplate + public void testIsNotNullOnlyFallsBackWithoutError() { + sql( + "INSERT INTO %s VALUES (1, 'A', parse_json('{\"city\": \"Austin\"}')), " + "(2, 'B', null)", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format("SELECT count(*) FROM %s WHERE v IS NOT NULL", tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + + assertThat(scanSchemas).hasSize(1); + assertThat(scanSchemas.get(0).apply("v").dataType()) + .as("IsNotNull-only queries fall back to full variant reads for now") + .isInstanceOf(VariantType.class); + assertThat(df.collectAsList().get(0).getLong(0)).isEqualTo(1L); + })); + } + + @TestTemplate + public void testVariantExtractionCorrectResults() { + sql( + "INSERT INTO %s VALUES (1, 'X', parse_json('{\"city\": \"Austin\"}')), " + + "(2, 'X', parse_json('{\"city\": \"Boston\"}')), " + + "(3, 'Y', null)", + tableName); + + // Execute without variant extraction pushdown so the Parquet reader returns full VariantVal. + withExtractionPushdown( + false, + () -> { + List results; + spark.conf().set(SQLConf.PUSH_VARIANT_INTO_SCAN().key(), "false"); + try { + results = + spark + .sql( + String.format( + "SELECT id, variant_get(v, '$.city', 'string') AS city FROM %s ORDER BY id", + tableName)) + .collectAsList(); + } finally { + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + } + + assertThat(results).hasSize(3); + assertThat(results.get(0).getString(1)).isEqualTo("Austin"); + assertThat(results.get(1).getString(1)).isEqualTo("Boston"); + assertThat(results.get(2).isNullAt(1)).isTrue(); + }); + } + + @TestTemplate + public void testVariantExtractionWithFilterCorrectResults() { + sql( + "INSERT INTO %s VALUES (1, 'X', parse_json('{\"size\": 10}')), " + + "(2, 'X', parse_json('{\"size\": 3}')), " + + "(3, 'X', parse_json('{\"size\": 7}'))", + tableName); + + // Execute without variant extraction pushdown. Rows with size > 5: id=1 (size=10) and id=3 + // (size=7). + withExtractionPushdown( + false, + () -> { + List results; + spark.conf().set(SQLConf.PUSH_VARIANT_INTO_SCAN().key(), "false"); + try { + results = + spark + .sql( + String.format( + "SELECT id FROM %s WHERE variant_get(v, '$.size', 'int') > 5 ORDER BY id", + tableName)) + .collectAsList(); + } finally { + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + } + + assertThat(results).hasSize(2); + assertThat(results.get(0).getInt(0)).isEqualTo(1); + assertThat(results.get(1).getInt(0)).isEqualTo(3); + }); + } + + @TestTemplate + public void testMultipleExtractionsFromSameColumnPlanShape() { + sql( + "INSERT INTO %s VALUES (1, 'A', parse_json('{\"city\": \"Austin\", \"zip\": \"78701\"}')), " + + "(2, 'B', parse_json('{\"city\": \"Boston\", \"zip\": \"02108\"}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT variant_get(v, '$.city', 'string') AS city, " + + "variant_get(v, '$.zip', 'string') AS zip FROM %s ORDER BY city", + tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + + assertThat(scanSchemas).hasSize(1); + StructType vStruct = (StructType) scanSchemas.get(0).apply("v").dataType(); + + assertThat(vStruct.fields()) + .as("struct should have two ordinal fields for two extractions") + .hasSize(2); + assertThat(vStruct.fields()[0].name()).isEqualTo("0"); + assertThat(vStruct.fields()[0].dataType()).isEqualTo(DataTypes.StringType); + assertThat(vStruct.fields()[1].name()).isEqualTo("1"); + assertThat(vStruct.fields()[1].dataType()).isEqualTo(DataTypes.StringType); + })); + } + + @TestTemplate + public void testAggregateWithSamePathTwoTypesIsCorrect() { + // c-q05 pattern: avg(size as double) and max(size as long) in the same query. + // Spark's PhysicalOperation does not collect variant_get expressions from inside aggregates, + // so the scan falls back to full variant evaluation — but the result must still be correct. + sql( + "INSERT INTO %s VALUES (1, 'PushEvent', parse_json('{\"size\": 10}')), " + + "(2, 'PushEvent', parse_json('{\"size\": 3}')), " + + "(3, 'IssueEvent', parse_json('{\"size\": 8}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Row row = + spark + .sql( + String.format( + "SELECT round(avg(variant_get(v, '$.size', 'double')), 6) AS avg_size, " + + "max(variant_get(v, '$.size', 'long')) AS max_size " + + "FROM %s WHERE type = 'PushEvent' " + + "AND variant_get(v, '$.size', 'long') IS NOT NULL", + tableName)) + .collectAsList() + .get(0); + assertThat(row.getDouble(0)).isEqualTo(6.5D); + assertThat(row.getLong(1)).isEqualTo(10L); + })); + } + + /** + * GHA query shape: table with multiple variant columns (actor, repo, payload), query touches only + * one column. Spark only sends extractions for {@code payload} — {@code actor} and {@code repo} + * are never in {@code variants.mapping} — so only {@code payload} gets rewritten to a struct. + * + *

Plan-shape only (no collect): correctness requires the Parquet extraction reader from the + * {@code variant-extraction-parquet-io} branch; verified on the integration branch. + */ + @TestTemplate + public void testMultipleVariantColumnsSingleColumnQueryPushesDown() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql( + "CREATE TABLE %s (id INT, type STRING, actor VARIANT, repo VARIANT, payload VARIANT)" + + " USING iceberg TBLPROPERTIES ('%s' = '3')", + tableName, TableProperties.FORMAT_VERSION); + sql( + "INSERT INTO %s VALUES " + + "(1, 'PushEvent', parse_json('{\"login\": \"alice\"}')," + + " parse_json('{\"name\": \"repo1\"}'), parse_json('{\"size\": 10}')), " + + "(2, 'PushEvent', parse_json('{\"login\": \"bob\"}')," + + " parse_json('{\"name\": \"repo2\"}'), parse_json('{\"size\": 3}')), " + + "(3, 'IssueEvent', parse_json('{\"login\": \"carol\"}')," + + " parse_json('{\"name\": \"repo3\"}'), parse_json('{\"size\": 8}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT count(*) AS n FROM %s WHERE type = 'PushEvent'" + + " AND variant_get(payload, '$.size', 'long') > 5", + tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + + assertThat(scanSchemas).hasSize(1); + StructType scanSchema = scanSchemas.get(0); + + assertVariantColumnIsAnnotatedStruct(scanSchema, "payload", 1); + assertThat( + ((StructType) scanSchema.apply("payload").dataType()) + .fields()[0].dataType()) + .as("payload extraction slot must be LongType") + .isEqualTo(DataTypes.LongType); + + // actor and repo are not referenced in this query — they must stay VariantType. + assertThat(variantField(scanSchema, "actor").dataType()) + .as("actor stays VariantType when not referenced") + .isEqualTo(VariantType$.MODULE$); + assertThat(variantField(scanSchema, "repo").dataType()) + .as("repo stays VariantType when not referenced") + .isEqualTo(VariantType$.MODULE$); + })); + } + + /** + * Two variant columns referenced in the same query (SELECT + WHERE). Spark sends extractions for + * both; the all-or-nothing policy accepts both because all paths are supported, and both columns + * are rewritten to annotated structs. + * + *

Plan-shape only (no collect): correctness requires the Parquet extraction reader from the + * {@code variant-extraction-parquet-io} branch; verified on the integration branch. + */ + @TestTemplate + public void testMultipleVariantColumnsTwoColumnQueryPushesDown() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql( + "CREATE TABLE %s (id INT, type STRING, actor VARIANT, repo VARIANT, payload VARIANT)" + + " USING iceberg TBLPROPERTIES ('%s' = '3')", + tableName, TableProperties.FORMAT_VERSION); + sql( + "INSERT INTO %s VALUES " + + "(1, 'PushEvent', parse_json('{\"login\": \"alice\"}')," + + " parse_json('{\"name\": \"repo1\"}'), parse_json('{\"size\": 10}')), " + + "(2, 'PushEvent', parse_json('{\"login\": \"bob\"}')," + + " parse_json('{\"name\": \"repo2\"}'), parse_json('{\"size\": 3}')), " + + "(3, 'IssueEvent', parse_json('{\"login\": \"carol\"}')," + + " parse_json('{\"name\": \"repo3\"}'), parse_json('{\"size\": 8}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + // Query references both repo and payload in the same Project/Filter visible to + // the scan — Spark sends both sets of extractions in one batch. + Dataset df = + spark.sql( + String.format( + "SELECT variant_get(repo, '$.name', 'string') AS rn," + + " variant_get(payload, '$.size', 'long') AS sz" + + " FROM %s WHERE type = 'PushEvent'", + tableName)); + + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + StructType scanSchema = scanSchemas.get(0); + + // Both referenced columns must be rewritten to annotated structs. + assertVariantColumnIsAnnotatedStruct(scanSchema, "repo", 1); + assertVariantColumnIsAnnotatedStruct(scanSchema, "payload", 1); + + // actor is not referenced — must stay VariantType. + assertThat(variantField(scanSchema, "actor").dataType()) + .as("actor stays VariantType when not referenced") + .isEqualTo(VariantType$.MODULE$); + })); + } + + @TestTemplate + public void testVariantExtractionWithPushdownCorrectResults() { + sql( + "INSERT INTO %s VALUES (1, 'X', parse_json('{\"city\": \"Austin\"}')), " + + "(2, 'X', parse_json('{\"city\": \"Boston\"}')), " + + "(3, 'Y', null)", + tableName); + + withExtractionPushdown( + true, + () -> { + // Plan-shape: v must be rewritten to an annotated struct. + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT id, variant_get(v, '$.city', 'string') AS city FROM %s ORDER BY id", + tableName)); + List scanSchemas = scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + assertVariantColumnIsAnnotatedStruct(scanSchemas.get(0), "v", 1); + }); + + // Correctness: disable Spark-level variant pushdown so collect() works on this branch + // without the Parquet extraction reader. End-to-end correctness with pushdown active is + // verified on the integration branch. + spark.conf().set(SQLConf.PUSH_VARIANT_INTO_SCAN().key(), "false"); + try { + List results = + spark + .sql( + String.format( + "SELECT id, variant_get(v, '$.city', 'string') AS city FROM %s ORDER BY id", + tableName)) + .collectAsList(); + assertThat(results).hasSize(3); + assertThat(results.get(0).getString(1)).isEqualTo("Austin"); + assertThat(results.get(1).getString(1)).isEqualTo("Boston"); + assertThat(results.get(2).isNullAt(1)).isTrue(); + } finally { + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + } + }); + } + + @TestTemplate + public void testArrayPathExtractionViaFieldLevelFallback() { + // Insert rows where the variant column contains a "commits" array. Shredding writes + // v.typed_value.commits with a serialized value column. The extraction pushdown reads + // that smaller blob and navigates [0].author.name in-memory rather than reading the full root. + sql( + "INSERT INTO %s VALUES " + + "(1, 'PushEvent', parse_json('{\"commits\": [{\"author\": {\"name\": \"Alice\"}}," + + " {\"author\": {\"name\": \"Bob\"}}]}')), " + + "(2, 'PushEvent', parse_json('{\"commits\": []}')), " + + "(3, 'IssueEvent', parse_json('{\"commits\": null}'))", + tableName); + + withExtractionPushdown( + true, + () -> { + // Plan-shape: v must be rewritten to an annotated struct (array-path extraction). + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT id, variant_get(v, '$.commits[0].author.name', 'string') AS first_author" + + " FROM %s ORDER BY id", + tableName)); + List scanSchemas = scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + assertVariantColumnIsAnnotatedStruct(scanSchemas.get(0), "v", 1); + }); + + // Correctness: disable Spark-level variant pushdown so collect() works on this branch + // without the Parquet extraction reader. + spark.conf().set(SQLConf.PUSH_VARIANT_INTO_SCAN().key(), "false"); + try { + List results = + spark + .sql( + String.format( + "SELECT id, variant_get(v, '$.commits[0].author.name', 'string') AS first_author" + + " FROM %s ORDER BY id", + tableName)) + .collectAsList(); + assertThat(results).hasSize(3); + assertThat(results.get(0).getString(1)).isEqualTo("Alice"); + assertThat(results.get(1).isNullAt(1)).isTrue(); + assertThat(results.get(2).isNullAt(1)).isTrue(); + } finally { + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + } + }); + } + + @TestTemplate + public void testSamePathDifferentTypesAcceptedInProjectPushdown() { + // When the same path appears with two different target types in the SELECT list (not + // inside an aggregate), both appear within PhysicalOperation and get pushed to the scan. + // The scan builder must accept both; the Parquet reader shares one physical column. + sql( + "INSERT INTO %s VALUES " + + "(1, 'PushEvent', parse_json('{\"size\": 10}')), " + + "(2, 'PushEvent', parse_json('{\"size\": 3}')), " + + "(3, 'IssueEvent', parse_json('{\"size\": 8}'))", + tableName); + + withExtractionPushdown( + true, + () -> { + // Plan-shape: both extractions must be accepted; v becomes a struct with 2 typed fields. + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT" + + " variant_get(v, '$.size', 'double') AS size_d," + + " variant_get(v, '$.size', 'long') AS size_l" + + " FROM %s" + + " WHERE variant_get(v, '$.size', 'long') IS NOT NULL", + tableName)); + + List scanSchemas = scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + StructType vStruct = (StructType) scanSchemas.get(0).apply("v").dataType(); + assertThat(vStruct.fields()) + .as("same path with two types must produce two struct slots") + .hasSize(2); + assertThat(vStruct.fields()[0].dataType()).isEqualTo(DataTypes.DoubleType); + assertThat(vStruct.fields()[1].dataType()).isEqualTo(DataTypes.LongType); + }); + + // Correctness: disable Spark-level variant pushdown so collect() works on this branch + // without the Parquet extraction reader. + spark.conf().set(SQLConf.PUSH_VARIANT_INTO_SCAN().key(), "false"); + try { + List results = + spark + .sql( + String.format( + "SELECT" + + " variant_get(v, '$.size', 'double') AS size_d," + + " variant_get(v, '$.size', 'long') AS size_l" + + " FROM %s" + + " WHERE variant_get(v, '$.size', 'long') IS NOT NULL" + + " ORDER BY size_l", + tableName)) + .collectAsList(); + assertThat(results).hasSize(3); + // All three rows, ordered by size_l: 3, 8, 10 + assertThat(results.get(0).getDouble(0)).isEqualTo(3.0D); + assertThat(results.get(0).getLong(1)).isEqualTo(3L); + assertThat(results.get(2).getDouble(0)).isEqualTo(10.0D); + assertThat(results.get(2).getLong(1)).isEqualTo(10L); + } finally { + spark.conf().unset(SQLConf.PUSH_VARIANT_INTO_SCAN().key()); + } + }); + } + + @TestTemplate + public void testDeclineUnsupportedStructTarget() { + sql( + "INSERT INTO %s VALUES (1, 'X', parse_json('{\"size\": 10}')), " + + "(2, 'X', parse_json('{\"size\": 3}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT variant_get(v, '$', 'struct') AS s FROM %s", + tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + assertThat(variantField(scanSchemas.get(0), "v").dataType()) + .isInstanceOf(VariantType.class); + })); + } + + @TestTemplate + public void testDeclineUnsupportedArrayTarget() { + sql("INSERT INTO %s VALUES (1, 'X', parse_json('{\"tags\": [\"a\", \"b\"]}'))", tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT variant_get(v, '$.tags', 'array') AS tags FROM %s", + tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + assertThat(variantField(scanSchemas.get(0), "v").dataType()) + .isInstanceOf(VariantType.class); + })); + } + + @TestTemplate + public void testDateExtractionPlanShape() { + sql( + "INSERT INTO %s VALUES (1, 'X', parse_json('{\"created\": \"1970-03-01\"}')), " + + "(2, 'X', parse_json('{\"created\": \"1970-01-01\"}'))", + tableName); + + withExtractionPushdown( + true, + () -> + withVariantPushIntoScan( + () -> { + Dataset df = + spark.sql( + String.format( + "SELECT variant_get(v, '$.created', 'date') AS created FROM %s" + + " WHERE variant_get(v, '$.created', 'date') IS NOT NULL", + tableName)); + List scanSchemas = + scanReadSchemas(df.queryExecution().optimizedPlan()); + assertThat(scanSchemas).hasSize(1); + assertVariantColumnIsAnnotatedStruct(scanSchemas.get(0), "v", 1); + })); + } + + @TestTemplate + public void testFallbackForNonVariantColumn() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql( + "CREATE TABLE %s (id INT, x INT) USING iceberg TBLPROPERTIES ('%s' = '3')", + tableName, TableProperties.FORMAT_VERSION); + sql("INSERT INTO %s VALUES (1, 42), (2, 99)", tableName); + + List results = + spark.sql(String.format("SELECT id, x FROM %s ORDER BY id", tableName)).collectAsList(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getInt(1)).isEqualTo(42); + assertThat(results.get(1).getInt(1)).isEqualTo(99); + } +}