Skip to content

Add a generic deserializer for Java/Scala 2.12 lambdas #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package scala.compat.java8.runtime

import java.lang.invoke._

/**
* This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12
* compiler will add to classes hosting lambdas.
*
* It is not intended to be consumed directly.
*/
object LambdaDeserializer {
/**
* Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class
* and instantiating this class with the captured arguments.
*
* A cache may be provided to ensure that subsequent deserialization of the same lambda expression
* is cheap, it amounts to a reflective call to the constructor of the previously created class.
* However, deserialization of the same lambda expression is not guaranteed to use the same class,
* concurrent deserialization of the same lambda expression may spin up more than one class.
*
* Assumptions:
* - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are
* not stored in `SerializedLambda`, so we can't reconstitute them.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I still don't quite understand: why do we need to add the Serializable marker interface if the functional interface extends Serializable anyway? The created lambda object has a class type that implements the functional interface type, so it would have the right Serializable already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scala.FunctionN traits don't extends Serializable, but the subclasses synthesized for anonymous functions do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's new in the scala/scala PR, currently it just uses metafactory and JFunctionN, right?

* - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming: for normal lambda creation (not deserialization), we use the ordinary metafactory (not altMetafactory) at the indy-callsite, which never creates additional bridge methods in the lambda object?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The companion PR in scala/scala switched to using altMetafactory unconditionally. We don't explicitly ask for additional bridges. The necessary bridges are created without asking for them.

This comment was pointing out the caveat that we can't faithfully unknown lambdas. Our heuristic is to use special knowledge of Scala's use of LMF if the functional interface is scala.FunctionN, and assume javac-s usage otherwise.

*
* @param lookup The factory for method handles. Must have access to the implementation method, the
* functional interface class, and `java.io.Serializable` or `scala.Serializable` as
* required.
* @param cache A cache used to avoid spinning up a class for each deserialization of a given lambda. May be `null`
* @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve`
* member of the anonymous class created by `LambdaMetaFactory`.
* @return An instance of the functional interface
*/
def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
def slashDot(name: String) = name.replaceAll("/", ".")
val loader = lookup.lookupClass().getClassLoader
val implClass = loader.loadClass(slashDot(serialized.getImplClass))

def makeCallSite: CallSite = {
import serialized._
def parseDescriptor(s: String) =
MethodType.fromMethodDescriptorString(s, loader)

val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature)
val instantiated = parseDescriptor(getInstantiatedMethodType)
val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass))

val implMethodSig = parseDescriptor(getImplMethodSignature)
// Construct the invoked type from the impl method type. This is the type of a factory
// that will be generated by the meta-factory. It is a method type, with param types
// coming form the types of the captures, and return type being the functional interface.
val invokedType: MethodType = {
// 1. Add receiver for non-static impl methods
val withReceiver = getImplMethodKind match {
case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial =>
implMethodSig
case _ =>
implMethodSig.insertParameterTypes(0, implClass)
}
// 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter,
// such as in `Function<Object, String> s = Object::toString`
val lambdaArity = funcInterfaceSignature.parameterCount()
val from = withReceiver.parameterCount() - lambdaArity
val to = withReceiver.parameterCount()

// 3. Drop the lambda return type and replace with the functional interface.
withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass)
}

// Lookup the implementation method
val implMethod: MethodHandle = try {
findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
} catch {
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
}

val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS
val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function")
val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So FLAG_SERIALIZABLE seems to already add the java.io.Serializable interface.

We should check what happens if we also add scala.Serializable as a marker interface - do we get both? Is that OK?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try removing the explicit addition of "java.io.Serializable" and test that is is added. But looking that the bytecode that Javac emits and the implementation of LambdaMetafactory, I think that documentation is misleading, and the marker interfaces must be explicitly added. Conversely, adding a the marker interface that extends Serializable without setting the flag is not sufficient, that will consider it to be "accidentally serializable", and generate a writeObject that throws.

Although scala.Serializable was intended as a JVM-independent way to mark a Scala class a serializable, in practice people can also write methods that depend an instance of that trait, so we definitely need to keep it here. It would be cleaner if it were a straight type alias.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs are correct, I was looking at the wrong spot in the implementation. I'm about to push a commit that explains that it is redundant but safe to add java.io.Serializable.

LambdaMetafactory.altMetafactory(
lookup, getFunctionalInterfaceMethodName, invokedType,

/* samMethodType = */ funcInterfaceSignature,
/* implMethod = */ implMethod,
/* instantiatedMethodType = */ instantiated,
/* flags = */ flags.asInstanceOf[AnyRef],
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
/* markerInterfaces[0] = */ markerInterface,
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
)
}

