Skip to content

[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions #78908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -4299,6 +4299,9 @@ NOTE(derivative_attr_fix_access,none,
"mark the derivative function as "
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
"to match the original function", (AccessLevel))
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
"either both or none of derivative and original function must have "
"@alwaysEmitIntoClient attribute", ())
ERROR(derivative_attr_static_method_mismatch_original,none,
"unexpected derivative function declaration; "
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",
Expand Down
20 changes: 17 additions & 3 deletions lib/SIL/IR/Linker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,23 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
// HiddenExternal linkage when they are declarations, then they
// become Shared after the body has been deserialized.
// So try deserializing HiddenExternal functions too.
if (linkage == SILLinkage::HiddenExternal)
return deserializeAndPushToWorklist(F);

if (linkage == SILLinkage::HiddenExternal) {
deserializeAndPushToWorklist(F);
if (!F->markedAsAlwaysEmitIntoClient())
return;
// For @_alwaysEmitIntoClient functions, we need to lookup its
// differentiability witness and, if present, ask SILLoader to obtain its
// definition. Otherwise, a linker error would occur due to undefined
// reference to these symbols.
for (SILDifferentiabilityWitness *witness :
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
F->getName())) {
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
witness->getKey());
}
return;
}

// Update the linkage of the function in case it's different in the serialized
// SIL than derived from the AST. This can be the case with cross-module-
// optimizations.
Expand Down
13 changes: 9 additions & 4 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1400,14 +1400,19 @@ void SILGenModule::emitDifferentiabilityWitness(
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
if (!diffWitness) {
// Differentiability witnesses have the same linkage as the original
// function, stripping external.
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
// function, stripping external. For @_alwaysEmitIntoClient original
// functions, force PublicNonABI linkage of the differentiability witness so
// we can serialize it (the original function itself might be HiddenExternal
// in this case if we only have declaration without definition).
auto linkage =
originalFunction->markedAsAlwaysEmitIntoClient()
? SILLinkage::PublicNonABI
: stripExternalFromLinkage(originalFunction->getLinkage());
diffWitness = SILDifferentiabilityWitness::createDefinition(
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
silConfig.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
attr);
/*isSerialized*/ hasPublicVisibility(linkage), attr);
}

// Set derivative function in differentiability witness.
Expand Down
10 changes: 8 additions & 2 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6281,8 +6281,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
auto loc = customDerivativeFn->getLocation();
SILGenFunctionBuilder fb(*this);
// Derivative thunks have the same linkage as the original function, stripping
// external.
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
// linkage of derivative thunks so we can serialize them (the original
// function itself might be HiddenExternal in this case if we only have
// declaration without definition).
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
? SILLinkage::PublicNonABI
: stripExternalFromLinkage(originalFn->getLinkage());

auto *thunk = fb.getOrCreateFunction(
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
customDerivativeFn->getSerializedKind(),
Expand Down
11 changes: 8 additions & 3 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
"definitions with explicit differentiable attributes");

return SILDifferentiabilityWitness::createDeclaration(
module, SILLinkage::PublicExternal, original, kind,
minimalConfig->parameterIndices, minimalConfig->resultIndices,
minimalConfig->derivativeGenericSignature);
module,
// Witness for @_alwaysEmitIntoClient original function must be emitted,
// otherwise a linker error would occur due to undefined reference to the
// witness symbol.
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
: SILLinkage::PublicExternal,
original, kind, minimalConfig->parameterIndices,
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
}

} // end namespace autodiff
Expand Down
12 changes: 8 additions & 4 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(

// We can generate empty JVP / VJP for functions available externally. These
// functions have the same linkage as the original ones sans `external`
// flag. Important exception here hidden_external functions as they are
// serializable but corresponding hidden ones would be not and the SIL
// verifier will fail. Patch `serializeFunctions` for this case.
if (orig->getLinkage() == SILLinkage::HiddenExternal)
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
// functions as they are serializable but corresponding hidden ones would be
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
// case. For @_alwaysEmitIntoClient original functions (which might be
// HiddenExternal if we only have declaration without definition), we want
// derivatives to be serialized and do not patch `serializeFunctions`.
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
!orig->markedAsAlwaysEmitIntoClient())
serializeFunctions = IsNotSerialized;

// If the JVP doesn't exist, need to synthesize it.
Expand Down
7 changes: 7 additions & 0 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6825,6 +6825,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
return true;
}

if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
diags.diagnose(derivative->getLoc(),
diag::derivative_attr_always_emit_into_client_mismatch);
return true;
}

// Get the resolved differentiability parameter indices.
auto *resolvedDiffParamIndices = attr->getParameterIndices();

Expand Down
6 changes: 2 additions & 4 deletions stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,6 @@ where
}
}

// FIXME(TF-1103): Derivative registration does not yet support
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
/*
extension SIMD
where
Self: Differentiable,
Expand All @@ -417,6 +414,7 @@ where
TangentVector == Self
{
@inlinable
@_alwaysEmitIntoClient
@derivative(of: sum)
func _vjpSum() -> (
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
Expand All @@ -425,14 +423,14 @@ where
}

@inlinable
@_alwaysEmitIntoClient
@derivative(of: sum)
func _jvpSum() -> (
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
) {
return (sum(), { v in Scalar.TangentVector(v.sum()) })
}
}
*/

