diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index e6ebb5782..9b1fdc496 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -448,6 +448,19 @@ public O visit(Expression.IfThen expr, C context) throws E { return visitFallback(expr, context); } + /** + * Visits a Lambda expression. + * + * @param expr the Lambda expression + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.Lambda expr, C context) throws E { + return visitFallback(expr, context); + } + /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 6361af3f5..320a8a614 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -707,6 +707,31 @@ public R accept( } } + @Value.Immutable + abstract class Lambda implements Expression { + public abstract Type.Struct parameters(); + + public abstract Expression body(); + + @Override + public Type getType() { + List paramTypes = parameters().fields(); + Type returnType = body().getType(); + + return Type.withNullability(false).func(paramTypes, returnType); + } + + public static ImmutableExpression.Lambda.Builder builder() { + return ImmutableExpression.Lambda.builder(); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + /** * Base interface for user-defined literals. * diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 7f094b688..a505af778 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -311,6 +311,16 @@ public interface ExpressionVisitor outerReferenceStepsOut(); + public abstract Optional lambdaParameterReferenceStepsOut(); + @Override public Type getType() { return type(); @@ -38,13 +40,18 @@ public R accept( public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() - && !outerReferenceStepsOut().isPresent(); + && !outerReferenceStepsOut().isPresent() + && !lambdaParameterReferenceStepsOut().isPresent(); } public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } + public boolean isLambdaParameterReference() { + return lambdaParameterReferenceStepsOut().isPresent(); + } + public FieldReference dereferenceStruct(int index) { Type newType = StructFieldFinder.getReferencedType(type(), index); return dereference(newType, StructField.of(index)); @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } + public static FieldReference newLambdaParameterReference( + int paramIndex, Type.Struct lambdaParamsType, int stepsOut) { + return ImmutableFieldReference.builder() + .addSegments(StructField.of(paramIndex)) + .type(lambdaParamsType.fields().get(paramIndex)) + .lambdaParameterReferenceStepsOut(stepsOut) + .build(); + } + public interface ReferenceSegment { FieldReference apply(FieldReference reference); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 498e6eada..620854139 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -373,6 +373,18 @@ public Expression visit( }); } + @Override + public Expression visit( + io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambda( + io.substrait.proto.Expression.Lambda.newBuilder() + .setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct()) + .setBody(expr.body().accept(this, context))) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.UserDefinedAnyLiteral expr, @@ -603,6 +615,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { out.setOuterReference( io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() .setStepsOut(expr.outerReferenceStepsOut().get())); + } else if (expr.lambdaParameterReferenceStepsOut().isPresent()) { + out.setLambdaParameterReference( + io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder() + .setStepsOut(expr.lambdaParameterReferenceStepsOut().get())); } else { out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 01e25c907..3780a8e1b 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -37,6 +37,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; + private final List lambdaParameterStack = new ArrayList<>(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -75,6 +76,26 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getDirectReference().getStructField().getField(), rootType, reference.getOuterReference().getStepsOut()); + case LAMBDA_PARAMETER_REFERENCE: + { + io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = + reference.getLambdaParameterReference(); + + int stepsOut = lambdaParamRef.getStepsOut(); + int lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut; + if (lambdaIndex < 0 || lambdaIndex >= lambdaParameterStack.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaParameterStack.size())); + } + + Type.Struct lambdaParameters = lambdaParameterStack.get(lambdaIndex); + return FieldReference.newLambdaParameterReference( + reference.getDirectReference().getStructField().getField(), + lambdaParameters, + stepsOut); + } case ROOTTYPE_NOT_SET: default: throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); @@ -260,6 +281,27 @@ public Type visit(Type.Struct type) throws RuntimeException { } } + case LAMBDA: + { + io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); + Type.Struct parameters = + (Type.Struct) + protoTypeConverter.from( + io.substrait.proto.Type.newBuilder() + .setStruct(protoLambda.getParameters()) + .build()); + + lambdaParameterStack.add(parameters); + + Expression body; + try { + body = from(protoLambda.getBody()); + } finally { + lambdaParameterStack.remove(lambdaParameterStack.size() - 1); + } + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); + } // TODO enum. case ENUM: throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 7b316d4be..478ed63e2 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog { /** Extension identifier for set functions. */ public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + /** Extension identifier for list functions. */ + public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list"; + /** Extension identifier for string functions. */ public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; diff --git a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java index c3360093d..b802d93e0 100644 --- a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator { T listE(T type); T mapE(T key, T value); + + T funcE(Iterable parameterTypes, T returnType); } diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index e514fb975..f35c21f7a 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -200,6 +200,23 @@ R accept(final ParameterizedTypeVisitor parameter } } + @Value.Immutable + abstract class Func extends BaseParameterizedType implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract ParameterizedType returnType(); + + public static ImmutableParameterizedType.Func.Builder builder() { + return ImmutableParameterizedType.Func.builder(); + } + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + } + @Value.Immutable abstract class ListType extends BaseParameterizedType implements NullableType { public abstract ParameterizedType name(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 6b89840f6..af7bda7e1 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -96,6 +96,16 @@ public ParameterizedType listE(ParameterizedType type) { return ParameterizedType.ListType.builder().nullable(nullable).name(type).build(); } + @Override + public ParameterizedType funcE( + Iterable parameterTypes, ParameterizedType returnType) { + return ParameterizedType.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + @Override public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) { return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 9ff42f549..755c99777 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor extends TypeVi R visit(ParameterizedType.StringLiteral stringLiteral) throws E; + R visit(ParameterizedType.Func expr) throws E; + abstract class ParameterizedTypeThrowsVisitor extends TypeVisitor.TypeThrowsVisitor implements ParameterizedTypeVisitor { @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E { public R visit(ParameterizedType.StringLiteral stringLiteral) throws E { throw t(); } + + @Override + public R visit(ParameterizedType.Func expr) throws E { + throw t(); + } } } diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index a183c1959..345fc0398 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -191,6 +191,22 @@ R acceptE(final TypeExpressionVisitor visitor) th } } + @Value.Immutable + abstract class Func extends BaseTypeExpression implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract TypeExpression returnType(); + + public static ImmutableTypeExpression.Func.Builder builder() { + return ImmutableTypeExpression.Func.builder(); + } + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract class BinaryOperation extends BaseTypeExpression { public enum OpType { diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index b7524911b..9d822ffed 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -82,6 +82,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) { return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build(); } + @Override + public TypeExpression funcE( + Iterable parameterTypes, TypeExpression returnType) { + return TypeExpression.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public static class Assign { String name; TypeExpression expr; diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 31d632c71..44d871337 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -24,6 +24,8 @@ public interface TypeExpressionVisitor R visit(TypeExpression.Map expr) throws E; + R visit(TypeExpression.Func expr) throws E; + R visit(TypeExpression.BinaryOperation expr) throws E; R visit(TypeExpression.NotOperation expr) throws E; @@ -97,6 +99,11 @@ public R visit(TypeExpression.Map expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.Func expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.BinaryOperation expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index cdb72aea8..292313282 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -432,6 +432,18 @@ public Optional visit( .build()); } + @Override + public Optional visit(Expression.Lambda lambda, EmptyVisitationContext context) + throws E { + Optional newBody = lambda.body().accept(this, context); + + if (allEmpty(newBody)) { + return Optional.empty(); + } + return Optional.of( + Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + } + // utilities protected Optional> visitExprList( diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index e71b0b00c..de7c29dbb 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -242,5 +242,11 @@ public Integer visit(Type.Map type) throws RuntimeException { public Integer visit(Type.UserDefined type) throws RuntimeException { return 0; } + + @Override + public Integer visit(Type.Func type) throws RuntimeException { + return type.parameterTypes().stream().mapToInt(p -> p.accept(this)).sum() + + type.returnType().accept(this); + } } } diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index d7c196148..e9d711d96 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -150,4 +150,13 @@ public String visit(Type.Map type) throws RuntimeException { public String visit(Type.UserDefined type) throws RuntimeException { return String.format("u!%s%s", type.name(), n(type)); } + + @Override + public String visit(Type.Func type) throws RuntimeException { + return String.format( + "func%s<%s -> %s>", + n(type), + type.parameterTypes().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")), + type.returnType().accept(this)); + } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 5a1594e59..685cbe25a 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -352,6 +352,22 @@ public R accept(final TypeVisitor typeVisitor) th } } + @Value.Immutable + abstract class Func implements Type { + public abstract java.util.List parameterTypes(); + + public abstract Type returnType(); + + public static ImmutableType.Func.Builder builder() { + return ImmutableType.Func.builder(); + } + + @Override + public R accept(TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + @Value.Immutable abstract class Struct implements Type { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 43358e505..7e4b1eec4 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -89,6 +89,14 @@ public final Type intervalCompound(int precision) { return Type.IntervalCompound.builder().nullable(nullable).precision(precision).build(); } + public Type.Func func(java.util.List parameterTypes, Type returnType) { + return Type.Func.builder() + .nullable(nullable) + .parameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public Type.Struct struct(Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } @@ -258,6 +266,11 @@ public Type visit(Type.Struct type) throws RuntimeException { return Type.Struct.builder().from(type).nullable(nullability).build(); } + @Override + public Type visit(Type.Func type) throws RuntimeException { + return Type.Func.builder().from(type).nullable(nullability).build(); + } + @Override public Type visit(Type.ListType type) throws RuntimeException { return Type.ListType.builder().from(type).nullable(nullability).build(); diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index ce6a08910..62e760175 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -51,6 +51,8 @@ public interface TypeVisitor { R visit(Type.Decimal type) throws E; + R visit(Type.Func type) throws E; + R visit(Type.Struct type) throws E; R visit(Type.ListType type) throws E; @@ -191,6 +193,11 @@ public R visit(Type.PrecisionTimestampTZ type) throws E { throw t(); } + @Override + public R visit(Type.Func type) throws E { + throw t(); + } + @Override public R visit(Type.Struct type) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 67d7bc9b5..3909d4033 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -142,6 +142,16 @@ public final T visit(final Type.PrecisionTimestampTZ expr) { return typeContainer(expr).precisionTimestampTZ(expr.precision()); } + @Override + public final T visit(final Type.Func expr) { + return typeContainer(expr) + .func( + expr.parameterTypes().stream() + .map(t -> t.accept(this)) + .collect(java.util.stream.Collectors.toList()), + expr.returnType().accept(this)); + } + @Override public final T visit(final Type.Struct expr) { return typeContainer(expr) diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 57b1f26b5..47842382f 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -119,6 +119,8 @@ public final T precisionTimestampTZ(int precision) { public abstract T intervalCompound(I precision); + public abstract T func(Iterable parameterTypes, T returnType); + public final T struct(T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index bdb600c1c..24231aebc 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -77,6 +77,13 @@ public Type from(io.substrait.proto.Type type) { case PRECISION_TIMESTAMP_TZ: return n(type.getPrecisionTimestampTz().getNullability()) .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case FUNC: + return n(type.getFunc().getNullability()) + .func( + type.getFunc().getParameterTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList()), + from(type.getFunc().getReturnType())); case STRUCT: return n(type.getStruct().getNullability()) .struct( diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 6422904c4..c0e785db4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -154,6 +154,16 @@ public Type precisionTimestampTZ(Integer precision) { .build()); } + @Override + public Type func(Iterable parameterTypes, Type returnType) { + return wrap( + Type.Func.newBuilder() + .addAllParameterTypes(parameterTypes) + .setReturnType(returnType) + .setNullability(nullability) + .build()); + } + @Override public Type struct(Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); @@ -237,6 +247,8 @@ protected Type wrap(final Object o) { return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); } else if (o instanceof Type.PrecisionTimestampTZ) { return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Func) { + return bldr.setFunc((Type.Func) o).build(); } else if (o instanceof Type.Struct) { return bldr.setStruct((Type.Struct) o).build(); } else if (o instanceof Type.List) { diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java new file mode 100644 index 000000000..c0a979a94 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -0,0 +1,361 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.type.Type; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests + * from substrait-go. + */ +class LambdaExpressionRoundtripTest extends TestBase { + + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + @Test + void zeroParameterLambda() { + Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(emptyParams).body(body).build(); + + verifyRoundTrip(lambda); + + // Verify the lambda type + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(0, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.returnType()); + } + + /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ + @Test + void validStepsOut0() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Lambda body references parameter 0 with stepsOut=0 + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + verifyRoundTrip(lambda); + + // Verify types + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(1, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.I32, funcType.returnType()); + } + + /** + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + // Reference the 3rd parameter (string) + FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + verifyRoundTrip(lambda); + + // Verify return type is string + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.STRING, funcType.returnType()); + } + + /** Test type resolution for different parameter types. */ + @Test + void typeResolution() { + // Test cases: (paramTypes, fieldIndex, expectedReturnType) + record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + + List testCases = + List.of( + new TestCase(List.of(R.I32), 0, R.I32), + new TestCase(List.of(R.I32, R.I64), 1, R.I64), + new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase(List.of(R.FP64), 0, R.FP64), + new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); + + for (TestCase tc : testCases) { + Type.Struct params = + Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); + + FieldReference paramRef = + FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + verifyRoundTrip(lambda); + + // Verify the body type matches expected + assertEquals( + tc.expectedType, + lambda.body().getType(), + "Body type should match referenced parameter type"); + + // Verify lambda return type + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals( + tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); + } + } + + /** + * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> + * outer[$0] : i64) : func<(i64, i64) -> func i64>> + */ + @Test + void nestedLambdaWithOuterRef() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references outer's parameter 0 with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + verifyRoundTrip(outerLambda); + + // Verify structure + assertInstanceOf(Expression.Lambda.class, outerLambda.body()); + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + assertEquals(1, resultInner.parameters().fields().size()); + } + + /** + * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: + * string) -> (($0: fp64) -> outer[$2] : string) : func<...> + */ + @Test + void outerRefTypeResolution() { + Type.Struct outerParams = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); + + // Inner references outer's field 2 (string) with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + verifyRoundTrip(outerLambda); + + // Verify inner lambda's return type is string (from outer param 2) + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + Type.Func innerFuncType = (Type.Func) resultInner.getType(); + assertEquals( + R.STRING, + innerFuncType.returnType(), + "Inner lambda return type should be string from outer.$2"); + + // Verify body's type is also string + assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); + } + + /** + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); + + verifyRoundTrip(lambda); + + // Verify the nested FieldRef has its type resolved + Expression.Cast resultCast = (Expression.Cast) lambda.body(); + assertInstanceOf(FieldReference.class, resultCast.input()); + FieldReference resultFieldRef = (FieldReference) resultCast.input(); + + assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + + // Verify lambda return type is i64 (cast output) + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.I64, funcType.returnType()); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); + + verifyRoundTrip(lambda); + + // Navigate to the deeply nested FieldRef (2 levels deep) + Expression.Cast resultOuter = (Expression.Cast) lambda.body(); + Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); + FieldReference resultFieldRef = (FieldReference) resultInner.input(); + + // Verify type is resolved even at depth 2 + assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType()); + } + + /** + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + + verifyRoundTrip(lambda); + } + + /** Test lambda getType returns correct Func type. */ + @Test + void lambdaGetTypeReturnsFunc() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + Type lambdaType = lambda.getType(); + + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + } + + // ==================== Validation Error Tests ==================== + + /** + * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: + * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) + */ + @Test + void invalidOuterRef_stepsOutTooHigh() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Create a parameter reference with stepsOut=1 but no outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(invalidRef).build(); + + // Convert to proto - this should work + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); + + // Converting back should fail because stepsOut=1 references non-existent outer lambda + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent outer lambda"); + } + + /** + * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: + * i32) -> $5 : INVALID (only has 1 param) + */ + @Test + void invalidFieldIndex_outOfBounds() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Create a reference to field 5, but lambda only has 1 parameter (index 0) + // This will fail at build time since newLambdaParameterReference accesses fields.get(5) + assertThrows( + IndexOutOfBoundsException.class, + () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, + "Should fail when field index is out of bounds"); + } + + /** + * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> + * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) + */ + @Test + void nestedInvalidOuterRef() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references stepsOut=2, but only 1 outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + // Convert to proto + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); + + // Converting back should fail because stepsOut=2 references non-existent grandparent + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent grandparent lambda"); + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index a26ec963e..93f8a9834 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -195,6 +195,12 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context return ""; } + @Override + public String visit(Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 0e13c3e2e..f5f8d93e9 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -149,6 +149,11 @@ public String visit(Decimal type) throws RuntimeException { return type.getClass().getSimpleName(); } + @Override + public String visit(Type.Func type) throws RuntimeException { + return type.getClass().getSimpleName(); + } + @Override public String visit(Struct type) throws RuntimeException { StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index e06f66e8b..a118beec0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -41,6 +41,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSubQuery; @@ -381,6 +382,23 @@ public RexNode visit(Expression.IfThen expr, Context context) throws RuntimeExce return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } + @Override + public RexNode visit(Expression.Lambda expr, Context context) throws RuntimeException { + List parameters = + IntStream.range(0, expr.parameters().fields().size()) + .mapToObj( + i -> + new RexLambdaRef( + i, + "p" + i, + typeConverter.toCalcite(typeFactory, expr.parameters().fields().get(i)))) + .collect(Collectors.toList()); + + RexNode body = expr.body().accept(this, context); + + return rexBuilder.makeLambdaCall(body, parameters); + } + @Override public RexNode visit(Switch expr, Context context) throws RuntimeException { RexNode match = expr.match().accept(this, context); @@ -655,6 +673,23 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti } return rexInputRef; + } else if (expr.isLambdaParameterReference()) { + // as of now calcite doesn't support nested lambda functions + // https://github.com/substrait-io/substrait-java/issues/711 + int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); + if (stepsOut != 0) { + throw new UnsupportedOperationException( + "Calcite does not support nested lambdas (stepsOut=" + stepsOut + ")"); + } + + final ReferenceSegment segment = expr.segments().get(0); + if (segment instanceof FieldReference.StructField) { + final FieldReference.StructField field = (FieldReference.StructField) segment; + RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + return new RexLambdaRef(field.offset(), "p" + field.offset(), calciteType); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } } return visitFallback(expr, context); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index f8b4be1dd..b56fa4fd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -128,6 +128,11 @@ public Boolean visit(Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(Type.Func type) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } + @Override public Boolean visit(Type.PrecisionTime type) { return typeToMatch instanceof Type.PrecisionTime @@ -234,4 +239,9 @@ public Boolean visit(ParameterizedType.Map expr) throws RuntimeException { public Boolean visit(ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { return false; } + + @Override + public Boolean visit(ParameterizedType.Func expr) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 6993c8451..176b246e7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -7,6 +7,7 @@ import io.substrait.isthmus.TypeConverter; import io.substrait.relation.Rel; import io.substrait.type.StringTypeVisitor; +import io.substrait.type.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -202,12 +203,30 @@ public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { @Override public Expression visitLambda(RexLambda rexLambda) { - throw new UnsupportedOperationException("RexLambda not supported"); + List paramTypes = + rexLambda.getParameters().stream() + .map(param -> typeConverter.toSubstrait(param.getType())) + .collect(Collectors.toList()); + + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + + Expression body = rexLambda.getExpression().accept(this); + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); } @Override public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { - throw new UnsupportedOperationException("RexLambdaRef not supported"); + int fieldIndex = rexLambdaRef.getIndex(); + Type paramType = typeConverter.toSubstrait(rexLambdaRef.getType()); + + return FieldReference.builder() + .addSegments(FieldReference.StructField.of(fieldIndex)) + .type(paramType) + .lambdaParameterReferenceStepsOut( + 0) // Always 0 since Calcite doesn't support nested Lambda expressions for now + // https://github.com/substrait-io/substrait-java/issues/711 + .build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java new file mode 100644 index 000000000..add746ee1 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -0,0 +1,149 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not + * support nested lambda expressions for the moment, so all tests use stepsOut=0. + */ +class LambdaExpressionTest extends PlanTestBase { + + final Rel emptyTable = sb.emptyVirtualTableScan(); + + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + @Test + void lambdaExpressionZeroParameters() { + Type.Struct params = Type.Struct.builder().nullable(false).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not + * support nested lambda expressions. + */ + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references outer's parameter with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + List expressionList = new ArrayList<>(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + expressionList.add(outerLambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala index 9cb38b8d0..acf006215 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -157,4 +157,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) @throws[RuntimeException] override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(`type`: Type.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] } diff --git a/substrait b/substrait index a9b90657d..92d2e757a 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit a9b90657db1e51bba69fcbc4a8c1edac3b661975 +Subproject commit 92d2e757a330f9c973bb566817dc92afd1badcb2