Skip to content

Commit 6e3ca78

Browse files
committed
add new tests for batching and update old ones for the new canonicalized form
1 parent f6ac98c commit 6e3ca78

File tree

6 files changed

+219
-53
lines changed

6 files changed

+219
-53
lines changed

tests/codegen/autodiff.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn square(x: &f64) -> f64 {
1111
x * x
1212
}
1313

14-
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
14+
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
1515
// CHECK-NEXT:invertstart:
1616
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
1717
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
@@ -22,7 +22,7 @@ fn square(x: &f64) -> f64 {
2222
// CHECK-NEXT:}
2323

2424
fn main() {
25-
let x = 3.0;
25+
let x = std::hint::black_box(3.0);
2626
let output = square(&x);
2727
assert_eq!(9.0, output);
2828

tests/codegen/autodiffv.rs

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
#![feature(autodiff)]
6+
7+
use std::autodiff::autodiff;
8+
9+
#[autodiff(d_square3, Forward, Dual, DualOnly)]
10+
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
11+
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
12+
#[no_mangle]
13+
fn square(x: &f32) -> f32 {
14+
x * x
15+
}
16+
17+
// d_sqaure2
18+
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
19+
// CHECK-NEXT: start:
20+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
21+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !38, !noalias !39
22+
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
23+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !40, !noalias !41
24+
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
25+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !42, !noalias !43
26+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
27+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !44, !noalias !45
28+
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
29+
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
30+
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
31+
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
32+
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
33+
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
34+
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
35+
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
36+
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
37+
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
38+
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
39+
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
40+
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
41+
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
42+
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
43+
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
44+
// CHECK-NEXT: ret [4 x float] %19
45+
// CHECK-NEXT: }
46+
47+
// d_square3, the extra float is the original return value (x * x)
48+
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
49+
// CHECK-NEXT: start:
50+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
51+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !46, !noalias !47
52+
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
53+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !48, !noalias !49
54+
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
55+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !50, !noalias !51
56+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
57+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !52, !noalias !53
58+
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
59+
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
60+
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
61+
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
62+
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
63+
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
64+
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
65+
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
66+
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
67+
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
68+
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
69+
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
70+
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
71+
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
72+
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
73+
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
74+
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
75+
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
76+
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
77+
// CHECK-NEXT: ret { float, [4 x float] } %21
78+
// CHECK-NEXT: }
79+
80+
fn main() {
81+
let x = std::hint::black_box(3.0);
82+
let output = square(&x);
83+
dbg!(&output);
84+
assert_eq!(9.0, output);
85+
dbg!(square(&x));
86+
87+
let mut df_dx1 = 1.0;
88+
let mut df_dx2 = 2.0;
89+
let mut df_dx3 = 3.0;
90+
let mut df_dx4 = 0.0;
91+
let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
92+
dbg!(o1, o2, o3, o4);
93+
let [output2, o1, o2, o3, o4] =
94+
d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
95+
dbg!(o1, o2, o3, o4);
96+
assert_eq!(output, output2);
97+
assert!((6.0 - o1).abs() < 1e-10);
98+
assert!((12.0 - o2).abs() < 1e-10);
99+
assert!((18.0 - o3).abs() < 1e-10);
100+
assert!((0.0 - o4).abs() < 1e-10);
101+
assert_eq!(1.0, df_dx1);
102+
assert_eq!(2.0, df_dx2);
103+
assert_eq!(3.0, df_dx3);
104+
assert_eq!(0.0, df_dx4);
105+
assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
106+
assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
107+
assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
108+
assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
109+
}

tests/pretty/autodiff_forward.pp

+79-21
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,52 @@
2525

2626
// We want to be sure that the same function can be differentiated in different ways
2727

28+
29+
// Make sure, that we add the None for the default return.
30+
31+
2832
::core::panicking::panic("not implemented")
2933
}
30-
#[rustc_autodiff(Forward, Dual, Const, Dual,)]
34+
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
3135
#[inline(never)]
32-
pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) {
36+
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
3337
unsafe { asm!("NOP", options(pure, nomem)); };
3438
::core::hint::black_box(f1(x, y));
35-
::core::hint::black_box((bx,));
36-
::core::hint::black_box((f1(x, y), f64::default()))
39+
::core::hint::black_box((bx_0,));
40+
::core::hint::black_box(<(f64, f64)>::default())
3741
}
3842
#[rustc_autodiff]
3943
#[inline(never)]
4044
pub fn f2(x: &[f64], y: f64) -> f64 {
4145
::core::panicking::panic("not implemented")
4246
}
43-
#[rustc_autodiff(Forward, Dual, Const, Const,)]
47+
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
4448
#[inline(never)]
45-
pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
49+
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
4650
unsafe { asm!("NOP", options(pure, nomem)); };
4751
::core::hint::black_box(f2(x, y));
48-
::core::hint::black_box((bx,));
52+
::core::hint::black_box((bx_0,));
4953
::core::hint::black_box(f2(x, y))
5054
}
5155
#[rustc_autodiff]
5256
#[inline(never)]
5357
pub fn f3(x: &[f64], y: f64) -> f64 {
5458
::core::panicking::panic("not implemented")
5559
}
56-
#[rustc_autodiff(Forward, Dual, Const, Const,)]
60+
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
5761
#[inline(never)]
58-
pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 {
62+
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
5963
unsafe { asm!("NOP", options(pure, nomem)); };
6064
::core::hint::black_box(f3(x, y));
61-
::core::hint::black_box((bx,));
65+
::core::hint::black_box((bx_0,));
6266
::core::hint::black_box(f3(x, y))
6367
}
6468
#[rustc_autodiff]
6569
#[inline(never)]
6670
pub fn f4() {}
67-
#[rustc_autodiff(Forward, None)]
71+
#[rustc_autodiff(Forward, 1, None)]
6872
#[inline(never)]
69-
pub fn df4() {
73+
pub fn df4() -> () {
7074
unsafe { asm!("NOP", options(pure, nomem)); };
7175
::core::hint::black_box(f4());
7276
::core::hint::black_box(());
@@ -76,28 +80,82 @@
7680
pub fn f5(x: &[f64], y: f64) -> f64 {
7781
::core::panicking::panic("not implemented")
7882
}
79-
#[rustc_autodiff(Forward, Const, Dual, Const,)]
83+
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
8084
#[inline(never)]
81-
pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 {
85+
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
8286
unsafe { asm!("NOP", options(pure, nomem)); };
8387
::core::hint::black_box(f5(x, y));
84-
::core::hint::black_box((by,));
88+
::core::hint::black_box((by_0,));
8589
::core::hint::black_box(f5(x, y))
8690
}
87-
#[rustc_autodiff(Forward, Dual, Const, Const,)]
91+
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
8892
#[inline(never)]
89-
pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 {
93+
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
9094
unsafe { asm!("NOP", options(pure, nomem)); };
9195
::core::hint::black_box(f5(x, y));
92-
::core::hint::black_box((bx,));
96+
::core::hint::black_box((bx_0,));
9397
::core::hint::black_box(f5(x, y))
9498
}
95-
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
99+
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
96100
#[inline(never)]
97-
pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
101+
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
98102
unsafe { asm!("NOP", options(pure, nomem)); };
99103
::core::hint::black_box(f5(x, y));
100-
::core::hint::black_box((dx, dret));
104+
::core::hint::black_box((dx_0, dret));
101105
::core::hint::black_box(f5(x, y))
102106
}
107+
struct DoesNotImplDefault;
108+
#[rustc_autodiff]
109+
#[inline(never)]
110+
pub fn f6() -> DoesNotImplDefault {
111+
::core::panicking::panic("not implemented")
112+
}
113+
#[rustc_autodiff(Forward, 1, Const)]
114+
#[inline(never)]
115+
pub fn df6() -> DoesNotImplDefault {
116+
unsafe { asm!("NOP", options(pure, nomem)); };
117+
::core::hint::black_box(f6());
118+
::core::hint::black_box(());
119+
::core::hint::black_box(f6())
120+
}
121+
#[rustc_autodiff]
122+
#[inline(never)]
123+
pub fn f7(x: f32) -> () {}
124+
#[rustc_autodiff(Forward, 1, Const, None)]
125+
#[inline(never)]
126+
pub fn df7(x: f32) -> () {
127+
unsafe { asm!("NOP", options(pure, nomem)); };
128+
::core::hint::black_box(f7(x));
129+
::core::hint::black_box(());
130+
}
131+
#[no_mangle]
132+
#[rustc_autodiff]
133+
#[inline(never)]
134+
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
135+
#[rustc_autodiff(Forward, 4, Dual, Dual)]
136+
#[inline(never)]
137+
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
138+
-> [f32; 5usize] {
139+
unsafe { asm!("NOP", options(pure, nomem)); };
140+
::core::hint::black_box(f8(x));
141+
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
142+
::core::hint::black_box(<[f32; 5usize]>::default())
143+
}
144+
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
145+
#[inline(never)]
146+
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
147+
-> [f32; 4usize] {
148+
unsafe { asm!("NOP", options(pure, nomem)); };
149+
::core::hint::black_box(f8(x));
150+
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
151+
::core::hint::black_box(<[f32; 4usize]>::default())
152+
}
153+
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
154+
#[inline(never)]
155+
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
156+
unsafe { asm!("NOP", options(pure, nomem)); };
157+
::core::hint::black_box(f8(x));
158+
::core::hint::black_box((bx_0,));
159+
::core::hint::black_box(<f32>::default())
160+
}
103161
fn main() {}

tests/pretty/autodiff_forward.rs

+18
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,22 @@ pub fn f5(x: &[f64], y: f64) -> f64 {
3636
unimplemented!()
3737
}
3838

39+
struct DoesNotImplDefault;
40+
#[autodiff(df6, Forward, Const)]
41+
pub fn f6() -> DoesNotImplDefault {
42+
unimplemented!()
43+
}
44+
45+
// Make sure, that we add the None for the default return.
46+
#[autodiff(df7, Forward, Const)]
47+
pub fn f7(x: f32) -> () {}
48+
49+
#[autodiff(f8_1, Forward, Dual, DualOnly)]
50+
#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
51+
#[autodiff(f8_3, Forward, 4, Dual, Dual)]
52+
#[no_mangle]
53+
fn f8(x: &f32) -> f32 {
54+
unimplemented!()
55+
}
56+
3957
fn main() {}

tests/pretty/autodiff_reverse.pp

+11-11
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,18 @@
2828
2929
::core::panicking::panic("not implemented")
3030
}
31-
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
31+
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
3232
#[inline(never)]
33-
pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
33+
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
3434
unsafe { asm!("NOP", options(pure, nomem)); };
3535
::core::hint::black_box(f1(x, y));
36-
::core::hint::black_box((dx, dret));
36+
::core::hint::black_box((dx_0, dret));
3737
::core::hint::black_box(f1(x, y))
3838
}
3939
#[rustc_autodiff]
4040
#[inline(never)]
4141
pub fn f2() {}
42-
#[rustc_autodiff(Reverse, None)]
42+
#[rustc_autodiff(Reverse, 1, None)]
4343
#[inline(never)]
4444
pub fn df2() {
4545
unsafe { asm!("NOP", options(pure, nomem)); };
@@ -51,20 +51,20 @@
5151
pub fn f3(x: &[f64], y: f64) -> f64 {
5252
::core::panicking::panic("not implemented")
5353
}
54-
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
54+
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
5555
#[inline(never)]
56-
pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
56+
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
5757
unsafe { asm!("NOP", options(pure, nomem)); };
5858
::core::hint::black_box(f3(x, y));
59-
::core::hint::black_box((dx, dret));
59+
::core::hint::black_box((dx_0, dret));
6060
::core::hint::black_box(f3(x, y))
6161
}
6262
enum Foo { Reverse, }
6363
use Foo::Reverse;
6464
#[rustc_autodiff]
6565
#[inline(never)]
6666
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
67-
#[rustc_autodiff(Reverse, Const, None)]
67+
#[rustc_autodiff(Reverse, 1, Const, None)]
6868
#[inline(never)]
6969
pub fn df4(x: f32) {
7070
unsafe { asm!("NOP", options(pure, nomem)); };
@@ -76,11 +76,11 @@
7676
pub fn f5(x: *const f32, y: &f32) {
7777
::core::panicking::panic("not implemented")
7878
}
79-
#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)]
79+
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
8080
#[inline(never)]
81-
pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) {
81+
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
8282
unsafe { asm!("NOP", options(pure, nomem)); };
8383
::core::hint::black_box(f5(x, y));
84-
::core::hint::black_box((dx, dy));
84+
::core::hint::black_box((dx_0, dy_0));
8585
}
8686
fn main() {}

0 commit comments

Comments
 (0)