Skip to content

Commit aa0908b

Browse files
committed
Merge pull request #37 from retronym/topic/lambda-deserialize
Add a generic deserializer for Java/Scala 2.12 lambdas
2 parents 82eba69 + 921b212 commit aa0908b

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package scala.compat.java8.runtime
2+
3+
import java.lang.invoke._
4+
5+
/**
6+
* This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12
7+
* compiler will add to classes hosting lambdas.
8+
*
9+
* It is not intended to be consumed directly.
10+
*/
11+
object LambdaDeserializer {
12+
/**
13+
* Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class
14+
* and instantiating this class with the captured arguments.
15+
*
16+
* A cache may be provided to ensure that subsequent deserialization of the same lambda expression
17+
* is cheap, it amounts to a reflective call to the constructor of the previously created class.
18+
* However, deserialization of the same lambda expression is not guaranteed to use the same class,
19+
* concurrent deserialization of the same lambda expression may spin up more than one class.
20+
*
21+
* Assumptions:
22+
* - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are
23+
* not stored in `SerializedLambda`, so we can't reconstitute them.
24+
* - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored.
25+
*
26+
* @param lookup The factory for method handles. Must have access to the implementation method, the
27+
* functional interface class, and `java.io.Serializable` or `scala.Serializable` as
28+
* required.
29+
* @param cache A cache used to avoid spinning up a class for each deserialization of a given lambda. May be `null`
30+
* @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve`
31+
* member of the anonymous class created by `LambdaMetaFactory`.
32+
* @return An instance of the functional interface
33+
*/
34+
def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
35+
def slashDot(name: String) = name.replaceAll("/", ".")
36+
val loader = lookup.lookupClass().getClassLoader
37+
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
38+
39+
def makeCallSite: CallSite = {
40+
import serialized._
41+
def parseDescriptor(s: String) =
42+
MethodType.fromMethodDescriptorString(s, loader)
43+
44+
val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature)
45+
val instantiated = parseDescriptor(getInstantiatedMethodType)
46+
val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass))
47+
48+
val implMethodSig = parseDescriptor(getImplMethodSignature)
49+
// Construct the invoked type from the impl method type. This is the type of a factory
50+
// that will be generated by the meta-factory. It is a method type, with param types
51+
// coming form the types of the captures, and return type being the functional interface.
52+
val invokedType: MethodType = {
53+
// 1. Add receiver for non-static impl methods
54+
val withReceiver = getImplMethodKind match {
55+
case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial =>
56+
implMethodSig
57+
case _ =>
58+
implMethodSig.insertParameterTypes(0, implClass)
59+
}
60+
// 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter,
61+
// such as in `Function<Object, String> s = Object::toString`
62+
val lambdaArity = funcInterfaceSignature.parameterCount()
63+
val from = withReceiver.parameterCount() - lambdaArity
64+
val to = withReceiver.parameterCount()
65+
66+
// 3. Drop the lambda return type and replace with the functional interface.
67+
withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass)
68+
}
69+
70+
// Lookup the implementation method
71+
val implMethod: MethodHandle = try {
72+
findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
73+
} catch {
74+
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
75+
}
76+
77+
val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS
78+
val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function")
79+
val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable)
80+
81+
LambdaMetafactory.altMetafactory(
82+
lookup, getFunctionalInterfaceMethodName, invokedType,
83+
84+
/* samMethodType = */ funcInterfaceSignature,
85+
/* implMethod = */ implMethod,
86+
/* instantiatedMethodType = */ instantiated,
87+
/* flags = */ flags.asInstanceOf[AnyRef],
88+
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
89+
/* markerInterfaces[0] = */ markerInterface,
90+
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
91+
)
92+
}
93+
94+
val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature
95+
val factory: MethodHandle = if (cache == null) {
96+
makeCallSite.getTarget
97+
} else cache.get(key) match {
98+
case null =>
99+
val callSite = makeCallSite
100+
val temp = callSite.getTarget
101+
cache.put(key, temp)
102+
temp
103+
case target => target
104+
}
105+
106+
val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n))
107+
factory.invokeWithArguments(captures: _*)
108+
}
109+
110+
private val ScalaSerializable = "scala.Serializable"
111+
112+
private val JavaIOSerializable = {
113+
// We could actually omit this marker interface as LambdaMetaFactory will add it if
114+
// the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code
115+
// is cleaner if we uniformly add a single marker, so I'm leaving it in place.
116+
"java.io.Serializable"
117+
}
118+
119+
private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_],
120+
name: String, signature: MethodType): MethodHandle = {
121+
kind match {
122+
case MethodHandleInfo.REF_invokeStatic =>
123+
lookup.findStatic(owner, name, signature)
124+
case MethodHandleInfo.REF_newInvokeSpecial =>
125+
lookup.findConstructor(owner, signature)
126+
case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface =>
127+
lookup.findVirtual(owner, name, signature)
128+
case MethodHandleInfo.REF_invokeSpecial =>
129+
lookup.findSpecial(owner, name, signature, owner)
130+
}
131+
}
132+
}
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
package scala.compat.java8.runtime;
2+
3+
import org.junit.Assert;
4+
import org.junit.Test;
5+
6+
import java.io.Serializable;
7+
import java.lang.invoke.MethodHandle;
8+
import java.lang.invoke.MethodHandles;
9+
import java.lang.invoke.SerializedLambda;
10+
import java.lang.reflect.Method;
11+
import java.util.Arrays;
12+
import java.util.HashMap;
13+
14+
public final class LambdaDeserializerTest {
15+
private LambdaHost lambdaHost = new LambdaHost();
16+
17+
@Test
18+
public void serializationPrivate() {
19+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
20+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
21+
}
22+
23+
@Test
24+
public void serializationStatic() {
25+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
26+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
27+
}
28+
29+
@Test
30+
public void serializationVirtualMethodReference() {
31+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference();
32+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
33+
}
34+
35+
@Test
36+
public void serializationInterfaceMethodReference() {
37+
F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference();
38+
I i = new I() {
39+
};
40+
Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i));
41+
}
42+
43+
@Test
44+
public void serializationStaticMethodReference() {
45+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference();
46+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
47+
}
48+
49+
@Test
50+
public void serializationNewInvokeSpecial() {
51+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
52+
Assert.assertEquals(f1.apply(), reconstitute(f1).apply());
53+
}
54+
55+
@Test
56+
public void uncached() {
57+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
58+
F0<Object> reconstituted1 = reconstitute(f1);
59+
F0<Object> reconstituted2 = reconstitute(f1);
60+
Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass());
61+
}
62+
63+
@Test
64+
public void cached() {
65+
HashMap<String, MethodHandle> cache = new HashMap<>();
66+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
67+
F0<Object> reconstituted1 = reconstitute(f1, cache);
68+
F0<Object> reconstituted2 = reconstitute(f1, cache);
69+
Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass());
70+
}
71+
72+
@Test
73+
public void cachedStatic() {
74+
HashMap<String, MethodHandle> cache = new HashMap<>();
75+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
76+
// Check that deserialization of a static lambda always returns the
77+
// same instance.
78+
Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache));
79+
80+
// (as is the case with regular invocation.)
81+
Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod());
82+
}
83+
84+
@Test
85+
public void implMethodNameChanged() {
86+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
87+
SerializedLambda sl = writeReplace(f1);
88+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
89+
}
90+
91+
@Test
92+
public void implMethodSignatureChanged() {
93+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
94+
SerializedLambda sl = writeReplace(f1);
95+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
96+
}
97+
98+
private void checkIllegalAccess(SerializedLambda serialized) {
99+
try {
100+
LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
101+
throw new AssertionError();
102+
} catch (IllegalArgumentException iae) {
103+
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
104+
Assert.fail("Unexpected message: " + iae.getMessage());
105+
}
106+
}
107+
}
108+
109+
private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
110+
Object[] captures = new Object[sl.getCapturedArgCount()];
111+
for (int i = 0; i < captures.length; i++) {
112+
captures[i] = sl.getCapturedArg(i);
113+
}
114+
return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
115+
sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
116+
sl.getInstantiatedMethodType(), captures);
117+
}
118+
119+
private Class<?> loadClass(String className) {
120+
try {
121+
return Class.forName(className.replace('/', '.'));
122+
} catch (ClassNotFoundException e) {
123+
throw new RuntimeException(e);
124+
}
125+
}
126+
private <A, B> A reconstitute(A f1) {
127+
return reconstitute(f1, null);
128+
}
129+
130+
@SuppressWarnings("unchecked")
131+
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
132+
try {
133+
return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
134+
} catch (Exception e) {
135+
throw new RuntimeException(e);
136+
}
137+
}
138+
139+
private <A> SerializedLambda writeReplace(A f1) {
140+
try {
141+
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
142+
writeReplace.setAccessible(true);
143+
return (SerializedLambda) writeReplace.invoke(f1);
144+
} catch (Exception e) {
145+
throw new RuntimeException(e);
146+
}
147+
}
148+
}
149+
150+
151+
interface F1<A, B> extends Serializable {
152+
B apply(A a);
153+
}
154+
155+
interface F0<A> extends Serializable {
156+
A apply();
157+
}
158+
159+
class LambdaHost {
160+
public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
161+
int local = 42;
162+
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
163+
}
164+
165+
@SuppressWarnings("Convert2MethodRef")
166+
public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
167+
return (b) -> String.valueOf(b);
168+
}
169+
170+
public F1<Boolean, String> lambdaBackedByStaticMethodReference() {
171+
return String::valueOf;
172+
}
173+
174+
public F1<Boolean, String> lambdaBackedByVirtualMethodReference() {
175+
return Object::toString;
176+
}
177+
178+
public F1<I, Object> lambdaBackedByInterfaceMethodReference() {
179+
return I::i;
180+
}
181+
182+
public F0<Object> lambdaBackedByConstructorCall() {
183+
return String::new;
184+
}
185+
186+
public static MethodHandles.Lookup lookup() {
187+
return MethodHandles.lookup();
188+
}
189+
}
190+
191+
interface I {
192+
default String i() { return "i"; };
193+
}

0 commit comments

Comments
 (0)