Skip to content

Commit 1f04bba

Browse files
committed
Factor out variance manipulation in TypeMap and TypeAccumulator
1 parent ca30985 commit 1f04bba

File tree

1 file changed

+24
-41
lines changed

1 file changed

+24
-41
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3679,14 +3679,26 @@ object Types {
36793679

36803680
// ----- TypeMaps --------------------------------------------------------------------
36813681

3682-
abstract class TypeMap(implicit protected val ctx: Context) extends (Type => Type) { thisMap =>
3682+
/** Common base class of TypeMap and TypeAccumulator */
3683+
abstract class VariantTraversal {
3684+
protected[core] var variance = 1
3685+
3686+
@inline protected def atVariance[T](v: Int)(op: => T): T = {
3687+
val saved = variance
3688+
variance = v
3689+
val res = op
3690+
variance = saved
3691+
res
3692+
}
3693+
}
3694+
3695+
abstract class TypeMap(implicit protected val ctx: Context)
3696+
extends VariantTraversal with (Type => Type) { thisMap =>
36833697

36843698
protected def stopAtStatic = true
36853699

36863700
def apply(tp: Type): Type
36873701

3688-
protected[core] var variance = 1
3689-
36903702
protected def derivedSelect(tp: NamedType, pre: Type): Type =
36913703
tp.derivedSelect(pre)
36923704
protected def derivedRefinedType(tp: RefinedType, parent: Type, info: Type): Type =
@@ -3724,16 +3736,13 @@ object Types {
37243736
case tp: NamedType =>
37253737
if (stopAtStatic && tp.symbol.isStatic) tp
37263738
else {
3727-
val saved = variance
3728-
variance = variance max 0
3739+
val prefix1 = atVariance(variance max 0)(this(tp.prefix))
37293740
// A prefix is never contravariant. Even if say `p.A` is used in a contravariant
37303741
// context, we cannot assume contravariance for `p` because `p`'s lower
37313742
// bound might not have a binding for `A` (e.g. the lower bound could be `Nothing`).
37323743
// By contrast, covariance does translate to the prefix, since we have that
37333744
// if `p <: q` then `p.A <: q.A`, and well-formedness requires that `A` is a member
37343745
// of `p`'s upper bound.
3735-
val prefix1 = this(tp.prefix)
3736-
variance = saved
37373746
derivedSelect(tp, prefix1)
37383747
}
37393748
case _: ThisType
@@ -3744,11 +3753,7 @@ object Types {
37443753
derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo))
37453754

37463755
case tp: TypeAlias =>
3747-
val saved = variance
3748-
variance *= tp.variance
3749-
val alias1 = this(tp.alias)
3750-
variance = saved
3751-
derivedTypeAlias(tp, alias1)
3756+
derivedTypeAlias(tp, atVariance(variance * tp.variance)(this(tp.alias)))
37523757

37533758
case tp: TypeBounds =>
37543759
variance = -variance
@@ -3764,12 +3769,8 @@ object Types {
37643769
if (inst.exists) apply(inst) else tp
37653770

37663771
case tp: HKApply =>
3767-
def mapArg(arg: Type, tparam: ParamInfo): Type = {
3768-
val saved = variance
3769-
variance *= tparam.paramVariance
3770-
try this(arg)
3771-
finally variance = saved
3772-
}
3772+
def mapArg(arg: Type, tparam: ParamInfo): Type =
3773+
atVariance(variance * tparam.paramVariance)(this(arg))
37733774
derivedAppliedType(tp, this(tp.tycon),
37743775
tp.args.zipWithConserve(tp.typeParams)(mapArg))
37753776

@@ -3894,12 +3895,6 @@ object Types {
38943895
case _ => tp
38953896
}
38963897

3897-
protected def atVariance[T](v: Int)(op: => T): T = {
3898-
val saved = variance
3899-
variance = v
3900-
try op finally variance = saved
3901-
}
3902-
39033898
/** Derived selection.
39043899
* @pre the (upper bound of) prefix `pre` has a member named `tp.name`.
39053900
*/
@@ -4058,23 +4053,17 @@ object Types {
40584053

40594054
// ----- TypeAccumulators ----------------------------------------------------
40604055

4061-
abstract class TypeAccumulator[T](implicit protected val ctx: Context) extends ((T, Type) => T) {
4056+
abstract class TypeAccumulator[T](implicit protected val ctx: Context)
4057+
extends VariantTraversal with ((T, Type) => T) {
40624058

40634059
protected def stopAtStatic = true
40644060

40654061
def apply(x: T, tp: Type): T
40664062

40674063
protected def applyToAnnot(x: T, annot: Annotation): T = x // don't go into annotations
40684064

4069-
protected var variance = 1
4070-
4071-
protected final def applyToPrefix(x: T, tp: NamedType) = {
4072-
val saved = variance
4073-
variance = variance max 0 // see remark on NamedType case in TypeMap
4074-
val result = this(x, tp.prefix)
4075-
variance = saved
4076-
result
4077-
}
4065+
protected final def applyToPrefix(x: T, tp: NamedType) =
4066+
atVariance(variance max 0)(this(x, tp.prefix)) // see remark on NamedType case in TypeMap
40784067

40794068
def foldOver(x: T, tp: Type): T = tp match {
40804069
case tp: TypeRef =>
@@ -4095,13 +4084,7 @@ object Types {
40954084
this(this(x, tp.parent), tp.refinedInfo)
40964085

40974086
case bounds @ TypeBounds(lo, hi) =>
4098-
if (lo eq hi) {
4099-
val saved = variance
4100-
variance = variance * bounds.variance
4101-
val result = this(x, lo)
4102-
variance = saved
4103-
result
4104-
}
4087+
if (lo eq hi) atVariance(variance * bounds.variance)(this(x, lo))
41054088
else {
41064089
variance = -variance
41074090
val y = this(x, lo)

0 commit comments

Comments
 (0)