Skip to content

adding rotary embedding example, with graph rewrite for complex subgraph [WIP] #3570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Jun 13, 2025

This PR-

  1. Adds an example for parallel rotary embedding
  2. Adds logic for complex graph detection
  3. Adds a pass for complex graph rewrite in aten_lowering_pass
    Please note that this PR is currently for the single GPU case where there is no DTensor in the inputs of the torch module. Ideally this should not require runtime changes.
    This should avoid the graph breaks caused due to view_as_complex and view_as_real nodes.

@apbose apbose marked this pull request as draft June 13, 2025 00:06
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 13, 2025
@github-actions github-actions bot requested a review from peri044 June 13, 2025 00:06
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py	2025-06-13 00:06:13.440339+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py	2025-06-13 00:06:42.189786+00:00
@@ -98,16 +98,17 @@
class _TorchTensorRTConstantFolder(ConstantFolder):  # type: ignore[misc]
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

    def is_impure(self, node: torch.fx.node.Node) -> bool:
-        # Set of known quantization ops to be excluded from constant folding. 
+        # Set of known quantization ops to be excluded from constant folding.
        # Currently, we exclude all quantization ops coming from modelopt library.
        quantization_ops = {}
        try:
-            # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered 
+            # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
            import modelopt.torch.quantization as mtq
+
            assert torch.ops.tensorrt.quantize_op.default
            quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
        except Exception as e:
            pass
        if quantization_ops and node.target in quantization_ops:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests for this pass

Comment on lines +247 to +255
"""Constructs the original and replacement functions for complex multiplication.

The original functions correspond to native complex multiplication
via torch.mul or operator.mul on complex tensors.

The replacement function assumes x and y are real tensors with the last
dimension size 2 representing real and imaginary parts, and performs
complex multiplication manually returning the same shaped tensor.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you know the limitations of this pass, please add for future reference

class ComplexOpDetector:
def __init__(self, logger):
self.logger = logger
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants