Skip to content

🐛 [Bug] torch.Tensor.std doesn't support multiple dimensions #1182

Closed
@feldman-cortica

Description

@feldman-cortica

Bug Description

torch.std doesn't support multiple dimensions while torch.mean works ok

To Reproduce

convert.py

import torch_tensorrt
import torch
import torch.nn as nn

class F1(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        y = X.mean(dim=(2,3), keepdim=True)
        return X - y

class F2(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        y = X.std(dim=(2,3), keepdim=True)
        return X - y

f1 = F1()
f2 = F2()

q = torch.rand(1,1,512,512, device='cuda:0')
f1_jit = torch.jit.trace(f1, q)
f2_jit = torch.jit.trace(f2, q)

f1_trt = torch_tensorrt.compile(f1_jit,
    inputs = [
        torch_tensorrt.Input( 
            shape=[1, 1, 512, 512],
            dtype=torch.float) 
    ],
    enabled_precisions = {torch.float})

f2_trt = torch_tensorrt.compile(f2_jit,
    inputs = [
        torch_tensorrt.Input( 
            shape=[1, 1, 512, 512],
            dtype=torch.float) 
    ],
    enabled_precisions = {torch.float})

Output:

root@3f9bb1ccc8aa:/app# python convert.py 
WARNING: [Torch-TensorRT] - Mean converter disregards dtype
WARNING: [Torch-TensorRT] - Mean converter disregards dtype
WARNING: [Torch-TensorRT] - Mean converter disregards dtype
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
Traceback (most recent call last):
  File "convert2.py", line 36, in <module>
    f2_trt = torch_tensorrt.compile(f2_jit,
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:38] Expected ivalue->isInt() to be true but got false
Requested unwrapping of arg IValue assuming it was l however type is (int, int)

Expected behavior

torch.std should behave in the same manner as torch.mean

Environment

Docker nvcr.io/nvidia/pytorch:22.06-py3 (specs)

Metadata

Metadata

Labels

bugSomething isn't workingcomponent: convertersIssues re: Specific op convertersfeature requestNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions