@@ -34,17 +34,21 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
34
34
LogicalResult FromTensorOp::verify () {
35
35
ArrayRef<int64_t > tensorShape = getInput ().getType ().getShape ();
36
36
RingAttr ring = getOutput ().getType ().getRing ();
37
- unsigned polyDegree = ring.getPolynomialModulus ().getPolynomial ().getDegree ();
38
- bool compatible = tensorShape.size () == 1 && tensorShape[0 ] <= polyDegree;
39
- if (!compatible) {
40
- InFlightDiagnostic diag = emitOpError ()
41
- << " input type " << getInput ().getType ()
42
- << " does not match output type "
43
- << getOutput ().getType ();
44
- diag.attachNote () << " the input type must be a tensor of shape [d] where d "
45
- " is at most the degree of the polynomialModulus of "
46
- " the output type's ring attribute" ;
47
- return diag;
37
+ IntPolynomialAttr polyMod = ring.getPolynomialModulus ();
38
+ if (polyMod) {
39
+ unsigned polyDegree = polyMod.getPolynomial ().getDegree ();
40
+ bool compatible = tensorShape.size () == 1 && tensorShape[0 ] <= polyDegree;
41
+ if (!compatible) {
42
+ InFlightDiagnostic diag = emitOpError ()
43
+ << " input type " << getInput ().getType ()
44
+ << " does not match output type "
45
+ << getOutput ().getType ();
46
+ diag.attachNote ()
47
+ << " the input type must be a tensor of shape [d] where d "
48
+ " is at most the degree of the polynomialModulus of "
49
+ " the output type's ring attribute" ;
50
+ return diag;
51
+ }
48
52
}
49
53
50
54
unsigned inputBitWidth = getInput ().getType ().getElementTypeBitWidth ();
@@ -64,24 +68,27 @@ LogicalResult FromTensorOp::verify() {
64
68
65
69
LogicalResult ToTensorOp::verify () {
66
70
ArrayRef<int64_t > tensorShape = getOutput ().getType ().getShape ();
67
- unsigned polyDegree = getInput ()
68
- .getType ()
69
- .getRing ()
70
- .getPolynomialModulus ()
71
- .getPolynomial ()
72
- .getDegree ();
73
- bool compatible = tensorShape.size () == 1 && tensorShape[0 ] == polyDegree;
71
+ IntPolynomialAttr polyMod =
72
+ getInput ().getType ().getRing ().getPolynomialModulus ();
73
+ if (polyMod) {
74
+ unsigned polyDegree = polyMod.getPolynomial ().getDegree ();
75
+ bool compatible = tensorShape.size () == 1 && tensorShape[0 ] == polyDegree;
74
76
75
- if (compatible)
76
- return success ();
77
+ if (compatible)
78
+ return success ();
77
79
78
- InFlightDiagnostic diag =
79
- emitOpError () << " input type " << getInput ().getType ()
80
- << " does not match output type " << getOutput ().getType ();
81
- diag.attachNote () << " the output type must be a tensor of shape [d] where d "
82
- " is at most the degree of the polynomialModulus of "
83
- " the input type's ring attribute" ;
84
- return diag;
80
+ InFlightDiagnostic diag = emitOpError ()
81
+ << " input type " << getInput ().getType ()
82
+ << " does not match output type "
83
+ << getOutput ().getType ();
84
+ diag.attachNote ()
85
+ << " the output type must be a tensor of shape [d] where d "
86
+ " is at most the degree of the polynomialModulus of "
87
+ " the input type's ring attribute" ;
88
+ return diag;
89
+ }
90
+
91
+ return success ();
85
92
}
86
93
87
94
LogicalResult MulScalarOp::verify () {
0 commit comments