Skip to content

Commit dcd164f

Browse files
authored
Merge pull request #28756 from rxwei/default-derivative
2 parents aa9281a + aa5ad26 commit dcd164f

File tree

1 file changed

+77
-7
lines changed

1 file changed

+77
-7
lines changed

docs/DifferentiableProgramming.md

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,8 +1452,11 @@ making the other function linear.
14521452

14531453
A protocol requirement or class method/property/subscript can be made
14541454
differentiable via a derivative function or transpose function defined in an
1455-
extension. A dispatched call to such a member can be differentiated even if the
1456-
concrete implementation is not differentiable.
1455+
extension. When a protocol requirement is not marked with `@differentiable` but
1456+
has been made differentiable by a `@derivative` or `@transpose` declaration in a
1457+
protocol extension, a dispatched call to such a member can be differentiated,
1458+
and the derivative or transpose is always the one provided in the protocol
1459+
extension.
14571460

14581461
#### Linear maps
14591462

@@ -1732,8 +1735,8 @@ where Self: Differentiable & FloatingPoint, Self == Self.TangentVector {
17321735

17331736
@inlinable
17341737
@derivative(of: log)
1735-
static func _(_ x: Self) -> (value: Self, differential: @differentiable(linear) (Self) -> Self) { dx in
1736-
(log(x), { 1 / x * dx })
1738+
static func _(_ x: Self) -> (value: Self, differential: @differentiable(linear) (Self) -> Self) {
1739+
(log(x), { dx in 1 / x * dx })
17371740
}
17381741

17391742
@inlinable
@@ -1750,6 +1753,73 @@ where Self: Differentiable & FloatingPoint, Self == Self.TangentVector {
17501753
}
17511754
```
17521755

1756+
#### Default derivatives
1757+
1758+
In a protocol extension, class definition, or class extension, providing a
1759+
derivative or transpose for a protocol extension or a non-final class member is
1760+
considered as providing a default derivative for that member. Types that conform
1761+
to the protocol or inherit from the class can inherit the default derivative.
1762+
1763+
If the original member does not have a `@differentiable` attribute, a default
1764+
derivative is implicitly added to all conforming/overriding implementations.
1765+
1766+
```swift
1767+
protocol P {
1768+
func foo(_ x: Float) -> Float
1769+
}
1770+
1771+
extension P {
1772+
@derivative(of: foo(x:))
1773+
func _(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
1774+
(value: foo(x), differential: { _ in 42 })
1775+
}
1776+
}
1777+
1778+
struct S: P {
1779+
func foo(_ x: Float) -> Float {
1780+
33
1781+
}
1782+
}
1783+
1784+
let s = S()
1785+
let d = derivative(at: 0) { x in
1786+
s.foo(x)
1787+
} // ==> 42
1788+
```
1789+
1790+
When a protocol requirement or class member is marked with `@differentiable`, it
1791+
is considered as a _differentiability customization point_. This means that all
1792+
conforming/overriding implementation must provide a corresponding
1793+
`@differentiable` attribute, which causes the implementation to be
1794+
differentiated. To inherit the default derivative without differentiating the
1795+
implementation, add `default` to the `@differentiable` attribute.
1796+
1797+
```swift
1798+
protocol P {
1799+
@differentiable
1800+
func foo(_ x: Float) -> Float
1801+
}
1802+
1803+
extension P {
1804+
@derivative(of: foo(x:))
1805+
func _(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
1806+
(value: foo(x), differential: { _ in 42 })
1807+
}
1808+
}
1809+
1810+
struct S: P {
1811+
@differentiable(default) // Inherits default derivative for `P.foo(_:)`.
1812+
func foo(_ x: Float) -> Float {
1813+
33
1814+
}
1815+
}
1816+
1817+
let s = S()
1818+
let d = derivative(at: 0) { x in
1819+
s.foo(x)
1820+
} // ==> 42
1821+
```
1822+
17531823
### Differentiable function types
17541824

17551825
Differentiability is a fundamental mathematical concept that applies not only to
@@ -2241,13 +2311,13 @@ whether the derivative is always zero and warns the user.
22412311

22422312
```swift
22432313
let grad = gradient(at: 1.0) { x in
2244-
3.squareRoot()
2314+
Double(3).squareRoot()
22452315
}
22462316
```
22472317

22482318
```console
2249-
test.swift:4:18: warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()' to make it explicit?
2250-
3.squareRoot()
2319+
test.swift:4:18: warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)' to make it explicit?
2320+
Double(3).squareRoot()
22512321
^
22522322
withoutDerivative(at:)
22532323
```

0 commit comments

Comments
 (0)