Skip to content

Mechanisms to register components of the code base that have debugging functionality #3549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

narendasan
Copy link
Collaborator

Description

This PR demonstrates how we can use a context manager to automatically patch the compiler to toggle debugging settings scattered through the codebase. It operates through 3 main contact points for developers, a optional kwarg, a settings struct and a decorator to declare a component as part of the debugging system.

  1. A component developer should add a optional kwarg to their entry point and decorate the entry point with the appropriate decorator:
# For functions 

@fn_supports_debugging 
def func(..., *, _debugging_settings: Optional[DebuggerConfig]=None): 


# For classes 
@cls_supports_debugging
class MyClass: 
    def __init__(self, ..., *, _debugging_settings: Optional[DebuggerConfig]=None):

This config will then be available to be stored as a attr in the class or to be used as part of the function.

  1. Any specific settings should be added to DebuggerConfig. This becomes the user facing API for the Debugger class. The developer can use this setting to manage debugging behavior.

  2. Users will set the config value using a context manager:

with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
        optimized_model = torch_tensorrt.compile(
        model,
        #ir="torch_compile",
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        workspace_size=workspace_size,
        min_block_size=min_block_size,
    )

Here engine_builder_monitor is a feature of TRTInterpreter and break_in_remove_assert_nodes is a feature of the remove_assert_nodes pass.

Both these components register as part of the debugging system. When the context is entered, for functions, the default kwargs are edited to contain the DebuggerConfig. For classes, the class is mocked with the argument filled in the constructor. On exit, these modifications are cleaned up.

Note: This only address the stateless components of the the system. The behavior for the stateful parts (e.g. TorchTensorRTModule, MutableTorchTensorRTModule) still need some solution that integrate with the debugger.

In this PR (which needs to be cleaned), the example debugging_example.py triggers both the TRTInterpreter setting to turn on the builder monitor and the lowering pass setting. Breakpoints are set to verify behavior. Correct execution means that on the first compilation two breakpoints are hit, one in the lowering pass and one in the TRTBuilderMonitor. Then on the second compilation, no breakpoints are hit and the script completes.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

…debugging features that will be active in a context manager (compile time only for now)
@narendasan narendasan requested review from peri044 and cehongwang June 3, 2025 04:51
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: build system Issues re: Build system component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 3, 2025
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py	2025-06-03 04:51:19.066284+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py	2025-06-03 04:51:41.306273+00:00
@@ -27,14 +27,16 @@
# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Compile within the debugging context to control different aspects of the compilation pipeline
-with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
+with torch_tensorrt.dynamo._Debugger.Debugger(
+    engine_builder_monitor=True, break_in_remove_assert_nodes=True
+):
    optimized_model = torch_tensorrt.compile(
        model,
-        #ir="torch_compile",
+        # ir="torch_compile",
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        workspace_size=workspace_size,
        min_block_size=min_block_size,
    )
@@ -50,11 +52,11 @@
print(new_outputs)


