Skip to content

Commit afa2ff9

Browse files
committed
[indylambda] Support lambda {de}serialization
To support serialization, we use the alternative lambda metafactory that lets us specify that our anonymous functions should extend the marker interface `scala.Serializable`. They will also have a `writeObject` method added that implements the serialization proxy pattern using `j.l.invoke.SerializedLamba`. To support deserialization, we synthesize a `$deserializeLamba$` method in each class with lambdas. This will be called reflectively by `SerializedLambda#readResolve`. This method in turn delegates to `LambdaDeserializer`, currently defined [1] in `scala-java8-compat`, that uses `LambdaMetafactory` to spin up the anonymous class and instantiate it with the deserialized environment. Note: `LambdaDeserializer` can reuses the anonymous class on subsequent deserializations of a given lambda, in the same spirit as an invokedynamic call site only spins up the class on the first time it is run. But first we'll need to host a cache in a static field of each lambda hosting class. This is noted as a TODO and a failing test, and will be updated in the next commit. `LambdaDeserializer` will be moved into our standard library in the 2.12.x branch, where we can introduce dependencies on the Java 8 standard library. The enclosed test cases must be manually run with indylambda enabled. Once we enable indylambda by default on 2.12.x, the test will actually test the new feature. ``` % echo $INDYLAMBDA -Ydelambdafy:method -Ybackend:GenBCode -target:jvm-1.8 -classpath .:scala-java8-compat_2.11-0.5.0-SNAPSHOT.jar % qscala $INDYLAMBDA -e "println((() => 42).getClass)" class Main$$anon$1$$Lambda$1/1183231938 % qscala $INDYLAMBDA -e "assert(classOf[scala.Serializable].isInstance(() => 42))" % qscalac $INDYLAMBDA test/files/run/lambda-serialization.scala && qscala $INDYLAMBDA Test ``` This commit contains a few minor refactorings to the code that generates the invokedynamic instruction to use more meaningful names and to reuse Java signature generation code in ASM rather than the DIY approach. [1] scala/scala-java8-compat#37
1 parent 6ad9b44 commit afa2ff9

File tree

9 files changed

+105
-22
lines changed

9 files changed

+105
-22
lines changed

src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
3333
* Functionality to build the body of ASM MethodNode, except for `synchronized` and `try` expressions.
3434
*/
3535
abstract class PlainBodyBuilder(cunit: CompilationUnit) extends PlainSkelBuilder(cunit) {
36-
3736
import icodes.TestOp
3837
import icodes.opcodes.InvokeStyle
3938

@@ -1287,38 +1286,42 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
12871286

12881287
def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol) {
12891288
val isStaticMethod = lambdaTarget.hasFlag(Flags.STATIC)
1289+
def asmType(sym: Symbol) = classBTypeFromSymbol(sym).toASMType
12901290

1291-
val targetHandle =
1291+
val implMethodHandle =
12921292
new asm.Handle(if (lambdaTarget.hasFlag(Flags.STATIC)) asm.Opcodes.H_INVOKESTATIC else asm.Opcodes.H_INVOKEVIRTUAL,
12931293
classBTypeFromSymbol(lambdaTarget.owner).internalName,
12941294
lambdaTarget.name.toString,
12951295
asmMethodType(lambdaTarget).descriptor)
1296-
val receiver = if (isStaticMethod) None else Some(lambdaTarget.owner)
1296+
val receiver = if (isStaticMethod) Nil else lambdaTarget.owner :: Nil
12971297
val (capturedParams, lambdaParams) = lambdaTarget.paramss.head.splitAt(lambdaTarget.paramss.head.length - arity)
12981298
// Requires https://github.com/scala/scala-java8-compat on the runtime classpath
1299-
val returnUnit = lambdaTarget.info.resultType.typeSymbol == UnitClass
1300-
val functionalInterfaceDesc: String = classBTypeFromSymbol(functionalInterface).descriptor
1301-
val desc = (receiver.toList ::: capturedParams).map(sym => toTypeKind(sym.info)).mkString(("("), "", ")") + functionalInterfaceDesc
1299+
val invokedType = asm.Type.getMethodDescriptor(asmType(functionalInterface), (receiver ::: capturedParams).map(sym => toTypeKind(sym.info).toASMType): _*)
13021300

1303-
// TODO specialization
13041301
val constrainedType = new MethodBType(lambdaParams.map(p => toTypeKind(p.tpe)), toTypeKind(lambdaTarget.tpe.resultType)).toASMType
1305-
val abstractMethod = functionalInterface.info.decls.find(_.isDeferred).getOrElse(functionalInterface.info.member(nme.apply))
1306-
val methodName = abstractMethod.name.toString
1307-
val applyN = {
1308-
val mt = asmMethodType(abstractMethod)
1309-
mt.toASMType
1310-
}
1311-
1312-
bc.jmethod.visitInvokeDynamicInsn(methodName, desc, lambdaMetaFactoryBootstrapHandle,
1313-
// boostrap args
1314-
applyN, targetHandle, constrainedType
1302+
val sam = functionalInterface.info.decls.find(_.isDeferred).getOrElse(functionalInterface.info.member(nme.apply))
1303+
val samName = sam.name.toString
1304+
val samMethodType = asmMethodType(sam).toASMType
1305+
1306+
val flags = 3 // TODO 2.12.x Replace with LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS
1307+
1308+
val ScalaSerializable = classBTypeFromSymbol(definitions.SerializableClass).toASMType
1309+
bc.jmethod.visitInvokeDynamicInsn(samName, invokedType, lambdaMetaFactoryBootstrapHandle,
1310+
/* samMethodType = */ samMethodType,
1311+
/* implMethod = */ implMethodHandle,
1312+
/* instantiatedMethodType = */ constrainedType,
1313+
/* flags = */ flags.asInstanceOf[AnyRef],
1314+
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
1315+
/* markerInterfaces[0] = */ ScalaSerializable,
1316+
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
13151317
)
1318+
indyLambdaHosts += this.claszSymbol
13161319
}
13171320
}
13181321

1319-
val lambdaMetaFactoryBootstrapHandle =
1322+
lazy val lambdaMetaFactoryBootstrapHandle =
13201323
new asm.Handle(asm.Opcodes.H_INVOKESTATIC,
1321-
"java/lang/invoke/LambdaMetafactory", "metafactory",
1322-
"(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;")
1324+
definitions.LambdaMetaFactory.fullName('/'), sn.AltMetafactory.toString,
1325+
"(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;")
13231326

13241327
}

src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,33 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters {
682682
new java.lang.Long(id)
683683
).visitEnd()
684684
}
685+
686+
/**
687+
* Add:
688+
*
689+
* private static Object $deserializeLambda$(SerializedLambda l) {
690+
* return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, l);
691+
* }
692+
* @param jclass
693+
*/
694+
// TODO add a static cache field to the class, and pass that as the second argument to `deserializeLambda`.
695+
// This will make the test at run/lambda-serialization.scala:15 work
696+
def addLambdaDeserialize(jclass: asm.ClassVisitor): Unit = {
697+
val cw = jclass
698+
import scala.tools.asm.Opcodes._
699+
cw.visitInnerClass("java/lang/invoke/MethodHandles$Lookup", "java/lang/invoke/MethodHandles", "Lookup", ACC_PUBLIC + ACC_FINAL + ACC_STATIC)
700+
701+
{
702+
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", "(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", null, null)
703+
mv.visitCode()
704+
mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "lookup", "()Ljava/lang/invoke/MethodHandles$Lookup;", false)
705+
mv.visitInsn(asm.Opcodes.ACONST_NULL)
706+
mv.visitVarInsn(ALOAD, 0)
707+
mv.visitMethodInsn(INVOKESTATIC, "scala/compat/java8/runtime/LambdaDeserializer", "deserializeLambda", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/util/Map;Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", false)
708+
mv.visitInsn(ARETURN)
709+
mv.visitEnd()
710+
}
711+
}
685712
} // end of trait BCClassGen
686713

687714
/* functionality for building plain and mirror classes */

src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
6868
var isCZStaticModule = false
6969
var isCZRemote = false
7070

71+
protected val indyLambdaHosts = collection.mutable.Set[Symbol]()
72+
7173
/* ---------------- idiomatic way to ask questions to typer ---------------- */
7274

