Skip to content

Commit 722b3d0

Browse files
committed
batching mvp
1 parent aa8f0fd commit 722b3d0

File tree

17 files changed

+569
-183
lines changed

17 files changed

+569
-183
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ pub struct AutoDiffAttrs {
7777
/// e.g. in the [JAX
7878
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979
pub mode: DiffMode,
80+
pub width: u32,
8081
pub ret_activity: DiffActivity,
8182
pub input_activity: Vec<DiffActivity>,
8283
}
@@ -222,13 +223,15 @@ impl AutoDiffAttrs {
222223
pub const fn error() -> Self {
223224
AutoDiffAttrs {
224225
mode: DiffMode::Error,
226+
width: 0,
225227
ret_activity: DiffActivity::None,
226228
input_activity: Vec::new(),
227229
}
228230
}
229231
pub fn source() -> Self {
230232
AutoDiffAttrs {
231233
mode: DiffMode::Source,
234+
width: 0,
232235
ret_activity: DiffActivity::None,
233236
input_activity: Vec::new(),
234237
}

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7777
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
7878
builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode} Mode
7979
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
80+
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
8081
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
8182
8283
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 206 additions & 126 deletions
Large diffs are not rendered by default.

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,14 @@ mod autodiff {
202202
pub(crate) mode: String,
203203
}
204204

205+
#[derive(Diagnostic)]
206+
#[diag(builtin_macros_autodiff_width)]
207+
pub(crate) struct AutoDiffInvalidWidth {
208+
#[primary_span]
209+
pub(crate) span: Span,
210+
pub(crate) width: u128,
211+
}
212+
205213
#[derive(Diagnostic)]
206214
#[diag(builtin_macros_autodiff)]
207215
pub(crate) struct AutoDiffInvalidApplication {

compiler/rustc_codegen_llvm/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
codegen_llvm_autodiff_unused_args = implementation bug, failed to match all args on llvm level
12
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
23
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
34

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ pub(crate) fn run_pass_manager(
655655
unsafe {
656656
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
657657
}
658+
// This is the final IR, so people should be able to inspect the optimized autodiff output.
659+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
660+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
661+
}
658662

659663
if cfg!(llvm_enzyme) && enable_ad {
660664
let opt_stage = llvm::OptStage::FatLTO;

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 120 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
use std::ptr;
2-
32
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
43
use rustc_codegen_ssa::ModuleCodegen;
54
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
7-
use rustc_errors::FatalError;
5+
use rustc_errors::{DiagCtxt, FatalError};
6+
use rustc_middle::bug;
87
use tracing::{debug, trace};
98

109
use crate::back::write::llvm_err;
1110
use crate::builder::SBuilder;
1211
use crate::context::SimpleCx;
1312
use crate::declare::declare_simple_fn;
14-
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
13+
use crate::errors::{AutoDiffUnusedArgs, AutoDiffWithoutEnable, LlvmError};
1514
use crate::llvm::AttributePlace::Function;
1615
use crate::llvm::{Metadata, True};
1716
use crate::value::Value;
17+
1818
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
1919

2020
fn get_params(fnc: &Value) -> Vec<&Value> {
@@ -28,6 +28,25 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828
}
2929
}
3030

31+
fn has_sret(fnc: &Value) -> bool {
32+
let num_args = unsafe { llvm::LLVMCountParams(fnc) as usize };
33+
if num_args == 0 {
34+
false
35+
} else {
36+
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
37+
}
38+
}
39+
40+
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41+
// original inputs, as well as metadata and the additional shadow arguments.
42+
// This function matches the arguments from the outer function to the inner enzyme call.
43+
//
44+
// This function also considers that Rust level arguments not always match the llvm-ir level
45+
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
46+
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
47+
// need to match those.
48+
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
49+
// using iterators and peek()?
3150
fn match_args_from_caller_to_enzyme<'ll>(
3251
cx: &SimpleCx<'ll>,
3352
args: &mut Vec<&'ll llvm::Value>,
@@ -135,6 +154,78 @@ fn match_args_from_caller_to_enzyme<'ll>(
135154
}
136155
}
137156

