Skip to content

Commit 9d55e41

Browse files
committed
Add parent types explicitly when specializing
When a class directly extends a specialized function class, we need to replace the parent with the specialized interface. In other cases we don't replace it, even if the parent of a parent has a specialized apply - the symbols would propagate anyway.
1 parent 3f2652b commit 9d55e41

File tree

3 files changed

+89
-46
lines changed

3 files changed

+89
-46
lines changed

compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,65 @@ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
1515
import ast.tpd._
1616
val phaseName = "specializeFunctions"
1717

18+
private[this] var _blacklistedSymbols: List[Symbol] = _
19+
20+
private def blacklistedSymbols(implicit ctx: Context): List[Symbol] = {
21+
if (_blacklistedSymbols eq null) _blacklistedSymbols = List(
22+
ctx.getClassIfDefined("scala.math.Ordering").asClass.membersNamed("Ops".toTypeName).first.symbol
23+
)
24+
25+
_blacklistedSymbols
26+
}
27+
1828
/** Transforms the type to include decls for specialized applys and replace
1929
* the class parents with specialized versions.
2030
*/
2131
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
22-
case tp: ClassInfo if !sym.is(Flags.Package) => {
32+
case tp: ClassInfo
33+
if !sym.is(Flags.Package) &&
34+
!blacklistedSymbols.contains(sym) &&
35+
(tp.decls ne EmptyScope)
36+
=> {
2337
var newApplys: List[Symbol] = Nil
2438

2539
val newParents = tp.parents.mapConserve { parent =>
26-
if (defn.isPlainFunctionClass(parent.symbol)) {
27-
val typeParams = tp.typeRef.baseArgTypes(parent.classSymbol)
28-
val interface = specInterface(typeParams)
29-
30-
if (interface.exists) {
31-
val specializedApply: Symbol = {
32-
val specializedMethodName = specializedName(nme.apply, typeParams)
33-
ctx.newSymbol(
34-
sym,
35-
specializedMethodName,
36-
Flags.Override | Flags.Method,
37-
interface.info.decls.lookup(specializedMethodName).info
38-
)
40+
List(0, 1, 2, 3).flatMap { arity =>
41+
val func = defn.FunctionClass(arity)
42+
if (!parent.isRef(func)) Nil
43+
else {
44+
val typeParams = tp.typeRef.baseArgInfos(func)
45+
val interface = specInterface(typeParams)
46+
47+
if (interface.exists) {
48+
val specializedApply = {
49+
val specializedMethodName = specializedName(nme.apply, typeParams)
50+
ctx.newSymbol(
51+
sym,
52+
specializedMethodName,
53+
Flags.Override | Flags.Method,
54+
interface.info.decls.lookup(specializedMethodName).info
55+
)
56+
}
57+
newApplys = specializedApply :: newApplys
58+
List(interface.typeRef)
3959
}
40-
41-
newApplys = specializedApply :: newApplys
42-
interface.typeRef
60+
else Nil
4361
}
44-
else parent
4562
}
46-
else parent
63+
.headOption
64+
.getOrElse(parent)
4765
}
4866

49-
def newDecls = newApplys.foldLeft(tp.decls.cloneScope) {
50-
(scope, sym) => scope.enter(sym); scope
51-
}
67+
def newDecls =
68+
if (newApplys.isEmpty) tp.decls
69+
else newApplys.foldLeft(tp.decls.cloneScope) {
70+
(scope, sym) => scope.enter(sym); scope
71+
}
5272

53-
if (newApplys eq Nil) tp
54-
else tp.derivedClassInfo(classParents = newParents, decls = newDecls)
73+
tp.derivedClassInfo(
74+
classParents = newParents,
75+
decls = newDecls
76+
)
5577
}
5678

5779
case _ => tp
@@ -63,10 +85,10 @@ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
6385
* template body.
6486
*/
6587
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
66-
val buf = new mutable.ListBuffer[Tree]
88+
val applyBuf = new mutable.ListBuffer[Tree]
6789
val newBody = tree.body.mapConserve {
6890
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
69-
val specializedApply = ctx.owner.info.decls.lookup {
91+
val specializedApply = tree.symbol.enclosingClass.info.decls.lookup {
7092
specializedName(
7193
nme.apply,
7294
dt.vparamss.head.map(_.symbol.info) :+ dt.tpe.widen.finalResultType
@@ -82,7 +104,7 @@ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
82104
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
83105
})
84106

85-
buf += specializedDecl
107+
applyBuf += specializedDecl
86108

87109
// create a forwarding to the specialized apply
88110
cpy.DefDef(dt)(rhs = {
@@ -95,29 +117,52 @@ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
95117
case x => x
96118
}
97119

98-
val newParents = tree.parents.mapConserve { parent =>
99-
if (defn.isPlainFunctionClass(parent.symbol)) {
100-
val typeParams = tree.tpe.baseArgTypes(parent.symbol)
101-
val interface = specInterface(typeParams)
120+
val missing: List[TypeTree] = List(0, 1, 2, 3).flatMap { arity =>
121+
val func = defn.FunctionClass(arity)
122+
val tr = tree.symbol.enclosingClass.typeRef
102123

103-
if (interface.exists) TypeTree(interface.info)
104-
else parent
105-
}
106-
else parent
107-
}
124+
if (!tr.parents.exists(_.isRef(func))) Nil
125+
else {
126+
val typeParams = tr.baseArgInfos(func)
127+
val interface = specInterface(typeParams)
128+
129+
if (interface.exists) List(interface.info)
130+
else Nil
131+
}
132+
}.map(TypeTree)
108133

109-
cpy.Template(tree)(parents = newParents, body = buf.toList ++ newBody)
134+
cpy.Template(tree)(
135+
parents = tree.parents ++ missing,
136+
body = applyBuf.toList ++ newBody
137+
)
110138
}
111139

112140
/** Dispatch to specialized `apply`s in user code when available */
113141
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) =
114142
tree match {
115-
case Apply(select @ Select(id, nme.apply), args) => {
116-
val params = args.map(_.tpe) :+ tree.tpe
143+
case app @ Apply(fun, args)
144+
if fun.symbol.name == nme.apply &&
145+
fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
146+
=> {
147+
val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
117148
val specializedApply = specializedName(nme.apply, params)
118149

119-
if (tree.fun.symbol.owner.info.member(specializedApply).exists)
120-
tpd.Apply(tpd.Select(id, specializedApply), args)
150+
if (!params.exists(_.isInstanceOf[ExprType]) && defn.FunctionClass(args.length).info.member(specializedApply).exists) {
151+
val newSel = fun match {
152+
case Select(qual, _) =>
153+
qual.select(specializedApply)
154+
case _ => {
155+
(fun.tpe: @unchecked) match {
156+
case TermRef(prefix: ThisType, name) =>
157+
tpd.This(prefix.cls).select(specializedApply)
158+
case TermRef(prefix: NamedType, name) =>
159+
tpd.ref(prefix).select(specializedApply)
160+
}
161+
}
162+
}
163+
164+
newSel.appliedToArgs(args)
165+
}
121166
else tree
122167
}
123168
case _ => tree
@@ -128,7 +173,7 @@ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
128173

129174
@inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = {
130175
val specName =
131-
("JFunction".toTermName ++ (typeParams.length - 1))
176+
("JFunction" + (typeParams.length - 1)).toTermName
132177
.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
133178

134179
ctx.getClassIfDefined("scala.compat.java8.".toTermName ++ specName)

compiler/src/dotty/tools/dotc/transform/SpecializedApplyMethods.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
4444
t1 <- List(IntType, LongType, FloatType, DoubleType)
4545
} yield specApply(func1, List(t1), r)
4646

47-
func2 = defn.FunctionClass(2)
47+
func2 = FunctionClass(2)
4848
func2Applys = for {
4949
r <- List(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
5050
t1 <- List(IntType, LongType, DoubleType)
@@ -54,7 +54,7 @@ class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
5454

5555
/** Add symbols for specialized methods to FunctionN */
5656
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
57-
case tp: ClassInfo if defn.isFunctionClass(sym) => {
57+
case tp: ClassInfo if defn.isPlainFunctionClass(sym) => {
5858
init()
5959
val newDecls = sym.name.functionArity match {
6060
case 0 => func0Applys.foldLeft(tp.decls.cloneScope) {

tests/run/t2857.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@ object Test extends dotty.runtime.LegacyApp {
55
m.removeBinding(6, "Foo")
66
println(m.contains(6))
77
}
8-
9-

0 commit comments

Comments
 (0)