optimized_model = torch_tensorrt.compile(
    model,
-    #ir="torch_compile",
+    # ir="torch_compile",
    inputs=inputs,
    enabled_precisions=enabled_precisions,
    workspace_size=workspace_size,
    min_block_size=min_block_size,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py	2025-06-03 04:51:19.076284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py	2025-06-03 04:51:41.757086+00:00
@@ -1,13 +1,14 @@
from typing import Any, List, Optional, Dict

import logging
from dataclasses import dataclass, field

+
@dataclass
class DebuggerConfig:
-    log_level: int = logging.getLevelName('DEBUG')
+    log_level: int = logging.getLevelName("DEBUG")
    capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
    capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
    save_engine_profile: bool = False
    engine_builder_monitor: bool = True
    break_in_remove_assert_nodes: bool = True
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py	2025-06-03 04:51:19.076284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py	2025-06-03 04:51:41.803440+00:00
@@ -11,11 +11,14 @@
from logging.config import dictConfig

import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
-from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
+from torch_tensorrt.dynamo._supports_debugger import (
+    _DEBUG_ENABLED_CLS,
+    _DEBUG_ENABLED_FUNCS,
+)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
    ATEN_POST_LOWERING_PASSES,
    ATEN_PRE_LOWERING_PASSES,
@@ -33,46 +36,52 @@
        self.cfg = DebuggerConfig(**kwargs)

    def __enter__(self) -> None:
        self.original_lvl = _LOGGER.getEffectiveLevel()
        self.rt_level = torch.ops.tensorrt.get_logging_level()
-        #dictConfig(self.get_config())
+        # dictConfig(self.get_config())

        self.old_pre_passes, self.old_post_passes = (
            ATEN_PRE_LOWERING_PASSES.passes,
            ATEN_POST_LOWERING_PASSES.passes,
        )
        pre_pass_names = [p.__name__ for p in self.old_pre_passes]
        post_pass_names = [p.__name__ for p in self.old_post_passes]
-        #path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
+        # path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")

-        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
-        ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
+        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(
+            self.cfg.capture_fx_graph_before
+        )
+        ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
+            self.cfg.capture_fx_graph_before
+        )

-        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
-        ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
+        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(
+            self.cfg.capture_fx_graph_after
+        )
+        ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(
+            self.cfg.capture_fx_graph_after
+        )

        self._context_stack = contextlib.ExitStack()

        for f in _DEBUG_ENABLED_FUNCS:
            f.__kwdefaults__["_debugger_settings"] = self.cfg

        [
            self._context_stack.enter_context(
                mock.patch.object(
                    c,
-                    '__init__',
-                    functools.partialmethod(
-                        c.__init__,
-                        _debugger_settings=self.cfg
-                    )
+                    "__init__",
+                    functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
                )
-            ) for c in _DEBUG_ENABLED_CLS
+            )
+            for c in _DEBUG_ENABLED_CLS
        ]

    def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
-        #dictConfig(self.get_default_config())
+        # dictConfig(self.get_default_config())
        torch.ops.tensorrt.set_logging_level(self.rt_level)

        ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
            self.old_pre_passes,
            self.old_post_passes,
@@ -80,12 +89,10 @@

        for f in _DEBUG_ENABLED_FUNCS:
            f.__kwdefaults__["_debugger_settings"] = None

        self._context_stack.close()
-
-

    # def get_config(self) -> dict[str, Any]:
    #     config = {
    #         "version": 1,
    #         "disable_existing_loggers": False,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py	2025-06-03 04:51:19.077284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py	2025-06-03 04:51:41.917136+00:00
@@ -4,12 +4,14 @@


_DEBUG_ENABLED_FUNCS = []
_DEBUG_ENABLED_CLS = []

+
def fn_supports_debugger(func):
    _DEBUG_ENABLED_FUNCS.append(func)
    return func

-def cls_supports_debugger(cls:  Type[T]) ->  Type[T]:
+
+def cls_supports_debugger(cls: Type[T]) -> Type[T]:
    _DEBUG_ENABLED_CLS.append(cls)
    return cls
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py	2025-06-03 04:51:19.077284+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py	2025-06-03 04:51:42.103174+00:00
@@ -90,10 +90,11 @@
                clear_line()
                print()
            move_cursor_up(len(self._active_phases) + blank_lines)
            sys.stdout.flush()

+
# try:
#     from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn

#     class _RichMonitor(trt.IProgressMonitor):  # type: ignore
#         def __init__(self, engine_name: str = "") -> None:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-06-03 04:51:19.078285+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-06-03 04:51:42.829249+00:00
@@ -69,10 +69,11 @@
    input_names: Sequence[str]
    output_names: Sequence[str]
    weight_name_map: Optional[dict[Any, Any]]
    requires_output_allocator: bool

+
@cls_supports_debugger
class TRTInterpreter(torch.fx.Interpreter):  # type: ignore[misc]
    def __init__(
        self,
        module: torch.fx.GraphModule,
@@ -207,11 +208,14 @@
        algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
        tactic_sources: Optional[int] = None,
    ) -> trt.IBuilderConfig:
        builder_config = self.builder.create_builder_config()

-        if self._debugger_settings is not None and self._debugger_settings.engine_builder_monitor:
+        if (
+            self._debugger_settings is not None
+            and self._debugger_settings.engine_builder_monitor
+        ):
            builder_config.progress_monitor = TRTBulderMonitor()

        if self.compilation_settings.workspace_size != 0:
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py	2025-06-03 04:51:19.081285+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py	2025-06-03 04:51:43.303996+00:00
@@ -9,16 +9,23 @@
    clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)

+
@fn_supports_debugger
def remove_assert_nodes(
-    gm: torch.fx.GraphModule, settings: CompilationSettings, *, _debugger_settings: Optional[DebuggerConfig]=None
+    gm: torch.fx.GraphModule,
+    settings: CompilationSettings,
+    *,
+    _debugger_settings: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
    """Remove assert_scalar ops in the graph"""
-    if _debugger_settings is not None and _debugger_settings.break_in_remove_assert_nodes:
+    if (
+        _debugger_settings is not None
+        and _debugger_settings.break_in_remove_assert_nodes
+    ):
        breakpoint()
    count = 0
    for node in gm.graph.nodes:
        if (
            node.target == torch.ops.aten._assert_scalar.default

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py	2025-06-03 04:51:19.389777+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py	2025-06-03 04:51:41.613814+00:00
@@ -27,14 +27,16 @@
# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Compile within the debugging context to control different aspects of the compilation pipeline
-with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
+with torch_tensorrt.dynamo._Debugger.Debugger(
+    engine_builder_monitor=True, break_in_remove_assert_nodes=True
+):
    optimized_model = torch_tensorrt.compile(
        model,
-        #ir="torch_compile",
+        # ir="torch_compile",
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        workspace_size=workspace_size,
        min_block_size=min_block_size,
    )
@@ -50,11 +52,11 @@
print(new_outputs)


optimized_model = torch_tensorrt.compile(
    model,
-    #ir="torch_compile",
+    # ir="torch_compile",
    inputs=inputs,
    enabled_precisions=enabled_precisions,
    workspace_size=workspace_size,
    min_block_size=min_block_size,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py	2025-06-03 04:51:19.400778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py	2025-06-03 04:51:42.057169+00:00
@@ -1,13 +1,14 @@
from typing import Any, List, Optional, Dict

import logging
from dataclasses import dataclass, field

+
@dataclass
class DebuggerConfig:
-    log_level: int = logging.getLevelName('DEBUG')
+    log_level: int = logging.getLevelName("DEBUG")
    capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
    capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
    save_engine_profile: bool = False
    engine_builder_monitor: bool = True
    break_in_remove_assert_nodes: bool = True
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py	2025-06-03 04:51:19.400778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py	2025-06-03 04:51:42.084847+00:00
@@ -11,11 +11,14 @@
from logging.config import dictConfig

import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
-from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
+from torch_tensorrt.dynamo._supports_debugger import (
+    _DEBUG_ENABLED_CLS,
+    _DEBUG_ENABLED_FUNCS,
+)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
    ATEN_POST_LOWERING_PASSES,
    ATEN_PRE_LOWERING_PASSES,
@@ -33,46 +36,52 @@
        self.cfg = DebuggerConfig(**kwargs)

    def __enter__(self) -> None:
        self.original_lvl = _LOGGER.getEffectiveLevel()
        self.rt_level = torch.ops.tensorrt.get_logging_level()
-        #dictConfig(self.get_config())
+        # dictConfig(self.get_config())

        self.old_pre_passes, self.old_post_passes = (
            ATEN_PRE_LOWERING_PASSES.passes,
            ATEN_POST_LOWERING_PASSES.passes,
        )
        pre_pass_names = [p.__name__ for p in self.old_pre_passes]
        post_pass_names = [p.__name__ for p in self.old_post_passes]
-        #path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
+        # path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")

-        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
-        ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
+        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(
+            self.cfg.capture_fx_graph_before
+        )
+        ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
+            self.cfg.capture_fx_graph_before
+        )

-        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
-        ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
+        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(
+            self.cfg.capture_fx_graph_after
+        )
+        ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(
+            self.cfg.capture_fx_graph_after
+        )

        self._context_stack = contextlib.ExitStack()

        for f in _DEBUG_ENABLED_FUNCS:
            f.__kwdefaults__["_debugger_settings"] = self.cfg

        [
            self._context_stack.enter_context(
                mock.patch.object(
                    c,
-                    '__init__',
-                    functools.partialmethod(
-                        c.__init__,
-                        _debugger_settings=self.cfg
-                    )
+                    "__init__",
+                    functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
                )
-            ) for c in _DEBUG_ENABLED_CLS
+            )
+            for c in _DEBUG_ENABLED_CLS
        ]

    def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
-        #dictConfig(self.get_default_config())
+        # dictConfig(self.get_default_config())
        torch.ops.tensorrt.set_logging_level(self.rt_level)

        ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
            self.old_pre_passes,
            self.old_post_passes,
@@ -80,12 +89,10 @@

        for f in _DEBUG_ENABLED_FUNCS:
            f.__kwdefaults__["_debugger_settings"] = None

        self._context_stack.close()
-
-

    # def get_config(self) -> dict[str, Any]:
    #     config = {
    #         "version": 1,
    #         "disable_existing_loggers": False,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py	2025-06-03 04:51:19.400778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py	2025-06-03 04:51:42.196192+00:00
@@ -4,12 +4,14 @@


_DEBUG_ENABLED_FUNCS = []
_DEBUG_ENABLED_CLS = []

+
def fn_supports_debugger(func):
    _DEBUG_ENABLED_FUNCS.append(func)
    return func

-def cls_supports_debugger(cls:  Type[T]) ->  Type[T]:
+
+def cls_supports_debugger(cls: Type[T]) -> Type[T]:
    _DEBUG_ENABLED_CLS.append(cls)
    return cls
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py	2025-06-03 04:51:19.401778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py	2025-06-03 04:51:42.406536+00:00
@@ -90,10 +90,11 @@
                clear_line()
                print()
            move_cursor_up(len(self._active_phases) + blank_lines)
            sys.stdout.flush()

+
# try:
#     from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn

#     class _RichMonitor(trt.IProgressMonitor):  # type: ignore
#         def __init__(self, engine_name: str = "") -> None:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-06-03 04:51:19.401778+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-06-03 04:51:43.112812+00:00
@@ -69,10 +69,11 @@
    input_names: Sequence[str]
    output_names: Sequence[str]
    weight_name_map: Optional[dict[Any, Any]]
    requires_output_allocator: bool

+
@cls_supports_debugger
class TRTInterpreter(torch.fx.Interpreter):  # type: ignore[misc]
    def __init__(
        self,
        module: torch.fx.GraphModule,
@@ -207,11 +208,14 @@
        algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
        tactic_sources: Optional[int] = None,
    ) -> trt.IBuilderConfig:
        builder_config = self.builder.create_builder_config()

-        if self._debugger_settings is not None and self._debugger_settings.engine_builder_monitor:
+        if (
+            self._debugger_settings is not None
+            and self._debugger_settings.engine_builder_monitor
+        ):
            builder_config.progress_monitor = TRTBulderMonitor()

        if self.compilation_settings.workspace_size != 0:
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py	2025-06-03 04:51:19.405777+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py	2025-06-03 04:51:43.491038+00:00
@@ -9,16 +9,23 @@
    clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)

