Skip to content

Commit 4d44efa

Browse files
pifon2aTensorFlow MLIR Team
authored and
TensorFlow MLIR Team
committed
Move vectorize_copy.cc and copy_removal.cc passes out of gml_st.
They don't use gml_st dialect. PiperOrigin-RevId: 585099478
1 parent 0378129 commit 4d44efa

File tree

11 files changed

+38
-51
lines changed

11 files changed

+38
-51
lines changed

BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,11 +1224,13 @@ cc_library(
12241224
"transforms/detensorize_scf_ops.cc",
12251225
"transforms/generic_host_to_llvm.cc",
12261226
"transforms/lower_index_cast_pass.cc",
1227+
"transforms/naive_copy_removal.cc",
12271228
"transforms/propagate_static_shapes_to_kernel.cc",
12281229
"transforms/test_hlo_transform_dialect_interpreter.cc",
12291230
"transforms/tile_loops_pass.cc",
12301231
"transforms/unbufferize_pass.cc",
12311232
"transforms/unroll_loops.cc",
1233+
"transforms/vectorize_copy.cc",
12321234
],
12331235
hdrs = [
12341236
"transforms/passes.h",
@@ -1286,6 +1288,7 @@ cc_library(
12861288
"@llvm-project//mlir:MemRefDialect",
12871289
"@llvm-project//mlir:MemRefToLLVM",
12881290
"@llvm-project//mlir:MemRefTransforms",
1291+
"@llvm-project//mlir:MemRefUtils",
12891292
"@llvm-project//mlir:NVVMDialect",
12901293
"@llvm-project//mlir:PDLDialect",
12911294
"@llvm-project//mlir:Pass",
@@ -1481,7 +1484,6 @@ cc_library(
14811484
"gml_st/transforms/collapse_shape/collapse_shape.cc",
14821485
"gml_st/transforms/collect_stats/collect_stats.cc",
14831486
"gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc",
1484-
"gml_st/transforms/copy_removal/copy_removal.cc",
14851487
"gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc",
14861488
"gml_st/transforms/cpu_tiling/fusion_outlining.cc",
14871489
"gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc",
@@ -1505,7 +1507,6 @@ cc_library(
15051507
"gml_st/transforms/transforms.h",
15061508
"gml_st/transforms/vectorization/lower_vectors.cc",
15071509
"gml_st/transforms/vectorization/vectorization.cc",
1508-
"gml_st/transforms/vectorization/vectorize_copy.cc",
15091510
"gml_st/transforms/vectorization/vectorize_for_cpu.cc",
15101511
"gml_st/utils/linalg_utils.cc",
15111512
"gml_st/utils/tensor_utils.cc",

gml_st/transforms/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ add_mlir_library(GmlStPasses
3232
collapse_shape/collapse_shape.cc
3333
collect_stats/collect_stats.cc
3434
compose_extract_insert_slice/compose_extract_insert_slice.cc
35-
copy_removal/copy_removal.cc
3635
cpu_tiling/cpu_tiling_pipeline.cc
3736
cpu_tiling/fusion_outlining.cc
3837
cpu_tiling/fusion_planning_for_cpu.cc
@@ -54,7 +53,6 @@ add_mlir_library(GmlStPasses
5453
tiling_softmax/tiling_softmax.cc
5554
vectorization/lower_vectors.cc
5655
vectorization/vectorization.cc
57-
vectorization/vectorize_copy.cc
5856
vectorization/vectorize_for_cpu.cc
5957

6058
DEPENDS

gml_st/transforms/passes.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,6 @@ createComposeExtractInsertSlicePass();
6161
std::unique_ptr<OperationPass<func::FuncOp>> createVectorizeForCPUPass(
6262
int64_t numElementsThreshold = 1024);
6363

64-
/// Pass to vectorize `memref.copy`.
65-
std::unique_ptr<OperationPass<func::FuncOp>> createVectorizeCopyPass(
66-
int64_t numElementsThreshold = 8);
67-
68-
/// Pass to remove redundant `memref.copy` ops.
69-
std::unique_ptr<OperationPass<func::FuncOp>> createNaiveCopyRemovalPass();
70-
7164
/// Pass to gradually lower vector ops to SCF.
7265
std::unique_ptr<OperationPass<func::FuncOp>> createLowerVectorsPass(
7366
bool enableAVX2 = true, bool flatten = false);

gml_st/transforms/passes.td

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,6 @@ def VectorizeForCPUPass : Pass<"vectorize-for-cpu", "mlir::func::FuncOp"> {
7070
];
7171
}
7272

73-
def VectorizeCopyPass : Pass<"vectorize-copy", "mlir::func::FuncOp"> {
74-
let summary = "Pass to vectorize `memref.copy`.";
75-
let constructor = "::mlir::gml_st::createVectorizeCopyPass()";
76-
let dependentDialects = [
77-
"scf::SCFDialect",
78-
"vector::VectorDialect",
79-
];
80-
let options = [
81-
Option<"numElementsThreshold", "num-elements-threshold", "int64_t",
82-
/*default=*/"8",
83-
"Max number of elements in src and dst memref for a copy to be "
84-
"vectorized.">,
85-
];
86-
}
87-
88-
def NaiveCopyRemovalPass : Pass<"naive-copy-removal", "mlir::func::FuncOp"> {
89-
let summary = "Pass to remove redundant `memref.copy` ops.";
90-
let constructor = "::mlir::gml_st::createNaiveCopyRemovalPass()";
91-
let dependentDialects = ["::mlir::memref::MemRefDialect"];
92-
}
93-
9473
def LowerVectorsPass : Pass<"lower-vectors", "mlir::func::FuncOp"> {
9574
let summary = "Pass to lower vector operations progressively.";
9675
let constructor = "::mlir::gml_st::createLowerVectorsPass()";

tests/Dialect/gml_st/vectorize_copy.mlir renamed to tests/vectorize_copy.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-hlo-opt %s --vectorize-copy="num-elements-threshold=8" --split-input-file | FileCheck %s
1+
// RUN: mlir-hlo-opt %s --vectorize-copy --split-input-file | FileCheck %s
22

33
func.func @vectorize_copy(%arg: memref<2x2xf32>) -> memref<2x2xf32> {
44
%subview = memref.subview %arg[0, 0] [2, 2] [1, 1] : memref<2x2xf32> to memref<2x2xf32, strided<[16, 1]>>

transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ add_mlir_library(MLIRBufferTransforms
3131
detensorize_scf_ops.cc
3232
generic_host_to_llvm.cc
3333
lower_index_cast_pass.cc
34+
naive_copy_removal.cc
3435
propagate_static_shapes_to_kernel.cc
3536
test_hlo_transform_dialect_interpreter.cc
3637
tile_loops_pass.cc
38+
vectorize_copy.cc
3739
unbufferize_pass.cc
3840
unroll_loops.cc
3941

gml_st/transforms/copy_removal/copy_removal.cc renamed to transforms/naive_copy_removal.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -16,16 +16,18 @@ limitations under the License.
1616
#include <memory>
1717
#include <utility>
1818

19-
#include "gml_st/transforms/passes.h"
19+
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Pass/Pass.h"
2223
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
#include "transforms/passes.h"
2325

24-
namespace mlir::gml_st {
26+
namespace mlir {
2527
namespace {
2628

2729
#define GEN_PASS_DEF_NAIVECOPYREMOVALPASS
28-
#include "gml_st/transforms/passes.h.inc"
30+
#include "transforms/passes.h.inc"
2931

3032
/// Remove memref::CopyOp whose target (can be either a memref::SubViewOp or
3133
/// memref::AllocOp) has no other users.
@@ -88,4 +90,4 @@ std::unique_ptr<OperationPass<func::FuncOp>> createNaiveCopyRemovalPass() {
8890
return std::make_unique<NaiveCopyRemovalPass>();
8991
}
9092

91-
} // namespace mlir::gml_st
93+
} // namespace mlir

transforms/passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ using BufferizePatternsCallback = std::function<void(
4949
#define GEN_PASS_DECL_PROPAGATESTATICSHAPESTOKERNELPASS
5050
#define GEN_PASS_DECL_TILELOOPSPASS
5151
#define GEN_PASS_DECL_GENERICHOSTTOLLVMPASS
52+
#define GEN_PASS_DECL_VECTORIZECOPYPASS
5253
#include "transforms/passes.h.inc"
5354

5455
/// Creates a pass that merges smaller buffer into bigger buffer to optimize
@@ -99,6 +100,12 @@ std::unique_ptr<OperationPass<func::FuncOp>> createTileLoopsPass(
99100
// and scf.if.
100101
std::unique_ptr<OperationPass<func::FuncOp>> createDetensorizeScfOpsPass();
101102

103+
/// Pass to remove redundant `memref.copy` ops.
104+
std::unique_ptr<OperationPass<func::FuncOp>> createNaiveCopyRemovalPass();
105+
106+
/// Pass to vectorize `memref.copy`.
107+
std::unique_ptr<OperationPass<func::FuncOp>> createVectorizeCopyPass();
108+
102109
/// Registers the test pass for erasing transform dialect ops.
103110
void registerTestHloTransformDialectEraseSchedulePass();
104111

transforms/passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,16 @@ def AllocToArgPass : Pass<"alloc-to-arg", "mlir::func::FuncOp"> {
172172
let constructor = "hlo::createAllocToArgPass()";
173173
}
174174

175+
def NaiveCopyRemovalPass : Pass<"naive-copy-removal", "mlir::func::FuncOp"> {
176+
let summary = "Pass to remove redundant `memref.copy` ops.";
177+
let constructor = "createNaiveCopyRemovalPass()";
178+
let dependentDialects = ["memref::MemRefDialect"];
179+
}
180+
181+
def VectorizeCopyPass : Pass<"vectorize-copy", "mlir::func::FuncOp"> {
182+
let summary = "Pass to vectorize `memref.copy`.";
183+
let constructor = "createVectorizeCopyPass()";
184+
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
185+
}
186+
175187
#endif // TENSORFLOW_COMPILER_MLIR_HLO_TRANSFORMS_PASSES

gml_st/transforms/vectorization/vectorize_copy.cc renamed to transforms/vectorize_copy.cc

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,21 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include <algorithm>
17-
#include <limits>
1817
#include <memory>
19-
#include <optional>
2018
#include <utility>
2119

22-
#include "gml_st/transforms/passes.h"
23-
#include "gml_st/transforms/vectorization/vectorization.h"
20+
#include "mlir/Dialect/Func/IR/FuncOps.h"
2421
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2522
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2623
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24+
#include "mlir/Pass/Pass.h"
2725
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2826

2927
namespace mlir {
30-
namespace gml_st {
3128
namespace {
3229

3330
#define GEN_PASS_DEF_VECTORIZECOPYPASS
34-
#include "gml_st/transforms/passes.h.inc"
31+
#include "transforms/passes.h.inc"
3532

3633
/// Transforms a big non-contiguous `memref.copy` into a loop over smaller
3734
/// copies that are either contiguous or can be vectorized.
@@ -217,7 +214,7 @@ struct VectorizeCopyPass
217214

218215
RewritePatternSet patterns(ctx);
219216
patterns.add<TileCopyPattern, CopyVectorizationPattern>(
220-
ctx, numElementsThreshold);
217+
ctx, /*numElementsThreshold = */ 8);
221218
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
222219
return signalPassFailure();
223220
}
@@ -226,12 +223,8 @@ struct VectorizeCopyPass
226223

227224
} // namespace
228225

229-
std::unique_ptr<OperationPass<func::FuncOp>> createVectorizeCopyPass(
230-
int64_t numElementsThreshold) {
231-
VectorizeCopyPassOptions opts;
232-
opts.numElementsThreshold = numElementsThreshold;
233-
return std::make_unique<VectorizeCopyPass>(opts);
226+
std::unique_ptr<OperationPass<func::FuncOp>> createVectorizeCopyPass() {
227+
return std::make_unique<VectorizeCopyPass>();
234228
}
235229

236-
} // namespace gml_st
237230
} // namespace mlir

0 commit comments

Comments
 (0)