Skip to content

Commit f6ac98c

Browse files
committed
add batching backend
1 parent ac01a05 commit f6ac98c

File tree

5 files changed

+198
-32
lines changed

5 files changed

+198
-32
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+156-28
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
76
use rustc_errors::FatalError;
7+
use rustc_middle::bug;
88
use tracing::{debug, trace};
99

1010
use crate::back::write::llvm_err;
@@ -28,11 +28,32 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828
}
2929
}
3030

31+
fn has_sret(fnc: &Value) -> bool {
32+
let num_args = unsafe { llvm::LLVMCountParams(fnc) as usize };
33+
if num_args == 0 {
34+
false
35+
} else {
36+
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
37+
}
38+
}
39+
40+
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41+
// original inputs, as well as metadata and the additional shadow arguments.
42+
// This function matches the arguments from the outer function to the inner enzyme call.
43+
//
44+
// This function also considers that Rust level arguments not always match the llvm-ir level
45+
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
46+
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
47+
// need to match those.
48+
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
49+
// using iterators and peek()?
3150
fn match_args_from_caller_to_enzyme<'ll>(
3251
cx: &SimpleCx<'ll>,
52+
width: u32,
3353
args: &mut Vec<&'ll llvm::Value>,
3454
inputs: &[DiffActivity],
3555
outer_args: &[&'ll llvm::Value],
56+
has_sret: bool,
3657
) {
3758
debug!("matching autodiff arguments");
3859
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +65,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
4465
let mut outer_pos: usize = 0;
4566
let mut activity_pos = 0;
4667

68+
if has_sret {
69+
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
70+
// inner function will still return something. We increase our outer_pos by one,
71+
// and once we're done with all other args we will take the return of the inner call and
72+
// update the sret pointer with it
73+
outer_pos = 1;
74+
}
75+
4776
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
4877
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
4978
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
@@ -95,20 +124,25 @@ fn match_args_from_caller_to_enzyme<'ll>(
95124
assert!(unsafe {
96125
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
97126
});
98-
let next_outer_arg2 = outer_args[outer_pos + 2];
99-
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
100-
assert!(unsafe {
101-
llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
102-
});
103-
let next_outer_arg3 = outer_args[outer_pos + 3];
104-
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
105-
assert!(unsafe {
106-
llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
107-
});
108-
args.push(next_outer_arg2);
127+
128+
for _ in 0..width {
129+
let next_outer_arg2 = outer_args[outer_pos + 2];
130+
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131+
assert!(
132+
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty2) }
133+
== llvm::TypeKind::Pointer
134+
);
135+
let next_outer_arg3 = outer_args[outer_pos + 3];
136+
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
137+
assert!(
138+
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty3) }
139+
== llvm::TypeKind::Integer
140+
);
141+
args.push(next_outer_arg2);
142+
}
109143
args.push(cx.get_metadata_value(enzyme_const));
110144
args.push(next_outer_arg);
111-
outer_pos += 4;
145+
outer_pos += 2 + 2 * width as usize;
112146
activity_pos += 2;
113147
} else {
114148
// A duplicated pointer will have the following two outer_fn arguments:
@@ -125,6 +159,13 @@ fn match_args_from_caller_to_enzyme<'ll>(
125159
args.push(next_outer_arg);
126160
outer_pos += 2;
127161
activity_pos += 1;
162+
163+
// Now, if width > 1, we need to account for that
164+
for _ in 1..width {
165+
let next_outer_arg = outer_args[outer_pos];
166+
args.push(next_outer_arg);
167+
outer_pos += 1;
168+
}
128169
}
129170
} else {
130171
// We do not differentiate with resprect to this argument.
@@ -135,6 +176,74 @@ fn match_args_from_caller_to_enzyme<'ll>(
135176
}
136177
}
137178

179+
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
180+
// arguments. We do however need to declare them with their correct return type.
181+
// We already figured the correct return type out in our frontend, when generating the outer_fn,
182+
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
183+
// Beyond sret, this article describes our challenges nicely:
184+
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
185+
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
186+
fn compute_enzyme_fn_ty<'ll>(
187+
cx: &SimpleCx<'ll>,
188+
attrs: &AutoDiffAttrs,
189+
fn_to_diff: &'ll Value,
190+
outer_fn: &'ll Value,
191+
) -> &'ll llvm::Type {
192+
let fn_ty = cx.get_type_of_global(outer_fn);
193+
let mut ret_ty = cx.get_return_type(fn_ty);
194+
195+
let has_sret = has_sret(outer_fn);
196+
197+
if has_sret {
198+
// Now we don't just forward the return type, so we have to figure it out based on the
199+
// primal return type, in combination with the autodiff settings.
200+
let fn_ty = cx.get_type_of_global(fn_to_diff);
201+
let inner_ret_ty = cx.get_return_type(fn_ty);
202+
203+
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
204+
if inner_ret_ty == void_ty {
205+
// This indicates that even the inner function has an sret.
206+
// Right now I only look for an sret in the outer function.
207+
// This *probably* needs some extra handling, but I never ran
208+
// into such a case. So I'll wait for user reports to have a test case.
209+
bug!("sret in inner function");
210+
}
211+
212+
if attrs.width == 1 {
213+
todo!("Handle sret for scalar ad");
214+
} else {
215+
// First we check if we also have to deal with the primal return.
216+
if attrs.mode.is_fwd() {
217+
match attrs.ret_activity {
218+
DiffActivity::Dual => {
219+
let arr_ty =
220+
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
221+
ret_ty = arr_ty;
222+
}
223+
DiffActivity::DualOnly => {
224+
let arr_ty =
225+
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
226+
ret_ty = arr_ty;
227+
}
228+
DiffActivity::Const => {
229+
todo!("Not sure, do we need to do something here?");
230+
}
231+
_ => {
232+
bug!("unreachable");
233+
}
234+
}
235+
} else if attrs.mode.is_rev() {
236+
todo!("Handle sret for reverse mode");
237+
} else {
238+
bug!("unreachable");
239+
}
240+
}
241+
}
242+
243+
// LLVM can figure out the input types on it's own, so we take a shortcut here.
244+
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
245+
}
246+
138247
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139248
/// function with expected naming and calling conventions[^1] which will be
140249
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +306,9 @@ fn generate_enzyme_call<'ll>(
197306
// }
198307
// ```
199308
unsafe {
200-
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201-
// arguments. We do however need to declare them with their correct return type.
202-
// We already figured the correct return type out in our frontend, when generating the outer_fn,
203-
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204-
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
205-
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
309+
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
206310

207-
// LLVM can figure out the input types on it's own, so we take a shortcut here.
208-
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
209-
210-
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
311+
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211312
// think a bit more about what should go here.
212313
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
213314
let ad_fn = declare_simple_fn(
@@ -240,14 +341,27 @@ fn generate_enzyme_call<'ll>(
240341
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
241342
args.push(cx.get_metadata_value(enzyme_primal_ret));
242343
}
344+
if attrs.width > 1 {
345+
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
346+
args.push(cx.get_metadata_value(enzyme_width));
347+
args.push(cx.get_const_i64(attrs.width as u64));
348+
}
243349

350+
let has_sret = has_sret(outer_fn);
244351
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
245-
match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
352+
match_args_from_caller_to_enzyme(
353+
&cx,
354+
attrs.width,
355+
&mut args,
356+
&attrs.input_activity,
357+
&outer_args,
358+
has_sret,
359+
);
246360

247361
let call = builder.call(enzyme_ty, ad_fn, &args, None);
248362

249363
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
250-
// metadata attachted to it, but we just created this code oota. Given that the
364+
// metadata attached to it, but we just created this code oota. Given that the
251365
// differentiated function already has partly confusing metadata, and given that this
252366
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
253367
// dummy code which we inserted at a higher level.
@@ -268,7 +382,22 @@ fn generate_enzyme_call<'ll>(
268382
// Now that we copied the metadata, get rid of dummy code.
269383
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
270384

271-
if cx.val_ty(call) == cx.type_void() {
385+
if cx.val_ty(call) == cx.type_void() || has_sret {
386+
if has_sret {
387+
// This is what we already have in our outer_fn (shortened):
388+
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
389+
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
390+
// <Here we are, we want to add the following two lines>
391+
// store [4 x double] %7, ptr %0, align 8
392+
// ret void
393+
// }
394+
395+
// now store the result of the enzyme call into the sret pointer.
396+
let sret_ptr = outer_args[0];
397+
let call_ty = cx.val_ty(call);
398+
assert!(llvm::LLVMRustIsArrayTy(call_ty));
399+
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
400+
}
272401
builder.ret_void();
273402
} else {
274403
builder.ret(call);
@@ -300,8 +429,7 @@ pub(crate) fn differentiate<'ll>(
300429
if !diff_items.is_empty()
301430
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
302431
{
303-
let dcx = cgcx.create_dcx();
304-
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
432+
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
305433
}
306434

307435
// Before dumping the module, we want all the TypeTrees to become part of the module.

compiler/rustc_codegen_llvm/src/consts.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ impl<'ll> CodegenCx<'ll, '_> {
425425
let val_llty = self.val_ty(v);
426426

427427
let g = self.get_static_inner(def_id, val_llty);
428-
let llty = llvm::LLVMGlobalGetValueType(g);
428+
let llty = self.get_type_of_global(g);
429429

430430
let g = if val_llty == llty {
431431
g

compiler/rustc_codegen_llvm/src/context.rs

+18-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
3838
use crate::llvm::Metadata;
3939
use crate::type_::Type;
4040
use crate::value::Value;
41-
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
41+
use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
4242

4343
/// `TyCtxt` (and related cache datastructures) can't be move between threads.
4444
/// However, there are various cx related functions which we want to be available to the builder and
@@ -643,7 +643,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
643643
llvm::set_section(g, c"llvm.metadata");
644644
}
645645
}
646-
646+
impl<'ll> SimpleCx<'ll> {
647+
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
648+
assert!(unsafe { llvm::LLVMRustIsFunctionTy(ty) });
649+
unsafe { llvm::LLVMGetReturnType(ty) }
650+
}
651+
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
652+
unsafe { llvm::LLVMGlobalGetValueType(val) }
653+
}
654+
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
655+
common::val_ty(v)
656+
}
657+
}
647658
impl<'ll> SimpleCx<'ll> {
648659
pub(crate) fn new(
649660
llmod: &'ll llvm::Module,
@@ -660,6 +671,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
660671
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
661672
}
662673

674+
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
675+
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
676+
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
677+
}
678+
663679
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
664680
let name = SmallCStr::new(name);
665681
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use libc::{c_char, c_uint};
55

66
use super::MetadataKindId;
7-
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
7+
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
88
use crate::llvm::Bool;
99

1010
#[link(name = "llvm-wrapper", kind = "static")]
@@ -17,6 +17,11 @@ unsafe extern "C" {
1717
pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
1818
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
1919
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
20+
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
21+
22+
pub(crate) fn LLVMRustIsFunctionTy(Ty: &Type) -> bool;
23+
pub(crate) fn LLVMRustIsArrayTy(Ty: &Type) -> bool;
24+
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
2025
}
2126

2227
unsafe extern "C" {

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,12 @@ static inline void AddAttributes(T *t, unsigned Index, LLVMAttributeRef *Attrs,
384384
t->setAttributes(PALNew);
385385
}
386386

387+
extern "C" bool LLVMRustHasAttributeAtIndex(LLVMValueRef Fn, unsigned Index,
388+
LLVMRustAttributeKind RustAttr) {
389+
Function *F = unwrap<Function>(Fn);
390+
return F->hasParamAttribute(Index, fromRust(RustAttr));
391+
}
392+
387393
extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index,
388394
LLVMAttributeRef *Attrs,
389395
size_t AttrsLen) {
@@ -635,6 +641,17 @@ static InlineAsm::AsmDialect fromRust(LLVMRustAsmDialect Dialect) {
635641
report_fatal_error("bad AsmDialect.");
636642
}
637643
}
644+
extern "C" bool LLVMRustIsFunctionTy(LLVMTypeRef Ty) {
645+
return unwrap(Ty)->isFunctionTy();
646+
}
647+
648+
extern "C" bool LLVMRustIsArrayTy(LLVMTypeRef Ty) {
649+
return unwrap(Ty)->isArrayTy();
650+
}
651+
652+
extern "C" uint64_t LLVMRustGetArrayNumElements(LLVMTypeRef Ty) {
653+
return unwrap(Ty)->getArrayNumElements();
654+
}
638655

639656
extern "C" LLVMValueRef
640657
LLVMRustInlineAsm(LLVMTypeRef Ty, char *AsmString, size_t AsmStringLen,

0 commit comments

Comments
 (0)