extension SIMD
where
Expand Down
7 changes: 4 additions & 3 deletions test/AutoDiff/SILGen/nil_coalescing.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s

import _Differentiation

// CHECK: sil @test_nil_coalescing
// CHECK: sil non_abi @test_nil_coalescing
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
Expand All @@ -15,7 +16,7 @@ import _Differentiation
//
@_silgen_name("test_nil_coalescing")
@derivative(of: ??)
@usableFromInline
@_alwaysEmitIntoClient
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
{
Expand Down
18 changes: 18 additions & 0 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
fatalError()
}

func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

@_alwaysEmitIntoClient
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: internal_original_alwaysemitintoclient_derivative)
Expand All @@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
fatalError()
}

@_alwaysEmitIntoClient
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

@_alwaysEmitIntoClient
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: package_original_alwaysemitintoclient_derivative)
Expand Down
8 changes: 0 additions & 8 deletions test/AutoDiff/stdlib/simd.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ SIMDTests.test("init(repeating:)") {
expectEqual(8, pb1(g))
}

// FIXME(TF-1103): Derivative registration does not yet support
// `@_alwaysEmitIntoClient` original functions.
/*
SIMDTests.test("Sum") {
let a = SIMD4<Float>(1, 2, 3, 4)

Expand All @@ -32,7 +29,6 @@ SIMDTests.test("Sum") {
expectEqual(10, val1)
expectEqual(SIMD4<Float>(3, 3, 3, 3), pb1(3))
}
*/

SIMDTests.test("Identity") {
let a = SIMD4<Float>(1, 2, 3, 4)
Expand Down Expand Up @@ -289,9 +285,6 @@ SIMDTests.test("Generics") {
expectEqual(SIMD3<Double>(5, 10, 15), val4)
expectEqual((SIMD3<Double>(5, 5, 5), 6), pb4(g))

// FIXME(TF-1103): Derivative registration does not yet support
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
/*
func testSum<Scalar, SIMDType: SIMD>(x: SIMDType) -> Scalar
where SIMDType.Scalar == Scalar,
SIMDType : Differentiable,
Expand All @@ -304,7 +297,6 @@ SIMDTests.test("Generics") {
let (val5, pb5) = valueWithPullback(at: a, of: simd3Sum)
expectEqual(6, val5)
expectEqual(SIMD3<Double>(7, 7, 7), pb5(7))
*/
}

runAllTests()
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@_alwaysEmitIntoClient
public func f(_ x: Float) -> Float {
x
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import _Differentiation

@derivative(of: f)
@_alwaysEmitIntoClient
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 42 * $0 })
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@_alwaysEmitIntoClient
public func f(_ x: Float) -> Float {
x
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import MultiModule1
import _Differentiation

@derivative(of: f)
@_alwaysEmitIntoClient
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 42 * $0 })
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import _Differentiation

public protocol Protocol {
var x : Float {get set}
init()
}

extension Protocol {
public init(_ val: Float) {
self.init()
x = val
}

@_alwaysEmitIntoClient
public func sum() -> Float { x }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import MultiModuleProtocol1
import _Differentiation

extension Protocol where Self: Differentiable, Self.TangentVector == Self {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _vjpSum() -> (
value: Float, pullback: (Float) -> Self.TangentVector
) {
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
}

@_alwaysEmitIntoClient
@derivative(of: sum)
public func _jvpSum() -> (
value: Float, differential: (Self.TangentVector) -> Float
) {
(value: self.x, differential: { 42 * $0.x })
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import MultiModuleProtocol1
import MultiModuleProtocol2
import _Differentiation

public struct Struct : Protocol {
private var _x : Float
public var x : Float {
get { _x }
set { _x = newValue }
}
public init() { _x = 0 }
}

extension Struct : AdditiveArithmetic {
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
public static var zero: Self { Self(0) }
}

extension Struct : Differentiable {
public typealias TangentVector = Self
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
public struct Struct {
public var x : Float
public typealias TangentVector = Self
public init() { x = 0 }
}

extension Struct {
public init(_ val: Float) {
self.init()
x = val
}

@_alwaysEmitIntoClient
public func sum() -> Float { x }
}

extension Struct : AdditiveArithmetic {
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
public static var zero: Self { Self(0) }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import MultiModuleStruct1
import _Differentiation

extension Struct : Differentiable {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _vjpSum() -> (
value: Float, pullback: (Float) -> Self.TangentVector
) {
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
}

@_alwaysEmitIntoClient
@derivative(of: sum)
public func _jvpSum() -> (
value: Float, differential: (Self.TangentVector) -> Float
) {
(value: self.x, differential: { 42 * $0.x })
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import MultiModuleStruct1
import _Differentiation

extension Struct : Differentiable {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _vjpSum() -> (
value: Float, pullback: (Float) -> Self.TangentVector
) {
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
}
}
Loading