157+
158+
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
159+
// arguments. We do however need to declare them with their correct return type.
160+
// We already figured the correct return type out in our frontend, when generating the outer_fn,
161+
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
162+
// Beyond sret, this article describes our challenges nicely:
163+
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
164+
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
165+
fn compute_enzyme_fn_ty<'ll>(
166+
cx: &SimpleCx<'ll>,
167+
attrs: &AutoDiffAttrs,
168+
fn_to_diff: &'ll Value,
169+
outer_fn: &'ll Value,
170+
) -> &'ll llvm::Type {
171+
let fn_ty = cx.get_type_of_global(outer_fn);
172+
let mut ret_ty = cx.get_return_type(fn_ty);
173+
174+
let has_sret = has_sret(outer_fn);
175+
176+
if has_sret {
177+
// Now we don't just forward the return type, so we have to figure it out based on the
178+
// primal return type, in combination with the autodiff settings.
179+
let fn_ty = cx.get_type_of_global(fn_to_diff);
180+
let inner_ret_ty = cx.get_return_type(fn_ty);
181+
182+
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
183+
if inner_ret_ty == void_ty {
184+
dbg!(&fn_to_diff);
185+
// This indicates that even the inner function has an sret.
186+
// Right now I only look for an sret in the outer function.
187+
// This *probably* needs some extra handling, but I never ran
188+
// into such a case. So I'll wait for user reports to have a test case.
189+
bug!("sret in inner function");
190+
}
191+
192+
if attrs.width == 1 {
193+
todo!("Handle sret for scalar ad");
194+
} else {
195+
// First we check if we also have to deal with the primal return.
196+
if attrs.mode.is_fwd() {
197+
match attrs.ret_activity {
198+
DiffActivity::Dual => {
199+
let arr_ty =
200+
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
201+
ret_ty = arr_ty;
202+
}
203+
DiffActivity::DualOnly => {
204+
let arr_ty =
205+
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
206+
ret_ty = arr_ty;
207+
}
208+
DiffActivity::Const => {
209+
todo!("Not sure, do we need to do something here?");
210+
}
211+
_ => {
212+
bug!("unreachable");
213+
}
214+
}
215+
} else if attrs.mode.is_rev() {
216+
todo!("Handle sret for reverse mode");
217+
} else {
218+
bug!("unreachable");
219+
}
220+
}
221+
}
222+
223+
dbg!(&outer_fn);
224+
225+
// LLVM can figure out the input types on it's own, so we take a shortcut here.
226+
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
227+
}
228+
138229
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139230
/// function with expected naming and calling conventions[^1] which will be
140231
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -145,6 +236,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
145236
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
146237
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
147238
fn generate_enzyme_call<'ll>(
239+
_dcx: &DiagCtxt,
148240
cx: &SimpleCx<'ll>,
149241
fn_to_diff: &'ll Value,
150242
outer_fn: &'ll Value,
@@ -197,17 +289,9 @@ fn generate_enzyme_call<'ll>(
197289
// }
198290
// ```
199291
unsafe {
200-
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201-
// arguments. We do however need to declare them with their correct return type.
202-
// We already figured the correct return type out in our frontend, when generating the outer_fn,
203-
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204-
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
205-
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
292+
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
206293

207-
// LLVM can figure out the input types on it's own, so we take a shortcut here.
208-
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
209-
210-
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
294+
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211295
// think a bit more about what should go here.
212296
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
213297
let ad_fn = declare_simple_fn(
@@ -268,12 +352,31 @@ fn generate_enzyme_call<'ll>(
268352
// Now that we copied the metadata, get rid of dummy code.
269353
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
270354

271-
if cx.val_ty(call) == cx.type_void() {
355+
let has_sret = has_sret(outer_fn);
356+
357+
if cx.val_ty(call) == cx.type_void() || has_sret {
358+
if has_sret {
359+
// This is what we already have in our outer_fn (shortened):
360+
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
361+
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
362+
// <Here we are, we want to add the following two lines>
363+
// store [4 x double] %7, ptr %0, align 8
364+
// ret void
365+
// }
366+
367+
// now store the result of the enzyme call into the sret pointer.
368+
let sret_ptr = outer_args[0];
369+
let call_ty = cx.val_ty(call);
370+
assert!(llvm::LLVMRustIsArrayTy(call_ty));
371+
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
372+
}
272373
builder.ret_void();
273374
} else {
274375
builder.ret(call);
275376
}
276377

378+
dbg!(&outer_fn);
379+
277380
// Let's crash in case that we messed something up above and generated invalid IR.
278381
llvm::LLVMRustVerifyFunction(
279382
outer_fn,
@@ -300,8 +403,7 @@ pub(crate) fn differentiate<'ll>(
300403
if !diff_items.is_empty()
301404
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
302405
{
303-
let dcx = cgcx.create_dcx();
304-
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
406+
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
305407
}
306408

307409
// Before dumping the module, we want all the TypeTrees to become part of the module.
@@ -331,7 +433,7 @@ pub(crate) fn differentiate<'ll>(
331433
));
332434
};
333435

334-
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
436+
generate_enzyme_call(&diag_handler, &cx, fn_def, fn_target, item.attrs.clone());
335437
}
336438

