Skip to content

Commit cff144a

Browse files
authored
[AutoDiff][stdlib] Add JVPs to SIMDDifferentiation.swift.gyb (#32854)
1 parent f807297 commit cff144a

File tree

2 files changed

+410
-0
lines changed

2 files changed

+410
-0
lines changed

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ where
5959
return zeros
6060
})
6161
}
62+
63+
@inlinable
64+
@derivative(of: subscript(_:))
65+
internal func _jvpSubscript(index: Int)
66+
-> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector)
67+
{
68+
return (self[index], { v in
69+
return .init(v[index])
70+
})
71+
}
6272
}
6373

6474
%end
@@ -82,6 +92,18 @@ where
8292
})
8393
}
8494

95+
@inlinable
96+
@derivative(of: +)
97+
static func _jvpAdd(lhs: Self, rhs: Self)
98+
-> (
99+
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
100+
)
101+
{
102+
return (lhs + rhs, { ltan, rtan in
103+
return ltan + rtan
104+
})
105+
}
106+
85107
@inlinable
86108
@derivative(of: -)
87109
static func _vjpSubtract(lhs: Self, rhs: Self)
@@ -94,6 +116,18 @@ where
94116
})
95117
}
96118

119+
@inlinable
120+
@derivative(of: -)
121+
static func _jvpSubtract(lhs: Self, rhs: Self)
122+
-> (
123+
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
124+
)
125+
{
126+
return (lhs - rhs, { ltan, rtan in
127+
return ltan - rtan
128+
})
129+
}
130+
97131
@inlinable
98132
@derivative(of: -)
99133
static func _vjpNegate(rhs: Self)
@@ -103,6 +137,16 @@ where
103137
return -v
104138
})
105139
}
140+
141+
@inlinable
142+
@derivative(of: -)
143+
static func _jvpNegate(rhs: Self)
144+
-> (value: Self, differential: (TangentVector) -> (TangentVector))
145+
{
146+
return (-rhs, { v in
147+
return -v
148+
})
149+
}
106150
}
107151

108152
extension SIMD
@@ -124,6 +168,18 @@ where
124168
})
125169
}
126170

171+
@inlinable
172+
@derivative(of: *)
173+
static func _jvpMultiply(lhs: Self, rhs: Self)
174+
-> (
175+
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
176+
)
177+
{
178+
return (lhs * rhs, { ltan, rtan in
179+
return lhs * ltan + rtan * rhs
180+
})
181+
}
182+
127183
@inlinable
128184
@derivative(of: /)
129185
static func _vjpDivide(lhs: Self, rhs: Self)
@@ -135,6 +191,18 @@ where
135191
(v / rhs, -lhs / (rhs * rhs) * v)
136192
})
137193
}
194+
195+
@inlinable
196+
@derivative(of: /)
197+
static func _jvpDivide(lhs: Self, rhs: Self)
198+
-> (
199+
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
200+
)
201+
{
202+
return ( lhs / rhs, { ltan, rtan in
203+
(ltan * rhs - lhs * rtan) / (rhs * rhs)
204+
})
205+
}
138206
}
139207

140208
extension SIMD
@@ -156,6 +224,17 @@ where
156224
})
157225
}
158226

227+
@inlinable
228+
@derivative(of: +)
229+
static func _jvpAdd(lhs: Scalar, rhs: Self) -> (
230+
value: Self,
231+
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
232+
) {
233+
return (lhs + rhs, { ltan, rtan in
234+
return ltan + rtan
235+
})
236+
}
237+
159238
@inlinable
160239
@derivative(of: -)
161240
static func _vjpSubtract(lhs: Scalar, rhs: Self) -> (
@@ -167,6 +246,17 @@ where
167246
})
168247
}
169248

249+
@inlinable
250+
@derivative(of: -)
251+
static func _jvpSubtract(lhs: Scalar, rhs: Self) -> (
252+
value: Self,
253+
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
254+
) {
255+
return (lhs - rhs, { ltan, rtan in
256+
return ltan - rtan
257+
})
258+
}
259+
170260
@inlinable
171261
@derivative(of: +)
172262
static func _vjpAdd(lhs: Self, rhs: Scalar) -> (
@@ -178,6 +268,17 @@ where
178268
})
179269
}
180270

