Skip to content

Commit 0f4dacd

Browse files
aslasavonic
andauthored
[AutoDiff] Fix custom derivative thunk for Optional (#74378)
Enable the nil coalescing operator (aka `??`) for Optional type. Fixes #55882 Co-authored-by: Andrew Savonichev <[email protected]>
1 parent dd0831f commit 0f4dacd

File tree

3 files changed

+52
-14
lines changed

3 files changed

+52
-14
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6347,10 +6347,13 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
63476347
arguments.push_back(indErrorRes.getLValueAddress());
63486348
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);
63496349

6350+
SubstitutionMap subs = thunk->getForwardingSubstitutionMap();
6351+
SILType substFnType = fnRef->getType().substGenericArgs(
6352+
M, subs, thunk->getTypeExpansionContext());
6353+
63506354
// Apply function argument.
6351-
auto apply = thunkSGF.emitApplyWithRethrow(
6352-
loc, fnRef, /*substFnType*/ fnRef->getType(),
6353-
thunk->getForwardingSubstitutionMap(), arguments);
6355+
auto apply =
6356+
thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments);
63546357

63556358
// Self reordering thunk is necessary if wrt at least two parameters,
63566359
// including self.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,15 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
906906
traceMessage.c_str(), witness->getOriginalFunction());
907907

908908
assert(witness->isDefinition());
909+
SILFunction *orig = witness->getOriginalFunction();
910+
911+
// We can generate empty JVP / VJP for functions available externally. These
912+
// functions have the same linkage as the original ones sans `external`
913+
// flag. Important exception here hidden_external functions as they are
914+
// serializable but corresponding hidden ones would be not and the SIL
915+
// verifier will fail. Patch `serializeFunctions` for this case.
916+
if (orig->getLinkage() == SILLinkage::HiddenExternal)
917+
serializeFunctions = IsNotSerialized;
909918

910919
// If the JVP doesn't exist, need to synthesize it.
911920
if (!witness->getJVP()) {
@@ -914,9 +923,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
914923
// - Functions with unsupported control flow.
915924
if (context.getASTContext()
916925
.LangOpts.hasFeature(Feature::ForwardModeDifferentiation) &&
917-
(diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
918-
diagnoseUnsupportedControlFlow(
919-
context, witness->getOriginalFunction(), invoker)))
926+
(diagnoseNoReturn(context, orig, invoker) ||
927+
diagnoseUnsupportedControlFlow(context, orig, invoker)))
920928
return true;
921929

922930
// Create empty JVP.
@@ -933,10 +941,10 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
933941
!witness->getVJP()) {
934942
// JVP and differential generation do not currently support functions with
935943
// multiple basic blocks.
936-
if (witness->getOriginalFunction()->size() > 1) {
937-
context.emitNondifferentiabilityError(
938-
witness->getOriginalFunction()->getLocation().getSourceLoc(),
939-
invoker, diag::autodiff_jvp_control_flow_not_supported);
944+
if (orig->size() > 1) {
945+
context.emitNondifferentiabilityError(orig->getLocation().getSourceLoc(),
946+
invoker,
947+
diag::autodiff_jvp_control_flow_not_supported);
940948
return true;
941949
}
942950
// Emit JVP function.
@@ -950,7 +958,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
950958
"_fatalErrorForwardModeDifferentiationDisabled");
951959
LLVM_DEBUG(getADDebugStream()
952960
<< "Generated empty JVP for "
953-
<< witness->getOriginalFunction()->getName() << ":\n"
961+
<< orig->getName() << ":\n"
954962
<< *jvp);
955963
}
956964
}
@@ -960,9 +968,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
960968
// Diagnose:
961969
// - Functions with no return.
962970
// - Functions with unsupported control flow.
963-
if (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
964-
diagnoseUnsupportedControlFlow(
965-
context, witness->getOriginalFunction(), invoker))
971+
if (diagnoseNoReturn(context, orig, invoker) ||
972+
diagnoseUnsupportedControlFlow(context, orig, invoker))
966973
return true;
967974

968975
// Create empty VJP.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s
2+
3+
import _Differentiation
4+
5+
// CHECK: sil @test_nil_coalescing
6+
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
7+
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
8+
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
9+
// CHECK: copy_addr %[[ARG_OPT]] to [init] %[[ALLOC_OPT]] : $*Optional<T>
10+
// We'd need to check that ALLOC_OPT is an argument of switch_enum_addr below. However, this code
11+
// is inlined from the standard library and therefore could have a sequence of copies in between
12+
// depending whether we're compiling against debug or release stdlib
13+
// CHECK: switch_enum_addr %{{.*}} : $*Optional<T>, case #Optional.some!enumelt: {{.*}}, case #Optional.none!enumelt: {{.*}}
14+
// CHECK: try_apply %[[ARG_PB]](%{{.*}}) : $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>, normal {{.*}}, error {{.*}}
15+
//
16+
@_silgen_name("test_nil_coalescing")
17+
@derivative(of: ??)
18+
@usableFromInline
19+
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
20+
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
21+
{
22+
let hasValue = optional != nil
23+
let value = try optional ?? defaultValue()
24+
func pullback(_ v: T.TangentVector) -> Optional<T>.TangentVector {
25+
return hasValue ? .init(v) : .zero
26+
}
27+
return (value, pullback)
28+
}

0 commit comments

Comments
 (0)