Skip to content

Commit 692ae54

Browse files
authored
[mlir][polynomial] verify from_tensor coeff type (#93243)
Rebased over #93227 --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent 270d95b commit 692ae54

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,25 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
3434
LogicalResult FromTensorOp::verify() {
3535
ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
3636
RingAttr ring = getOutput().getType().getRing();
37-
unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
38-
bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
39-
if (!compatible) {
40-
InFlightDiagnostic diag = emitOpError()
41-
<< "input type " << getInput().getType()
42-
<< " does not match output type "
43-
<< getOutput().getType();
44-
diag.attachNote() << "the input type must be a tensor of shape [d] where d "
45-
"is at most the degree of the polynomialModulus of "
46-
"the output type's ring attribute";
47-
return diag;
37+
IntPolynomialAttr polyMod = ring.getPolynomialModulus();
38+
if (polyMod) {
39+
unsigned polyDegree = polyMod.getPolynomial().getDegree();
40+
bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
41+
if (!compatible) {
42+
InFlightDiagnostic diag = emitOpError()
43+
<< "input type " << getInput().getType()
44+
<< " does not match output type "
45+
<< getOutput().getType();
46+
diag.attachNote()
47+
<< "the input type must be a tensor of shape [d] where d "
48+
"is at most the degree of the polynomialModulus of "
49+
"the output type's ring attribute";
50+
return diag;
51+
}
4852
}
4953

50-
APInt coefficientModulus = ring.getCoefficientModulus().getValue();
51-
unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
5254
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
53-
54-
if (inputBitWidth > cmodBitWidth) {
55+
if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
5556
InFlightDiagnostic diag = emitOpError()
5657
<< "input tensor element type "
5758
<< getInput().getType().getElementType()
@@ -67,24 +68,27 @@ LogicalResult FromTensorOp::verify() {
6768

6869
LogicalResult ToTensorOp::verify() {
6970
ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
70-
unsigned polyDegree = getInput()
71-
.getType()
72-
.getRing()
73-
.getPolynomialModulus()
74-
.getPolynomial()
75-
.getDegree();
76-
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
71+
IntPolynomialAttr polyMod =
72+
getInput().getType().getRing().getPolynomialModulus();
73+
if (polyMod) {
74+
unsigned polyDegree = polyMod.getPolynomial().getDegree();
75+
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
7776

78-
if (compatible)
79-
return success();
77+
if (compatible)
78+
return success();
79+
80+
InFlightDiagnostic diag = emitOpError()
81+
<< "input type " << getInput().getType()
82+
<< " does not match output type "
83+
<< getOutput().getType();
84+
diag.attachNote()
85+
<< "the output type must be a tensor of shape [d] where d "
86+
"is at most the degree of the polynomialModulus of "
87+
"the input type's ring attribute";
88+
return diag;
89+
}
8090

81-
InFlightDiagnostic diag =
82-
emitOpError() << "input type " << getInput().getType()
83-
<< " does not match output type " << getOutput().getType();
84-
diag.attachNote() << "the output type must be a tensor of shape [d] where d "
85-
"is at most the degree of the polynomialModulus of "
86-
"the input type's ring attribute";
87-
return diag;
91+
return success();
8892
}
8993

9094
LogicalResult MulScalarOp::verify() {

mlir/test/Dialect/Polynomial/ops_errors.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: mlir-opt --split-input-file --verify-diagnostics %s
22

33
#my_poly = #polynomial.int_polynomial<1 + x**1024>
4-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
4+
#ring = #polynomial.ring<coefficientType=i16>
55
!ty = !polynomial.polynomial<ring=#ring>
66

77
func.func @test_from_tensor_too_large_coeffs() {

0 commit comments

Comments
 (0)