Skip to content

Commit a274484

Browse files
committed
replace memremf.tensor_store with bufferization.materialize_in_destination as per llvm/llvm-project#71010
1 parent f90b031 commit a274484

File tree

55 files changed

+74
-74
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+74
-74
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Important details to note:
102102
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32>
103103
%extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor<?xf32>
104104
%subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref<?xf32>
105-
memref.tensor_store %extracted_slice, %subview : memref<?xf32>
105+
bufferization.materialize_in_destination %extracted_slice in %subview
106106
```
107107

108108
+ element-wise `arith` and `math` operators are converted to their corresponding `linalg.generic` version.

lib/Conversion/TritonToLinalg/TritonToLinalg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ struct StoreConverter : public OpConversionPattern<triton::StoreOp> {
730730

731731
// 1. Simple case where no mask is used.
732732
if (!mask) {
733-
rewriter.create<memref::TensorStoreOp>(loc, val, ptr);
733+
rewriter.create<bufferization::MaterializeInDestinationOp>(loc, val, ptr);
734734
rewriter.eraseOp(op);
735735
return success();
736736
}
@@ -747,7 +747,7 @@ struct StoreConverter : public OpConversionPattern<triton::StoreOp> {
747747
auto srcSlice = mstate.getExtractSlice(val, loc, rewriter);
748748
auto dstSubview = mstate.getSubview(ptr, loc, rewriter);
749749

750-
rewriter.create<memref::TensorStoreOp>(loc, srcSlice, dstSubview);
750+
rewriter.create<bufferization::MaterializeInDestinationOp>(loc, srcSlice, dstSubview);
751751
rewriter.eraseOp(op);
752752

753753
return success();

test/Conversion/TritonToLinalg/addptr_2d_example.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ module {
6464
// CHECK: } -> tensor<4x256xbf16>
6565
// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
6666
// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_21]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
67-
// CHECK: memref.tensor_store %[[VAL_23:.*]], %[[VAL_22]] : memref<4x256xbf16, strided<[1, ?], offset: ?>>
67+
// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in %[[VAL_22]]
6868
// CHECK: return
6969
// CHECK: }

test/Conversion/TritonToLinalg/addptr_add_value.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ module {
6363
// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<4x256xbf16>
6464
// CHECK: memref.copy %[[VAL_13]], %[[VAL_19]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
6565
// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_19]] restrict writable : memref<4x256xbf16>
66-
// CHECK: memref.tensor_store %[[VAL_20]], %[[VAL_18]] : memref<4x256xbf16, strided<[1, ?], offset: ?>>
66+
// CHECK: bufferization.materialize_in_destination %[[VAL_20]] in %[[VAL_18]]
6767
// CHECK: return
6868
// CHECK: }

test/Conversion/TritonToLinalg/addptr_dim1.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ module {
8181
// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1x256xbf16>
8282
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index
8383
// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1], offset: ?>>
84-
// CHECK: memref.tensor_store [[VAR_2_]], [[VAR_reinterpret_cast_0_]] : memref<1x256xbf16, strided<[256, 1], offset: ?>>
84+
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in [[VAR_reinterpret_cast_0_]]
8585
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[CST_0_]], [[VAR_arg8_:%.+]] = [[CST_0_]]) -> (tensor<4x256xbf16>, index, index) {
8686
// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg5_]] : index to i32
8787
// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_1_]] : i32
@@ -102,6 +102,6 @@ module {
102102
// CHECK: scf.yield [[VAR_10_]], [[VAR_12_]], [[CST_0_]] : tensor<4x256xbf16>, index, index
103103
// CHECK: }
104104
// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>>
105-
// CHECK: memref.tensor_store [[VAR_4_]]#0, [[VAR_reinterpret_cast_1_]] : memref<4x256xbf16, strided<[?, 1]>>
105+
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in [[VAR_reinterpret_cast_1_]]
106106
// CHECK: return
107107
// CHECK: }

test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,6 @@ module {
8888
// CHECK: }
8989
// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
9090
// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_36]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
91-
// CHECK: memref.tensor_store %[[VAL_38:.*]]#0, %[[VAL_37]] : memref<4x256xbf16, strided<[1, ?], offset: ?>>
91+
// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in %[[VAL_37]]
9292
// CHECK: return
9393
// CHECK: }

test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ module {
6565
// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<256x256xbf16>
6666
// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16>
6767
// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<256x256xbf16>
68-
// CHECK: memref.tensor_store %[[VAL_16]], %[[VAL_14]] : memref<256x256xbf16, strided<[?, 1], offset: ?>>
68+
// CHECK: bufferization.materialize_in_destination %[[VAL_16]] in %[[VAL_14]]
6969
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : index
7070
// CHECK: scf.yield %[[VAL_17]] : index
7171
// CHECK: }

test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ module {
5555
// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16>
5656
// CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16>
5757
// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16>
58-
// CHECK: memref.tensor_store %[[VAL_24]], %[[VAL_19]] : memref<256xbf16, strided<[?], offset: ?>>
58+
// CHECK: bufferization.materialize_in_destination %[[VAL_24]] in %[[VAL_19]]
5959
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index
6060
// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
6161
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index

test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ module {
9191
// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16>
9292
// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16>
9393
// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16>
94-
// CHECK: memref.tensor_store %[[VAL_15]], %[[VAL_13]] : memref<256xbf16, strided<[?], offset: ?>>
94+
// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in %[[VAL_13]]
9595
// CHECK: scf.yield %[[VAL_12]] : index
9696
// CHECK: }
9797
// CHECK: return

test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ module {
4646
// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16>
4747
// CHECK: memref.copy %[[VAL_12]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16>
4848
// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16>
49-
// CHECK: memref.tensor_store %[[VAL_15]], %[[VAL_12]] : memref<256xbf16, strided<[?], offset: ?>>
49+
// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in %[[VAL_12]]
5050
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
5151
// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_16]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
5252
// CHECK: scf.yield %[[VAL_17]], %[[VAL_16]] : memref<256xbf16, strided<[?], offset: ?>>, index

test/Conversion/TritonToLinalg/addptr_loopback.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ module {
4848
// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<4x256xbf16>
4949
// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
5050
// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<4x256xbf16>
51-
// CHECK: memref.tensor_store %[[VAL_12]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>>
51+
// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in %[[VAL_10]]
5252
// CHECK: return
5353
// CHECK: }

test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,6 @@ module {
4444
// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1024xbf16>
4545
// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16>
4646
// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16>
47-
// CHECK: memref.tensor_store %[[VAL_14]], %[[VAL_12]] : memref<1024xbf16, strided<[1], offset: ?>>
47+
// CHECK: bufferization.materialize_in_destination %[[VAL_14]] in %[[VAL_12]]
4848
// CHECK: return
4949
// CHECK: }

test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ module {
4646
// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xbf16>
4747
// CHECK: memref.copy %[[VAL_15]], %[[VAL_18]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16>
4848
// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xbf16>
49-
// CHECK: memref.tensor_store %[[VAL_19]], %[[VAL_17]] : memref<1024xbf16, strided<[1], offset: ?>>
49+
// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in %[[VAL_17]]
5050
// CHECK: return
5151
// CHECK: }

test/Conversion/TritonToLinalg/addptr_nested.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ module {
6868
// CHECK: %[[VAL_26:.*]] = arith.index_cast %[[VAL_1]] : i32 to index
6969
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index
7070
// CHECK: %[[VAL_28:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_27]]], sizes: [4, 256], strides: [3, %[[VAL_5]]] : memref<*xbf16> to memref<4x256xbf16, strided<[3, ?], offset: ?>>
71-
// CHECK: memref.tensor_store %[[VAL_29:.*]], %[[VAL_28]] : memref<4x256xbf16, strided<[3, ?], offset: ?>>
71+
// CHECK: bufferization.materialize_in_destination %[[VAL_29:.*]] in %[[VAL_28]]
7272
// CHECK: return
7373
// CHECK: }

test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ module {
3838
// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<256x128xbf16>
3939
// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16>
4040
// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<256x128xbf16>
41-
// CHECK: memref.tensor_store %[[VAL_9]], %[[VAL_7]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>>
41+
// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in %[[VAL_7]]
4242
// CHECK: return
4343
// CHECK: }

test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ module {
6060
// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32
6161
// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index
6262
// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>>
63-
// CHECK: memref.tensor_store %[[VAL_20:.*]], %[[VAL_19]] : memref<1024x1024xf32, strided<[1, 1], offset: ?>>
63+
// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in %[[VAL_19]]
6464
// CHECK: return
6565
// CHECK: }

test/Conversion/TritonToLinalg/addptr_scalar_for.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,6 @@ module {
6565
// CHECK: %[[VAL_35:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32
6666
// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_35]] : i32 to index
6767
// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_36]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
68-
// CHECK: memref.tensor_store %[[VAL_38:.*]]#0, %[[VAL_37]] : memref<1024xf32, strided<[1], offset: ?>>
68+
// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in %[[VAL_37]]
6969
// CHECK: return
7070
// CHECK: }

test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,6 @@ module {
8787
// CHECK: %[[VAL_38:.*]] = arith.index_cast %[[VAL_37]] : i32 to index
8888
// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_8]] : index
8989
// CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_39]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>>
90-
// CHECK: memref.tensor_store %[[VAL_41:.*]]#0, %[[VAL_40]] : memref<128x128xf32, strided<[1, 1], offset: ?>>
90+
// CHECK: bufferization.materialize_in_destination %[[VAL_41:.*]]#0 in %[[VAL_40]]
9191
// CHECK: return
9292
// CHECK: }

test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ module {
5252
// CHECK: %[[VAL_23:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32
5353
// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index
5454
// CHECK: %[[VAL_25:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_24]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
55-
// CHECK: memref.tensor_store %[[VAL_26:.*]], %[[VAL_25]] : memref<1024xf32, strided<[1], offset: ?>>
55+
// CHECK: bufferization.materialize_in_destination %[[VAL_26:.*]] in %[[VAL_25]]
5656
// CHECK: return
5757
// CHECK: }

test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ module {
4040
// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32
4141
// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index
4242
// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
43-
// CHECK: memref.tensor_store %[[VAL_20:.*]], %[[VAL_19]] : memref<1024xf32, strided<[1], offset: ?>>
43+
// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in %[[VAL_19]]
4444
// CHECK: return
4545
// CHECK: }

test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@ module {
5151
// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : i32 to index
5252
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index
5353
// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>>
54-
// CHECK: memref.tensor_store %[[VAL_23:.*]], %[[VAL_22]] : memref<128x128xf32, strided<[1, 1], offset: ?>>
54+
// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in %[[VAL_22]]
5555
// CHECK: return
5656
// CHECK: }

test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ module {
3434
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32
3535
// CHECK: linalg.yield %[[VAL_15]] : i32
3636
// CHECK: } -> tensor<1024xi32>
37-
// CHECK: memref.tensor_store %[[VAL_16:.*]], %[[VAL_6]] : memref<1024xi32, strided<[1]>>
37+
// CHECK: bufferization.materialize_in_destination %[[VAL_16:.*]] in %[[VAL_6]]
3838
// CHECK: return
3939
// CHECK: }

test/Conversion/TritonToLinalg/bitcast.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ module {
3737
// CHECK: [[VAR_5_:%.+]] = arith.bitcast %in : i32 to f32
3838
// CHECK: linalg.yield [[VAR_5_]] : f32
3939
// CHECK: } -> tensor<1024xf32>
40-
// CHECK: memref.tensor_store [[VAR_2_]], [[RC_0_]] : memref<1024xf32, strided<[1]>>
40+
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in [[RC_0_]]
4141
// CHECK: return
4242
// CHECK: }
4343
// CHECK: }

test/Conversion/TritonToLinalg/block_ptr_advance.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ module {
8787
// CHECK: %15 = arith.muli %11, %13 : index
8888
// CHECK: %16 = arith.addi %14, %15 : index
8989
// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [%16], sizes: [128, 64], strides: [%12, %13] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
90-
// CHECK: memref.tensor_store %8#0, %reinterpret_cast_1 : memref<128x64xbf16, strided<[?, ?], offset: ?>>
90+
// CHECK: bufferization.materialize_in_destination %8#0 in %reinterpret_cast_1
9191
// CHECK: return
9292
// CHECK: }
9393
// CHECK: }

test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,6 @@ module {
6767
// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_44]], %[[VAL_45]], %[[VAL_46]] : f32
6868
// CHECK: linalg.yield %[[VAL_48]] : f32
6969
// CHECK: } -> tensor<1024xf32>
70-
// CHECK: memref.tensor_store %[[VAL_49:.*]], %[[VAL_2]] : memref<1024xf32>
70+
// CHECK: bufferization.materialize_in_destination %[[VAL_49:.*]] in %[[VAL_2]]
7171
// CHECK: return
7272
// CHECK: }

test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,6 @@ module {
4444
// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32
4545
// CHECK: linalg.yield %[[VAL_21]] : f32
4646
// CHECK: } -> tensor<1024xf32>
47-
// CHECK: memref.tensor_store %[[VAL_22:.*]], %[[VAL_3]] : memref<1024xf32>
47+
// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in %[[VAL_3]]
4848
// CHECK: return
4949
// CHECK: }

test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ module {
7979
// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32
8080
// CHECK: linalg.yield %[[VAL_42]] : f32
8181
// CHECK: } -> tensor<1024xf32>
82-
// CHECK: memref.tensor_store %[[VAL_43:.*]], %[[VAL_3]] : memref<1024xbf16>
83-
// CHECK: memref.tensor_store %[[VAL_44:.*]], %[[VAL_4]] : memref<1024xf32>
84-
// CHECK: memref.tensor_store %[[VAL_45:.*]], %[[VAL_5]] : memref<1024xf32>
85-
// CHECK: memref.tensor_store %[[VAL_46:.*]], %[[VAL_6]] : memref<1024xf32>
86-
// CHECK: memref.tensor_store %[[VAL_47:.*]], %[[VAL_7]] : memref<1024xf32>
82+
// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in %[[VAL_3]]
83+
// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in %[[VAL_4]]
84+
// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in %[[VAL_5]]
85+
// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in %[[VAL_6]]
86+
// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in %[[VAL_7]]
8787
// CHECK: return
8888
// CHECK: }

test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ module {
4949
// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32
5050
// CHECK: linalg.yield %[[VAL_22]] : f32
5151
// CHECK: } -> tensor<128x128xf32>
52-
// CHECK: memref.tensor_store %[[VAL_23:.*]], %[[VAL_2]] : memref<128x128xf32>
53-
// CHECK: memref.tensor_store %[[VAL_24:.*]], %[[VAL_3]] : memref<128x128xf32>
52+
// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in %[[VAL_2]]
53+
// CHECK: bufferization.materialize_in_destination %[[VAL_24:.*]] in %[[VAL_3]]
5454
// CHECK: return
5555
// CHECK: }

test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ module {
5050
// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32
5151
// CHECK: linalg.yield %[[VAL_21]] : f32
5252
// CHECK: } -> tensor<128x128xf32>
53-
// CHECK: memref.tensor_store %[[VAL_22:.*]], %[[VAL_3]] : memref<128x128xf32>
53+
// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in %[[VAL_3]]
5454
// CHECK: return
5555
// CHECK: }

0 commit comments

Comments
 (0)