val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature
val factory: MethodHandle = if (cache == null) {
makeCallSite.getTarget
} else cache.get(key) match {
case null =>
val callSite = makeCallSite
val temp = callSite.getTarget
cache.put(key, temp)
temp
case target => target
}

val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n))
factory.invokeWithArguments(captures: _*)
}

private val ScalaSerializable = "scala.Serializable"

private val JavaIOSerializable = {
// We could actually omit this marker interface as LambdaMetaFactory will add it if
// the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code
// is cleaner if we uniformly add a single marker, so I'm leaving it in place.
"java.io.Serializable"
}

private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_],
name: String, signature: MethodType): MethodHandle = {
kind match {
case MethodHandleInfo.REF_invokeStatic =>
lookup.findStatic(owner, name, signature)
case MethodHandleInfo.REF_newInvokeSpecial =>
lookup.findConstructor(owner, signature)
case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface =>
lookup.findVirtual(owner, name, signature)
case MethodHandleInfo.REF_invokeSpecial =>
lookup.findSpecial(owner, name, signature, owner)
}
}
}
193 changes: 193 additions & 0 deletions src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package scala.compat.java8.runtime;

import org.junit.Assert;
import org.junit.Test;

import java.io.Serializable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;

public final class LambdaDeserializerTest {
private LambdaHost lambdaHost = new LambdaHost();

@Test
public void serializationPrivate() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationStatic() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationVirtualMethodReference() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationInterfaceMethodReference() {
F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference();
I i = new I() {
};
Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i));
}

@Test
public void serializationStaticMethodReference() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationNewInvokeSpecial() {
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
Assert.assertEquals(f1.apply(), reconstitute(f1).apply());
}

@Test
public void uncached() {
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
F0<Object> reconstituted1 = reconstitute(f1);
F0<Object> reconstituted2 = reconstitute(f1);
Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass());
}

@Test
public void cached() {
HashMap<String, MethodHandle> cache = new HashMap<>();
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
F0<Object> reconstituted1 = reconstitute(f1, cache);
F0<Object> reconstituted2 = reconstitute(f1, cache);
Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass());
}

@Test
public void cachedStatic() {
HashMap<String, MethodHandle> cache = new HashMap<>();
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
// Check that deserialization of a static lambda always returns the
// same instance.
Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache));

// (as is the case with regular invocation.)
Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod());
}

@Test
public void implMethodNameChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
}

@Test
public void implMethodSignatureChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
}

private void checkIllegalAccess(SerializedLambda serialized) {
try {
LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
throw new AssertionError();
} catch (IllegalArgumentException iae) {
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
Assert.fail("Unexpected message: " + iae.getMessage());
}
}
}

private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
Object[] captures = new Object[sl.getCapturedArgCount()];
for (int i = 0; i < captures.length; i++) {
captures[i] = sl.getCapturedArg(i);
}
return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
sl.getInstantiatedMethodType(), captures);
}

private Class<?> loadClass(String className) {
try {
return Class.forName(className.replace('/', '.'));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
private <A, B> A reconstitute(A f1) {
return reconstitute(f1, null);
}

@SuppressWarnings("unchecked")
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
try {
return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private <A> SerializedLambda writeReplace(A f1) {
try {
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
writeReplace.setAccessible(true);
return (SerializedLambda) writeReplace.invoke(f1);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}


interface F1<A, B> extends Serializable {
B apply(A a);
}

interface F0<A> extends Serializable {
A apply();
}

class LambdaHost {
public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
int local = 42;
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
}

@SuppressWarnings("Convert2MethodRef")
public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks the same as the above lambdaBackedByPrivateImplMethod, in both cases the lambda body seems to be non-static

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed this and confirmed we get test coverage of the different paths. I've also added some variations of Java lambdas that were not working before, e.g. String::new or Object::toString or Clonable::clone.

return (b) -> String.valueOf(b);
}

public F1<Boolean, String> lambdaBackedByStaticMethodReference() {
return String::valueOf;
}

public F1<Boolean, String> lambdaBackedByVirtualMethodReference() {
return Object::toString;
}

public F1<I, Object> lambdaBackedByInterfaceMethodReference() {
return I::i;
}

public F0<Object> lambdaBackedByConstructorCall() {
return String::new;
}

public static MethodHandles.Lookup lookup() {
return MethodHandles.lookup();
}
}

interface I {
default String i() { return "i"; };
}