Skip to content

Commit 605d2f9

Browse files
[mlir][IR] Experiment: Allow ptr as vector element type
1 parent 491d3df commit 605d2f9

25 files changed

+172
-91
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
1515
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
1616

17+
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
1819
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1920
#include "mlir/Interfaces/MemorySlotInterfaces.h"

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1313
include "mlir/IR/AttrTypeBase.td"
14+
include "mlir/IR/BuiltinTypes.td"
1415
include "mlir/Interfaces/DataLayoutInterfaces.td"
1516
include "mlir/Interfaces/MemorySlotInterfaces.td"
1617

@@ -259,7 +260,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
259260
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
260261
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
261262
"getIndexBitwidth", "areCompatible", "verifyEntries",
262-
"getPreferredAlignment"]>]> {
263+
"getPreferredAlignment"]>,
264+
PointerLike]> {
263265
let summary = "LLVM pointer type";
264266
let description = [{
265267
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents

mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Interfaces/DataLayoutInterfaces.td"
1313
include "mlir/IR/AttrTypeBase.td"
1414
include "mlir/IR/BuiltinTypeInterfaces.td"
15+
include "mlir/IR/BuiltinTypes.td"
1516
include "mlir/IR/OpBase.td"
1617

1718
//===----------------------------------------------------------------------===//
@@ -39,7 +40,8 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
3940
MemRefElementTypeInterface,
4041
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
4142
"areCompatible", "getIndexBitwidth", "verifyEntries",
42-
"getPreferredAlignment"]>
43+
"getPreferredAlignment"]>,
44+
PointerLike
4345
]> {
4446
let summary = "pointer type";
4547
let description = [{

mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
1515

1616
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
17+
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
1819
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1920

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ template <typename ConcreteType>
4343
class ValueSemantics
4444
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545

46+
/// Type trait indicating that the type is a pointer-like type.
47+
template <typename ConcreteType>
48+
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
49+
4650
//===----------------------------------------------------------------------===//
4751
// TensorType
4852
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
4040
let cppNamespace = "::mlir";
4141
}
4242

43+
/// Type trait indicating that the type is a pointer-like type.
44+
def PointerLike : NativeTypeTrait<"PointerLike"> {
45+
let cppNamespace = "::mlir";
46+
}
47+
4348
//===----------------------------------------------------------------------===//
4449
// ComplexType
4550
//===----------------------------------------------------------------------===//
@@ -1249,7 +1254,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
12491254
// VectorType
12501255
//===----------------------------------------------------------------------===//
12511256

1252-
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
1257+
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
12531258
let cppFunctionName = "isValidVectorTypeElementType";
12541259
}
12551260

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
301301
"::mlir::IndexType">,
302302
BuildableType<"$_builder.getIndexType()">;
303303

304+
def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
305+
304306
// Any signless integer type or index type.
305307
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
306308
"signless integer or index">;

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,13 @@ static bool isSupportedTypeForConversion(Type type) {
140140
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
141141
return false;
142142

143-
// Scalable types are not supported.
144-
if (auto vectorType = dyn_cast<VectorType>(type))
143+
if (auto vectorType = dyn_cast<VectorType>(type)) {
144+
// Vectors of pointers cannot be casted.
145+
if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
146+
return false;
147+
// Scalable types are not supported.
145148
return !vectorType.isScalable();
149+
}
146150
return true;
147151
}
148152

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
690690
}
691691

692692
bool LLVMFixedVectorType::isValidElementType(Type type) {
693-
return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
693+
return llvm::isa<LLVMPPCFP128Type>(type);
694694
}
695695

696696
LogicalResult
@@ -890,7 +890,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
890890
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
891891
return intType.isSignless();
892892
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
893-
Float80Type, Float128Type>(elementType);
893+
Float80Type, Float128Type, LLVMPointerType>(elementType);
894894
}
895895
return false;
896896
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,8 +2002,8 @@ func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1
20022002
}
20032003

20042004
// CHECK-LABEL: func @gather
2005-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2006-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2005+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
2006+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
20072007
// CHECK: return %[[G]] : vector<3xf32>
20082008

20092009
// -----
@@ -2015,8 +2015,8 @@ func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
20152015
}
20162016

20172017
// CHECK-LABEL: func @gather_scalable
2018-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2019-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2018+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
2019+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
20202020
// CHECK: return %[[G]] : vector<[3]xf32>
20212021

20222022
// -----
@@ -2028,8 +2028,8 @@ func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %
20282028
}
20292029

20302030
// CHECK-LABEL: func @gather_global_memory
2031-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
2032-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2031+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> vector<3x!llvm.ptr<1>>, f32
2032+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
20332033
// CHECK: return %[[G]] : vector<3xf32>
20342034

20352035
// -----
@@ -2041,8 +2041,8 @@ func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<
20412041
}
20422042

20432043
// CHECK-LABEL: func @gather_global_memory_scalable
2044-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
2045-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2044+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> vector<[3]x!llvm.ptr<1>>, f32
2045+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
20462046
// CHECK: return %[[G]] : vector<[3]xf32>
20472047

20482048
// -----
@@ -2055,8 +2055,8 @@ func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: v
20552055
}
20562056

20572057
// CHECK-LABEL: func @gather_index
2058-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
2059-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
2058+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
2059+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
20602060
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>
20612061

20622062
// -----
@@ -2068,12 +2068,57 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
20682068
}
20692069

20702070
// CHECK-LABEL: func @gather_index_scalable
2071-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
2072-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
2071+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
2072+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
20732073
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>
20742074

20752075
// -----
20762076

2077+
func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
2078+
%0 = arith.constant 0: index
2079+
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
2080+
return %1 : vector<2x3xf32>
2081+
}
2082+
2083+
// CHECK-LABEL: func @gather_2d_from_1d
2084+
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
2085+
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
2086+
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
2087+
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
2088+
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
2089+
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2090+
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
2091+
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
2092+
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
2093+
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
2094+
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
2095+
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2096+
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
2097+
2098+
// -----
2099+
2100+
func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
2101+
%0 = arith.constant 0: index
2102+
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
2103+
return %1 : vector<2x[3]xf32>
2104+
}
2105+
2106+
// CHECK-LABEL: func @gather_2d_from_1d_scalable
2107+
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
2108+
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
2109+
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
2110+
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
2111+
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
2112+
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2113+
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
2114+
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
2115+
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
2116+
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
2117+
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
2118+
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2119+
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
2120+
2121+
// -----
20772122

20782123
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
20792124
%0 = arith.constant 3 : index
@@ -2083,8 +2128,8 @@ func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2
20832128

20842129
// CHECK-LABEL: func @gather_1d_from_2d
20852130
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2086-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
2087-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
2131+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
2132+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<4x!llvm.ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
20882133
// CHECK: return %[[G]] : vector<4xf32>
20892134

20902135
// -----
@@ -2097,8 +2142,8 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
20972142

20982143
// CHECK-LABEL: func @gather_1d_from_2d_scalable
20992144
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2100-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
2101-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
2145+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
2146+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[4]x!llvm.ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
21022147
// CHECK: return %[[G]] : vector<[4]xf32>
21032148

21042149
// -----
@@ -2114,8 +2159,8 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
21142159
}
21152160

21162161
// CHECK-LABEL: func @scatter
2117-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2118-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
2162+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
2163+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
21192164

21202165
// -----
21212166

@@ -2126,8 +2171,8 @@ func.func @scatter_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
21262171
}
21272172

21282173
// CHECK-LABEL: func @scatter_scalable
2129-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2130-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
2174+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
2175+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
21312176

21322177
// -----
21332178

@@ -2138,8 +2183,8 @@ func.func @scatter_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2:
21382183
}
21392184

21402185
// CHECK-LABEL: func @scatter_index
2141-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
2142-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr>
2186+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
2187+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into vector<3x!llvm.ptr>
21432188

21442189
// -----
21452190

@@ -2150,8 +2195,8 @@ func.func @scatter_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xinde
21502195
}
21512196

21522197
// CHECK-LABEL: func @scatter_index_scalable
2153-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
2154-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
2198+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
2199+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
21552200

21562201
// -----
21572202

@@ -2163,8 +2208,8 @@ func.func @scatter_1d_into_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg
21632208

21642209
// CHECK-LABEL: func @scatter_1d_into_2d
21652210
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2166-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
2167-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr>
2211+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
2212+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into vector<4x!llvm.ptr>
21682213

21692214
// -----
21702215

@@ -2176,8 +2221,8 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
21762221

21772222
// CHECK-LABEL: func @scatter_1d_into_2d_scalable
21782223
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2179-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
2180-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.vec<? x 4 x ptr>
2224+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
2225+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
21812226

21822227
// -----
21832228

0 commit comments

Comments
 (0)