@@ -34,24 +34,25 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
34
34
LogicalResult FromTensorOp::verify () {
35
35
ArrayRef<int64_t > tensorShape = getInput ().getType ().getShape ();
36
36
RingAttr ring = getOutput ().getType ().getRing ();
37
- unsigned polyDegree = ring.getPolynomialModulus ().getPolynomial ().getDegree ();
38
- bool compatible = tensorShape.size () == 1 && tensorShape[0 ] <= polyDegree;
39
- if (!compatible) {
40
- InFlightDiagnostic diag = emitOpError ()
41
- << " input type " << getInput ().getType ()
42
- << " does not match output type "
43
- << getOutput ().getType ();
44
- diag.attachNote () << " the input type must be a tensor of shape [d] where d "
45
- " is at most the degree of the polynomialModulus of "
46
- " the output type's ring attribute" ;
47
- return diag;
37
+ IntPolynomialAttr polyMod = ring.getPolynomialModulus ();
38
+ if (polyMod) {
39
+ unsigned polyDegree = polyMod.getPolynomial ().getDegree ();
40
+ bool compatible = tensorShape.size () == 1 && tensorShape[0 ] <= polyDegree;
41
+ if (!compatible) {
42
+ InFlightDiagnostic diag = emitOpError ()
43
+ << " input type " << getInput ().getType ()
44
+ << " does not match output type "
45
+ << getOutput ().getType ();
46
+ diag.attachNote ()
47
+ << " the input type must be a tensor of shape [d] where d "
48
+ " is at most the degree of the polynomialModulus of "
49
+ " the output type's ring attribute" ;
50
+ return diag;
51
+ }
48
52
}
49
53
50
- APInt coefficientModulus = ring.getCoefficientModulus ().getValue ();
51
- unsigned cmodBitWidth = coefficientModulus.ceilLogBase2 ();
52
54
unsigned inputBitWidth = getInput ().getType ().getElementTypeBitWidth ();
53
-
54
- if (inputBitWidth > cmodBitWidth) {
55
+ if (inputBitWidth > ring.getCoefficientType ().getIntOrFloatBitWidth ()) {
55
56
InFlightDiagnostic diag = emitOpError ()
56
57
<< " input tensor element type "
57
58
<< getInput ().getType ().getElementType ()
@@ -67,24 +68,27 @@ LogicalResult FromTensorOp::verify() {
67
68
68
69
LogicalResult ToTensorOp::verify () {
69
70
ArrayRef<int64_t > tensorShape = getOutput ().getType ().getShape ();
70
- unsigned polyDegree = getInput ()
71
- .getType ()
72
- .getRing ()
73
- .getPolynomialModulus ()
74
- .getPolynomial ()
75
- .getDegree ();
76
- bool compatible = tensorShape.size () == 1 && tensorShape[0 ] == polyDegree;
71
+ IntPolynomialAttr polyMod =
72
+ getInput ().getType ().getRing ().getPolynomialModulus ();
73
+ if (polyMod) {
74
+ unsigned polyDegree = polyMod.getPolynomial ().getDegree ();
75
+ bool compatible = tensorShape.size () == 1 && tensorShape[0 ] == polyDegree;
77
76
78
- if (compatible)
79
- return success ();
77
+ if (compatible)
78
+ return success ();
79
+
80
+ InFlightDiagnostic diag = emitOpError ()
81
+ << " input type " << getInput ().getType ()
82
+ << " does not match output type "
83
+ << getOutput ().getType ();
84
+ diag.attachNote ()
85
+ << " the output type must be a tensor of shape [d] where d "
86
+ " is at most the degree of the polynomialModulus of "
87
+ " the input type's ring attribute" ;
88
+ return diag;
89
+ }
80
90
81
- InFlightDiagnostic diag =
82
- emitOpError () << " input type " << getInput ().getType ()
83
- << " does not match output type " << getOutput ().getType ();
84
- diag.attachNote () << " the output type must be a tensor of shape [d] where d "
85
- " is at most the degree of the polynomialModulus of "
86
- " the input type's ring attribute" ;
87
- return diag;
91
+ return success ();
88
92
}
89
93
90
94
LogicalResult MulScalarOp::verify () {
0 commit comments