337439
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

compiler/rustc_codegen_llvm/src/consts.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ impl<'ll> CodegenCx<'ll, '_> {
425425
let val_llty = self.val_ty(v);
426426

427427
let g = self.get_static_inner(def_id, val_llty);
428-
let llty = llvm::LLVMGlobalGetValueType(g);
428+
let llty = self.get_type_of_global(g);
429429

430430
let g = if val_llty == llty {
431431
g

compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
3838
use crate::llvm::Metadata;
3939
use crate::type_::Type;
4040
use crate::value::Value;
41-
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
41+
use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
4242

4343
/// `TyCtxt` (and related cache datastructures) can't be move between threads.
4444
/// However, there are various cx related functions which we want to be available to the builder and
@@ -643,7 +643,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
643643
llvm::set_section(g, c"llvm.metadata");
644644
}
645645
}
646-
646+
impl<'ll> SimpleCx<'ll> {
647+
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
648+
assert!(unsafe { llvm::LLVMRustIsFunctionTy(ty) });
649+
unsafe { llvm::LLVMGetReturnType(ty) }
650+
}
651+
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
652+
unsafe { llvm::LLVMGlobalGetValueType(val) }
653+
}
654+
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
655+
common::val_ty(v)
656+
}
657+
}
647658
impl<'ll> SimpleCx<'ll> {
648659
pub(crate) fn new(
649660
llmod: &'ll llvm::Module,
@@ -660,6 +671,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
660671
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
661672
}
662673

674+
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
675+
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
676+
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
677+
}
678+
663679
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
664680
let name = SmallCStr::new(name);
665681
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }

compiler/rustc_codegen_llvm/src/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
9494
#[diag(codegen_llvm_autodiff_without_lto)]
9595
pub(crate) struct AutoDiffWithoutLTO;
9696

97+
#[derive(Diagnostic)]
98+
#[diag(codegen_llvm_autodiff_unused_args)]
99+
pub(crate) struct AutoDiffUnusedArgs;
100+
97101
#[derive(Diagnostic)]
98102
#[diag(codegen_llvm_autodiff_without_enable)]
99103
pub(crate) struct AutoDiffWithoutEnable;

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use libc::{c_char, c_uint};
55

66
use super::MetadataKindId;
7-
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
7+
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
88
use crate::llvm::Bool;
99

1010
#[link(name = "llvm-wrapper", kind = "static")]
@@ -17,6 +17,11 @@ unsafe extern "C" {
1717
pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
1818
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
1919
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
20+
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
21+
22+
pub(crate) fn LLVMRustIsFunctionTy(Ty: &Type) -> bool;
23+
pub(crate) fn LLVMRustIsArrayTy(Ty: &Type) -> bool;
24+
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
2025
}
2126

2227
unsafe extern "C" {

0 commit comments

Comments
 (0)