Skip to content

[MLIR] Preserve Encoding During TensorOp Creation #80871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

manbearian
Copy link
Contributor

Preserve the encoding field when deriving new TensorTypes from existing ones.

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: ian Bearman (manbearian)

Changes

Preserve the encoding field when deriving new TensorTypes from existing ones.


Full diff: https://github.com/llvm/llvm-project/pull/80871.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+28-10)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+7-7)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a5713..8f117d9464f5f4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
     currentDim += dim;
   }
 
-  return RankedTensorType::get(newShape, type.getElementType());
+  auto encoding = type.getEncoding();
+  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
+    auto ignoreError = [&] {
+      auto emitter = mlir::emitError(UnknownLoc::get(type.getContext()));
+      emitter.abandon();
+      return emitter;
+    };
+    if (failed(
+            v.verifyEncoding(newShape, type.getElementType(), ignoreError))) {
+      // strip the encoding if it is not valid for the new shape.
+      encoding = Attribute();
+    }
+  }
+  return RankedTensorType::get(newShape, type.getElementType(), encoding);
 }
 
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1902,7 +1916,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
   assert(static_cast<int64_t>(staticSizes.size()) ==
              sourceTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
-  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+                               sourceTensorType.getEncoding());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(
@@ -1943,7 +1958,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
       if (!dimsToProject.test(pos))
         projectedShape.push_back(shape[pos]);
     inferredType =
-        RankedTensorType::get(projectedShape, inferredType.getElementType());
+        RankedTensorType::get(projectedShape, inferredType.getElementType(),
+                              inferredType.getEncoding());
   }
   return inferredType;
 }
@@ -2663,8 +2679,8 @@ struct InsertSliceOpSourceCastInserter final
     if (!hasValidSizesOffsets(newSrcShape))
       return failure();
 
-    RankedTensorType newSrcType =
-        RankedTensorType::get(newSrcShape, srcType.getElementType());
+    RankedTensorType newSrcType = RankedTensorType::get(
+        newSrcShape, srcType.getElementType(), srcType.getEncoding());
     if (srcType == newSrcType ||
         !preservesStaticInformation(srcType, newSrcType) ||
         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
@@ -2815,7 +2831,8 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
     }
   }
 
-  return RankedTensorType::get(inferredShape, sourceType.getElementType());
+  return RankedTensorType::get(inferredShape, sourceType.getElementType(),
+                               sourceType.getEncoding());
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3601,9 +3618,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
         "tiling factors must equal the number of dimensions to tile");
   }
 
-  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
-                              ? packOrUnPack.getDestType()
-                              : packOrUnPack.getSourceType();
+  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+                                    ? packOrUnPack.getDestType()
+                                    : packOrUnPack.getSourceType();
   size_t packedRank = packedType.getRank();
   // Require output rank to match input rank + number of blocking factors.
   if (unpackedRank + mixedTiles.size() != packedRank) {
@@ -3870,7 +3887,8 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
                                          ArrayRef<int64_t> outerDimsPerm) {
   SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
-  return RankedTensorType::get(resultShape, sourceType.getElementType());
+  return RankedTensorType::get(resultShape, sourceType.getElementType(),
+                               sourceType.getEncoding());
 }
 
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f5338747..dc3b202c8ea9c4 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -122,13 +122,13 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 

@manbearian manbearian force-pushed the users/ianb/tensor-fixes branch from 2fc76a1 to ac7fc12 Compare February 6, 2024 22:56
@manbearian manbearian changed the title Preserve Encoding During TensorOp Creation [MLIR] Preserve Encoding During TensorOp Creation Feb 8, 2024
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
currentDim += dim;
}

return RankedTensorType::get(newShape, type.getElementType());
auto encoding = type.getEncoding();
if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that the conservative option is that if the interface isn't implemented, the encoding should no be propagated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @joker-eph thanks for your response. I think there may be a misunderstanding, as i don't think what you're saying will work, but i'm not sure if its on my part or yours.

From what i read in the code, the previous behavior drops the encoding (not propagated) for all cases when creating a new type. I'm changing this to reuse the encoding on the newly created type. However, the special case test here is that for this particular encoding (the VerifiabletensorEncoding), the encoding may not be possible to propagate, since it is dimensionality specific.

I believe the best possible approach would be to update the encoding based on the new dimensionality, but as this isn't an area of the compiler i'm familiar with and that we don't use in our code base, i'm instead falling back to the exiting behavior of dropping the encoding.

Does this make sense?

I'm happy to discuss more if this approach is not what folks want to see.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general dropping is the "safe" thing to do when we transform code and we don't know if the transformation can preserve the semantics of the encoding.

Now there may be a specific argument here you'd want to make, but it should be made explicit, about why propagating "blindly" an opaque encoding would be correct in the absolute? If my attribute does not implement the "VerifiableTensorEncoding", you can't ensure that it is correct to just propagate it right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit, i'm unfamiliar with VerifiableTensorEncoding and its intended use. (I think i originally wrote up this code before it existed)

you can't ensure that it is correct to just propagate it right?

My thinking (before this discussion) that the Encoding attribute was opaque, so the code shouldn't know anything about them. The encodings used in my compiler are not specific to the shape of the tensor and i need to propagate them in all cases.

I suppose it comes down to the default behavior desired here. Currently default is: "don't propagate encoding" and the intent of my change was to change the default to: "always propagate if possible".

I think your suggesting two things:

  • don't change the default, but instead, only transfer the encoding if it can be proved to be transferable (via the VerifiableTensorEncoding interface).
  • make sure to query verifyEncoding before propagating

Please let me know if i'm getting this correct. I'm okay with doing one or both or neither of these, so please let me know what you think is best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking (before this discussion) that the Encoding attribute was opaque, so the code shouldn't know anything about them

But when something is "opaque" we have two approaches:

  1. We don't do the transformation (because we risk making an incorrect transformation)
  2. We drop the opaque attribute (we don't guarantee propagation, these attributes have to be droppable).

This come down to the guarantee we intend to provide with this attribute though.

don't change the default, but instead, only transfer the encoding if it can be proved to be transferable (via the VerifiableTensorEncoding interface).

Yes: basically we could consider that an attribute which implements VerifiableTensorEncoding is "opt in" into being propagated as long as it verifies.

But even this is sketchy: the "verifier" aspect does not guarantee that the transformation is semantic preserving!

Copy link
Contributor Author

@manbearian manbearian Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping encodings through operations seems exceedingly dangerous and definitely semantically incorrect, at least for the use cases in my compiler.

First, in the Tensor representation, an operation such as "SliceOp" should be producing a tensor that has the same encoding as the original tensor.

Second, if we allow operations to simply drop encodings, at least in my case, bufferization will either fail (since it cannot reconcile the types) or it will produce suboptimal code to convert the types using copies.

Is it possible that i'm completely misusing the encoding field? For our compiler it contains information on how the Tensor will be laid out in memory when allocated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're basically saying we should do:

  1. We don't do the transformation (because we risk making an incorrect transformation)

I don't disagree, I'm not sure we have a good documentation of the intent with respect to this tensor encoding.
(many folks, including me, were a bit concerned when this was introduced that we may be better introducing new types because this seems incredibly unsafe and it makes the codebase hard to get right: this is where we're at now)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants