Skip to content

Commit 38bb2f9

Browse files
authored
Cleaning up tests in chlo_ops.mlir backported twice (#2676)
1 parent 8c7946c commit 38bb2f9

File tree

1 file changed

+0
-216
lines changed

1 file changed

+0
-216
lines changed

stablehlo/tests/ops_chlo.mlir

Lines changed: 0 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -289,222 +289,6 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso
289289

290290
// -----
291291

292-
// ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n]
293-
func.func @ragged_dot_non_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> {
294-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
295-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
296-
lhs_batching_dimensions = [0],
297-
rhs_batching_dimensions = [1],
298-
lhs_contracting_dimensions = [2],
299-
rhs_contracting_dimensions = [2],
300-
lhs_ragged_dimensions = [1],
301-
rhs_group_dimensions = [0]
302-
>,
303-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
304-
} : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32>
305-
func.return %0 : tensor<2x11x7xf32>
306-
}
307-
308-
// -----
309-
310-
// ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n]
311-
func.func @ragged_dot_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x2x11x7xf32> {
312-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
313-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
314-
lhs_batching_dimensions = [0],
315-
rhs_batching_dimensions = [0],
316-
lhs_contracting_dimensions = [2],
317-
rhs_contracting_dimensions = [1],
318-
lhs_ragged_dimensions = [2],
319-
rhs_group_dimensions = []
320-
>,
321-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
322-
} : (tensor<2x11x5xf32>, tensor<2x5x7xf32>, tensor<3xi64>) -> tensor<3x2x11x7xf32>
323-
func.return %0 : tensor<3x2x11x7xf32>
324-
}
325-
326-
// -----
327-
328-
// ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n]
329-
func.func @ragged_dot_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> {
330-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
331-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
332-
lhs_batching_dimensions = [0],
333-
rhs_batching_dimensions = [0],
334-
lhs_contracting_dimensions = [2],
335-
rhs_contracting_dimensions = [1],
336-
lhs_ragged_dimensions = [0],
337-
rhs_group_dimensions = []
338-
>,
339-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
340-
} : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32>
341-
func.return %0 : tensor<3x11x7xf32>
342-
}
343-
344-
// -----
345-
346-
func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rhs : tensor<3x2x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
347-
// @expected-error@+1 {{contracting dimension sizes must match}}
348-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
349-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
350-
lhs_batching_dimensions = [],
351-
rhs_batching_dimensions = [],
352-
lhs_contracting_dimensions = [1],
353-
rhs_contracting_dimensions = [1],
354-
lhs_ragged_dimensions = [0],
355-
rhs_group_dimensions = [0]
356-
>,
357-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
358-
} : (tensor<11x5xf32>, tensor<3x2x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
359-
func.return %0 : tensor<11x7xf32>
360-
}
361-
362-
// -----
363-
364-
func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> {
365-
// @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
366-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
367-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
368-
lhs_batching_dimensions = [],
369-
rhs_batching_dimensions = [],
370-
lhs_contracting_dimensions = [1],
371-
rhs_contracting_dimensions = [1],
372-
lhs_ragged_dimensions = [0],
373-
rhs_group_dimensions = [0]
374-
>,
375-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
376-
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3x2xi64>) -> tensor<11x7xf32>
377-
func.return %0 : tensor<11x7xf32>
378-
}
379-
380-
// -----
381-
382-
func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
383-
// @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
384-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
385-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
386-
lhs_batching_dimensions = [],
387-
rhs_batching_dimensions = [],
388-
lhs_contracting_dimensions = [1],
389-
rhs_contracting_dimensions = [1],
390-
lhs_ragged_dimensions = [0],
391-
rhs_group_dimensions = [0]
392-
>,
393-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
394-
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<2xi64>) -> tensor<11x7xf32>
395-
func.return %0 : tensor<11x7xf32>
396-
}
397-
398-
// -----
399-
400-
func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
401-
// @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}}
402-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
403-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
404-
lhs_batching_dimensions = [],
405-
rhs_batching_dimensions = [],
406-
lhs_contracting_dimensions = [1],
407-
rhs_contracting_dimensions = [1],
408-
lhs_ragged_dimensions = [0, 1],
409-
rhs_group_dimensions = [0]
410-
>,
411-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
412-
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
413-
func.return %0 : tensor<11x7xf32>
414-
}
415-
416-
// -----
417-
418-
func.func @ragged_dot_rhs_group_dim_is_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> {
419-
// @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}}
420-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
421-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
422-
lhs_batching_dimensions = [0],
423-
rhs_batching_dimensions = [0],
424-
lhs_contracting_dimensions = [2],
425-
rhs_contracting_dimensions = [1],
426-
lhs_ragged_dimensions = [1],
427-
rhs_group_dimensions = [0]
428-
>,
429-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
430-
} : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32>
431-
func.return %0 : tensor<3x11x7xf32>
432-
}
433-
434-
// -----
435-
436-
func.func @ragged_dot_rhs_group_dim_is_contracting(%lhs : tensor<11x3xf32>, %rhs : tensor<3x3x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
437-
// @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}}
438-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
439-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
440-
lhs_batching_dimensions = [],
441-
rhs_batching_dimensions = [],
442-
lhs_contracting_dimensions = [1],
443-
rhs_contracting_dimensions = [1],
444-
lhs_ragged_dimensions = [0],
445-
rhs_group_dimensions = [1]
446-
>,
447-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
448-
} : (tensor<11x3xf32>, tensor<3x3x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
449-
func.return %0 : tensor<11x7xf32>
450-
}
451-
452-
// -----
453-
454-
func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> {
455-
// @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
456-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
457-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
458-
lhs_batching_dimensions = [0],
459-
rhs_batching_dimensions = [1],
460-
lhs_contracting_dimensions = [2],
461-
rhs_contracting_dimensions = [2],
462-
lhs_ragged_dimensions = [0],
463-
rhs_group_dimensions = [0]
464-
>,
465-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
466-
} : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32>
467-
func.return %0 : tensor<2x11x7xf32>
468-
}
469-
470-
// -----
471-
472-
func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
473-
// @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
474-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
475-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
476-
lhs_batching_dimensions = [],
477-
rhs_batching_dimensions = [],
478-
lhs_contracting_dimensions = [1],
479-
rhs_contracting_dimensions = [1],
480-
lhs_ragged_dimensions = [1],
481-
rhs_group_dimensions = [0]
482-
>,
483-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
484-
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
485-
func.return %0 : tensor<11x7xf32>
486-
}
487-
488-
// -----
489-
490-
func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tensor<11x5xf32>, %rhs : tensor<5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
491-
// @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}}
492-
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
493-
ragged_dot_dimension_numbers = #chlo.ragged_dot<
494-
lhs_batching_dimensions = [],
495-
rhs_batching_dimensions = [],
496-
lhs_contracting_dimensions = [1],
497-
rhs_contracting_dimensions = [0],
498-
lhs_ragged_dimensions = [0],
499-
rhs_group_dimensions = []
500-
>,
501-
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
502-
} : (tensor<11x5xf32>, tensor<5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
503-
func.return %0 : tensor<11x7xf32>
504-
}
505-
506-
// -----
507-
508292
func.func @top_k(%arg0 : tensor<f32>) {
509293
// expected-error @+2 {{failed to infer returned types}}
510294
// @expected-error @+1{{operand's rank must be at least 1}}

0 commit comments

Comments
 (0)