Skip to content

Commit 9c05e0b

Browse files
committed
Rewrite SpecializeFunctions from DenotTransformer to InfoTransformer
1 parent 7a40ed1 commit 9c05e0b

File tree

1 file changed

+86
-159
lines changed

1 file changed

+86
-159
lines changed

compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala

Lines changed: 86 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
88
import Denotations._, SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
99
import ast.tpd
1010

11-
class SpecializeFunctions extends MiniPhaseTransform with DenotTransformer {
11+
import scala.collection.mutable
12+
import scala.annotation.tailrec
13+
14+
class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
1215
import ast.tpd._
1316

14-
val phaseName = "specializeFunction1"
17+
val phaseName = "specializeFunctions"
1518

1619
// Setup ---------------------------------------------------------------------
1720
private[this] val functionName = "JFunction".toTermName
@@ -41,191 +44,115 @@ class SpecializeFunctions extends MiniPhaseTransform with DenotTransformer {
4144

4245
// Transformations -----------------------------------------------------------
4346

44-
/** Transforms all classes extending `Function1[-T1, +R]` so that
45-
* they instead extend the specialized version `JFunction$mp...`
46-
*/
47-
def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
48-
case cref @ ShouldTransformDenot(targets) => {
49-
val specializedSymbols: Map[Symbol, (Symbol, Symbol)] = (for (SpecializationTarget(target, args, ret, original) <- targets) yield {
50-
val arity = args.length
51-
val specializedParent = ctx.getClassIfDefined {
52-
functionPkg ++ specializedName(functionName ++ arity, args, ret)
53-
}
47+
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
48+
case tp: ClassInfo if !sym.is(Flags.Package) => {
49+
val newDecls = tp.decls.cloneScope
50+
def newParents = tp.parents.mapConserve { parent =>
51+
if (defn.isPlainFunctionClass(parent.symbol)) {
52+
val typeParams = tp.typeRef.baseArgTypes(parent.classSymbol)
53+
val interface = specInterface(typeParams)
54+
55+
if (interface.exists) {
56+
val specializedApply: Symbol = {
57+
val specializedMethodName = specializedName(nme.apply, typeParams)
58+
ctx.newSymbol(
59+
sym,
60+
specializedMethodName,
61+
Flags.Override | Flags.Method,
62+
interface.info.decls.lookup(specializedMethodName).info
63+
)
64+
}
5465

55-
val specializedApply: Symbol = {
56-
val specializedMethodName = specializedName(nme.apply, args, ret)
57-
ctx.newSymbol(
58-
cref.symbol,
59-
specializedMethodName,
60-
Flags.Override | Flags.Method,
61-
specializedParent.info.decls.lookup(specializedMethodName).info
62-
)
66+
newDecls.enter(specializedApply)
67+
interface.typeRef
68+
}
69+
else parent
6370
}
71+
else parent
72+
}
6473

65-
original -> (specializedParent, specializedApply)
66-
}).toMap
74+
tp.derivedClassInfo(classParents = newParents, decls = newDecls)
75+
}
6776

68-
def specializeApplys(scope: Scope): Scope = {
69-
val alteredScope = scope.cloneScope
70-
specializedSymbols.values.foreach { case (_, apply) =>
71-
alteredScope.enter(apply)
72-
}
73-
alteredScope
74-
}
77+
case _ => tp
78+
}
7579

76-
def replace(in: List[TypeRef]): List[TypeRef] =
77-
in.map { tref =>
78-
val sym = tref.symbol
79-
specializedSymbols.get(sym).map { case (specializedParent, _) =>
80-
specializedParent.typeRef
81-
}
82-
.getOrElse(tref)
80+
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
81+
val buf = new mutable.ListBuffer[Tree]
82+
val newBody = tree.body.mapConserve {
83+
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
84+
val specializedApply = ctx.owner.info.decls.lookup {
85+
specializedName(
86+
nme.apply,
87+
dt.vparamss.head.map(_.symbol.info) :+ dt.tpe.widen.finalResultType
88+
)
8389
}
8490

85-
val ClassInfo(prefix, cls, parents, decls, info) = cref.classInfo
86-
val newParents = replace(parents)
87-
val newInfo = ClassInfo(prefix, cls, newParents, specializeApplys(decls), info)
88-
cref.copySymDenotation(info = newInfo)
91+
if (specializedApply.exists) {
92+
val apply = specializedApply.asTerm
93+
val specializedDecl =
94+
polyDefDef(apply, trefs => vrefss => {
95+
dt.rhs
96+
.changeOwner(dt.symbol, apply)
97+
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
98+
})
99+
100+
buf += specializedDecl
101+
102+
// create a forwarding to the specialized apply
103+
cpy.DefDef(dt)(rhs = {
104+
tpd
105+
.ref(apply)
106+
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
107+
})
108+
} else dt
109+
}
110+
case x => x
89111
}
90-
case _ => ref
91-
}
92112

93-
/** Transform the class definition's `Template`:
94-
*
95-
* - change the tree to have the correct parent
96-
* - add the specialized apply method to the template body
97-
* - forward the old `apply` to the specialized version
98-
*/
99-
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) =
100-
tree match {
101-
case tmpl @ ShouldTransformTree(targets) => {
102-
val symbolMap = (for ((tree, SpecializationTarget(target, args, ret, orig)) <- targets) yield {
103-
val arity = args.length
104-
val specializedParent = TypeTree {
105-
ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
106-
}
107-
val specializedMethodName = specializedName(nme.apply, args, ret)
108-
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName)
109-
110-
if (specializedApply.exists)
111-
Some(orig -> (specializedParent, specializedApply.asTerm))
112-
else None
113-
}).flatten.toMap
114-
115-
val body0 = tmpl.body.foldRight(List.empty[Tree]) {
116-
case (tree: DefDef, acc) if tree.name == nme.apply => {
117-
val inheritedFrom =
118-
tree.symbol.allOverriddenSymbols
119-
.map(_.owner)
120-
.map(symbolMap.get)
121-
.flatten
122-
.toList
123-
.headOption
124-
125-
inheritedFrom.map { case (parent, apply) =>
126-
val forwardingBody = tpd
127-
.ref(apply)
128-
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
129-
130-
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
131-
132-
val specializedApplyDefDef =
133-
polyDefDef(apply, trefs => vrefss => {
134-
tree.rhs
135-
.changeOwner(tree.symbol, apply)
136-
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
137-
})
138-
139-
applyWithForwarding :: specializedApplyDefDef :: acc
140-
}
141-
.getOrElse(tree :: acc)
142-
}
143-
case (tree, acc) => tree :: acc
144-
}
113+
val newParents = tree.parents.mapConserve { parent =>
114+
if (defn.isPlainFunctionClass(parent.symbol)) {
115+
val typeParams = tree.tpe.baseArgTypes(parent.symbol)
116+
val interface = specInterface(typeParams)
145117

146-
val specializedParents = tree.parents.map { t =>
147-
symbolMap
148-
.get(t.symbol)
149-
.map { case (newSym, _) => newSym }
150-
.getOrElse(t)
118+
if (interface.exists) TypeTree(interface.info)
119+
else parent
151120
}
152-
153-
cpy.Template(tmpl)(parents = specializedParents, body = body0)
154-
}
155-
case _ => tree
121+
else parent
156122
}
157123

124+
cpy.Template(tree)(parents = newParents, body = buf.toList ++ newBody)
125+
}
126+
158127
/** Dispatch to specialized `apply`s in user code */
159128
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) = {
160129
import ast.Trees._
161130
tree match {
162-
case Apply(select @ Select(id, nme.apply), arg :: Nil) =>
131+
case Apply(select @ Select(id, nme.apply), arg :: Nil) => {
163132
val params = List(arg.tpe, tree.tpe)
164-
val specializedApply = nme.apply.specializedFor(params, params.map(_.typeSymbol.name), Nil, Nil)
165-
val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists { sym =>
166-
sym.is(Flags.Override) && (sym.name eq specializedApply)
133+
val specializedApply = specializedName(nme.apply, params)
134+
val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists {
135+
sym => sym.is(Flags.Override) && (sym.name eq specializedApply)
167136
}
168137

169138
if (hasOverridenSpecializedApply) tpd.Apply(tpd.Select(id, specializedApply), arg :: Nil)
170139
else tree
140+
}
171141
case _ => tree
172142
}
173143
}
174144

175-
private def specializedName(name: Name, args: List[Type], ret: Type)(implicit ctx: Context): Name = {
176-
val typeParams = args :+ ret
177-
name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
178-
}
179-
180-
// Extractors ----------------------------------------------------------------
181-
private object ShouldTransformDenot {
182-
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[Seq[SpecializationTarget]] =
183-
if (!cref.classParents.map(_.symbol).exists(defn.isPlainFunctionClass)) None
184-
else Some(getSpecTargets(cref.typeRef))
185-
}
145+
@inline private def specializedName(name: Name, args: List[Type])(implicit ctx: Context): Name =
146+
name.specializedFor(args, args.map(_.typeSymbol.name), Nil, Nil)
186147

187-
private object ShouldTransformTree {
188-
def unapply(tree: Template)(implicit ctx: Context): Option[Seq[(Tree, SpecializationTarget)]] = {
189-
val treeToTargets = tree.parents
190-
.map(t => (t, getSpecTargets(t.tpe)))
191-
.filter(_._2.nonEmpty)
192-
.map { case (t, xs) => (t, xs.head) }
148+
@inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = {
149+
val args = typeParams.init
150+
val ret = typeParams.last
193151

194-
if (treeToTargets.isEmpty) None else Some(treeToTargets)
195-
}
196-
}
152+
val specName =
153+
(functionName ++ args.length)
154+
.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
197155

198-
private case class SpecializationTarget(target: Symbol,
199-
params: List[Type],
200-
ret: Type,
201-
original: Symbol)
202-
203-
/** Gets all valid specialization targets on `tpe`, allowing multiple
204-
* implementations of FunctionX traits
205-
*/
206-
private def getSpecTargets(tpe: Type)(implicit ctx: Context): List[SpecializationTarget] = {
207-
val functionParents =
208-
tpe.classSymbols.iterator
209-
.flatMap(_.baseClasses)
210-
.filter(defn.isPlainFunctionClass)
211-
212-
val tpeCls = tpe.widenDealias
213-
functionParents.map { sym =>
214-
val typeParams = tpeCls.baseArgTypes(sym)
215-
val args = typeParams.init
216-
val ret = typeParams.last
217-
218-
val interfaceName =
219-
(functionName ++ args.length)
220-
.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
221-
222-
val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
223-
224-
if (interface.exists) Some {
225-
SpecializationTarget(interface, args, ret, sym)
226-
}
227-
else None
228-
}
229-
.flatten.toList
156+
ctx.getClassIfDefined(functionPkg ++ specName)
230157
}
231158
}

0 commit comments

Comments
 (0)