@@ -3,8 +3,8 @@ use std::ptr;
3
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods as _;
7
6
use rustc_errors:: FatalError ;
7
+ use rustc_middle:: bug;
8
8
use tracing:: { debug, trace} ;
9
9
10
10
use crate :: back:: write:: llvm_err;
@@ -28,11 +28,32 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
28
28
}
29
29
}
30
30
31
+ fn has_sret ( fnc : & Value ) -> bool {
32
+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33
+ if num_args == 0 {
34
+ false
35
+ } else {
36
+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
37
+ }
38
+ }
39
+
40
+ // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41
+ // original inputs, as well as metadata and the additional shadow arguments.
42
+ // This function matches the arguments from the outer function to the inner enzyme call.
43
+ //
44
+ // This function also considers that Rust level arguments not always match the llvm-ir level
45
+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
46
+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
47
+ // need to match those.
48
+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
49
+ // using iterators and peek()?
31
50
fn match_args_from_caller_to_enzyme < ' ll > (
32
51
cx : & SimpleCx < ' ll > ,
52
+ width : u32 ,
33
53
args : & mut Vec < & ' ll llvm:: Value > ,
34
54
inputs : & [ DiffActivity ] ,
35
55
outer_args : & [ & ' ll llvm:: Value ] ,
56
+ has_sret : bool ,
36
57
) {
37
58
debug ! ( "matching autodiff arguments" ) ;
38
59
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +65,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
44
65
let mut outer_pos: usize = 0 ;
45
66
let mut activity_pos = 0 ;
46
67
68
+ if has_sret {
69
+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
70
+ // inner function will still return something. We increase our outer_pos by one,
71
+ // and once we're done with all other args we will take the return of the inner call and
72
+ // update the sret pointer with it
73
+ outer_pos = 1 ;
74
+ }
75
+
47
76
let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
48
77
let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
49
78
let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -95,20 +124,25 @@ fn match_args_from_caller_to_enzyme<'ll>(
95
124
assert ! ( unsafe {
96
125
llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
97
126
} ) ;
98
- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
99
- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
100
- assert ! ( unsafe {
101
- llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
102
- } ) ;
103
- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
104
- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
105
- assert ! ( unsafe {
106
- llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
107
- } ) ;
108
- args. push ( next_outer_arg2) ;
127
+
128
+ for _ in 0 ..width {
129
+ let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130
+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131
+ assert ! (
132
+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) }
133
+ == llvm:: TypeKind :: Pointer
134
+ ) ;
135
+ let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
136
+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
137
+ assert ! (
138
+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) }
139
+ == llvm:: TypeKind :: Integer
140
+ ) ;
141
+ args. push ( next_outer_arg2) ;
142
+ }
109
143
args. push ( cx. get_metadata_value ( enzyme_const) ) ;
110
144
args. push ( next_outer_arg) ;
111
- outer_pos += 4 ;
145
+ outer_pos += 2 + 2 * width as usize ;
112
146
activity_pos += 2 ;
113
147
} else {
114
148
// A duplicated pointer will have the following two outer_fn arguments:
@@ -125,6 +159,13 @@ fn match_args_from_caller_to_enzyme<'ll>(
125
159
args. push ( next_outer_arg) ;
126
160
outer_pos += 2 ;
127
161
activity_pos += 1 ;
162
+
163
+ // Now, if width > 1, we need to account for that
164
+ for _ in 1 ..width {
165
+ let next_outer_arg = outer_args[ outer_pos] ;
166
+ args. push ( next_outer_arg) ;
167
+ outer_pos += 1 ;
168
+ }
128
169
}
129
170
} else {
130
171
// We do not differentiate with resprect to this argument.
@@ -135,6 +176,74 @@ fn match_args_from_caller_to_enzyme<'ll>(
135
176
}
136
177
}
137
178
179
+ // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
180
+ // arguments. We do however need to declare them with their correct return type.
181
+ // We already figured the correct return type out in our frontend, when generating the outer_fn,
182
+ // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
183
+ // Beyond sret, this article describes our challenges nicely:
184
+ // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
185
+ // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
186
+ fn compute_enzyme_fn_ty < ' ll > (
187
+ cx : & SimpleCx < ' ll > ,
188
+ attrs : & AutoDiffAttrs ,
189
+ fn_to_diff : & ' ll Value ,
190
+ outer_fn : & ' ll Value ,
191
+ ) -> & ' ll llvm:: Type {
192
+ let fn_ty = cx. get_type_of_global ( outer_fn) ;
193
+ let mut ret_ty = cx. get_return_type ( fn_ty) ;
194
+
195
+ let has_sret = has_sret ( outer_fn) ;
196
+
197
+ if has_sret {
198
+ // Now we don't just forward the return type, so we have to figure it out based on the
199
+ // primal return type, in combination with the autodiff settings.
200
+ let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
201
+ let inner_ret_ty = cx. get_return_type ( fn_ty) ;
202
+
203
+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
204
+ if inner_ret_ty == void_ty {
205
+ // This indicates that even the inner function has an sret.
206
+ // Right now I only look for an sret in the outer function.
207
+ // This *probably* needs some extra handling, but I never ran
208
+ // into such a case. So I'll wait for user reports to have a test case.
209
+ bug ! ( "sret in inner function" ) ;
210
+ }
211
+
212
+ if attrs. width == 1 {
213
+ todo ! ( "Handle sret for scalar ad" ) ;
214
+ } else {
215
+ // First we check if we also have to deal with the primal return.
216
+ if attrs. mode . is_fwd ( ) {
217
+ match attrs. ret_activity {
218
+ DiffActivity :: Dual => {
219
+ let arr_ty =
220
+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
221
+ ret_ty = arr_ty;
222
+ }
223
+ DiffActivity :: DualOnly => {
224
+ let arr_ty =
225
+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
226
+ ret_ty = arr_ty;
227
+ }
228
+ DiffActivity :: Const => {
229
+ todo ! ( "Not sure, do we need to do something here?" ) ;
230
+ }
231
+ _ => {
232
+ bug ! ( "unreachable" ) ;
233
+ }
234
+ }
235
+ } else if attrs. mode . is_rev ( ) {
236
+ todo ! ( "Handle sret for reverse mode" ) ;
237
+ } else {
238
+ bug ! ( "unreachable" ) ;
239
+ }
240
+ }
241
+ }
242
+
243
+ // LLVM can figure out the input types on it's own, so we take a shortcut here.
244
+ unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
245
+ }
246
+
138
247
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139
248
/// function with expected naming and calling conventions[^1] which will be
140
249
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +306,9 @@ fn generate_enzyme_call<'ll>(
197
306
// }
198
307
// ```
199
308
unsafe {
200
- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201
- // arguments. We do however need to declare them with their correct return type.
202
- // We already figured the correct return type out in our frontend, when generating the outer_fn,
203
- // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204
- let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
205
- let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
309
+ let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
206
310
207
- // LLVM can figure out the input types on it's own, so we take a shortcut here.
208
- let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
209
-
210
- //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
311
+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211
312
// think a bit more about what should go here.
212
313
let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
213
314
let ad_fn = declare_simple_fn (
@@ -240,14 +341,27 @@ fn generate_enzyme_call<'ll>(
240
341
if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
241
342
args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
242
343
}
344
+ if attrs. width > 1 {
345
+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
346
+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
347
+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
348
+ }
243
349
350
+ let has_sret = has_sret ( outer_fn) ;
244
351
let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
245
- match_args_from_caller_to_enzyme ( & cx, & mut args, & attrs. input_activity , & outer_args) ;
352
+ match_args_from_caller_to_enzyme (
353
+ & cx,
354
+ attrs. width ,
355
+ & mut args,
356
+ & attrs. input_activity ,
357
+ & outer_args,
358
+ has_sret,
359
+ ) ;
246
360
247
361
let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
248
362
249
363
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
250
- // metadata attachted to it, but we just created this code oota. Given that the
364
+ // metadata attached to it, but we just created this code oota. Given that the
251
365
// differentiated function already has partly confusing metadata, and given that this
252
366
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
253
367
// dummy code which we inserted at a higher level.
@@ -268,7 +382,22 @@ fn generate_enzyme_call<'ll>(
268
382
// Now that we copied the metadata, get rid of dummy code.
269
383
llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
270
384
271
- if cx. val_ty ( call) == cx. type_void ( ) {
385
+ if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
386
+ if has_sret {
387
+ // This is what we already have in our outer_fn (shortened):
388
+ // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
389
+ // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
390
+ // <Here we are, we want to add the following two lines>
391
+ // store [4 x double] %7, ptr %0, align 8
392
+ // ret void
393
+ // }
394
+
395
+ // now store the result of the enzyme call into the sret pointer.
396
+ let sret_ptr = outer_args[ 0 ] ;
397
+ let call_ty = cx. val_ty ( call) ;
398
+ assert ! ( llvm:: LLVMRustIsArrayTy ( call_ty) ) ;
399
+ llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
400
+ }
272
401
builder. ret_void ( ) ;
273
402
} else {
274
403
builder. ret ( call) ;
@@ -300,8 +429,7 @@ pub(crate) fn differentiate<'ll>(
300
429
if !diff_items. is_empty ( )
301
430
&& !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
302
431
{
303
- let dcx = cgcx. create_dcx ( ) ;
304
- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
432
+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
305
433
}
306
434
307
435
// Before dumping the module, we want all the TypeTrees to become part of the module.
0 commit comments