Skip to content

Commit 6472976

Browse files
committed
Add a generic deserializer for Java/Scala 2.12 lambdas
Java support serialization of lambdas by using the serialization proxy pattern. Deserialization of a lambda uses `LambdaMetafactory` to create a new anonymous subclass. More details of the scheme are documented: https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html From those docs: > SerializedLambda has a readResolve method that looks for a > (possibly private) static method called $deserializeLambda$ > in the capturing class, invokes that with itself as the first > argument, and returns the result. Lambda classes implementing > $deserializeLambda$ are responsible for validating that the > properties of the SerializedLambda are consistent with a lambda > actually captured by that class. The Java compiler generates code in `$deserializeLambda$` that switches on the implementation method name and signature to locate an invokedynamic instruction generated for the particular lambda expression. Then, the `SerializedLambda` is further unpacked, validating that this implementation method still represents the same functional interface as it did when it was serialized. (The source may have been recompiled in the interim.) In Java, serializable lambda expressions are the exception rather than the rule. In Scala, however, the serializability of `FunctionN` means that we would end up generating a large amount of code to support deserialization. Instead, we are pursuing an alternative approach in which the `$deserializeLambda$` method is a simple forwarder to the generic deserializer added here. This is capable of deserializing lambdas created by the Java compiler, although this is not its intended use case. The enclosed tests use Java lambdas. This generic deserializer also works by calling `LambdaMetafactory`, but it does so explicitly, rather than implicitly during linkage of the `invokedynamic` instruction. We have to mimic the caching property of `invokedynamic` instruction to ensure we reuse the classes when constructing. The cache here uses weak references to keys and values to avoid retention of `Class` or `ClassLoader` instances. If the name or signature of the implementation method has changed, we fail during deserialization with an `IllegalArgumentError.` However, we do not fail fast in a few cases that Java would, as we cannot reflect on the "current" functional interface supported by this implementation method. We just instantiate using the "previous" functional interface class/method. This might: 1. fail inside `LambdaMetafactory` if the new implementation method is not compatible with the old functional interface. 2. pass through `LambdaMetafactory` by chance, but fail when instantiating the class in other cases. For example: ``` % tail sandbox/test{1,2}.scala ==> sandbox/test1.scala <== class C { def test: (String => String) = { val s: String = "" (t) => s + t } } ==> sandbox/test2.scala <== class C { def test: (String, String) => String = { (s, t) => s + t } } % (for i in 1 2; do scalac -Ydelambdafy:method -Xprint:delambdafy sandbox/test$i.scala 2>&1 ; done) | grep 'def $anon' final <static> <artifact> private[this] def $anonfun$1(t: String, s$1: String): String = s$1.+(t); final <static> <artifact> private[this] def $anonfun$1(s: String, t: String): String = s.+(t); ``` 3. Silently create an instance of the old functional interface. For example, imagine switching from `FuncInterface1` to `FuncInterface2` where these were identical other than the name. I don't believe that these are showstoppers.
1 parent 82eba69 commit 6472976

File tree

2 files changed

+295
-0
lines changed

2 files changed

+295
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package scala.compat.java8.runtime
2+
3+
import java.lang.invoke._
4+
import java.lang.ref.WeakReference
5+
6+
/**
7+
* This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12
8+
* compiler will add to classes hosting lambdas.
9+
*
10+
* It is intended to be consumed directly.
11+
*/
12+
object LambdaDeserializer {
13+
private final case class CacheKey(implClass: Class[_], implMethodName: String, implMethodSignature: String)
14+
private val cache = new java.util.WeakHashMap[CacheKey, WeakReference[CallSite]]()
15+
16+
/**
17+
* Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class
18+
* and instantiating this class with the captured arguments.
19+
*
20+
* A cache is employed to ensure that subsequent deserialization of the same lambda expression
21+
* is cheap, it amounts to a reflective call to the constructor of the previously created class.
22+
* However, deserialization of the same lambda expression is not guaranteed to use the same class,
23+
* concurrent deserialization of the same lambda expression may spin up more than one class.
24+
*
25+
* This cache is weak in keys and values to avoid retention of the enclosing class (and its classloader)
26+
* of deserialized lambdas.
27+
*
28+
* Assumptions:
29+
* - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are
30+
* not stored in `SerializedLambda`, so we can't reconstitute them.
31+
* - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored.
32+
*
33+
* @param lookup The factory for method handles. Must have access to the implementation method, the
34+
* functional interface class, and `java.io.Serializable` or `scala.Serializable` as
35+
* required.
36+
* @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve`
37+
* member of the anonymous class created by `LambdaMetaFactory`.
38+
* @return An instance of the functional interface
39+
*/
40+
def deserializeLambda(lookup: MethodHandles.Lookup, serialized: SerializedLambda): AnyRef = {
41+
def slashDot(name: String) = name.replaceAll("/", ".")
42+
val loader = lookup.lookupClass().getClassLoader
43+
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
44+
45+
def makeCallSite: CallSite = {
46+
import serialized._
47+
def parseDescriptor(s: String) =
48+
MethodType.fromMethodDescriptorString(s, loader)
49+
50+
val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature)
51+
val instantiated = parseDescriptor(getInstantiatedMethodType)
52+
val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass))
53+
54+
val implMethodSig = parseDescriptor(getImplMethodSignature)
55+
// Construct the invoked type from the impl method type. This is the type of a factory
56+
// that will be generated by the meta-factory. It is a method type, with param types
57+
// coming form the types of the captures, and return type being the functional interface.
58+
val invokedType: MethodType = {
59+
// 1. Add receiver for non-static impl methods
60+
val withReceiver = getImplMethodKind match {
61+
case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial =>
62+
implMethodSig
63+
case _ =>
64+
implMethodSig.insertParameterTypes(0, implClass)
65+
}
66+
// 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter,
67+
// such as in `Function<Object, String> s = Object::toString`
68+
val lambdaArity = funcInterfaceSignature.parameterCount()
69+
val from = withReceiver.parameterCount() - lambdaArity
70+
val to = withReceiver.parameterCount()
71+
72+
// 3. Drop the lambda return type and replace with the functional interface.
73+
withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass)
74+
}
75+
76+
// Lookup the implementation method
77+
val implMethod: MethodHandle = try {
78+
findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
79+
} catch {
80+
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
81+
}
82+
83+
val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS
84+
val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function")
85+
val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable)
86+
87+
LambdaMetafactory.altMetafactory(
88+
lookup, getFunctionalInterfaceMethodName, invokedType,
89+
90+
/* samMethodType = */ funcInterfaceSignature,
91+
/* implMethod = */ implMethod,
92+
/* instantiatedMethodType = */ instantiated,
93+
/* flags = */ flags.asInstanceOf[AnyRef],
94+
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
95+
/* markerInterfaces[0] = */ markerInterface,
96+
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
97+
)
98+
}
99+
100+
val key = new CacheKey(implClass, serialized.getImplMethodName, serialized.getImplMethodSignature)
101+
val callSiteRef: WeakReference[CallSite] = cache.get(key)
102+
var site = if (callSiteRef == null) null else callSiteRef.get()
103+
if (site == null) {
104+
site = makeCallSite
105+
cache.put(key, new WeakReference(site))
106+
}
107+
108+
val factory = site.getTarget
109+
val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n))
110+
factory.invokeWithArguments(captures: _*)
111+
}
112+
113+
private val ScalaSerializable = "scala.Serializable"
114+
115+
private val JavaIOSerializable = {
116+
// We could actually omit this marker interface as LambdaMetaFactory will add it if
117+
// the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code
118+
// is cleaner if we uniformly add a single marker, so I'm leaving it in place.
119+
"java.io.Serializable"
120+
}
121+
122+
private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_],
123+
name: String, signature: MethodType): MethodHandle = {
124+
kind match {
125+
case MethodHandleInfo.REF_invokeStatic =>
126+
lookup.findStatic(owner, name, signature)
127+
case MethodHandleInfo.REF_newInvokeSpecial =>
128+
lookup.findConstructor(owner, signature)
129+
case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface =>
130+
lookup.findVirtual(owner, name, signature)
131+
case MethodHandleInfo.REF_invokeSpecial =>
132+
lookup.findSpecial(owner, name, signature, owner)
133+
}
134+
}
135+
}
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package scala.compat.java8.runtime;
2+
3+
import org.junit.Assert;
4+
import org.junit.Test;
5+
6+
import java.io.Serializable;
7+
import java.lang.Cloneable;
8+
import java.lang.invoke.MethodHandles;
9+
import java.lang.invoke.SerializedLambda;
10+
import java.lang.reflect.Method;
11+
import java.util.Arrays;
12+
13+
public final class LambdaDeserializerTest {
14+
private LambdaHost lambdaHost = new LambdaHost();
15+
16+
@Test
17+
public void serializationPrivate() {
18+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
19+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
20+
}
21+
22+
@Test
23+
public void serializationStatic() {
24+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
25+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
26+
}
27+
28+
@Test
29+
public void serializationVirtualMethodReference() {
30+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference();
31+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
32+
}
33+
34+
@Test
35+
public void serializationInterfaceMethodReference() {
36+
F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference();
37+
I i = new I() {
38+
};
39+
Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i));
40+
}
41+
42+
@Test
43+
public void serializationStaticMethodReference() {
44+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference();
45+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
46+
}
47+
48+
@Test
49+
public void serializationNewInvokeSpecial() {
50+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
51+
Assert.assertEquals(f1.apply(), reconstitute(f1).apply());
52+
}
53+
54+
@Test
55+
public void implMethodNameChanged() {
56+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
57+
SerializedLambda sl = writeReplace(f1);
58+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
59+
}
60+
61+
@Test
62+
public void implMethodSignatureChanged() {
63+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
64+
SerializedLambda sl = writeReplace(f1);
65+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
66+
}
67+
68+
private void checkIllegalAccess(SerializedLambda serialized) {
69+
try {
70+
LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), serialized);
71+
throw new AssertionError();
72+
} catch (IllegalArgumentException iae) {
73+
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
74+
Assert.fail("Unexpected message: " + iae.getMessage());
75+
}
76+
}
77+
}
78+
79+
private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
80+
Object[] captures = new Object[sl.getCapturedArgCount()];
81+
for (int i = 0; i < captures.length; i++) {
82+
captures[i] = sl.getCapturedArg(i);
83+
}
84+
return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
85+
sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
86+
sl.getInstantiatedMethodType(), captures);
87+
}
88+
89+
private Class<?> loadClass(String className) {
90+
try {
91+
return Class.forName(className.replace('/', '.'));
92+
} catch (ClassNotFoundException e) {
93+
throw new RuntimeException(e);
94+
}
95+
}
96+
97+
@SuppressWarnings("unchecked")
98+
private <A, B> A reconstitute(A f1) {
99+
try {
100+
return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), writeReplace(f1));
101+
} catch (Exception e) {
102+
throw new RuntimeException(e);
103+
}
104+
}
105+
106+
private <A> SerializedLambda writeReplace(A f1) {
107+
try {
108+
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
109+
writeReplace.setAccessible(true);
110+
return (SerializedLambda) writeReplace.invoke(f1);
111+
} catch (Exception e) {
112+
throw new RuntimeException(e);
113+
}
114+
}
115+
}
116+
117+
118+
interface F1<A, B> extends Serializable {
119+
B apply(A a);
120+
}
121+
122+
interface F0<A> extends Serializable {
123+
A apply();
124+
}
125+
126+
class LambdaHost {
127+
public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
128+
int local = 42;
129+
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
130+
}
131+
132+
@SuppressWarnings("Convert2MethodRef")
133+
public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
134+
return (b) -> String.valueOf(b);
135+
}
136+
137+
public F1<Boolean, String> lambdaBackedByStaticMethodReference() {
138+
return String::valueOf;
139+
}
140+
141+
public F1<Boolean, String> lambdaBackedByVirtualMethodReference() {
142+
return Object::toString;
143+
}
144+
145+
public F1<I, Object> lambdaBackedByInterfaceMethodReference() {
146+
return I::i;
147+
}
148+
149+
public F0<Object> lambdaBackedByConstructorCall() {
150+
return String::new;
151+
}
152+
153+
public static MethodHandles.Lookup lookup() {
154+
return MethodHandles.lookup();
155+
}
156+
}
157+
158+
interface I {
159+
default String i() { return "i"; };
160+
}

0 commit comments

Comments
 (0)