Skip to content

Commit 8c5d4ee

Browse files
committed
Add a generic deserializer for Java/Scala 2.12 lambdas
Java support serialization of lambdas by using the serialization proxy pattern. Deserialization of a lambda uses `LambdaMetafactory` to create a new anonymous subclass. More details of the scheme are documented: https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html From those docs: > SerializedLambda has a readResolve method that looks for a > (possibly private) static method called $deserializeLambda$ > in the capturing class, invokes that with itself as the first > argument, and returns the result. Lambda classes implementing > $deserializeLambda$ are responsible for validating that the > properties of the SerializedLambda are consistent with a lambda > actually captured by that class. The Java compiler generates code in `$deserializeLambda$` that switches on the implementation method name and signature to locate an invokedynamic instruction generated for the particular lambda expression. Then, the `SerializedLambda` is further unpacked, validating that this implementation method still represents the same functional interface as it did when it was serialized. (The source may have been recompiled in the interim.) In Java, serializable lambda expressions are the exception rather than the rule. In Scala, however, the serializability of `FunctionN` means that we would end up generating a large amount of code to support deserialization. Instead, we are pursuing an alternative approach in which the `$deserializeLambda$` method is a simple forwarder to the generic deserializer added here. This is capable of deserializing lambdas created by the Java compiler, although this is not its intended use case. The enclosed tests use Java lambdas. This generic deserializer also works by calling `LambdaMetafactory`, but it does so explicitly, rather than implicitly during linkage of the `invokedynamic` instruction. We have to mimic the caching property of `invokedynamic` instruction to ensure we reuse the classes when constructing. The cache here uses weak references to keys and values to avoid retention of `Class` or `ClassLoader` instances. If the name or signature of the implementation method has changed, we fail during deserialization with an `IllegalArgumentError.` However, we do not fail fast in a few cases that Java would, as we cannot reflect on the "current" functional interface supported by this implementation method. We just instantiate using the "previous" functional interface class/method. This might: 1. fail inside `LambdaMetafactory` if the new implementation method is not compatible with the old functional interface. 2. pass through `LambdaMetafactory` by chance, but fail when instantiating the class in other cases. For example: ``` % tail sandbox/test{1,2}.scala ==> sandbox/test1.scala <== class C { def test: (String => String) = { val s: String = "" (t) => s + t } } ==> sandbox/test2.scala <== class C { def test: (String, String) => String = { (s, t) => s + t } } % (for i in 1 2; do scalac -Ydelambdafy:method -Xprint:delambdafy sandbox/test$i.scala 2>&1 ; done) | grep 'def $anon' final <static> <artifact> private[this] def $anonfun$1(t: String, s$1: String): String = s$1.+(t); final <static> <artifact> private[this] def $anonfun$1(s: String, t: String): String = s.+(t); ``` 3. Silently create an instance of the old functional interface. For example, imagine switching from `FuncInterface1` to `FuncInterface2` where these were identical other than the name. I don't believe that these are showstoppers.
1 parent 82eba69 commit 8c5d4ee

File tree

3 files changed

+222
-2
lines changed

3 files changed

+222
-2
lines changed

build.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ import com.typesafe.tools.mima.plugin.{MimaPlugin, MimaKeys}
22

33
scalaModuleSettings
44

5-
scalaVersion := "2.11.5"
5+
scalaVersion := "2.11.6"
66

7-
snapshotScalaBinaryVersion := "2.11.5"
7+
snapshotScalaBinaryVersion := "2.11.6"
88

99
organization := "org.scala-lang.modules"
1010

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package scala.compat.java8.runtime
2+
3+
import java.lang.invoke._
4+
import java.lang.ref.WeakReference
5+
6+
/**
7+
* This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12
8+
* compiler will add to classes hosting lambdas.
9+
*
10+
* It is intended to be consumed directly.
11+
*/
12+
object LambdaDeserializer {
13+
private final case class CacheKey(implClass: Class[_], implMethodName: String, implMethodSignature: String)
14+
private val cache = new java.util.WeakHashMap[CacheKey, WeakReference[CallSite]]()
15+
16+
/**
17+
* Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class
18+
* and instantiating this class with the captured arguments.
19+
*
20+
* A cache is employed to ensure that subsequent deserialization of the same lambda expression
21+
* is cheap, it amounts to a reflective call to the constructor of the previously created class.
22+
* However, deserialization of the same lambda expression is not guaranteed to use the same class,
23+
* concurrent deserialization of the same lambda expression may spin up more than one class.
24+
*
25+
* This cache is weak in keys and values to avoid retention of the enclosing class (and its classloader)
26+
* of deserialized lambdas.
27+
*
28+
* Assumptions:
29+
* - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are
30+
* not stored in `SerializedLambda`, so we can't reconstitute them.
31+
* - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored.
32+
*
33+
* Note: The Java compiler
34+
*
35+
* @param lookup The factory for method handles. Must have access to the implementation method, the
36+
* functional interface class, and `java.io.Serializable` or `scala.Serializable` as
37+
* required.
38+
* @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve`
39+
* member of the anonymous class created by `LambdaMetaFactory`.
40+
* @return An instance of the functional interface
41+
*/
42+
def deserializeLambda(lookup: MethodHandles.Lookup, serialized: SerializedLambda): AnyRef = {
43+
def slashDot(name: String) = name.replaceAll("/", ".")
44+
val loader = lookup.lookupClass().getClassLoader
45+
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
46+
47+
def makeCallSite: CallSite = {
48+
import serialized._
49+
def parseDescriptor(s: String) =
50+
MethodType.fromMethodDescriptorString(s, loader)
51+
52+
val funcInterfacesSignature = parseDescriptor(getFunctionalInterfaceMethodSignature)
53+
val methodType: MethodType = funcInterfacesSignature
54+
val instantiated = parseDescriptor(getInstantiatedMethodType)
55+
val implMethodSig = parseDescriptor(getImplMethodSignature)
56+
57+
val from = implMethodSig.parameterCount() - funcInterfacesSignature.parameterCount()
58+
val to = implMethodSig.parameterCount()
59+
val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass))
60+
var invokedType: MethodType =
61+
implMethodSig.dropParameterTypes(from, to)
62+
.changeReturnType(functionalInterfaceClass)
63+
64+
val implMethod: MethodHandle = try {
65+
getImplMethodKind match {
66+
case MethodHandleInfo.REF_invokeStatic =>
67+
lookup.findStatic(implClass, getImplMethodName, implMethodSig)
68+
case MethodHandleInfo.REF_invokeVirtual =>
69+
invokedType = invokedType.insertParameterTypes(0, implClass)
70+
lookup.findVirtual(implClass, getImplMethodName, implMethodSig)
71+
case MethodHandleInfo.REF_invokeSpecial =>
72+
invokedType = invokedType.insertParameterTypes(0, implClass)
73+
lookup.findSpecial(implClass, getImplMethodName, implMethodSig, implClass)
74+
}
75+
} catch {
76+
case e: ReflectiveOperationException =>
77+
throw new IllegalArgumentException("Illegal lambda deserialization", e)
78+
}
79+
val FLAG_SERIALIZABLE = 1
80+
val FLAG_MARKERS = 2
81+
val flags: Int = FLAG_SERIALIZABLE | FLAG_MARKERS
82+
val markerInterface: Class[_] = if (functionalInterfaceClass.getName.startsWith("scala.Function"))
83+
loader.loadClass("scala.Serializable")
84+
else
85+
loader.loadClass("java.io.Serializable")
86+
87+
LambdaMetafactory.altMetafactory(
88+
lookup, getFunctionalInterfaceMethodName, invokedType,
89+
90+
/* samMethodType = */ funcInterfacesSignature,
91+
/* implMethod = */ implMethod,
92+
/* instantiatedMethodType = */ instantiated,
93+
/* flags = */ flags.asInstanceOf[AnyRef],
94+
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
95+
/* markerInterfaces[0] = */ markerInterface,
96+
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
97+
)
98+
}
99+
100+
val key = new CacheKey(implClass, serialized.getImplMethodName, serialized.getImplMethodSignature)
101+
val callSiteRef: WeakReference[CallSite] = cache.get(key)
102+
var site = if (callSiteRef == null) null else callSiteRef.get()
103+
if (site == null) {
104+
site = makeCallSite
105+
cache.put(key, new WeakReference(site))
106+
}
107+
108+
val factory = site.getTarget
109+
val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n))
110+
factory.invokeWithArguments(captures: _*)
111+
}
112+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package scala.compat.java8.runtime;
2+
3+
import org.junit.Assert;
4+
import org.junit.Test;
5+
6+
import java.io.Serializable;
7+
import java.lang.invoke.MethodHandles;
8+
import java.lang.invoke.SerializedLambda;
9+
import java.lang.reflect.Method;
10+
import java.util.Arrays;
11+
12+
public final class LambdaDeserializerTest {
13+
private LambdaHost lambdaHost = new LambdaHost();
14+
15+
@Test
16+
public void serializationPrivate() {
17+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
18+
F1<Boolean, String> f2 = reconstitute(f1);
19+
Assert.assertEquals(f1.apply(true), f2.apply(true));
20+
}
21+
22+
@Test
23+
public void serializationStatic() {
24+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
25+
F1<Boolean, String> f2 = reconstitute(f1);
26+
Assert.assertEquals(f1.apply(true), f2.apply(true));
27+
}
28+
29+
@Test
30+
public void implMethodNameChanged() {
31+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
32+
SerializedLambda sl = writeReplace(f1);
33+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
34+
}
35+
36+
@Test
37+
public void implMethodSignatureChanged() {
38+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
39+
SerializedLambda sl = writeReplace(f1);
40+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
41+
}
42+
43+
private void checkIllegalAccess(SerializedLambda serialized) {
44+
try {
45+
LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), serialized);
46+
throw new AssertionError();
47+
} catch (IllegalArgumentException iae) {
48+
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
49+
Assert.fail("Unexpected message: " + iae.getMessage());
50+
}
51+
}
52+
}
53+
54+
private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
55+
Object[] captures = new Object[sl.getCapturedArgCount()];
56+
for (int i = 0; i < captures.length; i++) {
57+
captures[i] = sl.getCapturedArg(i);
58+
}
59+
return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
60+
sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
61+
sl.getInstantiatedMethodType(), captures);
62+
}
63+
64+
private Class<?> loadClass(String className) {
65+
try {
66+
return Class.forName(className.replace('/', '.'));
67+
} catch (ClassNotFoundException e) {
68+
throw new RuntimeException(e);
69+
}
70+
}
71+
72+
private <A, B> F1<A, B> reconstitute(F1<A, B> f1) {
73+
try {
74+
return (F1<A, B>) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), writeReplace(f1));
75+
} catch (Exception e) {
76+
throw new RuntimeException(e);
77+
}
78+
}
79+
80+
private <A, B> SerializedLambda writeReplace(F1<A, B> f1) {
81+
try {
82+
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
83+
writeReplace.setAccessible(true);
84+
return (SerializedLambda) writeReplace.invoke(f1);
85+
} catch (Exception e) {
86+
throw new RuntimeException(e);
87+
}
88+
}
89+
}
90+
91+
92+
interface F1<A, B> extends Serializable {
93+
B apply(A a);
94+
}
95+
96+
class LambdaHost {
97+
public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
98+
int local = 42;
99+
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
100+
}
101+
102+
public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
103+
int local = 42;
104+
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
105+
}
106+
107+
public static MethodHandles.Lookup lookup() { return MethodHandles.lookup(); }
108+
}

0 commit comments

Comments
 (0)