Skip to content

Support pt1.8 by removing OpMathType dependencies #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def dcnv3_core_pytorch(
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
flatten(0, 1).to(input_.dtype)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_forward_cuda", ([&] {
dcnv3_im2col_cuda(
dcnv3_im2col_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
offset.data<scalar_t>() +
Expand Down Expand Up @@ -124,9 +124,6 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
channels, group * group_channels);

auto dtype = input.dtype();
if (dtype == at::kHalf) {
dtype = at::kFloat;
}

auto grad_input = at::zeros_like(input, dtype);
auto grad_offset = at::zeros_like(offset, dtype);
Expand All @@ -146,7 +143,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_backward_cuda", ([&] {
dcnv3_col2im_cuda(
dcnv3_col2im_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <cstring>

#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>

Expand All @@ -27,7 +26,7 @@ inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}

#define opmath_t at::opmath_type<scalar_t>
#define opmath_t scalar_t

template <typename scalar_t>
__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def dcnv3_core_pytorch(
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
flatten(0, 1).to(input_.dtype)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_forward_cuda", ([&] {
dcnv3_im2col_cuda(
dcnv3_im2col_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
offset.data<scalar_t>() +
Expand Down Expand Up @@ -124,9 +124,6 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
channels, group * group_channels);

auto dtype = input.dtype();
if (dtype == at::kHalf) {
dtype = at::kFloat;
}

auto grad_input = at::zeros_like(input, dtype);
auto grad_offset = at::zeros_like(offset, dtype);
Expand All @@ -146,7 +143,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_backward_cuda", ([&] {
dcnv3_col2im_cuda(
dcnv3_col2im_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <cstring>

#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>

Expand All @@ -27,7 +26,7 @@ inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}

#define opmath_t at::opmath_type<scalar_t>
#define opmath_t scalar_t

template <typename scalar_t>
__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data,
Expand Down
2 changes: 1 addition & 1 deletion classification/ops_dcnv3/functions/dcnv3_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def dcnv3_core_pytorch(
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
flatten(0, 1).to(input_.dtype)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
Expand Down
7 changes: 2 additions & 5 deletions classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_forward_cuda", ([&] {
dcnv3_im2col_cuda(
dcnv3_im2col_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
offset.data<scalar_t>() +
Expand Down Expand Up @@ -124,9 +124,6 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
channels, group * group_channels);

auto dtype = input.dtype();
if (dtype == at::kHalf) {
dtype = at::kFloat;
}

auto grad_input = at::zeros_like(input, dtype);
auto grad_offset = at::zeros_like(offset, dtype);
Expand All @@ -146,7 +143,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_backward_cuda", ([&] {
dcnv3_col2im_cuda(
dcnv3_col2im_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
Expand Down
3 changes: 1 addition & 2 deletions classification/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <cstring>

#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>

Expand All @@ -27,7 +26,7 @@ inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}

#define opmath_t at::opmath_type<scalar_t>
#define opmath_t scalar_t

template <typename scalar_t>
__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data,
Expand Down
95 changes: 95 additions & 0 deletions classification/ops_dcnv3/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,36 @@ def check_forward_equal_with_pytorch_double():
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


@torch.no_grad()
def check_forward_equal_with_pytorch_half():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)

output_pytorch = dcnv3_core_pytorch(
input.half(),
offset.half(),
mask.half(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()

im2col_step = 2
output_cuda = DCNv3Function.apply(
input.half(),
offset.half(),
mask.half(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center).detach().cpu()

fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
print('>>> forward half')
print(f'* {fwdok} check_forward_equal_with_pytorch_half: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


@torch.no_grad()
def check_forward_equal_with_pytorch_float():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
Expand Down Expand Up @@ -154,6 +184,68 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


def check_backward_equal_with_pytorch_half(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
# H_in, W_in = 4, 4
N = 2
M = 2
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P)
input0.requires_grad = grad_input
offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask

output_pytorch = dcnv3_core_pytorch(
input0.half(),
offset0.half(),
mask0.half(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
output_pytorch.sum().backward()

input1 = input0.detach()
offset1 = offset0.detach()
mask1 = mask0.detach()
input1.requires_grad = grad_input
offset1.requires_grad = grad_offset
mask1.requires_grad = grad_mask

im2col_step = 2
output_cuda = DCNv3Function.apply(
input1.half(),
offset1.half(),
mask1.half(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center)
output_cuda.sum().backward()

print(f'>>> backward half: channels {D}')
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (input0.grad - input1.grad).abs().max()
max_rel_err = ((input0.grad - input1.grad).abs() /
input0.grad.abs()).max()
print(
f'* {bwdok} input_grad check_backward_equal_with_pytorch_half: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (offset0.grad - offset1.grad).abs().max()
max_rel_err = ((offset0.grad - offset1.grad).abs() /
offset0.grad.abs()).max()
print(
f'* {bwdok} offset_grad check_backward_equal_with_pytorch_half: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (mask0.grad - mask1.grad).abs().max()
max_rel_err = ((mask0.grad - mask1.grad).abs() /
mask0.grad.abs()).max()
print(
f'* {bwdok} mask_grad check_backward_equal_with_pytorch_half: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
# H_in, W_in = 4, 4
N = 2
Expand Down Expand Up @@ -254,9 +346,12 @@ def check_time_cost(im2col_step=128):

if __name__ == '__main__':
check_forward_equal_with_pytorch_double()
check_forward_equal_with_pytorch_half()
check_forward_equal_with_pytorch_float()
for channels in [1, 16, 30, 32, 64, 71, 1025]:
check_backward_equal_with_pytorch_double(channels, True, True, True)
for channels in [1, 16, 30, 32, 64, 71, 1025]:
check_backward_equal_with_pytorch_half(channels, True, True, True)
for channels in [1, 16, 30, 32, 64, 71, 1025]:
check_backward_equal_with_pytorch_float(channels, True, True, True)
for i in range(3):
Expand Down
2 changes: 1 addition & 1 deletion detection/ops_dcnv3/functions/dcnv3_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def dcnv3_core_pytorch(
reshape(N_*group, group_channels, H_in, W_in)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
flatten(0, 1)
flatten(0, 1).to(input_.dtype)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
Expand Down
7 changes: 2 additions & 5 deletions detection/ops_dcnv3/src/cuda/dcnv3_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_forward_cuda", ([&] {
dcnv3_im2col_cuda(
dcnv3_im2col_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
offset.data<scalar_t>() +
Expand Down Expand Up @@ -124,9 +124,6 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
channels, group * group_channels);

auto dtype = input.dtype();
if (dtype == at::kHalf) {
dtype = at::kFloat;
}

auto grad_input = at::zeros_like(input, dtype);
auto grad_offset = at::zeros_like(offset, dtype);
Expand All @@ -146,7 +143,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_backward_cuda", ([&] {
dcnv3_col2im_cuda(
dcnv3_col2im_cuda<opmath_t>(
at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
Expand Down
3 changes: 1 addition & 2 deletions detection/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <cstring>

#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>

Expand All @@ -27,7 +26,7 @@ inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}

#define opmath_t at::opmath_type<scalar_t>
#define opmath_t scalar_t

template <typename scalar_t>
__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data,
Expand Down
Loading