1
1
use std:: ptr;
2
-
3
2
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
4
3
use rustc_codegen_ssa:: ModuleCodegen ;
5
4
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
- use rustc_codegen_ssa :: traits :: BaseTypeCodegenMethods as _ ;
7
- use rustc_errors :: FatalError ;
5
+ use rustc_errors :: { DiagCtxt , FatalError } ;
6
+ use rustc_middle :: bug ;
8
7
use tracing:: { debug, trace} ;
9
8
10
9
use crate :: back:: write:: llvm_err;
11
10
use crate :: builder:: SBuilder ;
12
11
use crate :: context:: SimpleCx ;
13
12
use crate :: declare:: declare_simple_fn;
14
- use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
13
+ use crate :: errors:: { AutoDiffUnusedArgs , AutoDiffWithoutEnable , LlvmError } ;
15
14
use crate :: llvm:: AttributePlace :: Function ;
16
15
use crate :: llvm:: { Metadata , True } ;
17
16
use crate :: value:: Value ;
17
+
18
18
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
19
19
20
20
fn get_params ( fnc : & Value ) -> Vec < & Value > {
@@ -28,6 +28,25 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
28
28
}
29
29
}
30
30
31
+ fn has_sret ( fnc : & Value ) -> bool {
32
+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33
+ if num_args == 0 {
34
+ false
35
+ } else {
36
+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
37
+ }
38
+ }
39
+
40
+ // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41
+ // original inputs, as well as metadata and the additional shadow arguments.
42
+ // This function matches the arguments from the outer function to the inner enzyme call.
43
+ //
44
+ // This function also considers that Rust level arguments not always match the llvm-ir level
45
+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
46
+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
47
+ // need to match those.
48
+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
49
+ // using iterators and peek()?
31
50
fn match_args_from_caller_to_enzyme < ' ll > (
32
51
cx : & SimpleCx < ' ll > ,
33
52
args : & mut Vec < & ' ll llvm:: Value > ,
@@ -135,6 +154,78 @@ fn match_args_from_caller_to_enzyme<'ll>(
135
154
}
136
155
}
137
156
157
+
158
+ // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
159
+ // arguments. We do however need to declare them with their correct return type.
160
+ // We already figured the correct return type out in our frontend, when generating the outer_fn,
161
+ // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
162
+ // Beyond sret, this article describes our challenges nicely:
163
+ // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
164
+ // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
165
+ fn compute_enzyme_fn_ty < ' ll > (
166
+ cx : & SimpleCx < ' ll > ,
167
+ attrs : & AutoDiffAttrs ,
168
+ fn_to_diff : & ' ll Value ,
169
+ outer_fn : & ' ll Value ,
170
+ ) -> & ' ll llvm:: Type {
171
+ let fn_ty = cx. get_type_of_global ( outer_fn) ;
172
+ let mut ret_ty = cx. get_return_type ( fn_ty) ;
173
+
174
+ let has_sret = has_sret ( outer_fn) ;
175
+
176
+ if has_sret {
177
+ // Now we don't just forward the return type, so we have to figure it out based on the
178
+ // primal return type, in combination with the autodiff settings.
179
+ let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
180
+ let inner_ret_ty = cx. get_return_type ( fn_ty) ;
181
+
182
+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
183
+ if inner_ret_ty == void_ty {
184
+ dbg ! ( & fn_to_diff) ;
185
+ // This indicates that even the inner function has an sret.
186
+ // Right now I only look for an sret in the outer function.
187
+ // This *probably* needs some extra handling, but I never ran
188
+ // into such a case. So I'll wait for user reports to have a test case.
189
+ bug ! ( "sret in inner function" ) ;
190
+ }
191
+
192
+ if attrs. width == 1 {
193
+ todo ! ( "Handle sret for scalar ad" ) ;
194
+ } else {
195
+ // First we check if we also have to deal with the primal return.
196
+ if attrs. mode . is_fwd ( ) {
197
+ match attrs. ret_activity {
198
+ DiffActivity :: Dual => {
199
+ let arr_ty =
200
+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
201
+ ret_ty = arr_ty;
202
+ }
203
+ DiffActivity :: DualOnly => {
204
+ let arr_ty =
205
+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
206
+ ret_ty = arr_ty;
207
+ }
208
+ DiffActivity :: Const => {
209
+ todo ! ( "Not sure, do we need to do something here?" ) ;
210
+ }
211
+ _ => {
212
+ bug ! ( "unreachable" ) ;
213
+ }
214
+ }
215
+ } else if attrs. mode . is_rev ( ) {
216
+ todo ! ( "Handle sret for reverse mode" ) ;
217
+ } else {
218
+ bug ! ( "unreachable" ) ;
219
+ }
220
+ }
221
+ }
222
+
223
+ dbg ! ( & outer_fn) ;
224
+
225
+ // LLVM can figure out the input types on it's own, so we take a shortcut here.
226
+ unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
227
+ }
228
+
138
229
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139
230
/// function with expected naming and calling conventions[^1] which will be
140
231
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -145,6 +236,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
145
236
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
146
237
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
147
238
fn generate_enzyme_call < ' ll > (
239
+ _dcx : & DiagCtxt ,
148
240
cx : & SimpleCx < ' ll > ,
149
241
fn_to_diff : & ' ll Value ,
150
242
outer_fn : & ' ll Value ,
@@ -197,17 +289,9 @@ fn generate_enzyme_call<'ll>(
197
289
// }
198
290
// ```
199
291
unsafe {
200
- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201
- // arguments. We do however need to declare them with their correct return type.
202
- // We already figured the correct return type out in our frontend, when generating the outer_fn,
203
- // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204
- let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
205
- let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
292
+ let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
206
293
207
- // LLVM can figure out the input types on it's own, so we take a shortcut here.
208
- let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
209
-
210
- //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
294
+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211
295
// think a bit more about what should go here.
212
296
let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
213
297
let ad_fn = declare_simple_fn (
@@ -268,12 +352,31 @@ fn generate_enzyme_call<'ll>(
268
352
// Now that we copied the metadata, get rid of dummy code.
269
353
llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
270
354
271
- if cx. val_ty ( call) == cx. type_void ( ) {
355
+ let has_sret = has_sret ( outer_fn) ;
356
+
357
+ if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
358
+ if has_sret {
359
+ // This is what we already have in our outer_fn (shortened):
360
+ // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
361
+ // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
362
+ // <Here we are, we want to add the following two lines>
363
+ // store [4 x double] %7, ptr %0, align 8
364
+ // ret void
365
+ // }
366
+
367
+ // now store the result of the enzyme call into the sret pointer.
368
+ let sret_ptr = outer_args[ 0 ] ;
369
+ let call_ty = cx. val_ty ( call) ;
370
+ assert ! ( llvm:: LLVMRustIsArrayTy ( call_ty) ) ;
371
+ llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
372
+ }
272
373
builder. ret_void ( ) ;
273
374
} else {
274
375
builder. ret ( call) ;
275
376
}
276
377
378
+ dbg ! ( & outer_fn) ;
379
+
277
380
// Let's crash in case that we messed something up above and generated invalid IR.
278
381
llvm:: LLVMRustVerifyFunction (
279
382
outer_fn,
@@ -300,8 +403,7 @@ pub(crate) fn differentiate<'ll>(
300
403
if !diff_items. is_empty ( )
301
404
&& !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
302
405
{
303
- let dcx = cgcx. create_dcx ( ) ;
304
- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
406
+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
305
407
}
306
408
307
409
// Before dumping the module, we want all the TypeTrees to become part of the module.
@@ -331,7 +433,7 @@ pub(crate) fn differentiate<'ll>(
331
433
) ) ;
332
434
} ;
333
435
334
- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
436
+ generate_enzyme_call ( & diag_handler , & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
335
437
}
336
438
337
439
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments