-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[polynomial] Move primitive root attribute to ntt/intt ops. #93227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() { | |
} | ||
|
||
/// Test if a value is a primitive nth root of unity modulo cmod. | ||
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n, | ||
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, | ||
const APInt &cmod) { | ||
// Root bitwidth may be 1 less then cmod. | ||
APInt r = APInt(root).zext(cmod.getBitWidth()); | ||
assert(r.ule(cmod) && "root must be less than cmod"); | ||
unsigned upperBound = n.getZExtValue(); | ||
|
||
APInt a = r; | ||
for (size_t k = 1; k < n; k++) { | ||
for (size_t k = 1; k < upperBound; k++) { | ||
if (a.isOne()) | ||
return false; | ||
a = (a * r).urem(cmod); | ||
|
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n, | |
/// Verify that the types involved in an NTT or INTT operation are | ||
/// compatible. | ||
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, | ||
RankedTensorType tensorType) { | ||
RankedTensorType tensorType, | ||
std::optional<PrimitiveRootAttr> root) { | ||
Attribute encoding = tensorType.getEncoding(); | ||
if (!encoding) { | ||
return op->emitOpError() | ||
|
@@ -157,33 +159,30 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, | |
return diag; | ||
} | ||
|
||
if (!ring.getPrimitiveRoot()) { | ||
return op->emitOpError() | ||
<< "ring type " << ring << " does not provide a primitive root " | ||
<< "of unity, which is required to express an NTT"; | ||
} | ||
|
||
if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree, | ||
ring.getCoefficientModulus().getValue())) { | ||
return op->emitOpError() | ||
<< "ring type " << ring << " has a primitiveRoot attribute '" | ||
<< ring.getPrimitiveRoot() | ||
<< "' that is not a primitive root of the coefficient ring"; | ||
if (root.has_value()) { | ||
APInt rootValue = root.value().getValue().getValue(); | ||
APInt rootDegree = root.value().getDegree().getValue(); | ||
APInt cmod = ring.getCoefficientModulus().getValue(); | ||
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it's questionable convention in the case of nested conditionals? up to you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is indeed questionable, I usually lean towards having braces for multi-line bodies even if there is a single statement. |
||
return op->emitOpError() | ||
<< "provided root " << rootValue.getZExtValue() | ||
<< " is not a primitive root " | ||
<< "of unity mod " << cmod.getZExtValue() | ||
<< ", with the specified degree " << rootDegree.getZExtValue(); | ||
} | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
LogicalResult NTTOp::verify() { | ||
auto ring = getInput().getType().getRing(); | ||
auto tensorType = getOutput().getType(); | ||
return verifyNTTOp(this->getOperation(), ring, tensorType); | ||
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(), | ||
getOutput().getType(), getRoot()); | ||
} | ||
|
||
LogicalResult INTTOp::verify() { | ||
auto tensorType = getInput().getType(); | ||
auto ring = getOutput().getType().getRing(); | ||
return verifyNTTOp(this->getOperation(), ring, tensorType); | ||
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(), | ||
getInput().getType(), getRoot()); | ||
} | ||
|
||
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if mlir docs can render latex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like yes:
