Skip to content

Commit af0e504

Browse files
committed
Fix widenUnion with flexible types
1 parent 47923f2 commit af0e504

File tree

11 files changed

+44
-30
lines changed

11 files changed

+44
-30
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,11 @@ trait ConstraintHandling {
696696
tp.rebind(tp.parent.hardenUnions)
697697
case tp: HKTypeLambda =>
698698
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
699+
case tp: FlexibleType =>
700+
tp.derivedFlexibleType(tp.hi.hardenUnions)
699701
case tp: OrType =>
700-
val tp1 = tp.stripNull
701-
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
702+
val tp1 = tp.stripNull(stripFlexibleTypes = false)
703+
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType, soft = false)
702704
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
703705
case _ =>
704706
tp

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ class Definitions {
648648
@tu lazy val StringModule: Symbol = StringClass.linkedClass
649649
@tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final)
650650
@tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match {
651-
case List(pt) => pt.isAny || pt.stripNull.isAnyRef
651+
case List(pt) => pt.isAny || pt.stripNull().isAnyRef
652652
case _ => false
653653
}).symbol
654654

@@ -660,13 +660,13 @@ class Definitions {
660660
@tu lazy val ClassCastExceptionClass: ClassSymbol = requiredClass("java.lang.ClassCastException")
661661
@tu lazy val ClassCastExceptionClass_stringConstructor: TermSymbol = ClassCastExceptionClass.info.member(nme.CONSTRUCTOR).suchThat(_.info.firstParamTypes match {
662662
case List(pt) =>
663-
pt.stripNull.isRef(StringClass)
663+
pt.stripNull().isRef(StringClass)
664664
case _ => false
665665
}).symbol.asTerm
666666
@tu lazy val ArithmeticExceptionClass: ClassSymbol = requiredClass("java.lang.ArithmeticException")
667667
@tu lazy val ArithmeticExceptionClass_stringConstructor: TermSymbol = ArithmeticExceptionClass.info.member(nme.CONSTRUCTOR).suchThat(_.info.firstParamTypes match {
668668
case List(pt) =>
669-
pt.stripNull.isRef(StringClass)
669+
pt.stripNull().isRef(StringClass)
670670
case _ => false
671671
}).symbol.asTerm
672672

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,13 @@ import Types.*
88
object NullOpsDecorator:
99

1010
extension (self: Type)
11-
def stripFlexible(using Context): Type = self match
12-
case FlexibleType(_, tp) => tp
13-
case _ => self
14-
1511
/** Syntactically strips the nullability from this type.
1612
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
1713
* then return `T1 | ... | Ti-1 | Ti+1 | ... | Tn`.
1814
* If this type isn't (syntactically) nullable, then returns the type unchanged.
1915
* The type will not be changed if explicit-nulls is not enabled.
2016
*/
21-
def stripNull(using Context): Type = {
17+
def stripNull(stripFlexibleTypes: Boolean = true)(using Context): Type = {
2218
def strip(tp: Type): Type =
2319
val tpWiden = tp.widenDealias
2420
val tpStripped = tpWiden match {
@@ -37,7 +33,9 @@ object NullOpsDecorator:
3733
if (tp1s ne tp1) && (tp2s ne tp2) then
3834
tp.derivedAndType(tp1s, tp2s)
3935
else tp
40-
case tp: FlexibleType => tp.hi
36+
case tp: FlexibleType =>
37+
val hi1 = strip(tp.hi)
38+
if stripFlexibleTypes then hi1 else tp.derivedFlexibleType(hi1)
4139
case tp @ TypeBounds(lo, hi) =>
4240
tp.derivedTypeBounds(strip(lo), strip(hi))
4341
case tp => tp
@@ -49,7 +47,7 @@ object NullOpsDecorator:
4947

5048
/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
5149
def isNullableUnion(using Context): Boolean = {
52-
val stripped = self.stripNull
50+
val stripped = self.stripNull()
5351
stripped ne self
5452
}
5553
end extension

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import reporting.trace
2424
import annotation.constructorOnly
2525
import cc.*
2626
import NameKinds.WildcardParamName
27-
import NullOpsDecorator.stripFlexible
2827

2928
/** Provides methods to compare types.
3029
*/

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ object Types extends TypeUtils {
928928
// Selecting `name` from a type `T | Null` is like selecting `name` from `T`, if
929929
// unsafeNulls is enabled and T is a subtype of AnyRef.
930930
// This can throw at runtime, but we trade soundness for usability.
931-
tp1.findMember(name, pre.stripNull, required, excluded)
931+
tp1.findMember(name, pre.stripNull(), required, excluded)
932932
case _ =>
933933
searchAfterJoin
934934
else searchAfterJoin
@@ -1352,13 +1352,13 @@ object Types extends TypeUtils {
13521352
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
13531353
*/
13541354
def widenUnion(using Context): Type = widen match
1355-
case tp: OrType => tp match
1356-
case OrNull(tp1) =>
1357-
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
1355+
case tp: OrType =>
1356+
val tp1 = tp.stripNull(stripFlexibleTypes = false)
1357+
if tp1 ne tp then
13581358
val tp1Widen = tp1.widenUnionWithoutNull
1359-
if (tp1Widen.isRef(defn.AnyClass)) tp1Widen
1359+
if tp1Widen.isRef(defn.AnyClass) then tp1Widen
13601360
else tp.derivedOrType(tp1Widen, defn.NullType)
1361-
case _ =>
1361+
else
13621362
tp.widenUnionWithoutNull
13631363
case tp =>
13641364
tp.widenUnionWithoutNull
@@ -1373,6 +1373,8 @@ object Types extends TypeUtils {
13731373
tp.rebind(tp.parent.widenUnion)
13741374
case tp: HKTypeLambda =>
13751375
tp.derivedLambdaType(resType = tp.resType.widenUnion)
1376+
case tp: FlexibleType =>
1377+
tp.derivedFlexibleType(tp.hi.widenUnionWithoutNull)
13761378
case tp =>
13771379
tp
13781380

@@ -3476,7 +3478,7 @@ object Types extends TypeUtils {
34763478
def apply(tp: Type)(using Context): Type = tp match {
34773479
case ft: FlexibleType => ft
34783480
case _ =>
3479-
// val tp1 = tp.stripNull
3481+
// val tp1 = tp.stripNull()
34803482
// if tp1.isNullType then
34813483
// // (Null)? =:= ? >: Null <: (Object & Null)
34823484
// FlexibleType(tp, AndType(defn.ObjectType, defn.NullType))
@@ -3749,7 +3751,8 @@ object Types extends TypeUtils {
37493751
assert(!ctx.isAfterTyper, s"$tp in $where") // we check correct kinds at PostTyper
37503752
throw TypeError(em"$tp is not a value type, cannot be used $where")
37513753

3752-
/** An extractor object to pattern match against a nullable union.
3754+
/** An extractor object to pattern match against a nullable union
3755+
* (including flexible types).
37533756
* e.g.
37543757
*
37553758
* (tp: Type) match
@@ -3760,7 +3763,7 @@ object Types extends TypeUtils {
37603763
def apply(tp: Type)(using Context) =
37613764
if tp.isNullType then tp else OrType(tp, defn.NullType, soft = false)
37623765
def unapply(tp: Type)(using Context): Option[Type] =
3763-
val tp1 = tp.stripNull
3766+
val tp1 = tp.stripNull()
37643767
if tp1 ne tp then Some(tp1) else None
37653768
}
37663769

compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
293293
val element = array.elemType.hiBound // T
294294

295295
if element <:< defn.AnyRefType
296-
|| ctx.mode.is(Mode.SafeNulls) && element.stripNull <:< defn.AnyRefType
296+
|| ctx.mode.is(Mode.SafeNulls) && element.stripNull() <:< defn.AnyRefType
297297
|| element.typeSymbol.isPrimitiveValueClass
298298
then array
299299
else defn.ArrayOf(TypeBounds.upper(AndType(element, defn.AnyRefType))) // Array[? <: T & AnyRef]

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ abstract class Recheck extends Phase, SymTransformer:
427427
TypeComparer.lub(bodyType :: casesTypes)
428428

429429
def recheckSeqLiteral(tree: SeqLiteral, pt: Type)(using Context): Type =
430-
val elemProto = pt.stripNull.elemType match
430+
val elemProto = pt.stripNull().elemType match
431431
case NoType => WildcardType
432432
case bounds: TypeBounds => WildcardType(bounds)
433433
case elemtp => elemtp

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
252252
// Second constructor of ioob that takes a String argument
253253
def filterStringConstructor(s: Symbol): Boolean = s.info match {
254254
case m: MethodType if s.isConstructor && m.paramInfos.size == 1 =>
255-
m.paramInfos.head.stripNull == defn.StringType
255+
m.paramInfos.head.stripNull() == defn.StringType
256256
case _ => false
257257
}
258258
val constructor = ioob.typeSymbol.info.decls.find(filterStringConstructor _).asTerm

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ trait Applications extends Compatibility {
971971
// one can imagine the original signature-polymorphic method as
972972
// being infinitely overloaded, with each individual overload only
973973
// being brought into existence as needed
974-
val originalResultType = funRef.symbol.info.resultType.stripNull
974+
val originalResultType = funRef.symbol.info.resultType.stripNull()
975975
val resultType =
976976
if !originalResultType.isRef(defn.ObjectClass) then originalResultType
977977
else AvoidWildcardsMap()(proto.resultType.deepenProtoTrans) match

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
10851085
* with annotation contructor, as named arguments are not allowed anywhere else in Java.
10861086
* Under explicit nulls, the pt could be nullable. We need to strip `Null` type first.
10871087
*/
1088-
val arg1 = pt.stripNull match {
1088+
val arg1 = pt.stripNull() match {
10891089
case AppliedType(a, typ :: Nil) if ctx.isJava && a.isRef(defn.ArrayClass) =>
10901090
tryAlternatively { typed(tree.arg, pt) } {
10911091
val elemTp = untpd.TypedSplice(TypeTree(typ))
@@ -1914,7 +1914,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19141914
val case1 = typedCase(cas, sel, wideSelType, tpe)(using caseCtx)
19151915
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
19161916
if !alreadyStripped && Nullables.matchesNull(case1) then
1917-
wideSelType = wideSelType.stripNull
1917+
wideSelType = wideSelType.stripNull()
19181918
alreadyStripped = true
19191919
case1
19201920
}
@@ -1937,7 +1937,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19371937
val case1 = typedCase(cas, sel, wideSelType, pt)(using caseCtx)
19381938
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
19391939
if !alreadyStripped && Nullables.matchesNull(case1) then
1940-
wideSelType = wideSelType.stripNull
1940+
wideSelType = wideSelType.stripNull()
19411941
alreadyStripped = true
19421942
case1
19431943
}
@@ -2140,7 +2140,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
21402140
else res
21412141

21422142
def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
2143-
val elemProto = pt.stripNull.elemType match {
2143+
val elemProto = pt.stripNull().elemType match {
21442144
case NoType => WildcardType
21452145
case bounds: TypeBounds => WildcardType(bounds)
21462146
case elemtp => elemtp

tests/explicit-nulls/pos/widen-nullable-union.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,16 @@ class Test {
3939
val y = x
4040
val _: (A & B) | Null = y
4141
}
42+
43+
def test1(s: String): String =
44+
val ss = if !s.isEmpty() then s.trim() else s
45+
ss + "!"
46+
47+
def test2(s: String): String =
48+
val ss = if !s.isEmpty() then s.trim().nn else s
49+
ss + "!"
50+
51+
def test3(s: String): String =
52+
val ss: String = if !s.isEmpty() then s.trim().nn else s
53+
ss + "!"
4254
}

0 commit comments

Comments
 (0)