59
59
return zeros
60
60
} )
61
61
}
62
+
63
+ @inlinable
64
+ @derivative ( of: subscript ( _: ) )
65
+ internal func _jvpSubscript( index: Int )
66
+ -> ( value: Scalar , differential: ( TangentVector ) -> Scalar . TangentVector )
67
+ {
68
+ return ( self [ index] , { v in
69
+ return . init( v [ index] )
70
+ } )
71
+ }
62
72
}
63
73
64
74
% end
82
92
} )
83
93
}
84
94
95
+ @inlinable
96
+ @derivative ( of: + )
97
+ static func _jvpAdd( lhs: Self , rhs: Self )
98
+ -> (
99
+ value: Self , differential: ( TangentVector , TangentVector ) -> TangentVector
100
+ )
101
+ {
102
+ return ( lhs + rhs, { ltan, rtan in
103
+ return ltan + rtan
104
+ } )
105
+ }
106
+
85
107
@inlinable
86
108
@derivative ( of: - )
87
109
static func _vjpSubtract( lhs: Self , rhs: Self )
@@ -94,6 +116,18 @@ where
94
116
} )
95
117
}
96
118
119
+ @inlinable
120
+ @derivative ( of: - )
121
+ static func _jvpSubtract( lhs: Self , rhs: Self )
122
+ -> (
123
+ value: Self , differential: ( TangentVector , TangentVector ) -> TangentVector
124
+ )
125
+ {
126
+ return ( lhs - rhs, { ltan, rtan in
127
+ return ltan - rtan
128
+ } )
129
+ }
130
+
97
131
@inlinable
98
132
@derivative ( of: - )
99
133
static func _vjpNegate( rhs: Self )
@@ -103,6 +137,16 @@ where
103
137
return - v
104
138
} )
105
139
}
140
+
141
+ @inlinable
142
+ @derivative ( of: - )
143
+ static func _jvpNegate( rhs: Self )
144
+ -> ( value: Self , differential: ( TangentVector ) -> ( TangentVector ) )
145
+ {
146
+ return ( - rhs, { v in
147
+ return - v
148
+ } )
149
+ }
106
150
}
107
151
108
152
extension SIMD
@@ -124,6 +168,18 @@ where
124
168
} )
125
169
}
126
170
171
+ @inlinable
172
+ @derivative ( of: * )
173
+ static func _jvpMultiply( lhs: Self , rhs: Self )
174
+ -> (
175
+ value: Self , differential: ( TangentVector , TangentVector ) -> TangentVector
176
+ )
177
+ {
178
+ return ( lhs * rhs, { ltan, rtan in
179
+ return lhs * ltan + rtan * rhs
180
+ } )
181
+ }
182
+
127
183
@inlinable
128
184
@derivative ( of: / )
129
185
static func _vjpDivide( lhs: Self , rhs: Self )
@@ -135,6 +191,18 @@ where
135
191
( v / rhs, - lhs / ( rhs * rhs) * v)
136
192
} )
137
193
}
194
+
195
+ @inlinable
196
+ @derivative ( of: / )
197
+ static func _jvpDivide( lhs: Self , rhs: Self )
198
+ -> (
199
+ value: Self , differential: ( TangentVector , TangentVector ) -> TangentVector
200
+ )
201
+ {
202
+ return ( lhs / rhs, { ltan, rtan in
203
+ ( ltan * rhs - lhs * rtan) / ( rhs * rhs)
204
+ } )
205
+ }
138
206
}
139
207
140
208
extension SIMD
@@ -156,6 +224,17 @@ where
156
224
} )
157
225
}
158
226
227
+ @inlinable
228
+ @derivative ( of: + )
229
+ static func _jvpAdd( lhs: Scalar , rhs: Self ) -> (
230
+ value: Self ,
231
+ differential: ( Scalar . TangentVector , TangentVector ) -> TangentVector
232
+ ) {
233
+ return ( lhs + rhs, { ltan, rtan in
234
+ return ltan + rtan
235
+ } )
236
+ }
237
+
159
238
@inlinable
160
239
@derivative ( of: - )
161
240
static func _vjpSubtract( lhs: Scalar , rhs: Self ) -> (
@@ -167,6 +246,17 @@ where
167
246
} )
168
247
}
169
248
249
+ @inlinable
250
+ @derivative ( of: - )
251
+ static func _jvpSubtract( lhs: Scalar , rhs: Self ) -> (
252
+ value: Self ,
253
+ differential: ( Scalar . TangentVector , TangentVector ) -> TangentVector
254
+ ) {
255
+ return ( lhs - rhs, { ltan, rtan in
256
+ return ltan - rtan
257
+ } )
258
+ }
259
+
170
260
@inlinable
171
261
@derivative ( of: + )
172
262
static func _vjpAdd( lhs: Self , rhs: Scalar ) -> (
@@ -178,6 +268,17 @@ where
178
268
} )
179
269
}
180
270
271
+ @inlinable
272
+ @derivative ( of: + )
273
+ static func _jvpAdd( lhs: Self , rhs: Scalar ) -> (
274
+ value: Self ,
275
+ differential: ( TangentVector , Scalar . TangentVector ) -> TangentVector
276
+ ) {
277
+ return ( lhs + rhs, { ltan, rtan in
278
+ return ltan + rtan
279
+ } )
280
+ }
281
+
181
282
@inlinable
182
283
@derivative ( of: - )
183
284
static func _vjpSubtract( lhs: Self , rhs: Scalar ) -> (
@@ -188,6 +289,17 @@ where
188
289
return ( v, - v. sum ( ) )
189
290
} )
190
291
}
292
+
293
+ @inlinable
294
+ @derivative ( of: - )
295
+ static func _jvpSubtract( lhs: Self , rhs: Scalar ) -> (
296
+ value: Self ,
297
+ differential: ( TangentVector , Scalar . TangentVector ) -> TangentVector
298
+ ) {
299
+ return ( lhs - rhs, { ltan, rtan in
300
+ return ltan - rtan
301
+ } )
302
+ }
191
303
}
192
304
193
305
extension SIMD
@@ -209,6 +321,17 @@ where
209
321
} )
210
322
}
211
323
324
+ @inlinable
325
+ @derivative ( of: * )
326
+ static func _jvpMultiply( lhs: Self , rhs: Scalar ) -> (
327
+ value: Self ,
328
+ differential: ( TangentVector , Scalar . TangentVector ) -> TangentVector
329
+ ) {
330
+ return ( lhs * rhs, { ltan, rtan in
331
+ return lhs * rtan + ltan * rhs
332
+ } )
333
+ }
334
+
212
335
@inlinable
213
336
@derivative ( of: / )
214
337
static func _vjpDivide( lhs: Self , rhs: Scalar ) -> (
@@ -220,6 +343,17 @@ where
220
343
} )
221
344
}
222
345
346
+ @inlinable
347
+ @derivative ( of: / )
348
+ static func _jvpDivide( lhs: Self , rhs: Scalar ) -> (
349
+ value: Self ,
350
+ differential: ( TangentVector , Scalar . TangentVector ) -> TangentVector
351
+ ) {
352
+ return ( lhs / rhs, { ltan, rtan in
353
+ ( ltan * rhs - lhs * rtan) / ( rhs * rhs)
354
+ } )
355
+ }
356
+
223
357
@inlinable
224
358
@derivative ( of: * )
225
359
static func _vjpMultiply( lhs: Scalar , rhs: Self ) -> (
@@ -231,6 +365,17 @@ where
231
365
} )
232
366
}
233
367
368
+ @inlinable
369
+ @derivative ( of: * )
370
+ static func _jvpMultiply( lhs: Scalar , rhs: Self ) -> (
371
+ value: Self ,
372
+ differential: ( Scalar . TangentVector , TangentVector ) -> TangentVector
373
+ ) {
374
+ return ( lhs * rhs, { ltan, rtan in
375
+ return lhs * rtan + ltan * rhs
376
+ } )
377
+ }
378
+
234
379
@inlinable
235
380
@derivative ( of: / )
236
381
static func _vjpDivide( lhs: Scalar , rhs: Self ) -> (
@@ -241,6 +386,17 @@ where
241
386
( ( v / rhs) . sum ( ) , - lhs / ( rhs * rhs) * v)
242
387
} )
243
388
}
389
+
390
+ @inlinable
391
+ @derivative ( of: / )
392
+ static func _jvpDivide( lhs: Scalar , rhs: Self ) -> (
393
+ value: Self ,
394
+ differential: ( Scalar . TangentVector , TangentVector ) -> TangentVector
395
+ ) {
396
+ return ( lhs / rhs, { ltan, rtan in
397
+ ( ltan * rhs - lhs * rtan) / ( rhs * rhs)
398
+ } )
399
+ }
244
400
}
245
401
246
402
// FIXME(TF-1103): Derivative registration does not yet support
@@ -261,6 +417,14 @@ where
261
417
) {
262
418
return (sum(), { v in Self(repeating: Scalar(v)) })
263
419
}
420
+
421
+ @inlinable
422
+ @derivative(of: sum)
423
+ func _jvpSum() -> (
424
+ value: Scalar, differential (TangentVector) -> Scalar.TangentVector
425
+ ) {
426
+ return (sum(), { v in v.sum() }
427
+ }
264
428
}
265
429
*/
266
430
@@ -279,4 +443,12 @@ where
279
443
{
280
444
return ( Self ( repeating: value) , { v in v. sum ( ) } )
281
445
}
446
+
447
+ @inlinable
448
+ @derivative ( of: init ( repeating: ) )
449
+ static func _jvpInit( repeating value: Scalar )
450
+ -> ( value: Self , differential: ( Scalar . TangentVector ) -> TangentVector )
451
+ {
452
+ return ( Self ( repeating: value) , { v in Self ( repeating: v) } )
453
+ }
282
454
}
0 commit comments