@@ -8,10 +8,13 @@ import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
8
8
import Denotations ._ , SymDenotations ._ , Scopes ._ , StdNames ._ , NameOps ._ , Names ._
9
9
import ast .tpd
10
10
11
- class SpecializeFunctions extends MiniPhaseTransform with DenotTransformer {
11
+ import scala .collection .mutable
12
+ import scala .annotation .tailrec
13
+
14
+ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
12
15
import ast .tpd ._
13
16
14
- val phaseName = " specializeFunction1 "
17
+ val phaseName = " specializeFunctions "
15
18
16
19
// Setup ---------------------------------------------------------------------
17
20
private [this ] val functionName = " JFunction" .toTermName
@@ -41,191 +44,115 @@ class SpecializeFunctions extends MiniPhaseTransform with DenotTransformer {
41
44
42
45
// Transformations -----------------------------------------------------------
43
46
44
- /** Transforms all classes extending `Function1[-T1, +R]` so that
45
- * they instead extend the specialized version `JFunction$mp...`
46
- */
47
- def transform (ref : SingleDenotation )(implicit ctx : Context ) = ref match {
48
- case cref @ ShouldTransformDenot (targets) => {
49
- val specializedSymbols : Map [Symbol , (Symbol , Symbol )] = (for (SpecializationTarget (target, args, ret, original) <- targets) yield {
50
- val arity = args.length
51
- val specializedParent = ctx.getClassIfDefined {
52
- functionPkg ++ specializedName(functionName ++ arity, args, ret)
53
- }
47
+ def transformInfo (tp : Type , sym : Symbol )(implicit ctx : Context ) = tp match {
48
+ case tp : ClassInfo if ! sym.is(Flags .Package ) => {
49
+ val newDecls = tp.decls.cloneScope
50
+ def newParents = tp.parents.mapConserve { parent =>
51
+ if (defn.isPlainFunctionClass(parent.symbol)) {
52
+ val typeParams = tp.typeRef.baseArgTypes(parent.classSymbol)
53
+ val interface = specInterface(typeParams)
54
+
55
+ if (interface.exists) {
56
+ val specializedApply : Symbol = {
57
+ val specializedMethodName = specializedName(nme.apply, typeParams)
58
+ ctx.newSymbol(
59
+ sym,
60
+ specializedMethodName,
61
+ Flags .Override | Flags .Method ,
62
+ interface.info.decls.lookup(specializedMethodName).info
63
+ )
64
+ }
54
65
55
- val specializedApply : Symbol = {
56
- val specializedMethodName = specializedName(nme.apply, args, ret)
57
- ctx.newSymbol(
58
- cref.symbol,
59
- specializedMethodName,
60
- Flags .Override | Flags .Method ,
61
- specializedParent.info.decls.lookup(specializedMethodName).info
62
- )
66
+ newDecls.enter(specializedApply)
67
+ interface.typeRef
68
+ }
69
+ else parent
63
70
}
71
+ else parent
72
+ }
64
73
65
- original -> (specializedParent, specializedApply )
66
- }).toMap
74
+ tp.derivedClassInfo(classParents = newParents, decls = newDecls )
75
+ }
67
76
68
- def specializeApplys (scope : Scope ): Scope = {
69
- val alteredScope = scope.cloneScope
70
- specializedSymbols.values.foreach { case (_, apply) =>
71
- alteredScope.enter(apply)
72
- }
73
- alteredScope
74
- }
77
+ case _ => tp
78
+ }
75
79
76
- def replace (in : List [TypeRef ]): List [TypeRef ] =
77
- in.map { tref =>
78
- val sym = tref.symbol
79
- specializedSymbols.get(sym).map { case (specializedParent, _) =>
80
- specializedParent.typeRef
81
- }
82
- .getOrElse(tref)
80
+ override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) = {
81
+ val buf = new mutable.ListBuffer [Tree ]
82
+ val newBody = tree.body.mapConserve {
83
+ case dt : DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
84
+ val specializedApply = ctx.owner.info.decls.lookup {
85
+ specializedName(
86
+ nme.apply,
87
+ dt.vparamss.head.map(_.symbol.info) :+ dt.tpe.widen.finalResultType
88
+ )
83
89
}
84
90
85
- val ClassInfo (prefix, cls, parents, decls, info) = cref.classInfo
86
- val newParents = replace(parents)
87
- val newInfo = ClassInfo (prefix, cls, newParents, specializeApplys(decls), info)
88
- cref.copySymDenotation(info = newInfo)
91
+ if (specializedApply.exists) {
92
+ val apply = specializedApply.asTerm
93
+ val specializedDecl =
94
+ polyDefDef(apply, trefs => vrefss => {
95
+ dt.rhs
96
+ .changeOwner(dt.symbol, apply)
97
+ .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
98
+ })
99
+
100
+ buf += specializedDecl
101
+
102
+ // create a forwarding to the specialized apply
103
+ cpy.DefDef (dt)(rhs = {
104
+ tpd
105
+ .ref(apply)
106
+ .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
107
+ })
108
+ } else dt
109
+ }
110
+ case x => x
89
111
}
90
- case _ => ref
91
- }
92
112
93
- /** Transform the class definition's `Template`:
94
- *
95
- * - change the tree to have the correct parent
96
- * - add the specialized apply method to the template body
97
- * - forward the old `apply` to the specialized version
98
- */
99
- override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) =
100
- tree match {
101
- case tmpl @ ShouldTransformTree (targets) => {
102
- val symbolMap = (for ((tree, SpecializationTarget (target, args, ret, orig)) <- targets) yield {
103
- val arity = args.length
104
- val specializedParent = TypeTree {
105
- ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
106
- }
107
- val specializedMethodName = specializedName(nme.apply, args, ret)
108
- val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName)
109
-
110
- if (specializedApply.exists)
111
- Some (orig -> (specializedParent, specializedApply.asTerm))
112
- else None
113
- }).flatten.toMap
114
-
115
- val body0 = tmpl.body.foldRight(List .empty[Tree ]) {
116
- case (tree : DefDef , acc) if tree.name == nme.apply => {
117
- val inheritedFrom =
118
- tree.symbol.allOverriddenSymbols
119
- .map(_.owner)
120
- .map(symbolMap.get)
121
- .flatten
122
- .toList
123
- .headOption
124
-
125
- inheritedFrom.map { case (parent, apply) =>
126
- val forwardingBody = tpd
127
- .ref(apply)
128
- .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
129
-
130
- val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
131
-
132
- val specializedApplyDefDef =
133
- polyDefDef(apply, trefs => vrefss => {
134
- tree.rhs
135
- .changeOwner(tree.symbol, apply)
136
- .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
137
- })
138
-
139
- applyWithForwarding :: specializedApplyDefDef :: acc
140
- }
141
- .getOrElse(tree :: acc)
142
- }
143
- case (tree, acc) => tree :: acc
144
- }
113
+ val newParents = tree.parents.mapConserve { parent =>
114
+ if (defn.isPlainFunctionClass(parent.symbol)) {
115
+ val typeParams = tree.tpe.baseArgTypes(parent.symbol)
116
+ val interface = specInterface(typeParams)
145
117
146
- val specializedParents = tree.parents.map { t =>
147
- symbolMap
148
- .get(t.symbol)
149
- .map { case (newSym, _) => newSym }
150
- .getOrElse(t)
118
+ if (interface.exists) TypeTree (interface.info)
119
+ else parent
151
120
}
152
-
153
- cpy.Template (tmpl)(parents = specializedParents, body = body0)
154
- }
155
- case _ => tree
121
+ else parent
156
122
}
157
123
124
+ cpy.Template (tree)(parents = newParents, body = buf.toList ++ newBody)
125
+ }
126
+
158
127
/** Dispatch to specialized `apply`s in user code */
159
128
override def transformApply (tree : Apply )(implicit ctx : Context , info : TransformerInfo ) = {
160
129
import ast .Trees ._
161
130
tree match {
162
- case Apply (select @ Select (id, nme.apply), arg :: Nil ) =>
131
+ case Apply (select @ Select (id, nme.apply), arg :: Nil ) => {
163
132
val params = List (arg.tpe, tree.tpe)
164
- val specializedApply = nme.apply.specializedFor(params , params.map(_.typeSymbol.name), Nil , Nil )
165
- val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists { sym =>
166
- sym.is(Flags .Override ) && (sym.name eq specializedApply)
133
+ val specializedApply = specializedName( nme.apply, params)
134
+ val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists {
135
+ sym => sym .is(Flags .Override ) && (sym.name eq specializedApply)
167
136
}
168
137
169
138
if (hasOverridenSpecializedApply) tpd.Apply (tpd.Select (id, specializedApply), arg :: Nil )
170
139
else tree
140
+ }
171
141
case _ => tree
172
142
}
173
143
}
174
144
175
- private def specializedName (name : Name , args : List [Type ], ret : Type )(implicit ctx : Context ): Name = {
176
- val typeParams = args :+ ret
177
- name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
178
- }
179
-
180
- // Extractors ----------------------------------------------------------------
181
- private object ShouldTransformDenot {
182
- def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [Seq [SpecializationTarget ]] =
183
- if (! cref.classParents.map(_.symbol).exists(defn.isPlainFunctionClass)) None
184
- else Some (getSpecTargets(cref.typeRef))
185
- }
145
+ @ inline private def specializedName (name : Name , args : List [Type ])(implicit ctx : Context ): Name =
146
+ name.specializedFor(args, args.map(_.typeSymbol.name), Nil , Nil )
186
147
187
- private object ShouldTransformTree {
188
- def unapply (tree : Template )(implicit ctx : Context ): Option [Seq [(Tree , SpecializationTarget )]] = {
189
- val treeToTargets = tree.parents
190
- .map(t => (t, getSpecTargets(t.tpe)))
191
- .filter(_._2.nonEmpty)
192
- .map { case (t, xs) => (t, xs.head) }
148
+ @ inline private def specInterface (typeParams : List [Type ])(implicit ctx : Context ) = {
149
+ val args = typeParams.init
150
+ val ret = typeParams.last
193
151
194
- if (treeToTargets.isEmpty) None else Some (treeToTargets)
195
- }
196
- }
152
+ val specName =
153
+ (functionName ++ args.length)
154
+ .specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
197
155
198
- private case class SpecializationTarget (target : Symbol ,
199
- params : List [Type ],
200
- ret : Type ,
201
- original : Symbol )
202
-
203
- /** Gets all valid specialization targets on `tpe`, allowing multiple
204
- * implementations of FunctionX traits
205
- */
206
- private def getSpecTargets (tpe : Type )(implicit ctx : Context ): List [SpecializationTarget ] = {
207
- val functionParents =
208
- tpe.classSymbols.iterator
209
- .flatMap(_.baseClasses)
210
- .filter(defn.isPlainFunctionClass)
211
-
212
- val tpeCls = tpe.widenDealias
213
- functionParents.map { sym =>
214
- val typeParams = tpeCls.baseArgTypes(sym)
215
- val args = typeParams.init
216
- val ret = typeParams.last
217
-
218
- val interfaceName =
219
- (functionName ++ args.length)
220
- .specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
221
-
222
- val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
223
-
224
- if (interface.exists) Some {
225
- SpecializationTarget (interface, args, ret, sym)
226
- }
227
- else None
228
- }
229
- .flatten.toList
156
+ ctx.getClassIfDefined(functionPkg ++ specName)
230
157
}
231
158
}
0 commit comments