271+
@inlinable
272+
@derivative(of: +)
273+
static func _jvpAdd(lhs: Self, rhs: Scalar) -> (
274+
value: Self,
275+
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
276+
) {
277+
return (lhs + rhs, { ltan, rtan in
278+
return ltan + rtan
279+
})
280+
}
281+
181282
@inlinable
182283
@derivative(of: -)
183284
static func _vjpSubtract(lhs: Self, rhs: Scalar) -> (
@@ -188,6 +289,17 @@ where
188289
return (v, -v.sum())
189290
})
190291
}
292+
293+
@inlinable
294+
@derivative(of: -)
295+
static func _jvpSubtract(lhs: Self, rhs: Scalar) -> (
296+
value: Self,
297+
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
298+
) {
299+
return (lhs - rhs, { ltan, rtan in
300+
return ltan - rtan
301+
})
302+
}
191303
}
192304

193305
extension SIMD
@@ -209,6 +321,17 @@ where
209321
})
210322
}
211323

324+
@inlinable
325+
@derivative(of: *)
326+
static func _jvpMultiply(lhs: Self, rhs: Scalar) -> (
327+
value: Self,
328+
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
329+
) {
330+
return (lhs * rhs, { ltan, rtan in
331+
return lhs * rtan + ltan * rhs
332+
})
333+
}
334+
212335
@inlinable
213336
@derivative(of: /)
214337
static func _vjpDivide(lhs: Self, rhs: Scalar) -> (
@@ -220,6 +343,17 @@ where
220343
})
221344
}
222345

346+
@inlinable
347+
@derivative(of: /)
348+
static func _jvpDivide(lhs: Self, rhs: Scalar) -> (
349+
value: Self,
350+
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
351+
) {
352+
return (lhs / rhs, { ltan, rtan in
353+
(ltan * rhs - lhs * rtan) / (rhs * rhs)
354+
})
355+
}
356+
223357
@inlinable
224358
@derivative(of: *)
225359
static func _vjpMultiply(lhs: Scalar, rhs: Self) -> (
@@ -231,6 +365,17 @@ where
231365
})
232366
}
233367

368+
@inlinable
369+
@derivative(of: *)
370+
static func _jvpMultiply(lhs: Scalar, rhs: Self) -> (
371+
value: Self,
372+
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
373+
) {
374+
return (lhs * rhs, { ltan, rtan in
375+
return lhs * rtan + ltan * rhs
376+
})
377+
}
378+
234379
@inlinable
235380
@derivative(of: /)
236381
static func _vjpDivide(lhs: Scalar, rhs: Self) -> (
@@ -241,6 +386,17 @@ where
241386
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
242387
})
243388
}
389+
390+
@inlinable
391+
@derivative(of: /)
392+
static func _jvpDivide(lhs: Scalar, rhs: Self) -> (
393+
value: Self,
394+
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
395+
) {
396+
return (lhs / rhs, { ltan, rtan in
397+
(ltan * rhs - lhs * rtan) / (rhs * rhs)
398+
})
399+
}
244400
}
245401

246402
// FIXME(TF-1103): Derivative registration does not yet support
@@ -261,6 +417,14 @@ where
261417
) {
262418
return (sum(), { v in Self(repeating: Scalar(v)) })
263419
}
420+
421+
@inlinable
422+
@derivative(of: sum)
423+
func _jvpSum() -> (
424+
value: Scalar, differential (TangentVector) -> Scalar.TangentVector
425+
) {
426+
return (sum(), { v in v.sum() }
427+
}
264428
}
265429
*/
266430

@@ -279,4 +443,12 @@ where
279443
{
280444
return (Self(repeating: value), { v in v.sum() })
281445
}
446+
447+
@inlinable
448+
@derivative(of: init(repeating:))
449+
static func _jvpInit(repeating value: Scalar)
450+
-> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
451+
{
452+
return (Self(repeating: value), { v in Self(repeating: v) })
453+
}
282454
}

0 commit comments

Comments
 (0)