@@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154
154
let hasVerifier = 1;
155
155
}
156
156
157
- def Linalg_WinogradFilterTransformOp :
158
- Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
157
+ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158
+ [AllElementTypesMatch<["filter", "output"]>,
159
+ DeclareOpInterfaceMethods<TilingInterface,
160
+ ["getIterationDomain",
161
+ "getLoopIteratorTypes",
162
+ "getResultTilePosition",
163
+ "getTiledImplementation"]>]> {
159
164
let summary = "Winograd filter transform operator";
160
165
let description = [{
161
166
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
190
195
`outs` `(` $output `:` type($output) `)`
191
196
`->` type($result)
192
197
}];
198
+ let extraClassDeclaration = [{
199
+ ShapedType getFilterOperandType() {
200
+ return cast<ShapedType>(getFilter().getType());
201
+ }
202
+ ShapedType getOutputOperandType() {
203
+ return cast<ShapedType>(getOutput().getType());
204
+ }
205
+ int64_t getFilterOperandRank() {
206
+ return getFilterOperandType().getRank();
207
+ }
208
+ int64_t getOutputOperandRank() {
209
+ return getOutputOperandType().getRank();
210
+ }
211
+ int64_t getFilterFDim() {
212
+ return 0;
213
+ }
214
+ int64_t getFilterHDim() {
215
+ return 1;
216
+ }
217
+ int64_t getFilterWDim() {
218
+ return 2;
219
+ }
220
+ int64_t getFilterCDim() {
221
+ return 3;
222
+ }
223
+ }];
193
224
let hasVerifier = 1;
194
225
}
195
226
196
- def Linalg_WinogradInputTransformOp :
197
- Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
227
+ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
228
+ [AllElementTypesMatch<["input", "output"]>,
229
+ DeclareOpInterfaceMethods<TilingInterface,
230
+ ["getIterationDomain",
231
+ "getLoopIteratorTypes",
232
+ "getResultTilePosition",
233
+ "getTiledImplementation"]>]> {
198
234
let summary = "Winograd input transform operator";
199
235
let description = [{
200
236
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
229
265
`outs` `(` $output `:` type($output) `)`
230
266
`->` type($result)
231
267
}];
268
+ let extraClassDeclaration = [{
269
+ ShapedType getInputOperandType() {
270
+ return cast<ShapedType>(getInput().getType());
271
+ }
272
+ ShapedType getOutputOperandType() {
273
+ return cast<ShapedType>(getOutput().getType());
274
+ }
275
+ int64_t getInputOperandRank() {
276
+ return getInputOperandType().getRank();
277
+ }
278
+ int64_t getOutputOperandRank() {
279
+ return getOutputOperandType().getRank();
280
+ }
281
+ int64_t getInputNDim() {
282
+ return 0;
283
+ }
284
+ int64_t getInputHDim() {
285
+ return 1;
286
+ }
287
+ int64_t getInputWDim() {
288
+ return 2;
289
+ }
290
+ int64_t getInputCDim() {
291
+ return 3;
292
+ }
293
+ int64_t getOutputAlphaHDim() {
294
+ return 0;
295
+ }
296
+ int64_t getOutputAlphaWDim() {
297
+ return 1;
298
+ }
299
+ int64_t getOutputTileHDim() {
300
+ return 2;
301
+ }
302
+ int64_t getOutputTileWDim() {
303
+ return 3;
304
+ }
305
+ int64_t getOutputNDim() {
306
+ return 4;
307
+ }
308
+ int64_t getOutputCDim() {
309
+ return 5;
310
+ }
311
+ }];
232
312
let hasVerifier = 1;
233
313
}
234
314
235
- def Linalg_WinogradOutputTransformOp :
236
- Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
315
+ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
316
+ [AllElementTypesMatch<["value", "output"]>,
317
+ DeclareOpInterfaceMethods<TilingInterface,
318
+ ["getIterationDomain",
319
+ "getLoopIteratorTypes",
320
+ "getResultTilePosition",
321
+ "getTiledImplementation"]>]> {
237
322
let summary = "Winograd output transform operator";
238
323
let description = [{
239
324
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
268
353
`outs` `(` $output `:` type($output) `)`
269
354
`->` type($result)
270
355
}];
356
+ let extraClassDeclaration = [{
357
+ ShapedType getValueOperandType() {
358
+ return cast<ShapedType>(getValue().getType());
359
+ }
360
+ ShapedType getOutputOperandType() {
361
+ return cast<ShapedType>(getOutput().getType());
362
+ }
363
+ int64_t getValueOperandRank() {
364
+ return getValueOperandType().getRank();
365
+ }
366
+ int64_t getOutputOperandRank() {
367
+ return getOutputOperandType().getRank();
368
+ }
369
+ int64_t getValueAlphaHDim() {
370
+ return 0;
371
+ }
372
+ int64_t getValueAlphaWDim() {
373
+ return 1;
374
+ }
375
+ int64_t getValueTileHDim() {
376
+ return 2;
377
+ }
378
+ int64_t getValueTileWDim() {
379
+ return 3;
380
+ }
381
+ int64_t getValueNDim() {
382
+ return 4;
383
+ }
384
+ int64_t getValueFDim() {
385
+ return 5;
386
+ }
387
+ int64_t getOutputNDim() {
388
+ return 0;
389
+ }
390
+ int64_t getOutputHDim() {
391
+ return 1;
392
+ }
393
+ int64_t getOutputWDim() {
394
+ return 2;
395
+ }
396
+ int64_t getOutputFDim() {
397
+ return 3;
398
+ }
399
+ }];
271
400
let hasVerifier = 1;
272
401
}
273
402
0 commit comments