Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,19 @@ public O visit(Expression.IfThen expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a Lambda expression.
*
* @param expr the Lambda expression
* @param context the visitation context
* @return the visit result
* @throws E if visitation fails
*/
@Override
public O visit(Expression.Lambda expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a scalar function invocation.
*
Expand Down
25 changes: 25 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,31 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class Lambda implements Expression {
public abstract Type.Struct parameters();

public abstract Expression body();

@Override
public Type getType() {
List<Type> paramTypes = parameters().fields();
Type returnType = body().getType();

return Type.withNullability(false).func(paramTypes, returnType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is setting nullability false correct here? I am pretty sure func can be nullable, e.g. func?<i32 -> i32>.

Also, probably will be good to make a test captures this then.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the substrait-go implementation: the GetType function for lambda returns Nullability: types.NullabilityRequired

func (l *Lambda) GetType() types.Type {
	return &types.FuncType{
		Nullability:    types.NullabilityRequired,
		ParameterTypes: l.Parameters.Types,
		ReturnType:     l.Body.GetType(),
	}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I am thinking about this right now but that may have been wrong in substrait-go 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a spec error. I raised an issue in upstream here.

Mind leaving a comment there referencing this issue as a TODO so we remember to fix it once it gets resolved?

}

public static ImmutableExpression.Lambda.Builder builder() {
return ImmutableExpression.Lambda.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* Base interface for user-defined literals.
*
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
*/
R visit(Expression.NestedStruct expr, C context) throws E;

/**
* Visit a Lambda expression.
*
* @param expr the Lambda expression
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.Lambda expr, C context) throws E;

/**
* Visit a user-defined any literal.
*
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/java/io/substrait/expression/FieldReference.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public abstract class FieldReference implements Expression {

public abstract Optional<Integer> outerReferenceStepsOut();

public abstract Optional<Integer> lambdaParameterReferenceStepsOut();

@Override
public Type getType() {
return type();
Expand All @@ -38,13 +40,18 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
public boolean isSimpleRootReference() {
return segments().size() == 1
&& !inputExpression().isPresent()
&& !outerReferenceStepsOut().isPresent();
&& !outerReferenceStepsOut().isPresent()
&& !lambdaParameterReferenceStepsOut().isPresent();
}

public boolean isOuterReference() {
return outerReferenceStepsOut().orElse(0) > 0;
}

public boolean isLambdaParameterReference() {
return lambdaParameterReferenceStepsOut().isPresent();
}

public FieldReference dereferenceStruct(int index) {
Type newType = StructFieldFinder.getReferencedType(type(), index);
return dereference(newType, StructField.of(index));
Expand Down Expand Up @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List<Rel> rels) {
index, currentOffset));
}

public static FieldReference newLambdaParameterReference(
int paramIndex, Type.Struct lambdaParamsType, int stepsOut) {
return ImmutableFieldReference.builder()
.addSegments(StructField.of(paramIndex))
.type(lambdaParamsType.fields().get(paramIndex))
.lambdaParameterReferenceStepsOut(stepsOut)
.build();
}

public interface ReferenceSegment {
FieldReference apply(FieldReference reference);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,18 @@ public Expression visit(
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context)
throws RuntimeException {
return io.substrait.proto.Expression.newBuilder()
.setLambda(
io.substrait.proto.Expression.Lambda.newBuilder()
.setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct())
.setBody(expr.body().accept(this, context)))
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
Expand Down Expand Up @@ -603,6 +615,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) {
out.setOuterReference(
io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder()
.setStepsOut(expr.outerReferenceStepsOut().get()));
} else if (expr.lambdaParameterReferenceStepsOut().isPresent()) {
out.setLambdaParameterReference(
io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder()
.setStepsOut(expr.lambdaParameterReferenceStepsOut().get()));
} else {
out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class ProtoExpressionConverter {
private final Type.Struct rootType;
private final ProtoTypeConverter protoTypeConverter;
private final ProtoRelConverter protoRelConverter;
private final List<Type.Struct> lambdaParameterStack = new ArrayList<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a stack here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to access lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut here: lambdaParameters = lambdaParameterStack.get(lambdaIndex); when building the Lambda parametre ref that's why I didn't use a stack

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. Java's stack does extend Vector though so you still get the get method. I'm not so particular on this one.

The only thing I will say is I do think that the whole stack parameter stuff can be confusing for people without a little bit of PL experience, which is why I pushed @Slimsammylim to encapsulate some of that logic and put it in docstring comments. I could see a case for encapsulating this logic in a local private class, though it probably makes more sense to do this only if this logic comes up again.

I expected to see something like this logic elsewhere, as you need to do it to do validation when building lambdas. Though I didn't find anything like this. Does the builder for lambdas in this PR actually validate that the lambda is semantically correct?


public ProtoExpressionConverter(
ExtensionLookup lookup,
Expand Down Expand Up @@ -75,6 +76,26 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
reference.getDirectReference().getStructField().getField(),
rootType,
reference.getOuterReference().getStepsOut());
case LAMBDA_PARAMETER_REFERENCE:
{
io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef =
reference.getLambdaParameterReference();

int stepsOut = lambdaParamRef.getStepsOut();
int lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut;
if (lambdaIndex < 0 || lambdaIndex >= lambdaParameterStack.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, lambdaParameterStack.size()));
}

Type.Struct lambdaParameters = lambdaParameterStack.get(lambdaIndex);
return FieldReference.newLambdaParameterReference(
reference.getDirectReference().getStructField().getField(),
lambdaParameters,
stepsOut);
}
case ROOTTYPE_NOT_SET:
default:
throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase());
Expand Down Expand Up @@ -260,6 +281,27 @@ public Type visit(Type.Struct type) throws RuntimeException {
}
}

case LAMBDA:
{
io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda();
Type.Struct parameters =
(Type.Struct)
protoTypeConverter.from(
io.substrait.proto.Type.newBuilder()
.setStruct(protoLambda.getParameters())
.build());

lambdaParameterStack.add(parameters);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we can use the standard stack push here


Expression body;
try {
body = from(protoLambda.getBody());
} finally {
lambdaParameterStack.remove(lambdaParameterStack.size() - 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the standard stack pop here

}

return Expression.Lambda.builder().parameters(parameters).body(body).build();
}
// TODO enum.
case ENUM:
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog {
/** Extension identifier for set functions. */
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";

/** Extension identifier for list functions. */
public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list";

/** Extension identifier for string functions. */
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator<T, I> {
T listE(T type);

T mapE(T key, T value);

T funcE(Iterable<? extends T> parameterTypes, T returnType);
}
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/function/ParameterizedType.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ <R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameter
}
}

@Value.Immutable
abstract class Func extends BaseParameterizedType implements NullableType {
public abstract java.util.List<ParameterizedType> parameterTypes();

public abstract ParameterizedType returnType();

public static ImmutableParameterizedType.Func.Builder builder() {
return ImmutableParameterizedType.Func.builder();
}

@Override
<R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameterizedTypeVisitor)
throws E {
return parameterizedTypeVisitor.visit(this);
}
}

@Value.Immutable
abstract class ListType extends BaseParameterizedType implements NullableType {
public abstract ParameterizedType name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ public ParameterizedType listE(ParameterizedType type) {
return ParameterizedType.ListType.builder().nullable(nullable).name(type).build();
}

@Override
public ParameterizedType funcE(
Iterable<? extends ParameterizedType> parameterTypes, ParameterizedType returnType) {
return ParameterizedType.Func.builder()
.nullable(nullable)
.addAllParameterTypes(parameterTypes)
.returnType(returnType)
.build();
}

@Override
public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) {
return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor<R, E extends Throwable> extends TypeVi

R visit(ParameterizedType.StringLiteral stringLiteral) throws E;

R visit(ParameterizedType.Func expr) throws E;

abstract class ParameterizedTypeThrowsVisitor<R, E extends Throwable>
extends TypeVisitor.TypeThrowsVisitor<R, E> implements ParameterizedTypeVisitor<R, E> {

Expand Down Expand Up @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E {
public R visit(ParameterizedType.StringLiteral stringLiteral) throws E {
throw t();
}

@Override
public R visit(ParameterizedType.Func expr) throws E {
throw t();
}
}
}
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/function/TypeExpression.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ <R, E extends Throwable> R acceptE(final TypeExpressionVisitor<R, E> visitor) th
}
}

@Value.Immutable
abstract class Func extends BaseTypeExpression implements NullableType {
public abstract java.util.List<TypeExpression> parameterTypes();

public abstract TypeExpression returnType();

public static ImmutableTypeExpression.Func.Builder builder() {
return ImmutableTypeExpression.Func.builder();
}

@Override
<R, E extends Throwable> R acceptE(final TypeExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract class BinaryOperation extends BaseTypeExpression {
public enum OpType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) {
return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build();
}

@Override
public TypeExpression funcE(
Iterable<? extends TypeExpression> parameterTypes, TypeExpression returnType) {
return TypeExpression.Func.builder()
.nullable(nullable)
.addAllParameterTypes(parameterTypes)
.returnType(returnType)
.build();
}

public static class Assign {
String name;
TypeExpression expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public interface TypeExpressionVisitor<R, E extends Throwable>

R visit(TypeExpression.Map expr) throws E;

R visit(TypeExpression.Func expr) throws E;

R visit(TypeExpression.BinaryOperation expr) throws E;

R visit(TypeExpression.NotOperation expr) throws E;
Expand Down Expand Up @@ -97,6 +99,11 @@ public R visit(TypeExpression.Map expr) throws E {
throw t();
}

@Override
public R visit(TypeExpression.Func expr) throws E {
throw t();
}

@Override
public R visit(TypeExpression.BinaryOperation expr) throws E {
throw t();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,18 @@ public Optional<Expression> visit(
.build());
}

@Override
public Optional<Expression> visit(Expression.Lambda lambda, EmptyVisitationContext context)
throws E {
Optional<Expression> newBody = lambda.body().accept(this, context);

if (allEmpty(newBody)) {
return Optional.empty();
}
return Optional.of(
Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build());
}

// utilities

protected Optional<List<Expression>> visitExprList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,11 @@ public Integer visit(Type.Map type) throws RuntimeException {
public Integer visit(Type.UserDefined type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Func type) throws RuntimeException {
return type.parameterTypes().stream().mapToInt(p -> p.accept(this)).sum()
+ type.returnType().accept(this);
}
}
}
9 changes: 9 additions & 0 deletions core/src/main/java/io/substrait/type/StringTypeVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,13 @@ public String visit(Type.Map type) throws RuntimeException {
public String visit(Type.UserDefined type) throws RuntimeException {
return String.format("u!%s%s", type.name(), n(type));
}

@Override
public String visit(Type.Func type) throws RuntimeException {
return String.format(
"func%s<%s -> %s>",
n(type),
type.parameterTypes().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")),
type.returnType().accept(this));
}
}
Loading