Skip to content

Commit 0a73a26

Browse files
committed
Turn on separation checking for applications
- Use unsafeAssumeSeparate(...) as an escape hatch
1 parent 3a26fe8 commit 0a73a26

25 files changed

+259
-70
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+18-2
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ object CheckCaptures:
239239

240240
/** Was a new type installed for this tree? */
241241
def hasNuType: Boolean
242+
243+
/** Is this tree passed to a parameter or assigned to a value with a type
244+
* that contains cap in no-flip covariant position, which will necessite
245+
* a separation check?
246+
*/
247+
def needsSepCheck: Boolean
242248
end CheckerAPI
243249

244250
class CheckCaptures extends Recheck, SymTransformer:
@@ -279,6 +285,12 @@ class CheckCaptures extends Recheck, SymTransformer:
279285
*/
280286
private val todoAtPostCheck = new mutable.ListBuffer[() => Unit]
281287

288+
/** Trees that will need a separation check because they contain cap */
289+
private val sepCheckable = util.EqHashSet[Tree]()
290+
291+
extension [T <: Tree](tree: T)
292+
def needsSepCheck: Boolean = sepCheckable.contains(tree)
293+
282294
/** Instantiate capture set variables appearing contra-variantly to their
283295
* upper approximation.
284296
*/
@@ -636,11 +648,11 @@ class CheckCaptures extends Recheck, SymTransformer:
636648
val meth = tree.fun.symbol
637649
if meth == defn.Caps_unsafeAssumePure then
638650
val arg :: Nil = tree.args: @unchecked
639-
val argType0 = recheck(arg, pt.capturing(CaptureSet.universal))
651+
val argType0 = recheck(arg, pt.stripCapturing.capturing(CaptureSet.universal))
640652
val argType =
641653
if argType0.captureSet.isAlwaysEmpty then argType0
642654
else argType0.widen.stripCapturing
643-
capt.println(i"rechecking $arg with $pt: $argType")
655+
capt.println(i"rechecking unsafeAssumePure of $arg with $pt: $argType")
644656
super.recheckFinish(argType, tree, pt)
645657
else
646658
val res = super.recheckApply(tree, pt)
@@ -660,6 +672,9 @@ class CheckCaptures extends Recheck, SymTransformer:
660672
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
661673
markFree(argType.deepCaptureSet, arg.srcPos)
662674
case _ =>
675+
if formal.containsCap then
676+
arg.updNuType(freshenedFormal)
677+
sepCheckable += arg
663678
argType
664679

665680
/** Map existential captures in result to `cap` and implement the following
@@ -1785,6 +1800,7 @@ class CheckCaptures extends Recheck, SymTransformer:
17851800
end checker
17861801

17871802
checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
1803+
if ccConfig.useFresh then SepChecker(this).traverse(unit)
17881804
if !ctx.reporter.errorsReported then
17891805
// We dont report errors here if previous errors were reported, because other
17901806
// errors often result in bad applied types, but flagging these bad types gives
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
import ast.tpd
5+
import collection.mutable
6+
7+
import core.*
8+
import Symbols.*, Types.*
9+
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
10+
import CaptureSet.{Refs, emptySet}
11+
import config.Printers.capt
12+
import StdNames.nme
13+
14+
class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
15+
import tpd.*
16+
import checker.*
17+
18+
extension (cs: CaptureSet)
19+
def footprint(using Context): CaptureSet =
20+
def recur(elems: CaptureSet.Refs, newElems: List[CaptureRef]): CaptureSet.Refs = newElems match
21+
case newElem :: newElems1 =>
22+
val superElems = newElem.captureSetOfInfo.elems.filter: superElem =>
23+
!superElem.isMaxCapability && !elems.contains(superElem)
24+
recur(superElems ++ elems, superElems.toList ++ newElems1)
25+
case Nil => elems
26+
val elems: CaptureSet.Refs = cs.elems.filter(!_.isMaxCapability)
27+
CaptureSet(recur(elems, elems.toList))
28+
29+
def overlapWith(other: CaptureSet)(using Context): CaptureSet.Refs =
30+
val refs1 = cs.elems
31+
val refs2 = other.elems
32+
def common(refs1: CaptureSet.Refs, refs2: CaptureSet.Refs) =
33+
refs1.filter: ref =>
34+
ref.isExclusive && refs2.exists(_.stripReadOnly eq ref)
35+
common(refs1, refs2) ++ common(refs2, refs1)
36+
37+
private def hidden(elem: CaptureRef)(using Context): CaptureSet.Refs = elem match
38+
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ hidden(hcs)
39+
case ReadOnlyCapability(ref) => hidden(ref).map(_.readOnly)
40+
case _ => emptySet
41+
42+
private def hidden(cs: CaptureSet)(using Context): CaptureSet.Refs =
43+
val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet
44+
45+
def hiddenByElem(elem: CaptureRef): CaptureSet.Refs =
46+
if seen.add(elem) then elem match
47+
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs)
48+
case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly)
49+
case _ => emptySet
50+
else emptySet
51+
52+
def recur(cs: CaptureSet): CaptureSet.Refs =
53+
(emptySet /: cs.elems): (elems, elem) =>
54+
elems ++ hiddenByElem(elem)
55+
56+
recur(cs)
57+
end hidden
58+
59+
private def checkApply(fn: Tree, args: List[Tree])(using Context): Unit =
60+
val fnCaptures = fn.nuType.deepCaptureSet
61+
62+
def captures(arg: Tree) =
63+
val argType = arg.nuType
64+
argType match
65+
case AnnotatedType(formal1, ann) if ann.symbol == defn.UseAnnot =>
66+
argType.deepCaptureSet
67+
case _ =>
68+
argType.captureSet
69+
70+
val argCaptures = args.map(captures)
71+
capt.println(i"check separate $fn($args), fnCaptures = $fnCaptures, argCaptures = $argCaptures")
72+
var footprint = argCaptures.foldLeft(fnCaptures.footprint): (fp, ac) =>
73+
fp ++ ac.footprint
74+
val paramNames = fn.nuType.widen match
75+
case MethodType(pnames) => pnames
76+
case _ => args.indices.map(nme.syntheticParamName(_))
77+
for (arg, ac, pname) <- args.lazyZip(argCaptures).lazyZip(paramNames) do
78+
if arg.needsSepCheck then
79+
val hiddenInArg = CaptureSet(hidden(ac))
80+
//println(i"check sep $arg / $footprint / $hiddenInArg")
81+
val overlap = hiddenInArg.footprint.overlapWith(footprint)
82+
if !overlap.isEmpty then
83+
def whatStr = if overlap.size == 1 then "this capability" else "these capabilities"
84+
def funStr =
85+
if fn.symbol.exists then i"${fn.symbol}"
86+
else "the function"
87+
report.error(
88+
em"""Separation failure: argument to capture-polymorphic parameter $pname: ${arg.nuType}
89+
|captures ${CaptureSet(overlap)} and also passes $whatStr separately to $funStr""",
90+
arg.srcPos)
91+
footprint ++= hiddenInArg
92+
93+
private def traverseApply(tree: Tree, argss: List[List[Tree]])(using Context): Unit = tree match
94+
case Apply(fn, args) => traverseApply(fn, args :: argss)
95+
case TypeApply(fn, args) => traverseApply(fn, argss) // skip type arguments
96+
case _ =>
97+
if argss.nestedExists(_.needsSepCheck) then
98+
checkApply(tree, argss.flatten)
99+
100+
def traverse(tree: Tree)(using Context): Unit =
101+
tree match
102+
case tree: GenericApply =>
103+
if tree.symbol != defn.Caps_unsafeAssumeSeparate then
104+
tree.tpe match
105+
case _: MethodOrPoly =>
106+
case _ => traverseApply(tree, Nil)
107+
traverseChildren(tree)
108+
case _ =>
109+
traverseChildren(tree)
110+
end SepChecker
111+
112+
113+
114+
115+
116+

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ class Definitions {
10001000
@tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists")
10011001
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
10021002
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
1003+
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
10031004
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
10041005
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")
10051006
@tu lazy val Caps_Mutable: ClassSymbol = requiredClass("scala.caps.Mutable")

compiler/src/dotty/tools/dotc/core/Types.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -4178,7 +4178,7 @@ object Types extends TypeUtils {
41784178
tl => params.map(p => tl.integrate(params, adaptParamInfo(p))),
41794179
tl => tl.integrate(params, resultType))
41804180

4181-
/** Adapt info of parameter symbol to be integhrated into corresponding MethodType
4181+
/** Adapt info of parameter symbol to be integrated into corresponding MethodType
41824182
* using the scheme described in `fromSymbols`.
41834183
*/
41844184
def adaptParamInfo(param: Symbol, pinfo: Type)(using Context): Type =

compiler/src/dotty/tools/dotc/transform/Recheck.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ abstract class Recheck extends Phase, SymTransformer:
167167
* from the current type.
168168
*/
169169
def setNuType(tpe: Type): Unit =
170-
if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then nuTypes(tree) = tpe
170+
if nuTypes.lookup(tree) == null then updNuType(tpe)
171+
172+
/** Set new type of the tree unconditionally. */
173+
def updNuType(tpe: Type): Unit =
174+
if tpe ne tree.tpe then nuTypes(tree) = tpe
171175

172176
/** The new type of the tree, or if none was installed, the original type */
173177
def nuType(using Context): Type =

library/src/scala/caps.scala

+5
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,9 @@ import annotation.{experimental, compileTimeOnly, retainsCap}
7979
*/
8080
def unsafeAssumePure: T = x
8181

82+
/** A wrapper around code for which separation checks are suppressed.
83+
*/
84+
def unsafeAssumeSeparate[T](op: T): T = op
85+
8286
end unsafe
87+
end caps

scala2-library-cc/src/scala/collection/IndexedSeqView.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ object IndexedSeqView {
136136

137137
@SerialVersionUID(3L)
138138
class Concat[A](prefix: SomeIndexedSeqOps[A]^, suffix: SomeIndexedSeqOps[A]^)
139-
extends SeqView.Concat[A](prefix, suffix) with IndexedSeqView[A]
139+
extends SeqView.Concat[A](prefix, caps.unsafe.unsafeAssumeSeparate(suffix)) with IndexedSeqView[A]
140140

141141
@SerialVersionUID(3L)
142142
class Take[A](underlying: SomeIndexedSeqOps[A]^, n: Int)

scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala

+8-5
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
682682
remaining -= 1
683683
scout = scout.tail
684684
}
685-
dropRightState(scout)
685+
caps.unsafe.unsafeAssumeSeparate(dropRightState(scout))
686686
}
687687
}
688688

