Skip to content

Commit 9a62f7a

Browse files
author
Ferdinand Lemaire
authored
Merge pull request #8 from Xilinx/ferdinand.FXML-1303_linearRelu
Add linearRelu op to linalg structured ops
2 parents 9eafef6 + 0fa6b0b commit 9a62f7a

File tree

4 files changed

+207
-10
lines changed

4 files changed

+207
-10
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5766,3 +5766,120 @@ structured_op: !LinalgStructuredOpConfig
57665766
scalar_const: '2.3283063999999999E-10 : f64'
57675767
- !ScalarExpression
57685768
scalar_arg: min
5769+
--- !LinalgOpConfig
5770+
metadata: !LinalgOpMetadata
5771+
name: linear_relu
5772+
cpp_class_name: LinearReluOp
5773+
doc: |-
5774+
Performs a linear/fully-connected + relu operation
5775+
5776+
This is a long description that I'll fill later
5777+
5778+
Layout:
5779+
* I: WH (Input)
5780+
* W: WH (Weights)
5781+
* B: H (Bias)
5782+
structured_op: !LinalgStructuredOpConfig
5783+
args:
5784+
- !LinalgOperandDefConfig
5785+
name: I
5786+
kind: input_tensor
5787+
type_var: T1
5788+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
5789+
- !LinalgOperandDefConfig
5790+
name: W
5791+
kind: input_tensor
5792+
type_var: T1
5793+
shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
5794+
- !LinalgOperandDefConfig
5795+
name: B
5796+
kind: input_tensor
5797+
type_var: T1
5798+
shape_map: affine_map<()[s0, s1, s2] -> (s2)>
5799+
- !LinalgOperandDefConfig
5800+
name: O
5801+
kind: output_tensor
5802+
type_var: T1
5803+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
5804+
indexing_maps: !LinalgIndexingMapsConfig
5805+
static_indexing_maps:
5806+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
5807+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
5808+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2)>
5809+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
5810+
iterator_types:
5811+
- parallel
5812+
- reduction
5813+
- parallel
5814+
assignments:
5815+
- !ScalarAssign
5816+
arg: O
5817+
value: !ScalarExpression
5818+
scalar_fn:
5819+
kind: binary
5820+
fn_name: add
5821+
operands:
5822+
- !ScalarExpression
5823+
scalar_arg: O
5824+
- !ScalarExpression
5825+
scalar_fn:
5826+
kind: binary
5827+
fn_name: add
5828+
operands:
5829+
- !ScalarExpression
5830+
scalar_fn:
5831+
kind: binary
5832+
fn_name: mul
5833+
operands:
5834+
- !ScalarExpression
5835+
scalar_arg: I
5836+
- !ScalarExpression
5837+
scalar_arg: W
5838+
- !ScalarExpression
5839+
scalar_arg: B
5840+
--- !LinalgOpConfig
5841+
metadata: !LinalgOpMetadata
5842+
name: relu_nc
5843+
cpp_class_name: ReluNcOp
5844+
doc: |-
5845+
Applies the ReLU activation function to every value in the tensor.
5846+
5847+
Layout:
5848+
* Input: NC
5849+
structured_op: !LinalgStructuredOpConfig
5850+
args:
5851+
- !LinalgOperandDefConfig
5852+
name: IFM
5853+
kind: input_tensor
5854+
type_var: T1
5855+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
5856+
- !LinalgOperandDefConfig
5857+
name: OFM
5858+
kind: output_tensor
5859+
type_var: T1
5860+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
5861+
indexing_maps: !LinalgIndexingMapsConfig
5862+
static_indexing_maps:
5863+
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
5864+
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
5865+
iterator_types:
5866+
- parallel
5867+
- parallel
5868+
assignments:
5869+
- !ScalarAssign
5870+
arg: OFM
5871+
value: !ScalarExpression
5872+
scalar_fn:
5873+
kind: binary
5874+
fn_name: max_signed
5875+
operands:
5876+
- !ScalarExpression
5877+
scalar_arg: IFM
5878+
- !ScalarExpression
5879+
scalar_fn:
5880+
kind: type
5881+
fn_name: cast_signed
5882+
type_var: T1
5883+
operands:
5884+
- !ScalarExpression
5885+
scalar_const: '0.000000e+00 : f64'

mlir/lib/Dialect/Linalg/Transforms/Unfuse.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/PatternMatch.h"
2525
#include "mlir/Support/LogicalResult.h"
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27+
#include "llvm/ADT/ArrayRef.h"
2728
#include "llvm/ADT/SmallVector.h"
2829
#include "llvm/Support/Debug.h"
2930

@@ -650,12 +651,9 @@ struct GlobalAveragePool2DLowering : OpRewritePattern<GlobalAveragePool2DOp> {
650651
}
651652
};
652653

653-
/// Torch MLIR does a similar lowering for their Linear operator to lin alg
654-
/// here we implement the same so we can run tests using the unfused version
655-
struct LinearLowering : OpRewritePattern<LinearOp> {
656-
using OpRewritePattern<LinearOp>::OpRewritePattern;
657-
LogicalResult matchAndRewrite(LinearOp op,
658-
PatternRewriter &rewriter) const override {
654+
template <class Linear>
655+
static Value unfuseLinear(Linear &op, PatternRewriter &rewriter) {
656+
659657
Location loc = op.getLoc();
660658
Value input = op.getOperand(0);
661659
Value weights = op.getOperand(1);
@@ -690,10 +688,35 @@ struct LinearLowering : OpRewritePattern<LinearOp> {
690688
->getResult(0);
691689

692690
// Create the matmul operation that does the multiplcation and addition
693-
rewriter.replaceOpWithNewOp<MatmulOp>(op, output.getType(),
694-
ValueRange{input, transposeWeightsOp},
695-
broadcastBiasOp);
691+
auto newOp = rewriter.create<MatmulOp>(loc, outputType, ValueRange{op.getOperand(0), transposeWeightsOp},
692+
broadcastBiasOp).getResult(0);
693+
return newOp;
694+
}
695+
/// Torch MLIR does a similar lowering for their Linear operator to lin alg
696+
/// here we implement the same so we can run tests using the unfused version
697+
struct LinearLowering : OpRewritePattern<LinearOp> {
698+
using OpRewritePattern<LinearOp>::OpRewritePattern;
699+
LogicalResult matchAndRewrite(LinearOp op,
700+
PatternRewriter &rewriter) const override {
701+
Value matmul = unfuseLinear<LinearOp>(op, rewriter);
702+
rewriter.replaceOp(op, matmul);
703+
return success();
704+
}
705+
};
696706

707+
708+
struct LinearReluLowering : OpRewritePattern<LinearReluOp> {
709+
using OpRewritePattern<LinearReluOp>::OpRewritePattern;
710+
LogicalResult matchAndRewrite(LinearReluOp op,
711+
PatternRewriter &rewriter) const override {
712+
713+
Value linearResult = unfuseLinear<LinearReluOp>(op, rewriter);
714+
715+
rewriter.replaceOpWithNewOp<ReluNcOp>(
716+
op,
717+
/*resultTensorTypes=*/linearResult.getType(),
718+
/*inputs=*/linearResult,
719+
/*outputs=*/linearResult);
697720
return success();
698721
}
699722
};
@@ -711,7 +734,8 @@ struct LinalgUnfusePass : public impl::LinalgUnfuseBase<LinalgUnfusePass> {
711734
Conv2DTensorAddLreluAveragePoolLowering,
712735
Conv2DActivationMaxpoolOpLowering<Conv2DLreluMaxpoolOp>,
713736
Conv2DActivationMaxpoolOpLowering<Conv2DReluMaxpoolOp>,
714-
SoftmaxLowering, GlobalAveragePool2DLowering, LinearLowering>(
737+
SoftmaxLowering, GlobalAveragePool2DLowering, LinearLowering,
738+
LinearReluLowering>(
715739
&getContext());
716740

717741
(void)applyPatternsAndFoldGreedily(getOperation().getBody(),

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,39 @@ def fill_rng_2d(min=ScalarDef(F64),
13771377
scaling = (max - min) * inv_range
13781378
O[D.m, D.n] = TypeFn.cast_signed(
13791379
T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
1380+
1381+
@linalg_structured_op
1382+
def linear_relu(
1383+
I=TensorDef(T1, S.W, S.H),
1384+
W=TensorDef(T1, S.K, S.H),
1385+
B=TensorDef(T1, S.K),
1386+
O=TensorDef(T1, S.W, S.K, output=True)):
1387+
"""Performs a linear/fully-connected + relu operation
1388+
1389+
Performs a linear operation followed by a Relu
1390+
1391+
Layout:
1392+
* I: WH (Input)
1393+
* W: WH (Weights)
1394+
* B: H (Bias)
1395+
"""
1396+
domain(D.W, D.H, D.K)
1397+
# implementation is incorrect the addition of the bias should happen after
1398+
# the multiplication, not on each element
1399+
O[D.W, D.K] += I[D.W, D.H]*W[D.K, D.H] + B[D.K]
1400+
1401+
1402+
@linalg_structured_op
1403+
def relu_nc(
1404+
IFM=TensorDef(T1, Batch, S.C ),
1405+
OFM=TensorDef(T1, Batch, S.C, output=True )):
1406+
"""Applies the ReLU activation function to every value in the tensor.
1407+
1408+
Layout:
1409+
* Input: NC
1410+
"""
1411+
domain(D.b, D.c)
1412+
OFM[D.b, D.c] = BinaryFn.max_signed(
1413+
IFM[D.b, D.c], TypeFn.cast_signed(T1, const(0.0))
1414+
)
1415+

mlir/test/Dialect/Linalg/unfuse.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,26 @@ func.func @unfuse_linear(%input: tensor<1x2048xf32>, %weights: tensor<1000x2048x
448448
// CHECK: %[[bias2dshape:.+]] = tensor.empty() : tensor<1x1000xf32>
449449
// CHECK: %[[bias2d:.+]] = linalg.broadcast_1d_to_2d ins(%arg2 : tensor<1000xf32>) outs(%2 : tensor<1x1000xf32>) -> tensor<1x1000xf32>
450450
// CHECK: %[[out:.+]] = linalg.matmul ins(%[[input]], %[[tweights]] : tensor<1x2048xf32>, tensor<2048x1000xf32>) outs(%[[bias2d]] : tensor<1x1000xf32>) -> tensor<1x1000xf32
451+
// CHECK: return %[[out]]
452+
453+
return %result : tensor<1x1000xf32>
454+
}
455+
456+
// -----
457+
458+
// CHECK: func.func @unfuse_linearRelu
459+
// CHECK-SAME: %[[input:.+]]: tensor<1x2048xf32>, %[[weights:.+]]: tensor<1000x2048xf32>, %[[bias:.+]]: tensor<1000xf32>
460+
func.func @unfuse_linearRelu(%input: tensor<1x2048xf32>, %weights: tensor<1000x2048xf32>, %bias: tensor<1000xf32>) -> tensor<1x1000xf32> {
461+
%zero = arith.constant 0.0 : f32
462+
%init = tensor.splat %zero : tensor<1x1000xf32>
463+
%result = linalg.linear_relu ins(%input, %weights, %bias: tensor<1x2048xf32>, tensor<1000x2048xf32>, tensor<1000xf32>) outs(%init: tensor<1x1000xf32>) -> tensor<1x1000xf32>
464+
465+
// CHECK: %[[tweightshape:.+]] = tensor.empty() : tensor<2048x1000xf32>
466+
// CHECK: %[[tweights:.+]] = linalg.transpose2d ins(%arg1 : tensor<1000x2048xf32>) outs(%0 : tensor<2048x1000xf32>) -> tensor<2048x1000xf32>
467+
// CHECK: %[[bias2dshape:.+]] = tensor.empty() : tensor<1x1000xf32>
468+
// CHECK: %[[bias2d:.+]] = linalg.broadcast_1d_to_2d ins(%arg2 : tensor<1000xf32>) outs(%2 : tensor<1x1000xf32>) -> tensor<1x1000xf32>
469+
// CHECK: %[[matmul:.+]] = linalg.matmul ins(%[[input]], %[[tweights]] : tensor<1x2048xf32>, tensor<2048x1000xf32>) outs(%[[bias2d]] : tensor<1x1000xf32>) -> tensor<1x1000xf32
470+
// CHECK: %[[out:.*]] = linalg.relu_nc ins(%[[matmul]] : tensor<1x1000xf32>) outs(%[[matmul]] : tensor<1x1000xf32>) -> tensor<1x1000xf32>
451471
// CHECK: return %[[out]]
452472

453473
return %result : tensor<1x1000xf32>

0 commit comments

Comments
 (0)