Skip to content

Commit a480ab6

Browse files
committed
Allow specializing on const trait bounds
1 parent e1c28e0 commit a480ab6

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

compiler/rustc_typeck/src/impl_wf_check/min_specialization.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,10 @@ fn trait_predicate_kind<'tcx>(
423423
predicate: ty::Predicate<'tcx>,
424424
) -> Option<TraitSpecializationKind> {
425425
match predicate.kind().skip_binder() {
426-
ty::PredicateKind::Trait(ty::TraitPredicate {
427-
trait_ref,
428-
constness: ty::BoundConstness::NotConst,
429-
polarity: _,
430-
}) => Some(tcx.trait_def(trait_ref.def_id).specialization_kind),
431-
ty::PredicateKind::Trait(_)
432-
| ty::PredicateKind::RegionOutlives(_)
426+
ty::PredicateKind::Trait(ty::TraitPredicate { trait_ref, constness: _, polarity: _ }) => {
427+
Some(tcx.trait_def(trait_ref.def_id).specialization_kind)
428+
}
429+
ty::PredicateKind::RegionOutlives(_)
433430
| ty::PredicateKind::TypeOutlives(_)
434431
| ty::PredicateKind::Projection(_)
435432
| ty::PredicateKind::WellFormed(_)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// check-pass
2+
#![feature(const_trait_impl, min_specialization, rustc_attrs)]
3+
4+
#[rustc_specialization_trait]
5+
#[const_trait]
6+
pub unsafe trait Sup {
7+
fn foo() -> u32;
8+
}
9+
10+
#[rustc_specialization_trait]
11+
#[const_trait]
12+
pub unsafe trait Sub: ~const Sup {}
13+
14+
unsafe impl const Sup for u8 {
15+
default fn foo() -> u32 {
16+
1
17+
}
18+
}
19+
20+
unsafe impl const Sup for () {
21+
fn foo() -> u32 {
22+
42
23+
}
24+
}
25+
26+
unsafe impl const Sub for () {}
27+
28+
#[const_trait]
29+
pub trait A {
30+
fn a() -> u32;
31+
}
32+
33+
impl<T: ~const Default> const A for T {
34+
default fn a() -> u32 {
35+
2
36+
}
37+
}
38+
39+
impl<T: ~const Default + ~const Sup> const A for T {
40+
default fn a() -> u32 {
41+
3
42+
}
43+
}
44+
45+
impl<T: ~const Default + ~const Sub> const A for T {
46+
fn a() -> u32 {
47+
T::foo()
48+
}
49+
}
50+
51+
const _: () = assert!(<()>::a() == 42);
52+
const _: () = assert!(<u8>::a() == 3);
53+
const _: () = assert!(<u16>::a() == 2);
54+
55+
fn main() {}

0 commit comments

Comments
 (0)