@@ -879,6 +879,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
879879
if (!cursor.stateDefined) b.append(sep).append("<not computed>")
880880
} else {
881881
@inline def same(a: LazyListIterable[A]^, b: LazyListIterable[A]^): Boolean = (a eq b) || (a.state eq b.state)
882+
// !!!CC with qualifiers, same should have cap.rd parameters
882883
// Cycle.
883884
// If we have a prefix of length P followed by a cycle of length C,
884885
// the scout will be at position (P%C) in the cycle when the cursor
@@ -890,7 +891,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
890891
// the start of the loop.
891892
var runner = this
892893
var k = 0
893-
while (!same(runner, scout)) {
894+
while (!caps.unsafe.unsafeAssumeSeparate(same(runner, scout))) {
894895
runner = runner.tail
895896
scout = scout.tail
896897
k += 1
@@ -900,11 +901,11 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
900901
// everything once. If cursor is already at beginning, we'd better
901902
// advance one first unless runner didn't go anywhere (in which case
902903
// we've already looped once).
903-
if (same(cursor, scout) && (k > 0)) {
904+
if (caps.unsafe.unsafeAssumeSeparate(same(cursor, scout)) && (k > 0)) {
904905
appendCursorElement()
905906
cursor = cursor.tail
906907
}
907-
while (!same(cursor, scout)) {
908+
while (!caps.unsafe.unsafeAssumeSeparate(same(cursor, scout))) {
908909
appendCursorElement()
909910
cursor = cursor.tail
910911
}
@@ -1052,7 +1053,9 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10521053
val head = it.next()
10531054
rest = rest.tail
10541055
restRef = rest // restRef.elem = rest
1055-
sCons(head, newLL(stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state)))
1056+
sCons(head, newLL(
1057+
caps.unsafe.unsafeAssumeSeparate(
1058+
stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state))))
10561059
} else State.Empty
10571060
}
10581061
}

scala2-library-cc/src/scala/collection/mutable/CheckedIndexedSeqView.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private[mutable] object CheckedIndexedSeqView {
7575

7676
@SerialVersionUID(3L)
7777
class Concat[A](prefix: SomeIndexedSeqOps[A]^, suffix: SomeIndexedSeqOps[A]^)(protected val mutationCount: () => Int)
78-
extends IndexedSeqView.Concat[A](prefix, suffix) with CheckedIndexedSeqView[A]
78+
extends IndexedSeqView.Concat[A](prefix, caps.unsafe.unsafeAssumeSeparate(suffix)) with CheckedIndexedSeqView[A]
7979

8080
@SerialVersionUID(3L)
8181
class Take[A](underlying: SomeIndexedSeqOps[A]^, n: Int)(protected val mutationCount: () => Int)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-- Error: tests/neg-custom-args/captures/cc-dep-param.scala:8:6 --------------------------------------------------------
2+
8 | foo(a, useA) // error: separation failure
3+
| ^
4+
| Separation failure: argument to capture-polymorphic parameter x$0: Foo[Int]^
5+
| captures {a} and also passes this capability separately to method foo
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import language.experimental.captureChecking
2+
3+
trait Foo[T]
4+
def test(): Unit =
5+
val a: Foo[Int]^ = ???
6+
val useA: () ->{a} Unit = ???
7+
def foo[X](x: Foo[X]^, op: () ->{x} Unit): Unit = ???
8+
foo(a, useA) // error: separation failure

tests/neg-custom-args/captures/cc-subst-param-exact.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@ trait Ref[T] { def set(x: T): T }
55
def test() = {
66

77
def swap[T](x: Ref[T]^)(y: Ref[T]^{x}): Unit = ???
8-
def foo[T](x: Ref[T]^): Unit =
8+
def foo[T](x: Ref[T]^{cap.rd}): Unit =
99
swap(x)(x)
1010

11-
def bar[T](x: () => Ref[T]^)(y: Ref[T]^{x}): Unit =
11+
def bar[T](x: () => Ref[T]^{cap.rd})(y: Ref[T]^{x}): Unit =
1212
swap(x())(y) // error
1313

14-
def baz[T](x: Ref[T]^)(y: Ref[T]^{x}): Unit =
14+
def baz[T](x: Ref[T]^{cap.rd})(y: Ref[T]^{x}): Unit =
1515
swap(x)(y)
1616
}
1717

1818
trait IO
1919
type Op = () -> Unit
2020
def test2(c: IO^, f: Op^{c}) = {
2121
def run(io: IO^)(op: Op^{io}): Unit = op()
22-
run(c)(f)
22+
run(c)(f) // error: separation failure
2323

2424
def bad(getIO: () => IO^, g: Op^{getIO}): Unit =
25-
run(getIO())(g) // error
25+
run(getIO())(g) // error // error: separation failure
2626
}
2727

2828
def test3() = {
2929
def run(io: IO^)(op: Op^{io}): Unit = ???
3030
val myIO: IO^ = ???
3131
val myOp: Op^{myIO} = ???
32-
run(myIO)(myOp)
32+
run(myIO)(myOp) // error: separation failure
3333
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-- Error: tests/neg-custom-args/captures/filevar-expanded.scala:34:19 --------------------------------------------------
2+
34 | withFile(io3): f => // error: separation failure
3+
| ^
4+
| Separation failure: argument to capture-polymorphic parameter x$1: (f: test2.File^{io3}) => Unit
5+
| captures {io3} and also passes this capability separately to method withFile
6+
35 | val o = Service(io3)
7+
36 | o.file = f // this is a bit dubious. It's legal since we treat class refinements
8+
37 | // as capture set variables that can be made to include refs coming from outside.
9+
38 | o.log

tests/pos-custom-args/captures/filevar-expanded.scala renamed to tests/neg-custom-args/captures/filevar-expanded.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object test2:
3131
op(new File)
3232

3333
def test(io3: IO^) =
34-
withFile(io3): f =>
34+
withFile(io3): f => // error: separation failure
3535
val o = Service(io3)
3636
o.file = f // this is a bit dubious. It's legal since we treat class refinements
3737
// as capture set variables that can be made to include refs coming from outside.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-- Error: tests/neg-custom-args/captures/function-combinators.scala:15:22 ----------------------------------------------
2+
15 | val b2 = g1.andThen(g1); // error: separation failure
3+
| ^^
4+
| Separation failure: argument to capture-polymorphic parameter x$0: Int => Int
5+
| captures {ctx1} and also passes this capability separately to method andThen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
class ContextClass
2+
type Context = ContextClass^
3+
import caps.unsafe.unsafeAssumePure
4+
5+
def Test(using ctx1: Context, ctx2: Context) =
6+
val f: Int => Int = identity
7+
val g1: Int ->{ctx1} Int = identity
8+
val g2: Int ->{ctx2} Int = identity
9+
val h: Int -> Int = identity
10+
val a1 = f.andThen(f); val _: Int ->{f} Int = a1
11+
val a2 = f.andThen(g1); val _: Int ->{f, g1} Int = a2
12+
val a3 = f.andThen(g2); val _: Int ->{f, g2} Int = a3
13+
val a4 = f.andThen(h); val _: Int ->{f} Int = a4
14+
val b1 = g1.andThen(f); val _: Int ->{f, g1} Int = b1
15+
val b2 = g1.andThen(g1); // error: separation failure
16+
val _: Int ->{g1} Int = b2
17+
val b3 = g1.andThen(g2); val _: Int ->{g1, g2} Int = b3
18+
val b4 = g1.andThen(h); val _: Int ->{g1} Int = b4
19+
val c1 = h.andThen(f); val _: Int ->{f} Int = c1
20+
val c2 = h.andThen(g1); val _: Int ->{g1} Int = c2
21+
val c3 = h.andThen(g2); val _: Int ->{g2} Int = c3
22+
val c4 = h.andThen(h); val _: Int -> Int = c4
23+
24+
val f2: (Int, Int) => Int = _ + _
25+
val f2c = f2.curried; val _: Int -> Int ->{f2} Int = f2c
26+
val f2t = f2.tupled; val _: ((Int, Int)) ->{f2} Int = f2t
27+
28+
val f3: (Int, Int, Int) => Int = ???
29+
val f3c = f3.curried; val _: Int -> Int -> Int ->{f3} Int = f3c
30+
val f3t = f3.tupled; val _: ((Int, Int, Int)) ->{f3} Int = f3t

0 commit comments

Comments
 (0)