-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Shape] Support >2 args in shape.broadcast
folder
#126808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { | |||||
return getShapes().front(); | ||||||
} | ||||||
|
||||||
// TODO: Support folding with more than 2 input shapes | ||||||
if (getShapes().size() > 2) | ||||||
if (!adaptor.getShapes().front()) | ||||||
return nullptr; | ||||||
|
||||||
if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1]) | ||||||
return nullptr; | ||||||
auto lhsShape = llvm::to_vector<6>( | ||||||
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0]) | ||||||
.getValues<int64_t>()); | ||||||
auto rhsShape = llvm::to_vector<6>( | ||||||
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1]) | ||||||
SmallVector<int64_t, 6> resultShape( | ||||||
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front()) | ||||||
.getValues<int64_t>()); | ||||||
SmallVector<int64_t, 6> resultShape; | ||||||
|
||||||
// If the shapes are not compatible, we can't fold it. | ||||||
// TODO: Fold to an "error". | ||||||
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) | ||||||
return nullptr; | ||||||
for (auto next : adaptor.getShapes().drop_front()) { | ||||||
if (!next) | ||||||
return nullptr; | ||||||
auto nextShape = llvm::to_vector<6>( | ||||||
llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>()); | ||||||
|
||||||
SmallVector<int64_t, 6> tmpShape; | ||||||
// If the shapes are not compatible, we can't fold it. | ||||||
// TODO: Fold to an "error". | ||||||
if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape)) | ||||||
return nullptr; | ||||||
|
||||||
resultShape.clear(); | ||||||
std::copy(tmpShape.begin(), tmpShape.end(), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this what clang-format produced? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jpienaar Yes, that's correct - it was produced by a llvm-project/clang/include/clang/Lex/MacroInfo.h Lines 536 to 537 in 74ca579
|
||||||
std::back_inserter(resultShape)); | ||||||
} | ||||||
|
||||||
Builder builder(getContext()); | ||||||
return builder.getIndexTensorAttr(resultShape); | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the
getBroadcastedShape
implementation shape vector size is hardcoded to6
, so I did it similarly here. Does it make sense? Looks like an arbitrary value from the outside.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).