Description
I tried this code:
#![feature(autodiff)]
use std::autodiff::autodiff;
// f(x) = x*x, f'(x) = 2.0 * x
#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
// CHECK-NEXT:invertstart:
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
// CHECK-NEXT: %2 = fadd fast double %1, %0
// CHECK-NEXT: store double %2, ptr %"x'", align 8
// CHECK-NEXT: ret double %_0
// CHECK-NEXT:}
fn main() {
let x = 3.0;
let output = square(&x);
assert_eq!(9.0, output);
let mut df_dx = 0.0;
let output_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(output, output_);
assert_eq!(6.0, df_dx);
}
I expected to see this happen: This code above works if it's pasted into main.rs, or as a codegen test.
Instead, this happened: If you copy fn square
and the autodiff macro into lib.rs (or a dependency)
it does not work anymore. It should add 2.0 * x
(thus 6.0) to the variable df_dx
, but instead does nothing.
Meta
rustc --version --verbose
:
self-build nightly
Backtrace
<backtrace>
This was previously also reported here: EnzymeAD#173
The issue is that in order to support dependencies, we need to handle encoding and decoding of our rustc_autodiff
attribute correctly.
A first step towards that is this change https://github.com/rust-lang/rust/pull/136428/files#diff-09c366d3ad3ec9a42125253b610ca83cad6b156aa2a723f6c7e83eddef7b1e8f,
where I marked it as EncodeCrossCrate::Yes
. There are likely some more changes needed
in https://github.com/rust-lang/rust/tree/master/compiler/rustc_metadata/src
There is some documentation about it in the dev guide: https://rustc-dev-guide.rust-lang.org/serialization.html
I also tried to go through some older, related PRs like this one to figure out what I'd need to change to encode autodiff, but I didn't figure it out yet: #96473
Relatedly, here are my changes to codegen_fn_attrs: https://github.com/rust-lang/rust/pull/133429/files#diff-1fc881800e5081e40249ec7bab8da6589b63fe074bfa2a03318b086ad84e6bc4 and https://github.com/rust-lang/rust/pull/133429/files#diff-a313c94149bd8e03a482dbf15c3d1b9414ca74c4f2e35f94253649c2be1ceb5c
I think Oli was hoping at some point that adding autodiff under codegen_fn_attrs could help with this issue, but I guess we need to look a bit more. So far I just tried adding a lot of dbg! statements to figure out what we need to change.
cc @vayunbiyani
Tracking: