@@ -289,222 +289,6 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso
289
289
290
290
// -----
291
291
292
- // ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n]
293
- func.func @ragged_dot_non_contracting (%lhs : tensor <2 x11 x5 xf32 >, %rhs : tensor <3 x2 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 > {
294
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
295
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
296
- lhs_batching_dimensions = [0 ],
297
- rhs_batching_dimensions = [1 ],
298
- lhs_contracting_dimensions = [2 ],
299
- rhs_contracting_dimensions = [2 ],
300
- lhs_ragged_dimensions = [1 ],
301
- rhs_group_dimensions = [0 ]
302
- >,
303
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
304
- } : (tensor <2 x11 x5 xf32 >, tensor <3 x2 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 >
305
- func.return %0 : tensor <2 x11 x7 xf32 >
306
- }
307
-
308
- // -----
309
-
310
- // ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n]
311
- func.func @ragged_dot_contracting (%lhs : tensor <2 x11 x5 xf32 >, %rhs : tensor <2 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <3 x2 x11 x7 xf32 > {
312
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
313
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
314
- lhs_batching_dimensions = [0 ],
315
- rhs_batching_dimensions = [0 ],
316
- lhs_contracting_dimensions = [2 ],
317
- rhs_contracting_dimensions = [1 ],
318
- lhs_ragged_dimensions = [2 ],
319
- rhs_group_dimensions = []
320
- >,
321
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
322
- } : (tensor <2 x11 x5 xf32 >, tensor <2 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <3 x2 x11 x7 xf32 >
323
- func.return %0 : tensor <3 x2 x11 x7 xf32 >
324
- }
325
-
326
- // -----
327
-
328
- // ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n]
329
- func.func @ragged_dot_batch (%lhs : tensor <3 x11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 > {
330
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
331
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
332
- lhs_batching_dimensions = [0 ],
333
- rhs_batching_dimensions = [0 ],
334
- lhs_contracting_dimensions = [2 ],
335
- rhs_contracting_dimensions = [1 ],
336
- lhs_ragged_dimensions = [0 ],
337
- rhs_group_dimensions = []
338
- >,
339
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
340
- } : (tensor <3 x11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 >
341
- func.return %0 : tensor <3 x11 x7 xf32 >
342
- }
343
-
344
- // -----
345
-
346
- func.func @ragged_dot_incompatible_contracting_dims (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x2 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
347
- // @expected-error@+1 {{contracting dimension sizes must match}}
348
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
349
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
350
- lhs_batching_dimensions = [],
351
- rhs_batching_dimensions = [],
352
- lhs_contracting_dimensions = [1 ],
353
- rhs_contracting_dimensions = [1 ],
354
- lhs_ragged_dimensions = [0 ],
355
- rhs_group_dimensions = [0 ]
356
- >,
357
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
358
- } : (tensor <11 x5 xf32 >, tensor <3 x2 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
359
- func.return %0 : tensor <11 x7 xf32 >
360
- }
361
-
362
- // -----
363
-
364
- func.func @ragged_dot_group_sizes_incorrect_rank (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 x2 xi64 >) -> tensor <11 x7 xf32 > {
365
- // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
366
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
367
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
368
- lhs_batching_dimensions = [],
369
- rhs_batching_dimensions = [],
370
- lhs_contracting_dimensions = [1 ],
371
- rhs_contracting_dimensions = [1 ],
372
- lhs_ragged_dimensions = [0 ],
373
- rhs_group_dimensions = [0 ]
374
- >,
375
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
376
- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 x2 xi64 >) -> tensor <11 x7 xf32 >
377
- func.return %0 : tensor <11 x7 xf32 >
378
- }
379
-
380
- // -----
381
-
382
- func.func @ragged_dot_group_sizes_incorrect_shape (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <2 xi64 >) -> tensor <11 x7 xf32 > {
383
- // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
384
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
385
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
386
- lhs_batching_dimensions = [],
387
- rhs_batching_dimensions = [],
388
- lhs_contracting_dimensions = [1 ],
389
- rhs_contracting_dimensions = [1 ],
390
- lhs_ragged_dimensions = [0 ],
391
- rhs_group_dimensions = [0 ]
392
- >,
393
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
394
- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <2 xi64 >) -> tensor <11 x7 xf32 >
395
- func.return %0 : tensor <11 x7 xf32 >
396
- }
397
-
398
- // -----
399
-
400
- func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
401
- // @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}}
402
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
403
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
404
- lhs_batching_dimensions = [],
405
- rhs_batching_dimensions = [],
406
- lhs_contracting_dimensions = [1 ],
407
- rhs_contracting_dimensions = [1 ],
408
- lhs_ragged_dimensions = [0 , 1 ],
409
- rhs_group_dimensions = [0 ]
410
- >,
411
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
412
- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
413
- func.return %0 : tensor <11 x7 xf32 >
414
- }
415
-
416
- // -----
417
-
418
- func.func @ragged_dot_rhs_group_dim_is_batch (%lhs : tensor <3 x11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 > {
419
- // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}}
420
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
421
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
422
- lhs_batching_dimensions = [0 ],
423
- rhs_batching_dimensions = [0 ],
424
- lhs_contracting_dimensions = [2 ],
425
- rhs_contracting_dimensions = [1 ],
426
- lhs_ragged_dimensions = [1 ],
427
- rhs_group_dimensions = [0 ]
428
- >,
429
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
430
- } : (tensor <3 x11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 >
431
- func.return %0 : tensor <3 x11 x7 xf32 >
432
- }
433
-
434
- // -----
435
-
436
- func.func @ragged_dot_rhs_group_dim_is_contracting (%lhs : tensor <11 x3 xf32 >, %rhs : tensor <3 x3 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
437
- // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}}
438
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
439
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
440
- lhs_batching_dimensions = [],
441
- rhs_batching_dimensions = [],
442
- lhs_contracting_dimensions = [1 ],
443
- rhs_contracting_dimensions = [1 ],
444
- lhs_ragged_dimensions = [0 ],
445
- rhs_group_dimensions = [1 ]
446
- >,
447
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
448
- } : (tensor <11 x3 xf32 >, tensor <3 x3 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
449
- func.return %0 : tensor <11 x7 xf32 >
450
- }
451
-
452
- // -----
453
-
454
- func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch (%lhs : tensor <2 x11 x5 xf32 >, %rhs : tensor <3 x2 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 > {
455
- // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
456
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
457
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
458
- lhs_batching_dimensions = [0 ],
459
- rhs_batching_dimensions = [1 ],
460
- lhs_contracting_dimensions = [2 ],
461
- rhs_contracting_dimensions = [2 ],
462
- lhs_ragged_dimensions = [0 ],
463
- rhs_group_dimensions = [0 ]
464
- >,
465
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
466
- } : (tensor <2 x11 x5 xf32 >, tensor <3 x2 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 >
467
- func.return %0 : tensor <2 x11 x7 xf32 >
468
- }
469
-
470
- // -----
471
-
472
- func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
473
- // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
474
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
475
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
476
- lhs_batching_dimensions = [],
477
- rhs_batching_dimensions = [],
478
- lhs_contracting_dimensions = [1 ],
479
- rhs_contracting_dimensions = [1 ],
480
- lhs_ragged_dimensions = [1 ],
481
- rhs_group_dimensions = [0 ]
482
- >,
483
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
484
- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
485
- func.return %0 : tensor <11 x7 xf32 >
486
- }
487
-
488
- // -----
489
-
490
- func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
491
- // @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}}
492
- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
493
- ragged_dot_dimension_numbers = #chlo.ragged_dot <
494
- lhs_batching_dimensions = [],
495
- rhs_batching_dimensions = [],
496
- lhs_contracting_dimensions = [1 ],
497
- rhs_contracting_dimensions = [0 ],
498
- lhs_ragged_dimensions = [0 ],
499
- rhs_group_dimensions = []
500
- >,
501
- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
502
- } : (tensor <11 x5 xf32 >, tensor <5 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
503
- func.return %0 : tensor <11 x7 xf32 >
504
- }
505
-
506
- // -----
507
-
508
292
func.func @top_k (%arg0 : tensor <f32 >) {
509
293
// expected-error @+2 {{failed to infer returned types}}
510
294
// @expected-error @+1{{operand's rank must be at least 1}}
0 commit comments