Skip to content

Commit b82ca5f

Browse files
committed
Use makeComposedFoldedAffineMax
1 parent cde70c4 commit b82ca5f

File tree

3 files changed

+31
-48
lines changed

3 files changed

+31
-48
lines changed

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -75,47 +75,30 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
7575
offsetValues[offsetIdx] = indicesVec[i];
7676
offsetValues[offsetIdx + 1] = strides[i];
7777
}
78-
7978
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8079
int64_t scaler = dstBits / srcBits;
80+
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
81+
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
82+
8183
size_t symbolIndex = 0;
82-
SmallVector<Value> values;
84+
SmallVector<OpFoldResult> values;
8385
SmallVector<AffineExpr> productExpressions;
8486
for (unsigned i = 0; i < sourceRank; ++i) {
85-
AffineExpr strideExpr, sizeExpr;
87+
AffineExpr strideExpr = symbols[symbolIndex++];
8688
OpFoldResult stride = strides[i];
87-
OpFoldResult size = sizes[i];
88-
if (auto constantStride = getConstantIntValue(stride)) {
89-
strideExpr = builder.getAffineConstantExpr(*constantStride);
90-
} else {
91-
strideExpr = symbols[symbolIndex++];
92-
values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
93-
}
89+
values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
9490

95-
if (auto constantSize = getConstantIntValue(size)) {
96-
sizeExpr = builder.getAffineConstantExpr(*constantSize);
97-
} else {
98-
sizeExpr = symbols[symbolIndex++];
99-
values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
100-
}
91+
AffineExpr sizeExpr = symbols[symbolIndex++];
92+
OpFoldResult size = sizes[i];
93+
values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
10194

10295
productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
10396
}
10497
AffineMap maxMap = AffineMap::get(
10598
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
10699
builder.getContext());
107-
108-
OpFoldResult linearizedSize;
109-
Value totalSize =
110-
builder.createOrFold<affine::AffineMaxOp>(loc, maxMap, values);
111-
if (auto constantSize = getConstantIntValue(totalSize)) {
112-
linearizedSize = builder.getIndexAttr(*constantSize);
113-
} else {
114-
linearizedSize = totalSize;
115-
}
116-
117-
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
118-
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
100+
OpFoldResult linearizedSize =
101+
affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
119102

120103
// Adjust baseOffset by the scale factor (dstBits / srcBits).
121104
AffineExpr s0;

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
104104
%1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
105105
return %1 : i4
106106
}
107-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
107+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
108108
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
109109
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
110110
// CHECK: func @memref_load_i4_dynamic(
111111
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
112112
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
113113
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
114114
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
115-
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
115+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
116116
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
117117
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
118118
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,15 +122,15 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
122122
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
123123
// CHECK: return %[[TRUNC]]
124124

125-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
125+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
126126
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
127127
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
128128
// CHECK32: func @memref_load_i4_dynamic(
129129
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
130130
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
131131
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
132132
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
133-
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
133+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
134134
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
135135
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
136136
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
399399
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
400400
return
401401
}
402-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
402+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
403403
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
404404
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
405405
// CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
408408
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
409409
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
410410
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
411-
// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
411+
// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
412412
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
413413
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
414414
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
423423
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
424424
// CHECK: return
425425

426-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
426+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
427427
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
428428
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
429429
// CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
432432
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
433433
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
434434
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
435-
// CHECK32-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
435+
// CHECK32-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
436436
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
437437
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
438438
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
5858
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
5959
return %1 : vector<8xi4>
6060
}
61-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
61+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
6262
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
6363
// CHECK: func.func @vector_load_i4_dynamic(
6464
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
6565
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
6666
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
6767
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
68-
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
68+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
6969
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
7070
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
7171
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
7272
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
7373

74-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
74+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
7575
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
7676
// CHECK32: func.func @vector_load_i4_dynamic(
7777
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
7878
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
7979
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
8080
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
81-
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
81+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
8282
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
8383
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
8484
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,29 +450,29 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
450450
return
451451
}
452452

453-
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
453+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
454454
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
455455
// CHECK: func @vector_store_i4_dynamic
456456
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
457457
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
458458
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
459459
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
460460
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
461-
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
461+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
462462
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
463463
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
464464
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
465465
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
466466

467-
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
467+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
468468
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
469469
// CHECK32: func @vector_store_i4_dynamic
470470
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
471471
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
472472
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
473473
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
474474
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
475-
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
475+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
476476
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
477477
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
478478
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
@@ -537,7 +537,7 @@ func.func @vector_maskedstore_i4(
537537
// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
538538
// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
539539

540-
// CHECK-LABEL: func.func @vector_maskedstore_i4(
540+
// CHECK: func.func @vector_maskedstore_i4(
541541
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
542542
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
543543
// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
@@ -557,7 +557,7 @@ func.func @vector_maskedstore_i4(
557557
// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
558558
// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
559559

560-
// CHECK32-LABEL: func.func @vector_maskedstore_i4(
560+
// CHECK32: func.func @vector_maskedstore_i4(
561561
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
562562
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
563563
// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
@@ -623,7 +623,7 @@ func.func @vector_maskedstore_i4_constant_mask(
623623
}
624624

625625
// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
626-
// CHECK-LABEL: func.func @vector_maskedstore_i4_constant_mask(
626+
// CHECK: func.func @vector_maskedstore_i4_constant_mask(
627627
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
628628
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
629629
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -639,7 +639,7 @@ func.func @vector_maskedstore_i4_constant_mask(
639639
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
640640

641641
// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
642-
// CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask(
642+
// CHECK32: func.func @vector_maskedstore_i4_constant_mask(
643643
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
644644
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
645645
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {

0 commit comments

Comments
 (0)