Skip to content

Commit ed224ba

Browse files
zero9178matthias-springer
authored andcommitted
use universal references for map
1 parent 5b49dce commit ed224ba

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,15 @@ struct ConversionValueMapping {
169169
ValueVector lookupOrNull(const ValueVector &from,
170170
TypeRange desiredTypes = {}) const;
171171

172+
template <typename T>
173+
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
174+
172175
/// Map a value to the one provided.
173-
void map(const ValueVector &oldVal, const ValueVector &newVal) {
176+
template <typename OldVal, typename NewVal>
177+
std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}>
178+
map(OldVal &&oldVal, NewVal &&newVal) {
174179
LLVM_DEBUG({
175-
ValueVector next = newVal;
180+
ValueVector next(newVal);
176181
while (true) {
177182
assert(next != oldVal && "inserting cyclic mapping");
178183
auto it = mapping.find(next);
@@ -181,9 +186,22 @@ struct ConversionValueMapping {
181186
next = it->second;
182187
}
183188
});
184-
mapping[oldVal] = newVal;
185189
for (Value v : newVal)
186190
mappedTo.insert(v);
191+
192+
mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
193+
}
194+
195+
template <typename OldVal, typename NewVal>
196+
std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}>
197+
map(OldVal &&oldVal, NewVal &&newVal) {
198+
if constexpr (IsValueVector<OldVal>{}) {
199+
map(std::forward<OldVal>(oldVal), ValueVector{newVal});
200+
} else if constexpr (IsValueVector<NewVal>{}) {
201+
map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
202+
} else {
203+
map(ValueVector{oldVal}, ValueVector{newVal});
204+
}
187205
}
188206

189207
/// Drop the last mapping for the given values.
@@ -1405,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14051423
assert(inputMap->size == 0 &&
14061424
"invalid to provide a replacement value when the argument isn't "
14071425
"dropped");
1408-
mapping.map({origArg}, {repl});
1426+
mapping.map(origArg, repl);
14091427
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
14101428
continue;
14111429
}
@@ -1417,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14171435
auto replArgs =
14181436
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
14191437
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1420-
mapping.map({origArg}, replArgVals);
1438+
mapping.map(origArg, std::move(replArgVals));
14211439
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
14221440
}
14231441

@@ -1447,7 +1465,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14471465
// Avoid materializing an unnecessary cast.
14481466
if (TypeRange(inputs) == outputTypes) {
14491467
if (!valuesToMap.empty())
1450-
mapping.map(valuesToMap, inputs);
1468+
mapping.map(std::move(valuesToMap), inputs);
14511469
return inputs;
14521470
}
14531471

@@ -1501,7 +1519,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
15011519
/*outputType=*/value.getType(),
15021520
/*originalType=*/Type(), converter)
15031521
.front();
1504-
mapping.map({value}, {castValue});
1522+
mapping.map(value, castValue);
15051523
return castValue;
15061524
}
15071525

@@ -1571,7 +1589,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
15711589
// Remap result to replacement value.
15721590
if (repl.empty())
15731591
continue;
1574-
mapping.map({result}, repl);
1592+
mapping.map(result, repl);
15751593
}
15761594

15771595
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1724,7 +1742,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
17241742
});
17251743
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
17261744
impl->currentTypeConverter);
1727-
impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
1745+
impl->mapping.map(impl->mapping.lookupOrDefault({from}), to);
17281746
}
17291747

17301748
Value ConversionPatternRewriter::getRemappedValue(Value key) {

0 commit comments

Comments
 (0)