Skip to content

Commit 2e075b7

Browse files
authored
Fix powi nan (rust-lang#507)
1 parent cfd4b6b commit 2e075b7

File tree

5 files changed

+76
-92
lines changed

5 files changed

+76
-92
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3328,10 +3328,10 @@ class AdjointGenerator
33283328
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
33293329
Value *op0 = gutils->getNewFromOriginal(orig_ops[0]);
33303330
Value *op1 = gutils->getNewFromOriginal(orig_ops[1]);
3331+
Value *nop1 = lookup(op1, Builder2);
33313332
SmallVector<Value *, 2> args = {
33323333
lookup(op0, Builder2),
3333-
Builder2.CreateSub(lookup(op1, Builder2),
3334-
ConstantInt::get(op1->getType(), 1))};
3334+
Builder2.CreateSub(nop1, ConstantInt::get(op1->getType(), 1))};
33353335
auto &CI = cast<CallInst>(I);
33363336
#if LLVM_VERSION_MAJOR >= 11
33373337
auto *PowF = CI.getCalledOperand();
@@ -3349,6 +3349,9 @@ class AdjointGenerator
33493349
Builder2.CreateFMul(vdiff, cal),
33503350
Builder2.CreateSIToFP(lookup(op1, Builder2),
33513351
op0->getType()->getScalarType()));
3352+
dif0 = Builder2.CreateSelect(
3353+
Builder2.CreateICmpEQ(ConstantInt::get(nop1->getType(), 0), nop1),
3354+
Constant::getNullValue(dif0->getType()), dif0);
33523355
addToDiffe(orig_ops[0], dif0, Builder2, I.getType());
33533356
}
33543357
return;
@@ -3808,8 +3811,12 @@ class AdjointGenerator
38083811
Builder2.CreateSIToFP(op1, op0->getType()->getScalarType());
38093812
Value *op = diffe(orig_ops[0], Builder2);
38103813

3814+
Value *cmp = Builder2.CreateICmpEQ(
3815+
ConstantInt::get(args[1]->getType(), 0), op1);
38113816
auto rule = [&](Value *op) {
3812-
return Builder2.CreateFMul(Builder2.CreateFMul(op, cal), cast);
3817+
return Builder2.CreateSelect(
3818+
cmp, Constant::getNullValue(op->getType()),
3819+
Builder2.CreateFMul(Builder2.CreateFMul(op, cal), cast));
38133820
};
38143821

38153822
Value *dif0 = applyChainRule(I.getType(), Builder2, rule, op);

