Skip to content

Commit e47b507

Browse files
Simon Camphausenmgehre-amd
Simon Camphausen
andauthored
[mlir][EmitC] Model lvalues as a type in EmitC (#91475)
This adds an `emitc.lvalue` type which models assignable lvlaues in the type system. Operations modifying memory are restricted to this type accordingly. See also the discussion on [discourse](https://discourse.llvm.org/t/rfc-separate-variables-from-ssa-values-in-emitc/75224/9). The most notable changes are as follows. - `emitc.variable` and `emitc.global` ops are restricted to return `emitc.array` or `emitc.lvalue` types - Taking the address of a value is restricted to operands with lvalue type - Conversion from lvalues into SSA values is done with the new `emitc.load` op - The var operand of the `emitc.assign` op is restricted to lvalue type - The result of the `emitc.subscript` and `emitc.get_global` ops is a lvalue type - The operands and results of the `emitc.member` and `emitc.member_of_ptr` ops are restricted to lvalue types --------- Co-authored-by: Matthias Gehre <[email protected]>
1 parent 3c53745 commit e47b507

27 files changed

+810
-382
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,17 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
8888

8989
```mlir
9090
// Custom form of applying the & operator.
91-
%0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
91+
%0 = emitc.apply "&"(%arg0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
9292

9393
// Generic form of the same operation.
9494
%0 = "emitc.apply"(%arg0) {applicableOperator = "&"}
95-
: (i32) -> !emitc.ptr<i32>
95+
: (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
9696

9797
```
9898
}];
9999
let arguments = (ins
100100
Arg<StrAttr, "the operator to apply">:$applicableOperator,
101-
EmitCType:$operand
101+
AnyTypeOf<[EmitCType, EmitC_LValueType]>:$operand
102102
);
103103
let results = (outs EmitCType:$result);
104104
let assemblyFormat = [{
@@ -836,6 +836,35 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
836836
let assemblyFormat = "operands attr-dict `:` type(operands)";
837837
}
838838

839+
def EmitC_LoadOp : EmitC_Op<"load", [
840+
TypesMatchWith<"result type matches value type of 'operand'",
841+
"operand", "result",
842+
"::llvm::cast<LValueType>($_self).getValueType()">
843+
]> {
844+
let summary = "Load an lvalue into an SSA value.";
845+
let description = [{
846+
This operation loads the content of a modifiable lvalue into an SSA value.
847+
Modifications of the lvalue executed after the load are not observable on
848+
the produced value.
849+
850+
Example:
851+
852+
```mlir
853+
%1 = emitc.load %0 : !emitc.lvalue<i32>
854+
```
855+
```c++
856+
// Code emitted for the operation above.
857+
int32_t v2 = v1;
858+
```
859+
}];
860+
861+
let arguments = (ins
862+
Res<EmitC_LValueType, "", [MemRead<DefaultResource, 0, FullEffect>]>:$operand);
863+
let results = (outs AnyType:$result);
864+
865+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
866+
}
867+
839868
def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
840869
let summary = "Multiplication operation";
841870
let description = [{
@@ -918,15 +947,15 @@ def EmitC_MemberOp : EmitC_Op<"member"> {
918947

919948
```mlir
920949
%0 = "emitc.member" (%arg0) {member = "a"}
921-
: (!emitc.opaque<"mystruct">) -> i32
950+
: (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.lvalue<i32>
922951
```
923952
}];
924953

925954
let arguments = (ins
926955
Arg<StrAttr, "the member to access">:$member,
927-
EmitC_OpaqueType:$operand
956+
EmitC_LValueOf<[EmitC_OpaqueType]>:$operand
928957
);
929-
let results = (outs EmitCType);
958+
let results = (outs EmitC_LValueOf<[EmitCType]>);
930959
}
931960

932961
def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
@@ -939,15 +968,16 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
939968

940969
```mlir
941970
%0 = "emitc.member_of_ptr" (%arg0) {member = "a"}
942-
: (!emitc.ptr<!emitc.opaque<"mystruct">>) -> i32
971+
: (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>)
972+
-> !emitc.lvalue<i32>
943973
```
944974
}];
945975

946976
let arguments = (ins
947977
Arg<StrAttr, "the member to access">:$member,
948-
AnyTypeOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand
978+
EmitC_LValueOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand
949979
);
950-
let results = (outs EmitCType);
980+
let results = (outs EmitC_LValueOf<[EmitCType]>);
951981
}
952982

953983
def EmitC_ConditionalOp : EmitC_Op<"conditional",
@@ -1031,28 +1061,29 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10311061

10321062
```mlir
10331063
// Integer variable
1034-
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
1064+
%0 = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>
10351065

10361066
// Variable emitted as `int32_t* = NULL;`
10371067
%1 = "emitc.variable"() {value = #emitc.opaque<"NULL">}
1038-
: () -> !emitc.ptr<!emitc.opaque<"int32_t">>
1068+
: () -> !emitc.lvalue<!emitc.ptr<!emitc.opaque<"int32_t">>>
10391069
```
10401070

10411071
Since folding is not supported, it can be used with pointers.
10421072
As an example, it is valid to create pointers to `variable` operations
10431073
by using `apply` operations and pass these to a `call` operation.
10441074
```mlir
1045-
%0 = "emitc.variable"() {value = 0 : i32} : () -> i32
1046-
%1 = "emitc.variable"() {value = 0 : i32} : () -> i32
1047-
%2 = emitc.apply "&"(%0) : (i32) -> !emitc.ptr<i32>
1048-
%3 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
1075+
%0 = "emitc.variable"() {value = 0 : i32} : () -> !emitc.lvalue<i32>
1076+
%1 = "emitc.variable"() {value = 0 : i32} : () -> !emitc.lvalue<i32>
1077+
%2 = emitc.apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
1078+
%3 = emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
10491079
emitc.call_opaque "write"(%2, %3)
10501080
: (!emitc.ptr<i32>, !emitc.ptr<i32>) -> ()
10511081
```
10521082
}];
10531083

10541084
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
1055-
let results = (outs EmitCType);
1085+
let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>, "",
1086+
[MemAlloc<DefaultResource, 0, FullEffect>]>);
10561087

10571088
let hasVerifier = 1;
10581089
}
@@ -1118,11 +1149,12 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global",
11181149

11191150
```mlir
11201151
%x = emitc.get_global @foo : !emitc.array<2xf32>
1152+
%y = emitc.get_global @bar : !emitc.lvalue<i32>
11211153
```
11221154
}];
11231155

11241156
let arguments = (ins FlatSymbolRefAttr:$name);
1125-
let results = (outs EmitCType:$result);
1157+
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
11261158
let assemblyFormat = "$name `:` type($result) attr-dict";
11271159
}
11281160

@@ -1172,15 +1204,17 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
11721204

11731205
```mlir
11741206
// Integer variable
1175-
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
1207+
%0 = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>
11761208
%1 = emitc.call_opaque "foo"() : () -> (i32)
11771209

11781210
// Assign emitted as `... = ...;`
1179-
"emitc.assign"(%0, %1) : (i32, i32) -> ()
1211+
"emitc.assign"(%0, %1) : (!emitc.lvalue<i32>, i32) -> ()
11801212
```
11811213
}];
11821214

1183-
let arguments = (ins EmitCType:$var, EmitCType:$value);
1215+
let arguments = (ins
1216+
Res<EmitC_LValueType, "", [MemWrite<DefaultResource, 1, FullEffect>]>:$var,
1217+
EmitCType:$value);
11841218
let results = (outs);
11851219

11861220
let hasVerifier = 1;
@@ -1276,8 +1310,10 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
12761310
```mlir
12771311
%i = index.constant 1
12781312
%j = index.constant 7
1279-
%0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
1280-
%1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
1313+
%0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, index, index)
1314+
-> !emitc.lvalue<f32>
1315+
%1 = emitc.subscript %arg1[%i] : (!emitc.ptr<i32>, index)
1316+
-> !emitc.lvalue<i32>
12811317
```
12821318
}];
12831319
let arguments = (ins Arg<AnyTypeOf<[
@@ -1286,15 +1322,26 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
12861322
EmitC_PointerType]>,
12871323
"the value to subscript">:$value,
12881324
Variadic<EmitCType>:$indices);
1289-
let results = (outs EmitCType:$result);
1325+
let results = (outs EmitC_LValueType:$result);
12901326

12911327
let builders = [
12921328
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
1293-
build($_builder, $_state, array.getType().getElementType(), array, indices);
1329+
build(
1330+
$_builder,
1331+
$_state,
1332+
emitc::LValueType::get(array.getType().getElementType()),
1333+
array,
1334+
indices
1335+
);
12941336
}]>,
12951337
OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
1296-
build($_builder, $_state, pointer.getType().getPointee(), pointer,
1297-
ValueRange{index});
1338+
build(
1339+
$_builder,
1340+
$_state,
1341+
emitc::LValueType::get(pointer.getType().getPointee()),
1342+
pointer,
1343+
ValueRange{index}
1344+
);
12981345
}]>
12991346
];
13001347

@@ -1338,7 +1385,7 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
13381385
emitc.yield
13391386
}
13401387
default {
1341-
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1388+
%3 = "emitc.constant"(){value = 42.0 : f32} : () -> f32
13421389
emitc.call_opaque "func2" (%3) : (f32) -> ()
13431390
}
13441391
```

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
8484
let hasCustomAssemblyFormat = 1;
8585
}
8686

87+
def EmitC_LValueType : EmitC_Type<"LValue", "lvalue"> {
88+
let summary = "EmitC lvalue type";
89+
90+
let description = [{
91+
Values of this type can be assigned to and their address can be taken.
92+
}];
93+
94+
let parameters = (ins "Type":$valueType);
95+
let builders = [
96+
TypeBuilderWithInferredContext<(ins "Type":$valueType), [{
97+
return $_get(valueType.getContext(), valueType);
98+
}]>
99+
];
100+
let assemblyFormat = "`<` qualified($valueType) `>`";
101+
let genVerifyDecl = 1;
102+
}
103+
87104
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
88105
let summary = "EmitC opaque type";
89106

@@ -129,6 +146,7 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
129146
}]>
130147
];
131148
let assemblyFormat = "`<` qualified($pointee) `>`";
149+
let genVerifyDecl = 1;
132150
}
133151

134152
def EmitC_SignedSizeT : EmitC_Type<"SignedSizeT", "ssize_t"> {
@@ -158,4 +176,13 @@ def EmitC_SizeT : EmitC_Type<"SizeT", "size_t"> {
158176
}];
159177
}
160178

179+
class EmitC_LValueOf<list<Type> allowedTypes> :
180+
ContainerType<
181+
AnyTypeOf<allowedTypes>,
182+
CPred<"::llvm::isa<::mlir::emitc::LValueType>($_self)">,
183+
"::llvm::cast<::mlir::emitc::LValueType>($_self).getValueType()",
184+
"emitc.lvalue",
185+
"::mlir::emitc::LValueType"
186+
>;
187+
161188
#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,7 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
137137
auto subscript = rewriter.create<emitc::SubscriptOp>(
138138
op.getLoc(), arrayValue, operands.getIndices());
139139

140-
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
141-
auto var =
142-
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
143-
144-
rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
145-
rewriter.replaceOp(op, var);
140+
rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
146141
return success();
147142
}
148143
};

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,
6363

6464
for (OpResult result : op.getResults()) {
6565
Type resultType = result.getType();
66+
Type varType = emitc::LValueType::get(resultType);
6667
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
6768
emitc::VariableOp var =
68-
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69+
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
6970
resultVariables.push_back(var);
7071
}
7172

@@ -80,6 +81,14 @@ static void assignValues(ValueRange values, SmallVector<Value> &variables,
8081
rewriter.create<emitc::AssignOp>(loc, var, value);
8182
}
8283

84+
SmallVector<Value> loadValues(const SmallVector<Value> &variables,
85+
PatternRewriter &rewriter, Location loc) {
86+
return llvm::map_to_vector<>(variables, [&](Value var) {
87+
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
88+
return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
89+
});
90+
}
91+
8392
static void lowerYield(SmallVector<Value> &resultVariables,
8493
PatternRewriter &rewriter, scf::YieldOp yield) {
8594
Location loc = yield.getLoc();
@@ -126,15 +135,26 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
126135
// Erase the auto-generated terminator for the lowered for op.
127136
rewriter.eraseOp(loweredBody->getTerminator());
128137

138+
IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
139+
rewriter.setInsertionPointToEnd(loweredBody);
140+
141+
SmallVector<Value> iterArgsValues =
142+
loadValues(resultVariables, rewriter, loc);
143+
144+
rewriter.restoreInsertionPoint(ip);
145+
129146
SmallVector<Value> replacingValues;
130147
replacingValues.push_back(loweredFor.getInductionVar());
131-
replacingValues.append(resultVariables.begin(), resultVariables.end());
148+
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
132149

133150
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
134151
lowerYield(resultVariables, rewriter,
135152
cast<scf::YieldOp>(loweredBody->getTerminator()));
136153

137-
rewriter.replaceOp(forOp, resultVariables);
154+
// Load variables into SSA values after the for loop.
155+
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
156+
157+
rewriter.replaceOp(forOp, resultValues);
138158
return success();
139159
}
140160

@@ -174,7 +194,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
174194
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
175195
}
176196

177-
rewriter.replaceOp(ifOp, resultVariables);
197+
rewriter.setInsertionPointAfter(ifOp);
198+
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
199+
200+
rewriter.replaceOp(ifOp, results);
178201
return success();
179202
}
180203

@@ -212,7 +235,10 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
212235
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
213236
loweredSwitch.getDefaultRegion());
214237

215-
rewriter.replaceOp(indexSwitchOp, resultVariables);
238+
rewriter.setInsertionPointAfter(indexSwitchOp);
239+
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
240+
241+
rewriter.replaceOp(indexSwitchOp, results);
216242
return success();
217243
}
218244

0 commit comments

Comments
 (0)