Skip to content

Commit aa5dddb

Browse files
asavonicasl
andauthored
[AutoDiff] Fix custom derivative thunk for Optional (#71721)
Enable the nil coalescing operator (aka `??`) for Optional type. Fixes #55882 Co-authored-by: Anton Korobeynikov <[email protected]>
1 parent bffb878 commit aa5dddb

File tree

3 files changed

+49
-14
lines changed

3 files changed

+49
-14
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6305,10 +6305,13 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
63056305
arguments.push_back(indErrorRes.getLValueAddress());
63066306
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);
63076307

6308+
SubstitutionMap subs = thunk->getForwardingSubstitutionMap();
6309+
SILType substFnType = fnRef->getType().substGenericArgs(
6310+
M, subs, thunk->getTypeExpansionContext());
6311+
63086312
// Apply function argument.
6309-
auto apply = thunkSGF.emitApplyWithRethrow(
6310-
loc, fnRef, /*substFnType*/ fnRef->getType(),
6311-
thunk->getForwardingSubstitutionMap(), arguments);
6313+
auto apply =
6314+
thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments);
63126315

63136316
// Self reordering thunk is necessary if wrt at least two parameters,
63146317
// including self.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,15 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
906906
traceMessage.c_str(), witness->getOriginalFunction());
907907

908908
assert(witness->isDefinition());
909+
SILFunction *orig = witness->getOriginalFunction();
910+
911+
// We can generate empty JVP / VJP for functions available externally. These
912+
// functions have the same linkage as the original ones sans `external`
913+
// flag. Important exception here hidden_external functions as they are
914+
// serializable but corresponding hidden ones would be not and the SIL
915+
// verifier will fail. Patch `serializeFunctions` for this case.
916+
if (orig->getLinkage() == SILLinkage::HiddenExternal)
917+
serializeFunctions = IsNotSerialized;
909918

910919
// If the JVP doesn't exist, need to synthesize it.
911920
if (!witness->getJVP()) {
@@ -914,9 +923,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
914923
// - Functions with unsupported control flow.
915924
if (context.getASTContext()
916925
.LangOpts.hasFeature(Feature::ForwardModeDifferentiation) &&
917-
(diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
918-
diagnoseUnsupportedControlFlow(
919-
context, witness->getOriginalFunction(), invoker)))
926+
(diagnoseNoReturn(context, orig, invoker) ||
927+
diagnoseUnsupportedControlFlow(context, orig, invoker)))
920928
return true;
921929

922930
// Create empty JVP.
@@ -933,10 +941,10 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
933941
!witness->getVJP()) {
934942
// JVP and differential generation do not currently support functions with
935943
// multiple basic blocks.
936-
if (witness->getOriginalFunction()->size() > 1) {
937-
context.emitNondifferentiabilityError(
938-
witness->getOriginalFunction()->getLocation().getSourceLoc(),
939-
invoker, diag::autodiff_jvp_control_flow_not_supported);
944+
if (orig->size() > 1) {
945+
context.emitNondifferentiabilityError(orig->getLocation().getSourceLoc(),
946+
invoker,
947+
diag::autodiff_jvp_control_flow_not_supported);
940948
return true;
941949
}
942950
// Emit JVP function.
@@ -950,7 +958,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
950958
"_fatalErrorForwardModeDifferentiationDisabled");
951959
LLVM_DEBUG(getADDebugStream()
952960
<< "Generated empty JVP for "
953-
<< witness->getOriginalFunction()->getName() << ":\n"
961+
<< orig->getName() << ":\n"
954962
<< *jvp);
955963
}
956964
}
@@ -960,9 +968,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
960968
// Diagnose:
961969
// - Functions with no return.
962970
// - Functions with unsupported control flow.
963-
if (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
964-
diagnoseUnsupportedControlFlow(
965-
context, witness->getOriginalFunction(), invoker))
971+
if (diagnoseNoReturn(context, orig, invoker) ||
972+
diagnoseUnsupportedControlFlow(context, orig, invoker))
966973
return true;
967974

968975
// Create empty VJP.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s
2+
3+
import _Differentiation
4+
5+
// CHECK: sil @test_nil_coalescing
6+
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
7+
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
8+
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
9+
// CHECK: copy_addr %[[ARG_OPT]] to [init] %[[ALLOC_OPT]] : $*Optional<T>
10+
// CHECK: switch_enum_addr %[[ALLOC_OPT]] : $*Optional<T>, case #Optional.some!enumelt: {{.*}}, case #Optional.none!enumelt: {{.*}}
11+
// CHECK: try_apply %[[ARG_PB]](%{{.*}}) : $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>, normal {{.*}}, error {{.*}}
12+
//
13+
@_silgen_name("test_nil_coalescing")
14+
@derivative(of: ??)
15+
@usableFromInline
16+
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
17+
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
18+
{
19+
let hasValue = optional != nil
20+
let value = try optional ?? defaultValue()
21+
func pullback(_ v: T.TangentVector) -> Optional<T>.TangentVector {
22+
return hasValue ? .init(v) : .zero
23+
}
24+
return (value, pullback)
25+
}

0 commit comments

Comments
 (0)