enzyme/test/Enzyme/ForwardMode/powi13.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: if [ %llvmver -ge 13 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
22

33
; Function Attrs: noinline nounwind readnone uwtable
44
define double @tester(double %x, i32 %y) {
@@ -22,10 +22,12 @@ declare double @__enzyme_fwddiff(double (double, i32)*, ...)
2222
; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", i32 %y)
2323
; CHECK-NEXT: entry:
2424
; CHECK-NEXT: %[[ym1:.+]] = sub i32 %y, 1
25-
; CHECK-NEXT: %[[newpow:.+]] = call fast double @llvm.powi.f64.i32(double %x, i32 %[[ym1]])
25+
; CHECK-NEXT: %[[newpow:.+]] = call fast double @llvm.powi.f64{{(\.i32)?}}(double %x, i32 %[[ym1]])
2626
; CHECK-DAG: %[[sitofp:.+]] = sitofp i32 %y to double
27+
; CHECK-DAG: %[[cmp:.+]] = icmp eq i32 0, %y
2728
; CHECK-DAG: %[[newpowdret:.+]] = fmul fast double %"x'", %[[newpow]]
2829
; CHECK-NEXT: %[[dx:.+]] = fmul fast double %[[newpowdret]], %[[sitofp]]
29-
; CHECK-NEXT: ret double %[[dx:.+]]
30+
; CHECK-NEXT: %[[res:.+]] = select {{(fast )?}}i1 %[[cmp]], double 0.000000e+00, double %[[dx]]
31+
; CHECK-NEXT: ret double %[[res]]
3032
; CHECK-NEXT: }
3133

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: if [ %llvmver -ge 13 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
22

33
%struct.Gradients = type { double, double, double }
44

@@ -22,22 +22,26 @@ entry:
2222
declare double @llvm.powi.f64.i32(double, i32)
2323

2424

25-
; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", i32 %y) #1 {
26-
; CHECK-NEXT entry:
27-
; CHECK-NEXT %0 = sub i32 %y, 1
28-
; CHECK-NEXT %1 = call fast double @llvm.powi.f64.i32(double %x, i32 %0)
29-
; CHECK-NEXT %2 = sitofp i32 %y to double
30-
; CHECK-NEXT %3 = extractvalue [3 x double] %"x'", 0
31-
; CHECK-NEXT %4 = fmul fast double %3, %1
32-
; CHECK-NEXT %5 = fmul fast double %4, %2
33-
; CHECK-NEXT %6 = insertvalue [3 x double] undef, double %5, 0
34-
; CHECK-NEXT %7 = extractvalue [3 x double] %"x'", 1
35-
; CHECK-NEXT %8 = fmul fast double %7, %1
36-
; CHECK-NEXT %9 = fmul fast double %8, %2
37-
; CHECK-NEXT %10 = insertvalue [3 x double] %6, double %9, 1
38-
; CHECK-NEXT %11 = extractvalue [3 x double] %"x'", 2
39-
; CHECK-NEXT %12 = fmul fast double %11, %1
40-
; CHECK-NEXT %13 = fmul fast double %12, %2
41-
; CHECK-NEXT %14 = insertvalue [3 x double] %10, double %13, 2
42-
; CHECK-NEXT ret [3 x double] %14
43-
; CHECK-NEXT }
25+
; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", i32 %y)
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: %0 = sub i32 %y, 1
28+
; CHECK-NEXT: %1 = call fast double @llvm.powi.f64{{(\.i32)?}}(double %x, i32 %0)
29+
; CHECK-NEXT: %2 = sitofp i32 %y to double
30+
; CHECK-NEXT: %3 = icmp eq i32 0, %y
31+
; CHECK-NEXT: %4 = extractvalue [3 x double] %"x'", 0
32+
; CHECK-NEXT: %5 = fmul fast double %4, %1
33+
; CHECK-NEXT: %6 = fmul fast double %5, %2
34+
; CHECK-NEXT: %7 = select {{(fast )?}}i1 %3, double 0.000000e+00, double %6
35+
; CHECK-NEXT: %8 = insertvalue [3 x double] undef, double %7, 0
36+
; CHECK-NEXT: %9 = extractvalue [3 x double] %"x'", 1
37+
; CHECK-NEXT: %10 = fmul fast double %9, %1
38+
; CHECK-NEXT: %11 = fmul fast double %10, %2
39+
; CHECK-NEXT: %12 = select {{(fast )?}}i1 %3, double 0.000000e+00, double %11
40+
; CHECK-NEXT: %13 = insertvalue [3 x double] %8, double %12, 1
41+
; CHECK-NEXT: %14 = extractvalue [3 x double] %"x'", 2
42+
; CHECK-NEXT: %15 = fmul fast double %14, %1
43+
; CHECK-NEXT: %16 = fmul fast double %15, %2
44+
; CHECK-NEXT: %17 = select {{(fast )?}}i1 %3, double 0.000000e+00, double %16
45+
; CHECK-NEXT: %18 = insertvalue [3 x double] %13, double %17, 2
46+
; CHECK-NEXT: ret [3 x double] %18
47+
; CHECK-NEXT }
Lines changed: 32 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,43 @@
1-
; RUN: if [ %llvmver -ge 13 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
23
source_filename = "text"
34
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
45
target triple = "x86_64-pc-linux-gnu"
56

6-
define private fastcc double @julia___2797(double %0, i64 signext %1) unnamed_addr #0 !dbg !7 {
7+
define private fastcc double @julia___2797(double %x0, i64 signext %x1) unnamed_addr #0 {
78
top:
8-
switch i64 %1, label %L20 [
9+
switch i64 %x1, label %L20 [
910
i64 -1, label %L3
1011
i64 0, label %L7
1112
i64 1, label %L7.fold.split
1213
i64 2, label %L13
1314
i64 3, label %L17
14-
], !dbg !9
15+
]
1516

1617
L3: ; preds = %top
17-
%2 = fdiv double 1.000000e+00, %0, !dbg !10
18-
ret double %2, !dbg !9
18+
%x2 = fdiv double 1.000000e+00, %x0
19+
ret double %x2
1920

2021
L7.fold.split: ; preds = %top
21-
br label %L7, !dbg !16
22+
br label %L7
2223

2324
L7: ; preds = %top, %L7.fold.split
24-
%merge = phi double [ 1.000000e+00, %top ], [ %0, %L7.fold.split ]
25-
ret double %merge, !dbg !16
25+
%merge = phi double [ 1.000000e+00, %top ], [ %x0, %L7.fold.split ]
26+
ret double %merge
2627

2728
L13: ; preds = %top
28-
%3 = fmul double %0, %0, !dbg !17
29-
ret double %3, !dbg !19
29+
%x3 = fmul double %x0, %x0
30+
ret double %x3
3031

3132
L17: ; preds = %top
32-
%4 = fmul double %0, %0, !dbg !20
33-
%5 = fmul double %4, %0, !dbg !20
34-
ret double %5, !dbg !24
33+
%x4 = fmul double %x0, %x0
34+
%x5 = fmul double %x4, %x0
35+
ret double %x5
3536

3637
L20: ; preds = %top
37-
%6 = sitofp i64 %1 to double, !dbg !25
38-
%7 = call double @llvm.pow.f64(double %0, double %6), !dbg !27
39-
ret double %7, !dbg !27
38+
%x6 = sitofp i64 %x1 to double
39+
%x7 = call double @llvm.pow.f64(double %x0, double %x6)
40+
ret double %x7
4041
}
4142

4243
; Function Attrs: nofree nosync nounwind readnone speculatable willreturn
@@ -46,10 +47,10 @@ declare double @llvm.pow.f64(double, double) #1
4647
declare double @__enzyme_autodiff(double (double, i64)*, ...)
4748

4849
; Function Attrs: alwaysinline nosync readnone
49-
define double @julia_f_2794(double %0, i64 signext %1) local_unnamed_addr #2 !dbg !28 {
50+
define double @julia_f_2794(double %y0, i64 signext %y1) {
5051
entry:
51-
%2 = call fastcc double @julia___2797(double %0, i64 signext %1) #5, !dbg !29
52-
ret double %2
52+
%y2 = call fastcc double @julia___2797(double %y0, i64 signext %y1) #5
53+
ret double %y2
5354
}
5455

5556
define double @test_derivative(double %x, i64 %y) {
@@ -58,13 +59,15 @@ entry:
5859
ret double %0
5960
}
6061

61-
; CHECK: define internal { double } @diffejulia_f_2794(double %0, i64 signext %1, double %differeturn) local_unnamed_addr #5 !dbg !35 {
62+
; CHECK: define internal { double } @diffejulia_f_2794(double %y0, i64 signext %y1, double %differeturn)
6263
; CHECK-NEXT: entry:
63-
; CHECK-NEXT: %2 = sub i64 %1, 1
64-
; CHECK-NEXT: %3 = call fast fastcc double @julia___2797(double %0, i64 %2), !dbg !36
65-
; CHECK-NEXT: %4 = sitofp i64 %1 to double
66-
; CHECK-NEXT: %5 = fmul fast double %differeturn, %3
67-
; CHECK-NEXT: %6 = fmul fast double %5, %4
64+
; CHECK-NEXT: %0 = sub i64 %y1, 1
65+
; CHECK-NEXT: %1 = call fast fastcc double @julia___2797(double %y0, i64 %0)
66+
; CHECK-NEXT: %2 = sitofp i64 %y1 to double
67+
; CHECK-NEXT: %3 = fmul fast double %differeturn, %1
68+
; CHECK-NEXT: %4 = fmul fast double %3, %2
69+
; CHECK-NEXT: %5 = icmp eq i64 0, %y1
70+
; CHECK-NEXT: %6 = select {{(fast )?}}i1 %5, double 0.000000e+00, double %4
6871
; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0
6972
; CHECK-NEXT: ret { double } %7
7073
; CHECK-NEXT: }
@@ -78,44 +81,10 @@ declare noalias nonnull {} addrspace(10)* @jl_gc_pool_alloc(i8*, i32, i32) #4
7881
; Function Attrs: allocsize(1)
7982
declare noalias nonnull {} addrspace(10)* @jl_gc_big_alloc(i8*, i64) #4
8083

81-
attributes #0 = { noinline nosync readnone "enzyme_math"="powi" "enzyme_shouldrecompute"="powi" "probe-stack"="inline-asm" }
82-
attributes #1 = { nofree nosync nounwind readnone speculatable willreturn }
83-
attributes #2 = { alwaysinline nosync readnone "probe-stack"="inline-asm" }
84+
attributes #0 = { noinline readnone "enzyme_math"="powi" "enzyme_shouldrecompute"="powi"}
85+
attributes #1 = { nounwind readnone speculatable }
86+
attributes #2 = { alwaysinline readnone "probe-stack"="inline-asm" }
8487
attributes #3 = { inaccessiblemem_or_argmemonly }
8588
attributes #4 = { allocsize(1) }
8689
attributes #5 = { "probe-stack"="inline-asm" }
8790

88-
!llvm.module.flags = !{!0, !1}
89-
!llvm.dbg.cu = !{!2, !5}
90-
91-
!0 = !{i32 2, !"Dwarf Version", i32 4}
92-
!1 = !{i32 2, !"Debug Info Version", i32 3}
93-
!2 = distinct !DICompileUnit(language: DW_LANG_Julia, file: !3, producer: "julia", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly, enums: !4, nameTableKind: None)
94-
!3 = !DIFile(filename: "math.jl", directory: ".")
95-
!4 = !{}
96-
!5 = distinct !DICompileUnit(language: DW_LANG_Julia, file: !6, producer: "julia", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly, enums: !4, nameTableKind: None)
97-
!6 = !DIFile(filename: "REPL[3]", directory: ".")
98-
!7 = distinct !DISubprogram(name: "^", linkageName: "julia_^_2797", scope: null, file: !3, line: 922, type: !8, scopeLine: 922, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
99-
!8 = !DISubroutineType(types: !4)
100-
!9 = !DILocation(line: 923, scope: !7)
101-
!10 = !DILocation(line: 408, scope: !11, inlinedAt: !13)
102-
!11 = distinct !DISubprogram(name: "/;", linkageName: "/", scope: !12, file: !12, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
103-
!12 = !DIFile(filename: "float.jl", directory: ".")
104-
!13 = !DILocation(line: 243, scope: !14, inlinedAt: !9)
105-
!14 = distinct !DISubprogram(name: "inv;", linkageName: "inv", scope: !15, file: !15, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
106-
!15 = !DIFile(filename: "number.jl", directory: ".")
107-
!16 = !DILocation(line: 924, scope: !7)
108-
!17 = !DILocation(line: 405, scope: !18, inlinedAt: !19)
109-
!18 = distinct !DISubprogram(name: "*;", linkageName: "*", scope: !12, file: !12, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
110-
!19 = !DILocation(line: 926, scope: !7)
111-
!20 = !DILocation(line: 405, scope: !18, inlinedAt: !21)
112-
!21 = !DILocation(line: 655, scope: !22, inlinedAt: !24)
113-
!22 = distinct !DISubprogram(name: "*;", linkageName: "*", scope: !23, file: !23, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
114-
!23 = !DIFile(filename: "operators.jl", directory: ".")
115-
!24 = !DILocation(line: 927, scope: !7)
116-
!25 = !DILocation(line: 146, scope: !26, inlinedAt: !27)
117-
!26 = distinct !DISubprogram(name: "Float64;", linkageName: "Float64", scope: !12, file: !12, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
118-
!27 = !DILocation(line: 928, scope: !7)
119-
!28 = distinct !DISubprogram(name: "f", linkageName: "julia_f_2794", scope: null, file: !6, line: 1, type: !8, scopeLine: 1, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !5, retainedNodes: !4)
120-
!29 = !DILocation(line: 1, scope: !28, inlinedAt: !30)
121-
!30 = distinct !DILocation(line: 0, scope: !28)

enzyme/test/Enzyme/ReverseMode/powi13.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: if [ %llvmver -ge 13 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
22

33
; Function Attrs: noinline nounwind readnone uwtable
44
define double @tester(double %x, i32 %y) {
@@ -22,10 +22,12 @@ declare double @__enzyme_autodiff(double (double, i32)*, ...)
2222
; CHECK: define internal {{(dso_local )?}}{ double } @diffetester(double %x, i32 %y, double %differeturn)
2323
; CHECK-NEXT: entry:
2424
; CHECK-NEXT: %[[ym1:.+]] = sub i32 %y, 1
25-
; CHECK-NEXT: %[[newpow:.+]] = call fast double @llvm.powi.f64.i32(double %x, i32 %[[ym1]])
25+
; CHECK-NEXT: %[[newpow:.+]] = call fast double @llvm.powi.f64{{(\.i32)?}}(double %x, i32 %[[ym1]])
2626
; CHECK-DAG: %[[sitofp:.+]] = sitofp i32 %y to double
2727
; CHECK-DAG: %[[newpowdret:.+]] = fmul fast double %differeturn, %[[newpow]]
2828
; CHECK-NEXT: %[[dx:.+]] = fmul fast double %[[newpowdret]], %[[sitofp]]
29-
; CHECK-NEXT: %[[interres:.+]] = insertvalue { double } undef, double %[[dx:.+]], 0
29+
; CHECK-NEXT: %[[cmp:.+]] = icmp eq i32 0, %y
30+
; CHECK-NEXT: %[[res:.+]] = select {{(fast )?}}i1 %[[cmp]], double 0.000000e+00, double %[[dx]]
31+
; CHECK-NEXT: %[[interres:.+]] = insertvalue { double } undef, double %[[res]], 0
3032
; CHECK-NEXT: ret { double } %[[interres]]
3133
; CHECK-NEXT: }

0 commit comments

Comments
 (0)