-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][scf] Extend option to yield replacement for multiple results case #93144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (Yun-Fly) ChangesCurrently, we only have an option to yield replacement for
However, it has no chance to yield replacement for multiple results as followed:
With this method, the original untiled Based on the earlier talk with @MaheshRavishankar in discourse, this PR extends the functionality of yielding replacement for multiple results case. NOTE that, it is still decided by the caller whether need to yield replacement as same as current status. Two major changes:
Considering downstream impact, not sure its better to break down this option and add another new one for current @MaheshRavishankar would you help to review this PR? Thanks. Full diff: https://github.com/llvm/llvm-project/pull/93144.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6d567171e185a..32249b90644a8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -190,10 +190,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
/// where `%0` had other uses as well. If not reconstructed from within the loop
/// body, uses of `%0` could not be replaced, making it still live and the
/// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops);
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 14d775d986d20..5ac8eaac402b2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return tile offset and size of Iteration Domain
+ based on the given tile info from the certain result.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromResultTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$resultNumber,
+ "ArrayRef<OpFoldResult> ":$resultOffsets,
+ "ArrayRef<OpFoldResult> ":$resultSizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f512be46cc13d..b46c8135d1c7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -160,10 +160,11 @@ struct LinalgOpTilingInterface
return success();
}
- FailureOr<TilingResult>
- generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) const {
+ LogicalResult getIterationDomainTileFromResultTile(
+ Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> resultOffsets, ArrayRef<OpFoldResult> resultSizes,
+ SmallVector<OpFoldResult> &iterDomainOffsets,
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the output is a projected
@@ -193,8 +194,27 @@ struct LinalgOpTilingInterface
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition =
cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
+ iterationTileOffsets[dimPosition] = resultOffsets[resultExpr.index()];
+ iterationTileSizes[dimPosition] = resultSizes[resultExpr.index()];
+ }
+
+ iterDomainOffsets = iterationTileOffsets;
+ iterDomainSizes = iterationTileSizes;
+
+ return success();
+ }
+
+ FailureOr<TilingResult>
+ generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+
+ SmallVector<OpFoldResult> iterationTileOffsets, iterationTileSizes;
+ if (failed(tilingInterfaceOp.getIterationDomainTileFromResultTile(
+ b, resultNumber, offsets, sizes, iterationTileOffsets,
+ iterationTileSizes))) {
+ return failure();
}
FailureOr<TilingResult> tilingResult =
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a72dafe725177..ddd0e94f9bd4c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -939,49 +939,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops) {
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ std::optional<ArrayRef<unsigned>> yieldResultNumber) {
if (loops.empty())
return success();
- OpResult fusableProducer = fusedProducerInfo.origProducer;
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
- FailureOr<Value> initValue = tensor::getOrCreateDestination(
- rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
- if (succeeded(initValue)) {
-
- YieldTiledValuesFn newYieldValuesFn =
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
- -> LogicalResult {
- OpBuilder::InsertionGuard g(innerRewriter);
- if (auto tiledDestStyleOp =
- tiledAndFusedProducer
- .getDefiningOp<DestinationStyleOpInterface>()) {
- rewriter.setInsertionPoint(tiledDestStyleOp);
- Value newRegionArg = newRegionIterArgs.back();
+ Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+ *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+ Location loc = originalOwner->getLoc();
+ // a. collect all init Value to be appended
+ ArrayRef<unsigned> initNumberList =
+ yieldResultNumber ? yieldResultNumber.value()
+ : llvm::to_vector(llvm::seq<unsigned>(
+ 0, originalOwner->getNumResults()));
+ SmallVector<Value> initValueList;
+ for (const auto &resultNumber : initNumberList) {
+ FailureOr<Value> initValue = tensor::getOrCreateDestination(
+ rewriter, loc, originalOwner->getResult(resultNumber));
+ if (succeeded(initValue)) {
+ initValueList.push_back(initValue.value());
+ } else {
+ return failure();
+ }
+ }
+
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+
+ // get sliceOp tile information
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+ sliceSizes = sliceOp.getMixedSizes();
+
+ // expect all strides of sliceOp being 1
+ if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 1);
+ }))
+ return failure();
+
+ unsigned sliceResultNumber =
+ fusedProducerInfo.origProducer.getResultNumber();
+
+ auto tilableOp = cast<TilingInterface>(originalOwner);
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+ // skip tensor.pack/unpack/pad, which expects single opResult
+ if (tilableOp->getNumResults() > 1 &&
+ failed(tilableOp.getIterationDomainTileFromResultTile(
+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+ iterDomainOffset, iterDomainSizes))) {
+ return failure();
+ }
+
+ // c. calculate offsets and sizes info of all OpResults respectively based
+ // on iteration Domain Tile
+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+ for (const auto &resultNumber : initNumberList) {
+ if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+ offsetList.push_back(sliceOffset);
+ sizesList.push_back(sliceSizes);
+ } else {
+ assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+ // infer result tile according to the iteration domain tile
+ SmallVector<OpFoldResult> offset, sizes;
+ if (failed(tilableOp.getResultTilePosition(
+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+ offset, sizes))) {
+ return failure();
+ }
+ offsetList.push_back(offset);
+ sizesList.push_back(sizes);
+ }
+ }
+
+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+ if (auto tiledDestStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ for (const auto &&[index, newRegionArg] :
+ llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
- unsigned resultNumber = fusableProducer.getResultNumber();
+ loc, newRegionArg, offsetList[index], sizesList[index],
+ SmallVector<OpFoldResult>(offsetList[index].size(),
+ rewriter.getIndexAttr(1)));
+ unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
- Block *block = rewriter.getInsertionPoint()->getBlock();
- rewriter.setInsertionPoint(block->getTerminator());
- tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
- tiledOffset.emplace_back(sliceOp.getMixedOffsets());
- tiledSizes.emplace_back(sliceOp.getMixedSizes());
- return success();
- };
+ }
- return addInitOperandsToLoopNest(rewriter, loops,
- SmallVector<Value>{initValue.value()},
- newYieldValuesFn);
- }
- return success();
+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
+ // caller
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+ tiledResult.push_back(tiledOwner->getResult(resultNumber));
+ tiledOffset.emplace_back(offsetList[index]);
+ tiledSizes.emplace_back(sizesList[index]);
+ }
+ return success();
+ };
+
+ return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+ newYieldValuesFn);
}
/// Implementation of tile consumer and fuse producer greedily.
@@ -1071,14 +1136,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
continue;
if (yieldReplacement) {
+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
+ // caller can specific which one to yield by designating optional argument
+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+ Operation *fusableProducerOp = fusableProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
return rewriter.notifyMatchFailure(
- fusableProducer.getOwner(), "failed to replacement value for this "
- "oepration from within the tiled loop");
+ fusableProducerOp, "failed to replacement value for this "
+ "operation from within the tiled loop");
+ }
+ for (const auto &result : fusableProducerOp->getResults()) {
+ origValToResultNumber[result] =
+ loops.front()->getNumResults() -
+ (fusableProducerOp->getNumResults() - result.getResultNumber());
}
- origValToResultNumber[fusableProducer] =
- loops.front()->getNumResults() - 1;
}
if (Operation *tiledAndFusedOp =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+ %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
+ %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
+ -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+ %out0, %out1 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j, i)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+ outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+ ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+ %4 = arith.mulf %0, %1 : f32
+ %5 = arith.addf %0, %1 : f32
+ linalg.yield %4, %5: f32, f32
+ } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+ %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+ return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_and_yield %add [16]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME: outs(%[[INIT2_TILE]] :
+// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mostly looks right to me. Just looking at this once more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, looking again this does look right to me. Thanks for the addition!
4235f25
to
4377bf0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also looks very adhoc to me, same as the other PR I just looked at, this seems to want to split the op before fusion instead of adding more code to support ever more complex cases.
Its not always possible to "split the op" I dont know which other PR you are refering to, but as far as I can see this is OK to me. Please provide more targeted feedback to help navigate. |
Could you explain more about |
eb824e1
to
74d925e
Compare
Hi @MaheshRavishankar @nicolasvasilache, is there any update comments? Or shall we merge this patch? |
74d925e
to
eda4bf3
Compare
Hi, @MaheshRavishankar. I have updated document to align with your recent PR #95178 . If there is no new comment, It is planned to merge this patch ASAP in avoid of one more rebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. I think I already approved it. @nicolasvasilache please comment if you have any more comments. If not maybe your comments can be addressed post landing.
@Yun-Fly Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
…ase (llvm#93144) This patch extends the functionality of yielding replacement for multiple results case and adds another optional argument called `yieldResultNumber` indicating which result(s) need yield. If not given, all of results will be yield by default.
Currently, we only have an option to yield replacement for
fusableProducer
like this:However, it has no chance to yield replacement for multiple results as followed:
With this method, the original untiled
op2
will has no uses any more and expect cleaned up later, otherwise leading an unnecessary computation.Based on the earlier talk with @MaheshRavishankar in discourse, this PR extends the functionality of yielding replacement for multiple results case. NOTE that, it is still decided by the caller whether need to yield replacement as same as current status.
Two major changes:
getIterationDomainTileFromResultTile
, which is used to compute other results tile according candidatesliceOp
. BTW, this utility is much similar to another one namedgetIterationDomainTileFromOperandTile
in this PR. I think they can be further unified when finally merged.yieldReplacementForFusedProducer
to deal with multipleOpResult
s and add another optional argument calledyieldResultNumber
indicating which result need yield. If not given, all of results will be yield by default.Considering downstream impact, not sure its better to break down current
yieldReplacement
option and add another new one forfusionControlFn
?@MaheshRavishankar would you help to review this PR? Thanks.