@@ -13,27 +13,119 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
13
13
14
14
val phaseName = " specializeFunction1"
15
15
16
+ // Setup ---------------------------------------------------------------------
17
+ private [this ] val functionName = " JFunction1" .toTermName
18
+ private [this ] val functionPkg = " scala.compat.java8." .toTermName
19
+ private [this ] var argTypes : Set [Symbol ] = _
20
+ private [this ] var retTypes : Set [Symbol ] = _
21
+
22
+ override def prepareForUnit (tree : Tree )(implicit ctx : Context ) = {
23
+ argTypes = Set (defn.DoubleClass ,
24
+ defn.FloatClass ,
25
+ defn.IntClass ,
26
+ defn.LongClass ,
27
+ defn.UnitClass ,
28
+ defn.BooleanClass )
29
+
30
+ retTypes = Set (defn.DoubleClass ,
31
+ defn.FloatClass ,
32
+ defn.IntClass ,
33
+ defn.LongClass )
34
+ this
35
+ }
36
+
37
+ // Transformations -----------------------------------------------------------
38
+
16
39
/** Transforms all classes extending `Function1[-T1, +R]` so that
17
40
* they instead extend the specialized version `JFunction$mp...`
18
41
*/
19
42
def transform (ref : SingleDenotation )(implicit ctx : Context ) = ref match {
20
- case ShouldTransformDenot (cref, t1, r, func1) =>
21
- transformDenot(cref, t1, r, func1)
43
+ case ShouldTransformDenot (cref, t1, r, func1) => {
44
+ val specializedFunction : Symbol =
45
+ ctx.getClassIfDefined(functionPkg ++ specializedName(functionName, t1, r))
46
+
47
+ def replaceFunction1 (in : List [TypeRef ]): List [TypeRef ] =
48
+ in.mapConserve { tp =>
49
+ if (tp.isRef(defn.FunctionClass (1 )) && (specializedFunction ne NoSymbol ))
50
+ specializedFunction.typeRef
51
+ else tp
52
+ }
53
+
54
+ def specializeApply (scope : Scope ): Scope =
55
+ if (specializedFunction ne NoSymbol ) {
56
+ def specializedApply : Symbol = {
57
+ val specializedMethodName = specializedName(nme.apply, t1, r)
58
+ ctx.newSymbol(
59
+ cref.symbol,
60
+ specializedMethodName,
61
+ Flags .Override | Flags .Method ,
62
+ specializedFunction.info.decls.lookup(specializedMethodName).info
63
+ )
64
+ }
65
+
66
+ val alteredScope = scope.cloneScope
67
+ alteredScope.enter(specializedApply)
68
+ alteredScope
69
+ }
70
+ else scope
71
+
72
+ val ClassInfo (prefix, cls, parents, decls, info) = cref.classInfo
73
+ val newInfo = ClassInfo (prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
74
+ cref.copySymDenotation(info = newInfo)
75
+ }
22
76
case _ => ref
23
77
}
24
78
79
+ /** Transform the class definition's `Template`:
80
+ *
81
+ * - change the tree to have the correct parent
82
+ * - add the specialized apply method to the template body
83
+ * - forward the old `apply` to the specialized version
84
+ */
25
85
override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) =
26
86
tree match {
27
- case ShouldTransformTree (func1, t1, r) => transformTree(tree, func1, t1, r)
87
+ case tmpl @ ShouldTransformTree (func1, t1, r) => {
88
+ val specializedFunc1 =
89
+ TypeTree (ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
90
+
91
+ val parents = tmpl.parents.foldRight(List .empty[Tree ]) { (t, trees) =>
92
+ (if (func1 eq t) specializedFunc1 else t) :: trees
93
+ }
94
+
95
+ val body = tmpl.body.foldRight(List .empty[Tree ]) {
96
+ case (tree : DefDef , acc) if tree.name == nme.apply => {
97
+ val specializedMethodName = specializedName(nme.apply, t1, r)
98
+ val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
99
+
100
+ val forwardingBody =
101
+ tpd.ref(specializedApply)
102
+ .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
103
+
104
+ val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
105
+
106
+ val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
107
+ tree.rhs
108
+ .changeOwner(tree.symbol, specializedApply)
109
+ .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
110
+ })
111
+
112
+ applyWithForwarding :: specializedApplyDefDef :: acc
113
+ }
114
+ case (tree, acc) => tree :: acc
115
+ }
116
+
117
+ cpy.Template (tmpl)(parents = parents, body = body)
118
+ }
28
119
case _ => tree
29
120
}
30
121
122
+ /** Dispatch to specialized `apply`s in user code */
31
123
override def transformApply (tree : Apply )(implicit ctx : Context , info : TransformerInfo ) = {
32
124
import ast .Trees ._
33
125
tree match {
34
126
case Apply (select @ Select (id, nme.apply), arg :: Nil ) =>
35
127
val params = List (arg.tpe, tree.tpe)
36
- val specializedApply = nme.apply.specializedFor(params, params.map(_.typeSymbol.name))
128
+ val specializedApply = nme.apply.specializedFor(params, params.map(_.typeSymbol.name), Nil , Nil )
37
129
val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists { sym =>
38
130
sym.is(Flags .Override ) && (sym.name eq specializedApply)
39
131
}
@@ -44,137 +136,32 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
44
136
}
45
137
}
46
138
47
- private [this ] val functionName = " JFunction1" .toTermName
48
- private [this ] val functionPkg = " scala.compat.java8." .toTermName
49
- private [this ] var Function1 : Symbol = _
50
- private [this ] var argTypes : Set [Symbol ] = _
51
- private [this ] var returnTypes : Set [Symbol ] = _
52
- private [this ] var blacklisted : Set [Symbol ] = _
53
-
54
- /** Do setup of `argTypes` and `returnTypes` */
55
- override def prepareForUnit (tree : Tree )(implicit ctx : Context ) = {
56
- argTypes = Set (defn.DoubleClass ,
57
- defn.FloatClass ,
58
- defn.IntClass ,
59
- defn.LongClass ,
60
- defn.UnitClass ,
61
- defn.BooleanClass )
62
-
63
- returnTypes = Set (defn.DoubleClass ,
64
- defn.FloatClass ,
65
- defn.IntClass ,
66
- defn.LongClass )
67
-
68
- Function1 = ctx.requiredClass(" scala.Function1" )
69
-
70
- blacklisted = Set (
71
- " scala.compat.java8.JFunction1" ,
72
- " scala.runtime.AbstractFunction1" ,
73
- " scala.PartialFunction" ,
74
- " scala.runtime.AbstractPartialFunction"
75
- ).map(ctx.requiredClass(_))
76
-
77
- this
78
- }
139
+ private def specializedName (name : Name , t1 : Type , r : Type )(implicit ctx : Context ): Name =
140
+ name.specializedFor(List (t1, r), List (t1, r).map(_.typeSymbol.name), Nil , Nil )
79
141
142
+ // Extractors ----------------------------------------------------------------
80
143
private object ShouldTransformDenot {
81
- def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [(ClassDenotation , Type , Type , Type )] = {
82
- def collectFunc1 (xs : List [Type ])(implicit ctx : Context ): Option [(Type , Type , Type )] =
83
- xs.collect {
84
- case func1 @ RefinedType (RefinedType (parent, _, t1), _, r)
85
- if func1.isRef(Function1 ) => (t1, r, func1)
86
- }.headOption
87
-
88
- collectFunc1(cref.info.parentsWithArgs).flatMap { case (t1, r, func1) =>
89
- if (
90
- ! argTypes.contains(t1.typeSymbol) ||
91
- ! returnTypes.contains(r.typeSymbol) ||
92
- blacklisted.exists(sym => cref.symbol.derivesFrom(sym))
93
- ) None
94
- else Some ((cref, t1, r, func1))
144
+ def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [(ClassDenotation , Type , Type , Type )] =
145
+ getFunc1(cref.typeRef).map {
146
+ case (t1, r, func1) => (cref, t1, r, func1)
95
147
}
96
- }
97
148
}
98
149
99
150
private object ShouldTransformTree {
100
151
def unapply (tree : Template )(implicit ctx : Context ): Option [(Tree , Type , Type )] =
101
152
tree.parents
102
- .map { t => (t.tpe, t) }
103
- .collect {
104
- case (tp @ RefinedType (RefinedType (parent, _, t1), _, r), func1)
105
- if tp.isRef(Function1 ) &&
106
- argTypes.contains(t1.typeSymbol) &&
107
- returnTypes.contains(r.typeSymbol) => (func1, t1, r)
108
- }
109
- .headOption
110
- }
111
-
112
- private def specializedName (name : Name , t1 : Type , r : Type )(implicit ctx : Context ): Name =
113
- name.specializedFor(List (t1, r), List (t1, r).map(_.typeSymbol.name))
114
-
115
- def transformDenot (cref : ClassDenotation , t1 : Type , r : Type , func1 : Type )(implicit ctx : Context ): SingleDenotation = {
116
- val specializedFunction : TypeRef =
117
- ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r))
118
-
119
- def replaceFunction1 (in : List [TypeRef ]): List [TypeRef ] =
120
- in.foldRight(List .empty[TypeRef ]) { (tp, acc) =>
121
- val newTp =
122
- if (tp.isRef(Function1 )) specializedFunction
123
- else tp
124
- newTp :: acc
125
- }
126
-
127
- def specializeApply (scope : Scope ): Scope = {
128
- def specializedApply : Symbol = {
129
- val specializedMethodName = specializedName(nme.apply, t1, r)
130
- ctx.newSymbol(
131
- cref.symbol,
132
- specializedMethodName,
133
- Flags .Override | Flags .Method ,
134
- specializedFunction.info.decls.lookup(specializedMethodName).info
135
- )
136
- }
137
-
138
- val alteredScope = scope.cloneScope
139
- alteredScope.enter(specializedApply)
140
- alteredScope
141
- }
142
-
143
- val ClassInfo (prefix, cls, parents, decls, info) = cref.classInfo
144
- val newInfo = ClassInfo (prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
145
- cref.copySymDenotation(info = newInfo)
153
+ .map(t => getFunc1(t.tpe).map { case (t1, r, _) => (t, t1, r) })
154
+ .flatten.headOption
146
155
}
147
156
148
- private def transformTree (tmpl : Template , func1 : Tree , t1 : Type , r : Type )(implicit ctx : Context ) = {
149
- val specializedFunc1 =
150
- TypeTree (ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
151
-
152
- val parents = tmpl.parents.foldRight(List .empty[Tree ]) { (t, trees) =>
153
- (if (func1 eq t) specializedFunc1 else t) :: trees
154
- }
155
-
156
- val body = tmpl.body.foldRight(List .empty[Tree ]) {
157
- case (tree : DefDef , acc) if tree.name == nme.apply => {
158
- val specializedMethodName = specializedName(nme.apply, t1, r)
159
- val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
160
-
161
- val forwardingBody =
162
- tpd.ref(specializedApply)
163
- .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
164
-
165
- val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
166
-
167
- val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
168
- tree.rhs
169
- .changeOwner(tree.symbol, specializedApply)
170
- .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
171
- })
172
-
173
- applyWithForwarding :: specializedApplyDefDef :: acc
157
+ private def getFunc1 (tpe : Type )(implicit ctx : Context ): Option [(Type , Type , Type )] =
158
+ if (! tpe.derivesFrom(defn.FunctionClass (1 )))
159
+ None
160
+ else
161
+ tpe.baseTypeWithArgs(defn.FunctionClass (1 )) match {
162
+ case func1 @ RefinedType (RefinedType (parent, _, t1), _, r) if (
163
+ argTypes.contains(t1.typeSymbol) && retTypes.contains(r.typeSymbol)
164
+ ) => Some ((t1, r, func1))
165
+ case _ => None
174
166
}
175
- case (tree, acc) => tree :: acc
176
- }
177
-
178
- cpy.Template (tmpl)(parents = parents, body = body)
179
- }
180
167
}
0 commit comments