+
@fn_supports_debugger
def remove_assert_nodes(
-    gm: torch.fx.GraphModule, settings: CompilationSettings, *, _debugger_settings: Optional[DebuggerConfig]=None
+    gm: torch.fx.GraphModule,
+    settings: CompilationSettings,
+    *,
+    _debugger_settings: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
    """Remove assert_scalar ops in the graph"""
-    if _debugger_settings is not None and _debugger_settings.break_in_remove_assert_nodes:
+    if (
+        _debugger_settings is not None
+        and _debugger_settings.break_in_remove_assert_nodes
+    ):
        breakpoint()
    count = 0
    for node in gm.graph.nodes:
        if (
            node.target == torch.ops.aten._assert_scalar.default

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py	2025-06-03 04:51:20.312103+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/debugging_example.py	2025-06-03 04:51:42.612724+00:00
@@ -27,14 +27,16 @@
# %%
# Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Compile within the debugging context to control different aspects of the compilation pipeline
-with torch_tensorrt.dynamo._Debugger.Debugger(engine_builder_monitor=True, break_in_remove_assert_nodes=True):
+with torch_tensorrt.dynamo._Debugger.Debugger(
+    engine_builder_monitor=True, break_in_remove_assert_nodes=True
+):
    optimized_model = torch_tensorrt.compile(
        model,
-        #ir="torch_compile",
+        # ir="torch_compile",
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        workspace_size=workspace_size,
        min_block_size=min_block_size,
    )
@@ -50,11 +52,11 @@
print(new_outputs)


optimized_model = torch_tensorrt.compile(
    model,
-    #ir="torch_compile",
+    # ir="torch_compile",
    inputs=inputs,
    enabled_precisions=enabled_precisions,
    workspace_size=workspace_size,
    min_block_size=min_block_size,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py	2025-06-03 04:51:20.323103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_DebuggerConfig.py	2025-06-03 04:51:43.069284+00:00
@@ -1,13 +1,14 @@
from typing import Any, List, Optional, Dict

import logging
from dataclasses import dataclass, field

+
@dataclass
class DebuggerConfig:
-    log_level: int = logging.getLevelName('DEBUG')
+    log_level: int = logging.getLevelName("DEBUG")
    capture_fx_graph_before: List[str] = field(default_factory=lambda: [])
    capture_fx_graph_after: List[str] = field(default_factory=lambda: [])
    save_engine_profile: bool = False
    engine_builder_monitor: bool = True
    break_in_remove_assert_nodes: bool = True
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py	2025-06-03 04:51:20.323103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_Debugger.py	2025-06-03 04:51:43.098669+00:00
@@ -11,11 +11,14 @@
from logging.config import dictConfig

import torch
import torch_tensorrt
from torch_tensorrt.dynamo._DebuggerConfig import DebuggerConfig
-from torch_tensorrt.dynamo._supports_debugger import _DEBUG_ENABLED_CLS, _DEBUG_ENABLED_FUNCS
+from torch_tensorrt.dynamo._supports_debugger import (
+    _DEBUG_ENABLED_CLS,
+    _DEBUG_ENABLED_FUNCS,
+)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold
from torch_tensorrt.dynamo.lowering import (
    ATEN_POST_LOWERING_PASSES,
    ATEN_PRE_LOWERING_PASSES,
@@ -33,46 +36,52 @@
        self.cfg = DebuggerConfig(**kwargs)

    def __enter__(self) -> None:
        self.original_lvl = _LOGGER.getEffectiveLevel()
        self.rt_level = torch.ops.tensorrt.get_logging_level()
-        #dictConfig(self.get_config())
+        # dictConfig(self.get_config())

        self.old_pre_passes, self.old_post_passes = (
            ATEN_PRE_LOWERING_PASSES.passes,
            ATEN_POST_LOWERING_PASSES.passes,
        )
        pre_pass_names = [p.__name__ for p in self.old_pre_passes]
        post_pass_names = [p.__name__ for p in self.old_post_passes]
-        #path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
+        # path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")

-        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
-        ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(self.cfg.capture_fx_graph_before)
+        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(
+            self.cfg.capture_fx_graph_before
+        )
+        ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
+            self.cfg.capture_fx_graph_before
+        )

-        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
-        ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(self.cfg.capture_fx_graph_after)
+        ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(
+            self.cfg.capture_fx_graph_after
+        )
+        ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(
+            self.cfg.capture_fx_graph_after
+        )

        self._context_stack = contextlib.ExitStack()

        for f in _DEBUG_ENABLED_FUNCS:
            f.__kwdefaults__["_debugger_settings"] = self.cfg

        [
            self._context_stack.enter_context(
                mock.patch.object(
                    c,
-                    '__init__',
-                    functools.partialmethod(
-                        c.__init__,
-                        _debugger_settings=self.cfg
-                    )
+                    "__init__",
+                    functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
                )
-            ) for c in _DEBUG_ENABLED_CLS
+            )
+            for c in _DEBUG_ENABLED_CLS
        ]

    def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
-        #dictConfig(self.get_default_config())
+        # dictConfig(self.get_default_config())
        torch.ops.tensorrt.set_logging_level(self.rt_level)

        ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
            self.old_pre_passes,
            self.old_post_passes,
@@ -80,12 +89,10 @@

        for f in _DEBUG_ENABLED_FUNCS:
            f.__kwdefaults__["_debugger_settings"] = None

        self._context_stack.close()
-
-

    # def get_config(self) -> dict[str, Any]:
    #     config = {
    #         "version": 1,
    #         "disable_existing_loggers": False,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py	2025-06-03 04:51:20.323103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_supports_debugger.py	2025-06-03 04:51:43.226752+00:00
@@ -4,12 +4,14 @@


_DEBUG_ENABLED_FUNCS = []
_DEBUG_ENABLED_CLS = []

+
def fn_supports_debugger(func):
    _DEBUG_ENABLED_FUNCS.append(func)
    return func

-def cls_supports_debugger(cls:  Type[T]) ->  Type[T]:
+
+def cls_supports_debugger(cls: Type[T]) -> Type[T]:
    _DEBUG_ENABLED_CLS.append(cls)
    return cls
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py	2025-06-03 04:51:20.324103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py	2025-06-03 04:51:43.415699+00:00
@@ -90,10 +90,11 @@
                clear_line()
                print()
            move_cursor_up(len(self._active_phases) + blank_lines)
            sys.stdout.flush()

+
# try:
#     from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn

#     class _RichMonitor(trt.IProgressMonitor):  # type: ignore
#         def __init__(self, engine_name: str = "") -> None:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-06-03 04:51:20.324103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-06-03 04:51:44.112858+00:00
@@ -69,10 +69,11 @@
    input_names: Sequence[str]
    output_names: Sequence[str]
    weight_name_map: Optional[dict[Any, Any]]
    requires_output_allocator: bool

+
@cls_supports_debugger
class TRTInterpreter(torch.fx.Interpreter):  # type: ignore[misc]
    def __init__(
        self,
        module: torch.fx.GraphModule,
@@ -207,11 +208,14 @@
        algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
        tactic_sources: Optional[int] = None,
    ) -> trt.IBuilderConfig:
        builder_config = self.builder.create_builder_config()

-        if self._debugger_settings is not None and self._debugger_settings.engine_builder_monitor:
+        if (
+            self._debugger_settings is not None
+            and self._debugger_settings.engine_builder_monitor
+        ):
            builder_config.progress_monitor = TRTBulderMonitor()

        if self.compilation_settings.workspace_size != 0:
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py	2025-06-03 04:51:20.328103+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py	2025-06-03 04:51:44.527302+00:00
@@ -9,16 +9,23 @@
    clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)

+
@fn_supports_debugger
def remove_assert_nodes(
-    gm: torch.fx.GraphModule, settings: CompilationSettings, *, _debugger_settings: Optional[DebuggerConfig]=None
+    gm: torch.fx.GraphModule,
+    settings: CompilationSettings,
+    *,
+    _debugger_settings: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
    """Remove assert_scalar ops in the graph"""
-    if _debugger_settings is not None and _debugger_settings.break_in_remove_assert_nodes:
+    if (
+        _debugger_settings is not None
+        and _debugger_settings.break_in_remove_assert_nodes
+    ):
        breakpoint()
    count = 0
    for node in gm.graph.nodes:
        if (
            node.target == torch.ops.aten._assert_scalar.default

@narendasan narendasan closed this Jun 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants