Skip to content

Commit 70cd9ba

Browse files
committed
SpecializeFunction1: don't roll over parents, use mapConserve
1 parent 03040b1 commit 70cd9ba

File tree

2 files changed

+114
-127
lines changed

2 files changed

+114
-127
lines changed

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ object NameOps {
256256
case nme.clone_ => nme.clone_
257257
}
258258

259-
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type] = scala.Nil, methodTarsNames: List[Name] = scala.Nil)(implicit ctx: Context): name.ThisName = {
259+
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = {
260260

261261
def typeToTag(tp: Types.Type): Name = {
262262
tp.classSymbol match {

compiler/src/dotty/tools/dotc/transform/SpecializeFunction1.scala

Lines changed: 113 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,119 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
1313

1414
val phaseName = "specializeFunction1"
1515

16+
// Setup ---------------------------------------------------------------------
17+
private[this] val functionName = "JFunction1".toTermName
18+
private[this] val functionPkg = "scala.compat.java8.".toTermName
19+
private[this] var argTypes: Set[Symbol] = _
20+
private[this] var retTypes: Set[Symbol] = _
21+
22+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
23+
argTypes = Set(defn.DoubleClass,
24+
defn.FloatClass,
25+
defn.IntClass,
26+
defn.LongClass,
27+
defn.UnitClass,
28+
defn.BooleanClass)
29+
30+
retTypes = Set(defn.DoubleClass,
31+
defn.FloatClass,
32+
defn.IntClass,
33+
defn.LongClass)
34+
this
35+
}
36+
37+
// Transformations -----------------------------------------------------------
38+
1639
/** Transforms all classes extending `Function1[-T1, +R]` so that
1740
* they instead extend the specialized version `JFunction$mp...`
1841
*/
1942
def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
20-
case ShouldTransformDenot(cref, t1, r, func1) =>
21-
transformDenot(cref, t1, r, func1)
43+
case ShouldTransformDenot(cref, t1, r, func1) => {
44+
val specializedFunction: Symbol =
45+
ctx.getClassIfDefined(functionPkg ++ specializedName(functionName, t1, r))
46+
47+
def replaceFunction1(in: List[TypeRef]): List[TypeRef] =
48+
in.mapConserve { tp =>
49+
if (tp.isRef(defn.FunctionClass(1)) && (specializedFunction ne NoSymbol))
50+
specializedFunction.typeRef
51+
else tp
52+
}
53+
54+
def specializeApply(scope: Scope): Scope =
55+
if (specializedFunction ne NoSymbol) {
56+
def specializedApply: Symbol = {
57+
val specializedMethodName = specializedName(nme.apply, t1, r)
58+
ctx.newSymbol(
59+
cref.symbol,
60+
specializedMethodName,
61+
Flags.Override | Flags.Method,
62+
specializedFunction.info.decls.lookup(specializedMethodName).info
63+
)
64+
}
65+
66+
val alteredScope = scope.cloneScope
67+
alteredScope.enter(specializedApply)
68+
alteredScope
69+
}
70+
else scope
71+
72+
val ClassInfo(prefix, cls, parents, decls, info) = cref.classInfo
73+
val newInfo = ClassInfo(prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
74+
cref.copySymDenotation(info = newInfo)
75+
}
2276
case _ => ref
2377
}
2478

79+
/** Transform the class definition's `Template`:
80+
*
81+
* - change the tree to have the correct parent
82+
* - add the specialized apply method to the template body
83+
* - forward the old `apply` to the specialized version
84+
*/
2585
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) =
2686
tree match {
27-
case ShouldTransformTree(func1, t1, r) => transformTree(tree, func1, t1, r)
87+
case tmpl @ ShouldTransformTree(func1, t1, r) => {
88+
val specializedFunc1 =
89+
TypeTree(ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
90+
91+
val parents = tmpl.parents.foldRight(List.empty[Tree]) { (t, trees) =>
92+
(if (func1 eq t) specializedFunc1 else t) :: trees
93+
}
94+
95+
val body = tmpl.body.foldRight(List.empty[Tree]) {
96+
case (tree: DefDef, acc) if tree.name == nme.apply => {
97+
val specializedMethodName = specializedName(nme.apply, t1, r)
98+
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
99+
100+
val forwardingBody =
101+
tpd.ref(specializedApply)
102+
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
103+
104+
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
105+
106+
val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
107+
tree.rhs
108+
.changeOwner(tree.symbol, specializedApply)
109+
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
110+
})
111+
112+
applyWithForwarding :: specializedApplyDefDef :: acc
113+
}
114+
case (tree, acc) => tree :: acc
115+
}
116+
117+
cpy.Template(tmpl)(parents = parents, body = body)
118+
}
28119
case _ => tree
29120
}
30121

122+
/** Dispatch to specialized `apply`s in user code */
31123
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) = {
32124
import ast.Trees._
33125
tree match {
34126
case Apply(select @ Select(id, nme.apply), arg :: Nil) =>
35127
val params = List(arg.tpe, tree.tpe)
36-
val specializedApply = nme.apply.specializedFor(params, params.map(_.typeSymbol.name))
128+
val specializedApply = nme.apply.specializedFor(params, params.map(_.typeSymbol.name), Nil, Nil)
37129
val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists { sym =>
38130
sym.is(Flags.Override) && (sym.name eq specializedApply)
39131
}
@@ -44,137 +136,32 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
44136
}
45137
}
46138

47-
private[this] val functionName = "JFunction1".toTermName
48-
private[this] val functionPkg = "scala.compat.java8.".toTermName
49-
private[this] var Function1: Symbol = _
50-
private[this] var argTypes: Set[Symbol] = _
51-
private[this] var returnTypes: Set[Symbol] = _
52-
private[this] var blacklisted: Set[Symbol] = _
53-
54-
/** Do setup of `argTypes` and `returnTypes` */
55-
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
56-
argTypes = Set(defn.DoubleClass,
57-
defn.FloatClass,
58-
defn.IntClass,
59-
defn.LongClass,
60-
defn.UnitClass,
61-
defn.BooleanClass)
62-
63-
returnTypes = Set(defn.DoubleClass,
64-
defn.FloatClass,
65-
defn.IntClass,
66-
defn.LongClass)
67-
68-
Function1 = ctx.requiredClass("scala.Function1")
69-
70-
blacklisted = Set(
71-
"scala.compat.java8.JFunction1",
72-
"scala.runtime.AbstractFunction1",
73-
"scala.PartialFunction",
74-
"scala.runtime.AbstractPartialFunction"
75-
).map(ctx.requiredClass(_))
76-
77-
this
78-
}
139+
private def specializedName(name: Name, t1: Type, r: Type)(implicit ctx: Context): Name =
140+
name.specializedFor(List(t1, r), List(t1, r).map(_.typeSymbol.name), Nil, Nil)
79141

142+
// Extractors ----------------------------------------------------------------
80143
private object ShouldTransformDenot {
81-
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[(ClassDenotation, Type, Type, Type)] = {
82-
def collectFunc1(xs: List[Type])(implicit ctx: Context): Option[(Type, Type, Type)] =
83-
xs.collect {
84-
case func1 @ RefinedType(RefinedType(parent, _, t1), _, r)
85-
if func1.isRef(Function1) => (t1, r, func1)
86-
}.headOption
87-
88-
collectFunc1(cref.info.parentsWithArgs).flatMap { case (t1, r, func1) =>
89-
if (
90-
!argTypes.contains(t1.typeSymbol) ||
91-
!returnTypes.contains(r.typeSymbol) ||
92-
blacklisted.exists(sym => cref.symbol.derivesFrom(sym))
93-
) None
94-
else Some((cref, t1, r, func1))
144+
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[(ClassDenotation, Type, Type, Type)] =
145+
getFunc1(cref.typeRef).map {
146+
case (t1, r, func1) => (cref, t1, r, func1)
95147
}
96-
}
97148
}
98149

99150
private object ShouldTransformTree {
100151
def unapply(tree: Template)(implicit ctx: Context): Option[(Tree, Type, Type)] =
101152
tree.parents
102-
.map { t => (t.tpe, t) }
103-
.collect {
104-
case (tp @ RefinedType(RefinedType(parent, _, t1), _, r), func1)
105-
if tp.isRef(Function1) &&
106-
argTypes.contains(t1.typeSymbol) &&
107-
returnTypes.contains(r.typeSymbol) => (func1, t1, r)
108-
}
109-
.headOption
110-
}
111-
112-
private def specializedName(name: Name, t1: Type, r: Type)(implicit ctx: Context): Name =
113-
name.specializedFor(List(t1, r), List(t1, r).map(_.typeSymbol.name))
114-
115-
def transformDenot(cref: ClassDenotation, t1: Type, r: Type, func1: Type)(implicit ctx: Context): SingleDenotation = {
116-
val specializedFunction: TypeRef =
117-
ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r))
118-
119-
def replaceFunction1(in: List[TypeRef]): List[TypeRef] =
120-
in.foldRight(List.empty[TypeRef]) { (tp, acc) =>
121-
val newTp =
122-
if (tp.isRef(Function1)) specializedFunction
123-
else tp
124-
newTp :: acc
125-
}
126-
127-
def specializeApply(scope: Scope): Scope = {
128-
def specializedApply: Symbol = {
129-
val specializedMethodName = specializedName(nme.apply, t1, r)
130-
ctx.newSymbol(
131-
cref.symbol,
132-
specializedMethodName,
133-
Flags.Override | Flags.Method,
134-
specializedFunction.info.decls.lookup(specializedMethodName).info
135-
)
136-
}
137-
138-
val alteredScope = scope.cloneScope
139-
alteredScope.enter(specializedApply)
140-
alteredScope
141-
}
142-
143-
val ClassInfo(prefix, cls, parents, decls, info) = cref.classInfo
144-
val newInfo = ClassInfo(prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
145-
cref.copySymDenotation(info = newInfo)
153+
.map(t => getFunc1(t.tpe).map { case (t1, r, _) => (t, t1, r) })
154+
.flatten.headOption
146155
}
147156

148-
private def transformTree(tmpl: Template, func1: Tree, t1: Type, r: Type)(implicit ctx: Context) = {
149-
val specializedFunc1 =
150-
TypeTree(ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
151-
152-
val parents = tmpl.parents.foldRight(List.empty[Tree]) { (t, trees) =>
153-
(if (func1 eq t) specializedFunc1 else t) :: trees
154-
}
155-
156-
val body = tmpl.body.foldRight(List.empty[Tree]) {
157-
case (tree: DefDef, acc) if tree.name == nme.apply => {
158-
val specializedMethodName = specializedName(nme.apply, t1, r)
159-
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
160-
161-
val forwardingBody =
162-
tpd.ref(specializedApply)
163-
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
164-
165-
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
166-
167-
val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
168-
tree.rhs
169-
.changeOwner(tree.symbol, specializedApply)
170-
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
171-
})
172-
173-
applyWithForwarding :: specializedApplyDefDef :: acc
157+
private def getFunc1(tpe: Type)(implicit ctx: Context): Option[(Type, Type, Type)] =
158+
if (!tpe.derivesFrom(defn.FunctionClass(1)))
159+
None
160+
else
161+
tpe.baseTypeWithArgs(defn.FunctionClass(1)) match {
162+
case func1 @ RefinedType(RefinedType(parent, _, t1), _, r) if (
163+
argTypes.contains(t1.typeSymbol) && retTypes.contains(r.typeSymbol)
164+
) => Some((t1, r, func1))
165+
case _ => None
174166
}
175-
case (tree, acc) => tree :: acc
176-
}
177-
178-
cpy.Template(tmpl)(parents = parents, body = body)
179-
}
180167
}

0 commit comments

Comments
 (0)