Skip to content

Check exhaustivity of any case class #22604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ trait BCodeSkelBuilder extends BCodeHelpers {

/* ---------------- helper utils for generating classes and fields ---------------- */

def genPlainClass(cd0: TypeDef) = cd0 match {
def genPlainClass(cd0: TypeDef) = (cd0: @unchecked) match {
case TypeDef(_, impl: Template) =>
assert(cnode == null, "GenBCode detected nested methods.")

Expand Down
13 changes: 4 additions & 9 deletions compiler/src/dotty/tools/dotc/core/Comments.scala
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,10 @@ object Comments {
val Trim = "(?s)^[\\s&&[^\n\r]]*(.*?)\\s*$".r

val raw = ctx.docCtx.flatMap(_.docstring(sym).map(_.raw)).getOrElse("")
defs(sym) ++= defines(raw).map {
str => {
val start = skipWhitespace(str, "@define".length)
val (key, value) = str.splitAt(skipVariable(str, start))
key.drop(start) -> value
}
} map {
case (key, Trim(value)) =>
variableName(key) -> value.replaceAll("\\s+\\*+$", "")
defs(sym) ++= defines(raw).map { str =>
val start = skipWhitespace(str, "@define".length)
val (key, Trim(value)) = str.splitAt(skipVariable(str, start)): @unchecked
variableName(key.drop(start)) -> value.replaceAll("\\s+\\*+$", "")
}
}

Expand Down
7 changes: 1 addition & 6 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -841,8 +841,6 @@ object SpaceEngine {
if Nullables.unsafeNullsEnabled then self.stripNull() else self

private def exhaustivityCheckable(sel: Tree)(using Context): Boolean = trace(i"exhaustivityCheckable($sel ${sel.className})") {
val seen = collection.mutable.Set.empty[Symbol]

// Possible to check everything, but be compatible with scalac by default
def isCheckable(tp: Type): Boolean = trace(i"isCheckable($tp ${tp.className})"):
val tpw = tp.widen.dealias.stripUnsafeNulls()
Expand All @@ -856,10 +854,7 @@ object SpaceEngine {
}) ||
tpw.isRef(defn.BooleanClass) ||
classSym.isAllOf(JavaEnum) ||
classSym.is(Case) && {
if seen.add(classSym) then productSelectorTypes(tpw, sel.srcPos).exists(isCheckable(_))
else true // recursive case class: return true and other members can still fail the check
}
classSym.is(Case)

!sel.tpe.hasAnnotation(defn.UncheckedAnnot)
&& !sel.tpe.hasAnnotation(defn.RuntimeCheckedAnnot)
Expand Down
49 changes: 25 additions & 24 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1383,30 +1383,31 @@ trait Implicits:
if alt1.isExtension then
// Fall back: if both results are extension method applications,
// compare the extension methods instead of their wrappers.
def stripExtension(alt: SearchSuccess) = methPart(stripApply(alt.tree)).tpe
(stripExtension(alt1), stripExtension(alt2)) match
case (ref1: TermRef, ref2: TermRef) =>
// ref1 and ref2 might refer to type variables owned by
// alt1.tstate and alt2.tstate respectively, to compare the
// alternatives correctly we need a TyperState that includes
// constraints from both sides, see
// tests/*/extension-specificity2.scala for test cases.
val constraintsIn1 = alt1.tstate.constraint ne ctx.typerState.constraint
val constraintsIn2 = alt2.tstate.constraint ne ctx.typerState.constraint
def exploreState(alt: SearchSuccess): TyperState =
alt.tstate.fresh(committable = false)
val comparisonState =
if constraintsIn1 && constraintsIn2 then
exploreState(alt1).mergeConstraintWith(alt2.tstate)
else if constraintsIn1 then
exploreState(alt1)
else if constraintsIn2 then
exploreState(alt2)
else
ctx.typerState

diff = inContext(searchContext().withTyperState(comparisonState)):
compare(ref1, ref2, preferGeneral = true)
def stripExtension(alt: SearchSuccess) =
methPart(stripApply(alt.tree)).tpe: @unchecked match { case ref: TermRef => ref }
val ref1 = stripExtension(alt1)
val ref2 = stripExtension(alt2)
// ref1 and ref2 might refer to type variables owned by
// alt1.tstate and alt2.tstate respectively, to compare the
// alternatives correctly we need a TyperState that includes
// constraints from both sides, see
// tests/*/extension-specificity2.scala for test cases.
val constraintsIn1 = alt1.tstate.constraint ne ctx.typerState.constraint
val constraintsIn2 = alt2.tstate.constraint ne ctx.typerState.constraint
def exploreState(alt: SearchSuccess): TyperState =
alt.tstate.fresh(committable = false)
val comparisonState =
if constraintsIn1 && constraintsIn2 then
exploreState(alt1).mergeConstraintWith(alt2.tstate)
else if constraintsIn1 then
exploreState(alt1)
else if constraintsIn2 then
exploreState(alt2)
else
ctx.typerState

diff = inContext(searchContext().withTyperState(comparisonState)):
compare(ref1, ref2, preferGeneral = true)
else // alt1 is a conversion, prefer extension alt2 over it
diff = -1
if diff < 0 then alt2
Expand Down
8 changes: 3 additions & 5 deletions scaladoc/src/dotty/tools/scaladoc/site/templates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,9 @@ case class TemplateFile(
ctx.layouts.getOrElse(name, throw new RuntimeException(s"No layouts named $name in ${ctx.layouts}"))
)

def asJavaElement(o: Object): Object = o match
case m: Map[?, ?] => m.transform {
case (k: String, v: Object) => asJavaElement(v)
}.asJava
case l: List[?] => l.map(x => asJavaElement(x.asInstanceOf[Object])).asJava
def asJavaElement(o: Any): Any = o match
case m: Map[?, ?] => m.transform { (k, v) => asJavaElement(v) }.asJava
case l: List[?] => l.map(asJavaElement).asJava
case other => other

// Library requires mutable maps..
Expand Down
1 change: 1 addition & 0 deletions tests/pos/switches.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Test {
case IntAnyVal(100) => 2
case IntAnyVal(1000) => 3
case IntAnyVal(10000) => 4
case _ => -1
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/warn/i15662.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ case class Composite[T](v: T)
def m(composite: Composite[?]): Unit =
composite match {
case Composite[Int](v) => println(v) // warn: cannot be checked at runtime
case _ => println("OTHER")
}

def m2(composite: Composite[?]): Unit =
Expand Down
15 changes: 15 additions & 0 deletions tests/warn/i22590.arity2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sealed trait T_B
case class CC_A() extends T_B
case class CC_C() extends T_B

sealed trait T_A
case class CC_B[B](a: B,b:T_B) extends T_A


@main def test() = {
val v_a: CC_B[Int] = null
val v_b: Int = v_a match { // warn: match may not be exhaustive.
case CC_B(12, CC_A()) => 0
case CC_B(_, CC_C()) => 0
}
}
9 changes: 9 additions & 0 deletions tests/warn/i22590.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
sealed trait T_A
case class CC_B[T](a: T) extends T_A

@main def test() = {
val v_a: CC_B[Int] = CC_B(10)
val v_b: Int = v_a match{ // warn: match may not be exhaustive.
case CC_B(12) => 0
}
}
2 changes: 2 additions & 0 deletions tests/warn/opaque-match.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def Test[T] =
case _: C => ??? // ok
C() match
case _: O.T => ??? // warn
case _ => ???
C() match
case _: T => ??? // warn
case _ => ???

(??? : Any) match
case _: List[O.T] => ??? // warn
Expand Down
4 changes: 2 additions & 2 deletions tests/pos/t10373.scala → tests/warn/t10373.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//> using options -Xfatal-warnings -deprecation -feature
//> using options -deprecation -feature

abstract class Foo {
def bar(): Unit = this match {
Expand All @@ -7,7 +7,7 @@ abstract class Foo {
// Works fine
}

def baz(that: Foo): Unit = (this, that) match {
def baz(that: Foo): Unit = (this, that) match { // warn: match may not be exhaustive.
case (Foo_1(), _) => //do something
case (Foo_2(), _) => //do something
// match may not be exhaustive
Expand Down
Loading