7375
def paramTKs(app: Apply): List[BType] = {
@@ -121,6 +123,16 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
121123

122124
innerClassBufferASM ++= classBType.info.get.nestedClasses
123125
gen(cd.impl)
126+
127+
128+
val shouldAddLambdaDeserialize = (
129+
settings.target.value == "jvm-1.8"
130+
&& settings.Ydelambdafy.value == "method"
131+
&& indyLambdaHosts.contains(claszSymbol))
132+
133+
if (shouldAddLambdaDeserialize)
134+
addLambdaDeserialize(cnode)
135+
124136
addInnerClassesASM(cnode, innerClassBufferASM.toList)
125137

126138
cnode.visitAttribute(classBType.inlineInfoAttribute.get)

src/compiler/scala/tools/nsc/transform/Delambdafy.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
146146
val isStatic = target.hasFlag(STATIC)
147147

148148
def createBoxingBridgeMethod(functionParamTypes: List[Type], functionResultType: Type): Tree = {
149+
// Note: we bail out of this method and return EmptyTree if we find there is no adaptation required.
150+
// If we need to improve performance, we could check the types first before creating the
151+
// method and parameter symbols.
149152
val methSym = oldClass.newMethod(target.name.append("$adapted").toTermName, target.pos, target.flags | FINAL | ARTIFACT)
150153
var neededAdaptation = false
151154
def boxedType(tpe: Type): Type = {

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ trait Definitions extends api.StandardDefinitions {
514514
lazy val ScalaSignatureAnnotation = requiredClass[scala.reflect.ScalaSignature]
515515
lazy val ScalaLongSignatureAnnotation = requiredClass[scala.reflect.ScalaLongSignature]
516516

517+
lazy val LambdaMetaFactory = getClassIfDefined("java.lang.invoke.LambdaMetafactory")
517518
lazy val MethodHandle = getClassIfDefined("java.lang.invoke.MethodHandle")
518519

519520
// Option classes

src/reflect/scala/reflect/internal/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,8 @@ trait StdNames {
11671167
final val Invoke: TermName = newTermName("invoke")
11681168
final val InvokeExact: TermName = newTermName("invokeExact")
11691169

1170+
final val AltMetafactory: TermName = newTermName("altMetafactory")
1171+
11701172
val Boxed = immutable.Map[TypeName, TypeName](
11711173
tpnme.Boolean -> BoxedBoolean,
11721174
tpnme.Byte -> BoxedByte,

src/reflect/scala/reflect/internal/transform/PostErasure.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ trait PostErasure {
99
object elimErasedValueType extends TypeMap {
1010
def apply(tp: Type) = tp match {
1111
case ConstantType(Constant(tp: Type)) => ConstantType(Constant(apply(tp)))
12-
case ErasedValueType(_, underlying) =>
13-
underlying
12+
case ErasedValueType(_, underlying) => underlying
1413
case _ => mapOver(tp)
1514
}
1615
}

src/reflect/scala/reflect/runtime/JavaUniverseForce.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse =>
310310
definitions.QuasiquoteClass_api_unapply
311311
definitions.ScalaSignatureAnnotation
312312
definitions.ScalaLongSignatureAnnotation
313+
definitions.LambdaMetaFactory
313314
definitions.MethodHandle
314315
definitions.OptionClass
315316
definitions.OptionModule
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
roundTrip
6+
}
7+
8+
def roundTrip(): Unit = {
9+
val c = new Capture("Capture")
10+
val lambda = (p: Param) => ("a", p, c)
11+
val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
12+
val p = new Param
13+
assert(reconstituted1.apply(p) == ("a", p, c))
14+
val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
15+
assert(reconstituted1.getClass == reconstituted2.getClass)
16+
17+
val reconstituted3 = serializeDeserialize(reconstituted1)
18+
assert(reconstituted3.apply(p) == ("a", p, c))
19+
20+
val specializedLambda = (p: Int) => List(p, c).length
21+
assert(serializeDeserialize(specializedLambda).apply(42) == 2)
22+
assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
23+
}
24+
25+
def serializeDeserialize[T <: AnyRef](obj: T) = {
26+
val buffer = new ByteArrayOutputStream
27+
val out = new ObjectOutputStream(buffer)
28+
out.writeObject(obj)
29+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
30+
in.readObject.asInstanceOf[T]
31+
}
32+
}
33+
34+
case class Capture(s: String) extends Serializable
35+
class Param

0 commit comments

Comments
 (0)