-
Notifications
You must be signed in to change notification settings - Fork 363
adding rotary embedding example, with graph rewrite for complex subgraph [WIP] #3570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-13 00:06:13.440339+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-13 00:06:42.189786+00:00
@@ -98,16 +98,17 @@
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def is_impure(self, node: torch.fx.node.Node) -> bool:
- # Set of known quantization ops to be excluded from constant folding.
+ # Set of known quantization ops to be excluded from constant folding.
# Currently, we exclude all quantization ops coming from modelopt library.
quantization_ops = {}
try:
- # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
import modelopt.torch.quantization as mtq
+
assert torch.ops.tensorrt.quantize_op.default
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
except Exception as e:
pass
if quantization_ops and node.target in quantization_ops:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add tests for this pass
"""Constructs the original and replacement functions for complex multiplication. | ||
|
||
The original functions correspond to native complex multiplication | ||
via torch.mul or operator.mul on complex tensors. | ||
|
||
The replacement function assumes x and y are real tensors with the last | ||
dimension size 2 representing real and imaginary parts, and performs | ||
complex multiplication manually returning the same shaped tensor. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you know the limitations of this pass, please add for future reference
class ComplexOpDetector: | ||
def __init__(self, logger): | ||
self.logger = logger | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed ?
This PR-
Please note that this PR is currently for the single GPU case where there is no DTensor in the inputs of the torch module. Ideally this should not require runtime changes.
This should avoid the graph breaks caused due to view_as_complex and view_as_real nodes.