diff --git a/benchmarks/ops/bench_convolution.py b/benchmarks/ops/bench_convolution.py index 37e3c27e7..c7fa4d945 100644 --- a/benchmarks/ops/bench_convolution.py +++ b/benchmarks/ops/bench_convolution.py @@ -139,6 +139,7 @@ def __init__( stride: tuple[int, int], padding: tuple[int, int], dilation: tuple[int, int], + groups: int, dtype: torch.dtype, ) -> None: self.n = n @@ -150,12 +151,13 @@ def __init__( self.stride = stride self.padding = padding self.dilation = dilation + self.groups = groups self.dtype = dtype def gen_inputs(self) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: x = torch.randn(self.n, self.c_in, self.h, self.w, device="cuda", dtype=self.dtype).contiguous() weight = torch.randn( - self.c_out, self.c_in, self.kernel_size[0], self.kernel_size[1], + self.c_out, self.c_in // self.groups, self.kernel_size[0], self.kernel_size[1], device="cuda", dtype=self.dtype, ).contiguous() bias = torch.zeros(self.c_out, device="cuda", dtype=self.dtype).contiguous() @@ -174,7 +176,7 @@ def ref_program( stride=self.stride, padding=self.padding, dilation=self.dilation, - groups=1, + groups=self.groups, ) class Conv2dBenchmark(BenchmarkBase[Conv2dBenchCase]): @@ -183,41 +185,47 @@ def calculate_flops(self) -> Optional[float]: t = self.workload out_h = (t.h + 2 * t.padding[0] - t.dilation[0] * (t.kernel_size[0] - 1) - 1) // t.stride[0] + 1 out_w = (t.w + 2 * t.padding[1] - t.dilation[1] * (t.kernel_size[1] - 1) - 1) // t.stride[1] + 1 - return 2.0 * t.n * t.c_out * out_h * out_w * t.c_in * t.kernel_size[0] * t.kernel_size[1] + c_in_g = t.c_in // t.groups + return 2.0 * t.n * t.c_out * out_h * out_w * c_in_g * t.kernel_size[0] * t.kernel_size[1] def calculate_memory(self) -> Optional[float]: t = self.workload out_h = (t.h + 2 * t.padding[0] - t.dilation[0] * (t.kernel_size[0] - 1) - 1) // t.stride[0] + 1 out_w = (t.w + 2 * t.padding[1] - t.dilation[1] * (t.kernel_size[1] - 1) - 1) // t.stride[1] + 1 + c_in_g = t.c_in // t.groups bytes_ = ( t.n * t.c_in * t.h * t.w - + t.c_out * t.c_in * t.kernel_size[0] * t.kernel_size[1] + + t.c_out * c_in_g * t.kernel_size[0] * t.kernel_size[1] + t.n * t.c_out * out_h * out_w ) * t.dtype.itemsize return bytes_ _CONV2D_BENCH_PARAMS = [ - pytest.param(2, 64, 56, 56, 64, (3, 3), (1, 1), (1, 1), (1, 1), torch.float16, True, id="resnet-3x3-fp16"), - pytest.param(1, 3, 112, 112, 64, (3, 3), (2, 2), (1, 1), (1, 1), torch.float16, True, id="stem-3x3-s2-fp16"), - pytest.param(1, 128, 56, 56, 256, (3, 3), (2, 2), (1, 1), (1, 1), torch.float16, True, id="stage-transition-3x3-s2-fp16"), - pytest.param(1, 256, 112, 112, 512, (3, 3), (1, 1), (1, 1), (1, 1), torch.float16, True, id="highres-3x3-s1-fp16"), - pytest.param(1, 64, 56, 56, 128, (5, 5), (1, 1), (2, 2), (1, 1), torch.float16, True, id="midres-5x5-s1-fp16"), - pytest.param(1, 128, 56, 56, 256, (5, 5), (2, 2), (2, 2), (1, 1), torch.float16, True, id="stage-transition-5x5-s2-fp16"), - pytest.param(1, 128, 28, 28, 128, (3, 3), (2, 2), (1, 1), (1, 1), torch.bfloat16, True, id="stride2-bf16"), - pytest.param(2, 64, 56, 56, 256, (1, 1), (1, 1), (0, 0), (1, 1), torch.float16, True, id="resnet-1x1-fp16"), - pytest.param(2, 128, 28, 28, 512, (1, 1), (1, 1), (0, 0), (1, 1), torch.float16, True, id="bottleneck-expand-1x1-fp16"), - pytest.param(2, 512, 28, 28, 128, (1, 1), (1, 1), (0, 0), (1, 1), torch.float16, True, id="bottleneck-reduce-1x1-fp16"), - pytest.param(1, 256, 14, 14, 1024, (1, 1), (1, 1), (0, 0), (1, 1), torch.float16, True, id="late-stage-1x1-fp16"), - pytest.param(1, 512, 7, 7, 2048, (1, 1), (1, 1), (0, 0), (1, 1), torch.float16, True, id="classifier-1x1-fp16"), - pytest.param(2, 64, 56, 56, 256, (1, 1), (1, 1), (0, 0), (1, 1), torch.bfloat16, True, id="resnet-1x1-bf16"), + pytest.param(2, 64, 56, 56, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, torch.float16, True, id="resnet-3x3-fp16"), + pytest.param(1, 3, 112, 112, 64, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.float16, True, id="stem-3x3-s2-fp16"), + pytest.param(1, 128, 56, 56, 256, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.float16, True, id="stage-transition-3x3-s2-fp16"), + pytest.param(1, 256, 112, 112, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, torch.float16, True, id="highres-3x3-s1-fp16"), + pytest.param(1, 64, 56, 56, 128, (5, 5), (1, 1), (2, 2), (1, 1), 1, torch.float16, True, id="midres-5x5-s1-fp16"), + pytest.param(1, 128, 56, 56, 256, (5, 5), (2, 2), (2, 2), (1, 1), 1, torch.float16, True, id="stage-transition-5x5-s2-fp16"), + pytest.param(1, 128, 28, 28, 128, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.bfloat16, True, id="stride2-bf16"), + pytest.param(2, 64, 56, 56, 256, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.float16, True, id="resnet-1x1-fp16"), + pytest.param(2, 128, 28, 28, 512, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.float16, True, id="bottleneck-expand-1x1-fp16"), + pytest.param(2, 512, 28, 28, 128, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.float16, True, id="bottleneck-reduce-1x1-fp16"), + pytest.param(1, 256, 14, 14, 1024, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.float16, True, id="late-stage-1x1-fp16"), + pytest.param(1, 512, 7, 7, 2048, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.float16, True, id="classifier-1x1-fp16"), + pytest.param(2, 64, 56, 56, 256, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.bfloat16, True, id="resnet-1x1-bf16"), # DeepLabV3/DeepLabV3+ ASPP branch: 3x3 atrous conv on stride-16 encoder features. - pytest.param(1, 2048, 32, 32, 256, (3, 3), (1, 1), (12, 12), (12, 12), torch.float16, True, id="deeplabv3-aspp-3x3-rate12-fp16"), + pytest.param(1, 2048, 32, 32, 256, (3, 3), (1, 1), (12, 12), (12, 12), 1, torch.float16, True, id="deeplabv3-aspp-3x3-rate12-fp16"), + # MobileNetV2 inverted residual depthwise 3x3 convolution. + pytest.param(1, 32, 56, 56, 32, (3, 3), (1, 1), (1, 1), (1, 1), 32, torch.float16, True, id="mobilenetv2-depthwise-fp16"), + # ResNeXt bottleneck grouped 3x3 convolution. + pytest.param(1, 128, 28, 28, 256, (3, 3), (1, 1), (1, 1), (1, 1), 32, torch.float16, True, id="resnext-grouped-3x3-fp16"), ] @pytest.mark.parametrize( - "n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, dtype, tune", + "n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype, tune", _CONV2D_BENCH_PARAMS, ) def test_conv2d_bench( @@ -230,10 +238,11 @@ def test_conv2d_bench( stride: tuple[int, int], padding: tuple[int, int], dilation: tuple[int, int], + groups: int, dtype: torch.dtype, tune: bool, ) -> None: - test = Conv2dBenchCase(n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, dtype) + test = Conv2dBenchCase(n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype) bm = Conv2dBenchmark(test) inputs = test.gen_inputs() x, weight, bias = inputs @@ -248,6 +257,7 @@ def test_conv2d_bench( stride=stride, padding=padding, dilation=dilation, + groups=groups, dtype=dtype, tune=tune, ) @@ -272,6 +282,7 @@ def __init__( stride: tuple[int, int, int], padding: tuple[int, int, int], dilation: tuple[int, int, int], + groups: int, dtype: torch.dtype, ) -> None: self.n = n @@ -284,6 +295,7 @@ def __init__( self.stride = stride self.padding = padding self.dilation = dilation + self.groups = groups self.dtype = dtype def gen_inputs(self) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: @@ -292,7 +304,11 @@ def gen_inputs(self) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor] device="cuda", dtype=self.dtype, ).contiguous() weight = torch.randn( - self.c_out, self.c_in, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], + self.c_out, + self.c_in // self.groups, + self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2], device="cuda", dtype=self.dtype, ).contiguous() bias = torch.zeros(self.c_out, device="cuda", dtype=self.dtype).contiguous() @@ -311,7 +327,7 @@ def ref_program( stride=self.stride, padding=self.padding, dilation=self.dilation, - groups=1, + groups=self.groups, ) class Conv3dBenchmark(BenchmarkBase[Conv3dBenchCase]): @@ -321,32 +337,38 @@ def calculate_flops(self) -> Optional[float]: out_d = (t.d + 2 * t.padding[0] - t.dilation[0] * (t.kernel_size[0] - 1) - 1) // t.stride[0] + 1 out_h = (t.h + 2 * t.padding[1] - t.dilation[1] * (t.kernel_size[1] - 1) - 1) // t.stride[1] + 1 out_w = (t.w + 2 * t.padding[2] - t.dilation[2] * (t.kernel_size[2] - 1) - 1) // t.stride[2] + 1 - return 2.0 * t.n * t.c_out * out_d * out_h * out_w * t.c_in * t.kernel_size[0] * t.kernel_size[1] * t.kernel_size[2] + c_in_g = t.c_in // t.groups + return 2.0 * t.n * t.c_out * out_d * out_h * out_w * c_in_g * t.kernel_size[0] * t.kernel_size[1] * t.kernel_size[2] def calculate_memory(self) -> Optional[float]: t = self.workload out_d = (t.d + 2 * t.padding[0] - t.dilation[0] * (t.kernel_size[0] - 1) - 1) // t.stride[0] + 1 out_h = (t.h + 2 * t.padding[1] - t.dilation[1] * (t.kernel_size[1] - 1) - 1) // t.stride[1] + 1 out_w = (t.w + 2 * t.padding[2] - t.dilation[2] * (t.kernel_size[2] - 1) - 1) // t.stride[2] + 1 + c_in_g = t.c_in // t.groups bytes_ = ( t.n * t.c_in * t.d * t.h * t.w - + t.c_out * t.c_in * t.kernel_size[0] * t.kernel_size[1] * t.kernel_size[2] + + t.c_out * c_in_g * t.kernel_size[0] * t.kernel_size[1] * t.kernel_size[2] + t.n * t.c_out * out_d * out_h * out_w ) * t.dtype.itemsize return bytes_ _CONV3D_BENCH_PARAMS = [ - pytest.param(1, 3, 16, 112, 112, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), torch.float16, True, id="r3d-stem-k3-s1-fp16"), - pytest.param(1, 64, 8, 56, 56, 128, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), torch.float16, True, id="video-stage-downsample-k3-s2-fp16"), - pytest.param(1, 32, 32, 64, 64, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), torch.bfloat16, True, id="unet-encoder-k3-s1-bf16"), + pytest.param(1, 3, 16, 112, 112, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 1, torch.float16, True, id="r3d-stem-k3-s1-fp16"), + pytest.param(1, 64, 8, 56, 56, 128, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), 1, torch.float16, True, id="video-stage-downsample-k3-s2-fp16"), + pytest.param(1, 32, 32, 64, 64, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 1, torch.bfloat16, True, id="unet-encoder-k3-s1-bf16"), # 3D U-Net + 3D ASPP medical segmentation branch: 3x3x3 atrous conv on low-resolution volume features. - pytest.param(1, 256, 8, 16, 16, 256, (3, 3, 3), (1, 1, 1), (6, 6, 6), (6, 6, 6), torch.float16, True, id="3d-unet-aspp-3x3x3-rate6-fp16"), + pytest.param(1, 256, 8, 16, 16, 256, (3, 3, 3), (1, 1, 1), (6, 6, 6), (6, 6, 6), 1, torch.float16, True, id="3d-unet-aspp-3x3x3-rate6-fp16"), + # 3D-ResNeXt/video backbone grouped 3x3x3 convolution. + pytest.param(1, 64, 8, 28, 28, 128, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 32, torch.float16, True, id="3d-resnext-grouped-k3-fp16"), + # 3D-ResNeXt/video backbone grouped 3x3x3 convolution, larger batch for throughput. + pytest.param(8, 64, 8, 28, 28, 128, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 32, torch.float16, True, id="3d-resnext-grouped-k3-b8-fp16"), ] @pytest.mark.parametrize( - "n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, dtype, tune", + "n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype, tune", _CONV3D_BENCH_PARAMS, ) def test_conv3d_bench( @@ -360,10 +382,11 @@ def test_conv3d_bench( stride: tuple[int, int, int], padding: tuple[int, int, int], dilation: tuple[int, int, int], + groups: int, dtype: torch.dtype, tune: bool, ) -> None: - test = Conv3dBenchCase(n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, dtype) + test = Conv3dBenchCase(n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype) bm = Conv3dBenchmark(test) inputs = test.gen_inputs() x, weight, bias = inputs @@ -379,6 +402,7 @@ def test_conv3d_bench( stride=stride, padding=padding, dilation=dilation, + groups=groups, dtype=dtype, tune=tune, ) diff --git a/tests/ops/test_convolution.py b/tests/ops/test_convolution.py index a8bba7caf..094c3e451 100644 --- a/tests/ops/test_convolution.py +++ b/tests/ops/test_convolution.py @@ -12,6 +12,8 @@ Conv2dKernel, Conv3dKernel, GroupConv1dKernel, + GroupConv2dKernel, + GroupConv3dKernel, ) from tileops.ops import ( Conv1dBiasFwdOp, @@ -324,67 +326,79 @@ def test_conv1d_dispatches_kernel( class Conv2dFixture(FixtureBase): PARAMS = [ - ("n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, dtype, tune", [ + ("n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype, tune", [ pytest.param( - 2, 32, 32, 32, 64, (3, 3), (1, 1), (1, 1), (1, 1), torch.float16, False, + 2, 32, 32, 32, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, torch.float16, False, marks=pytest.mark.smoke, id="smoke-fp16-3x3", ), pytest.param( - 2, 32, 32, 32, 64, (3, 3), (1, 1), (1, 1), (1, 1), torch.bfloat16, False, + 2, 32, 32, 32, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, torch.bfloat16, False, marks=pytest.mark.smoke, id="smoke-bf16-3x3", ), + # MobileNetV2 depthwise 3x3 block, reduced spatial size for smoke cost. pytest.param( - 1, 3, 112, 112, 64, (3, 3), (2, 2), (1, 1), (1, 1), torch.float16, False, + 1, 16, 16, 16, 16, (3, 3), (1, 1), (1, 1), (1, 1), 16, torch.float16, False, + marks=pytest.mark.smoke, + id="smoke-mobilenetv2-depthwise-small-fp16", + ), + pytest.param( + 1, 3, 112, 112, 64, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-stem-3x3-s2-fp16", ), pytest.param( - 1, 64, 56, 56, 64, (3, 3), (1, 1), (1, 1), (1, 1), torch.float16, False, + 1, 64, 56, 56, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-resblock-3x3-s1-fp16", ), pytest.param( - 1, 128, 56, 56, 256, (3, 3), (2, 2), (1, 1), (1, 1), torch.float16, False, + 1, 128, 56, 56, 256, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-stage-transition-3x3-s2-fp16", ), pytest.param( - 1, 32, 28, 28, 64, (5, 5), (1, 1), (2, 2), (1, 1), torch.float16, False, + 1, 32, 28, 28, 64, (5, 5), (1, 1), (2, 2), (1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-small-5x5-s1-fp16", ), pytest.param( - 1, 64, 28, 28, 128, (5, 5), (2, 2), (2, 2), (1, 1), torch.float16, False, + 1, 64, 28, 28, 128, (5, 5), (2, 2), (2, 2), (1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-small-5x5-s2-fp16", ), pytest.param( - 2, 32, 32, 32, 64, (1, 1), (1, 1), (0, 0), (1, 1), torch.float16, True, + 2, 32, 32, 32, 64, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.float16, True, marks=pytest.mark.full, id="full-fp16-1x1-tuned", ), pytest.param( - 1, 64, 28, 28, 128, (3, 3), (2, 2), (1, 1), (1, 1), torch.float16, False, + 1, 64, 28, 28, 128, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-fp16-stride2", ), pytest.param( - 1, 64, 56, 56, 128, (3, 3), (2, 2), (1, 1), (1, 1), torch.bfloat16, False, + 1, 64, 56, 56, 128, (3, 3), (2, 2), (1, 1), (1, 1), 1, torch.bfloat16, False, marks=pytest.mark.full, id="full-bf16-3x3-s2", ), pytest.param( - 1, 64, 28, 28, 64, (1, 1), (1, 1), (0, 0), (1, 1), torch.bfloat16, False, + 1, 64, 28, 28, 64, (1, 1), (1, 1), (0, 0), (1, 1), 1, torch.bfloat16, False, marks=pytest.mark.full, id="full-bf16-1x1", ), pytest.param( - 1, 64, 32, 32, 128, (3, 3), (1, 1), (2, 2), (2, 2), torch.float16, False, + 1, 64, 32, 32, 128, (3, 3), (1, 1), (2, 2), (2, 2), 1, torch.float16, False, marks=pytest.mark.full, id="full-deeplab-aspp-3x3-d2-fp16", ), + # ResNeXt bottleneck grouped 3x3 convolution. + pytest.param( + 1, 128, 28, 28, 256, (3, 3), (1, 1), (1, 1), (1, 1), 32, torch.float16, False, + marks=pytest.mark.full, + id="full-resnext-grouped-3x3-fp16", + ), ]), ] @@ -402,6 +416,7 @@ def __init__( stride: tuple[int, int], padding: tuple[int, int], dilation: tuple[int, int], + groups: int, dtype: torch.dtype, ) -> None: self.n = n @@ -413,12 +428,13 @@ def __init__( self.stride = stride self.padding = padding self.dilation = dilation + self.groups = groups self.dtype = dtype def gen_inputs(self) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: x = torch.randn(self.n, self.c_in, self.h, self.w, device="cuda", dtype=self.dtype).contiguous() weight = torch.randn( - self.c_out, self.c_in, self.kernel_size[0], self.kernel_size[1], + self.c_out, self.c_in // self.groups, self.kernel_size[0], self.kernel_size[1], device="cuda", dtype=self.dtype, ).contiguous() bias = torch.zeros(self.c_out, device="cuda", dtype=self.dtype).contiguous() @@ -437,7 +453,7 @@ def ref_program( stride=self.stride, padding=self.padding, dilation=self.dilation, - groups=1, + groups=self.groups, ) return out.contiguous() @@ -453,10 +469,11 @@ def test_conv2d( stride: tuple[int, int], padding: tuple[int, int], dilation: tuple[int, int], + groups: int, dtype: torch.dtype, tune: bool, ) -> None: - test = Conv2dTest(n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, dtype) + test = Conv2dTest(n, c_in, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype) op = Conv2dBiasFwdOp( n=n, c_in=c_in, @@ -467,9 +484,12 @@ def test_conv2d( stride=stride, padding=padding, dilation=dilation, + groups=groups, dtype=dtype, tune=tune, ) + if groups > 1: + assert isinstance(op.kernel, GroupConv2dKernel) atol, rtol = ((1e-3, 1e-3) if dtype == torch.float16 else (1.6e-2, 1.6e-2)) test.check(op, *test.gen_inputs(), atol=atol, rtol=rtol) @@ -502,6 +522,26 @@ def test_conv2d_no_bias_matches_torch() -> None: torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) +@pytest.mark.smoke +def test_conv2d_no_bias_grouped_matches_torch() -> None: + groups = 8 + op = Conv2dFwdOp( + n=1, + c_in=16, + h=16, + w=16, + c_out=32, + kernel_size=3, + padding=1, + groups=groups, + ) + x = torch.randn(1, 16, 16, 16, device="cuda", dtype=torch.float16).contiguous() + weight = torch.randn(32, 2, 3, 3, device="cuda", dtype=torch.float16).contiguous() + out = op(x, weight) + ref = F.conv2d(x, weight, bias=None, padding=1, groups=groups).contiguous() + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) + + @pytest.mark.smoke def test_conv2d_dispatches_1x1_kernel() -> None: op = Conv2dFwdOp( @@ -559,37 +599,49 @@ def test_conv2d_dispatches_5x5_kernel() -> None: class Conv3dFixture(FixtureBase): PARAMS = [ - ("n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, dtype, tune", [ + ("n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype, tune", [ pytest.param( - 1, 16, 8, 32, 32, 32, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), torch.float16, False, + 1, 16, 8, 32, 32, 32, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 1, torch.float16, False, marks=pytest.mark.smoke, id="smoke-3d-unet-k3-s1-fp16", ), pytest.param( - 1, 16, 8, 32, 32, 32, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), torch.bfloat16, False, + 1, 16, 8, 32, 32, 32, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 1, torch.bfloat16, False, marks=pytest.mark.smoke, id="smoke-3d-unet-k3-s1-bf16", ), + # Video depthwise 3D block, reduced size for smoke cost. pytest.param( - 1, 3, 16, 112, 112, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), torch.float16, False, + 1, 8, 4, 12, 12, 8, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 8, torch.float16, False, + marks=pytest.mark.smoke, + id="smoke-video-depthwise3d-small-fp16", + ), + pytest.param( + 1, 3, 16, 112, 112, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-r3d-stem-k3-s1-fp16", ), pytest.param( - 1, 64, 8, 56, 56, 128, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), torch.float16, False, + 1, 64, 8, 56, 56, 128, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), 1, torch.float16, False, marks=pytest.mark.full, id="full-video-stage-downsample-k3-s2-fp16", ), pytest.param( - 1, 32, 32, 64, 64, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), torch.bfloat16, False, + 1, 32, 32, 64, 64, 64, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 1, torch.bfloat16, False, marks=pytest.mark.full, id="full-unet-encoder-k3-s1-bf16", ), pytest.param( - 1, 16, 8, 32, 32, 32, (3, 3, 3), (1, 1, 1), (2, 2, 2), (2, 2, 2), torch.float16, False, + 1, 16, 8, 32, 32, 32, (3, 3, 3), (1, 1, 1), (2, 2, 2), (2, 2, 2), 1, torch.float16, False, marks=pytest.mark.full, id="full-3d-aspp-3x3x3-d2-fp16", ), + # 3D-ResNeXt/video backbone grouped 3x3x3 convolution. + pytest.param( + 1, 64, 8, 28, 28, 128, (3, 3, 3), (1, 1, 1), (1, 1, 1), (1, 1, 1), 32, torch.float16, False, + marks=pytest.mark.full, + id="full-3d-resnext-grouped-k3-fp16", + ), ]), ] @@ -608,6 +660,7 @@ def __init__( stride: tuple[int, int, int], padding: tuple[int, int, int], dilation: tuple[int, int, int], + groups: int, dtype: torch.dtype, ) -> None: self.n = n @@ -620,6 +673,7 @@ def __init__( self.stride = stride self.padding = padding self.dilation = dilation + self.groups = groups self.dtype = dtype def gen_inputs(self) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: @@ -628,7 +682,11 @@ def gen_inputs(self) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor] device="cuda", dtype=self.dtype, ).contiguous() weight = torch.randn( - self.c_out, self.c_in, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], + self.c_out, + self.c_in // self.groups, + self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2], device="cuda", dtype=self.dtype, ).contiguous() bias = torch.zeros(self.c_out, device="cuda", dtype=self.dtype).contiguous() @@ -647,7 +705,7 @@ def ref_program( stride=self.stride, padding=self.padding, dilation=self.dilation, - groups=1, + groups=self.groups, ) return out.contiguous() @@ -664,10 +722,11 @@ def test_conv3d( stride: tuple[int, int, int], padding: tuple[int, int, int], dilation: tuple[int, int, int], + groups: int, dtype: torch.dtype, tune: bool, ) -> None: - test = Conv3dTest(n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, dtype) + test = Conv3dTest(n, c_in, d, h, w, c_out, kernel_size, stride, padding, dilation, groups, dtype) op = Conv3dBiasFwdOp( n=n, c_in=c_in, @@ -679,9 +738,12 @@ def test_conv3d( stride=stride, padding=padding, dilation=dilation, + groups=groups, dtype=dtype, tune=tune, ) + if groups > 1: + assert isinstance(op.kernel, GroupConv3dKernel) atol, rtol = ((1e-3, 1e-3) if dtype == torch.float16 else (1.6e-2, 1.6e-2)) test.check(op, *test.gen_inputs(), atol=atol, rtol=rtol) @@ -715,6 +777,27 @@ def test_conv3d_no_bias_matches_torch() -> None: torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) +@pytest.mark.smoke +def test_conv3d_no_bias_grouped_matches_torch() -> None: + groups = 4 + op = Conv3dFwdOp( + n=1, + c_in=8, + d=4, + h=12, + w=12, + c_out=16, + kernel_size=3, + padding=1, + groups=groups, + ) + x = torch.randn(1, 8, 4, 12, 12, device="cuda", dtype=torch.float16).contiguous() + weight = torch.randn(16, 2, 3, 3, 3, device="cuda", dtype=torch.float16).contiguous() + out = op(x, weight) + ref = F.conv3d(x, weight, bias=None, padding=1, groups=groups).contiguous() + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3) + + @pytest.mark.smoke def test_conv3d_accepts_zero_bias() -> None: op = Conv3dBiasFwdOp( diff --git a/tileops/kernels/__init__.py b/tileops/kernels/__init__.py index 7d73d4ec5..fbd0e8776 100644 --- a/tileops/kernels/__init__.py +++ b/tileops/kernels/__init__.py @@ -41,6 +41,8 @@ Conv2dKernel, Conv3dKernel, GroupConv1dKernel, + GroupConv2dKernel, + GroupConv3dKernel, ) from .deltanet import DeltaNetBwdKernel, DeltaNetFwdKernel from .deltanet_recurrence import DeltaNetDecodeFP32Kernel, DeltaNetDecodeKernel @@ -135,6 +137,8 @@ "GemmKernel", "GemvKernel", "GroupConv1dKernel", + "GroupConv2dKernel", + "GroupConv3dKernel", "GroupNormKernel", "GroupedGemmKernel", "Kernel", diff --git a/tileops/kernels/convolution.py b/tileops/kernels/convolution.py index bc50cbb8d..de8de4c81 100644 --- a/tileops/kernels/convolution.py +++ b/tileops/kernels/convolution.py @@ -16,6 +16,8 @@ "Conv2dKernel", "Conv3dKernel", "GroupConv1dKernel", + "GroupConv2dKernel", + "GroupConv3dKernel", ] @@ -1316,6 +1318,143 @@ def _conv2d_main( return _conv2d_func +@functools.lru_cache(maxsize=64) +def _conv2d_group_kernel( + n: int, + c_in: int, + h: int, + w: int, + c_out: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + has_bias: bool, + dtype: str = "float16", + groups: int = 1, + c_in_g: int = 0, + c_out_g: int = 0, +): + accum_dtype = "float" + out_h = (h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 + out_w = (w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 + out_hw = out_h * out_w + c_in_g = c_in_g if c_in_g > 0 else c_in // groups + c_out_g = c_out_g if c_out_g > 0 else c_out // groups + k_total = kernel_h * kernel_w * c_in_g + + @tilelang.jit( + out_idx=[2], + compile_flags=["-O3", "-DENABLE_BF16"], + pass_configs={"tl.enable_async_copy": False}, + ) + def _conv2d_group_func( + block_m: int, + block_n: int, + block_k: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + ): + @T.prim_func + def _conv2d_group_main( + x: T.Tensor((n, c_in, h, w), dtype), # type: ignore + weight: T.Tensor((c_out, c_in_g, kernel_h, kernel_w), dtype), # type: ignore + out: T.Tensor((n, c_out, out_h, out_w), dtype), # type: ignore + bias: T.Tensor((c_out,), dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(out_hw, block_n), + T.ceildiv(c_out_g, block_m), + n * groups, + threads=threads, + ) as (bx, by, bz): + weight_shared = T.alloc_shared((block_m, block_k), dtype) + data_shared = T.alloc_shared((block_k, block_n), dtype) + out_local = T.alloc_fragment((block_m, block_n), accum_dtype) + out_shared = T.alloc_shared((block_m, block_n), dtype) + + T.use_swizzle(10, enable=enable_rasterization) + T.clear(out_local) + + batch_id = bz // groups + group_id = bz % groups + + for k_iter in T.Pipelined(T.ceildiv(k_total, block_k), num_stages=num_stages): + for i, k in T.Parallel(block_m, block_k): + oc_g = by * block_m + i + oc = group_id * c_out_g + oc_g + k_idx = k_iter * block_k + k + ci_g = k_idx // (kernel_h * kernel_w) + kernel_idx = k_idx % (kernel_h * kernel_w) + kh = kernel_idx // kernel_w + kw = kernel_idx % kernel_w + weight_shared[i, k] = T.if_then_else( + (oc_g < c_out_g) & (k_idx < k_total), + weight[oc, ci_g, kh, kw], + T.cast(0.0, dtype), + ) + + for k, j in T.Parallel(block_k, block_n): + k_idx = k_iter * block_k + k + spatial_idx = bx * block_n + j + ci_g = k_idx // (kernel_h * kernel_w) + ci = group_id * c_in_g + ci_g + kernel_idx = k_idx % (kernel_h * kernel_w) + kh = kernel_idx // kernel_w + kw = kernel_idx % kernel_w + oh = spatial_idx // out_w + ow = spatial_idx % out_w + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + data_shared[k, j] = T.if_then_else( + (spatial_idx < out_hw) + & (k_idx < k_total) + & (ih >= 0) + & (iw >= 0) + & (ih < h) + & (iw < w), + x[batch_id, ci, ih, iw], + T.cast(0.0, dtype), + ) + + T.gemm(weight_shared, data_shared, out_local) + + for i, j in T.Parallel(block_m, block_n): + oc_g = by * block_m + i + oc = group_id * c_out_g + oc_g + spatial_idx = bx * block_n + j + if has_bias: + out_shared[i, j] = T.if_then_else( + (oc_g < c_out_g) & (spatial_idx < out_hw), + T.cast(out_local[i, j] + T.cast(bias[oc], accum_dtype), dtype), + T.cast(0.0, dtype), + ) + else: + out_shared[i, j] = T.if_then_else( + (oc_g < c_out_g) & (spatial_idx < out_hw), + T.cast(out_local[i, j], dtype), + T.cast(0.0, dtype), + ) + + for i, j in T.Parallel(block_m, block_n): + oc_g = by * block_m + i + oc = group_id * c_out_g + oc_g + spatial_idx = bx * block_n + j + oh = spatial_idx // out_w + ow = spatial_idx % out_w + if oc_g < c_out_g and spatial_idx < out_hw: + out[batch_id, oc, oh, ow] = out_shared[i, j] + + return _conv2d_group_main + + return _conv2d_group_func + + @torch.library.custom_op("top::conv2d_1x1_wrapped_kernel", mutates_args=()) def _conv2d_1x1_wrapped_kernel( n: int, @@ -1392,6 +1531,58 @@ def _conv2d_wrapped_kernel( )(block_m, block_n, block_k, num_stages, threads, enable_rasterization)(x, weight, bias) +@torch.library.custom_op("top::conv2d_group_wrapped_kernel", mutates_args=()) +def _conv2d_group_wrapped_kernel( + n: int, + c_in: int, + h: int, + w: int, + c_out: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + has_bias: bool, + dtype: str, + groups: int, + c_in_g: int, + c_out_g: int, + block_m: int, + block_n: int, + block_k: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return _conv2d_group_kernel( + n, + c_in, + h, + w, + c_out, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + has_bias, + dtype, + groups, + c_in_g, + c_out_g, + )(block_m, block_n, block_k, num_stages, threads, enable_rasterization)(x, weight, bias) + + @_conv2d_wrapped_kernel.register_fake def _( n: int, @@ -1421,6 +1612,41 @@ def _( out_w = (w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 return torch.empty((n, c_out, out_h, out_w), dtype=inputs[0].dtype, device=inputs[0].device) + +@_conv2d_group_wrapped_kernel.register_fake +def _( + n: int, + c_in: int, + h: int, + w: int, + c_out: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + has_bias: bool, + dtype: str, + groups: int, + c_in_g: int, + c_out_g: int, + block_m: int, + block_n: int, + block_k: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + *inputs: tuple[torch.Tensor, ...], +) -> torch.Tensor: + del groups, c_in_g, c_out_g + out_h = (h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 + out_w = (w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 + return torch.empty((n, c_out, out_h, out_w), dtype=inputs[0].dtype, device=inputs[0].device) + + class Conv2dKernel(Kernel): supported_archs: list[int] = [80, 86, 89, 90] @@ -1577,7 +1803,7 @@ def forward( ) -class Conv2d1x1Kernel(Kernel): +class GroupConv2dKernel(Kernel): supported_archs: list[int] = [80, 86, 89, 90] def __init__( @@ -1587,12 +1813,19 @@ def __init__( h: int, w: int, c_out: int, + kernel_h: int, + kernel_w: int, stride_h: int, stride_w: int, pad_h: int, pad_w: int, + dilation_h: int, + dilation_w: int, dtype: torch.dtype, has_bias: bool = False, + groups: int = 1, + c_in_g: Optional[int] = None, + c_out_g: Optional[int] = None, config: Optional[dict] = None, tune: bool = False, ) -> None: @@ -1602,55 +1835,74 @@ def __init__( self.h = h self.w = w self.c_out = c_out + self.kernel_h = kernel_h + self.kernel_w = kernel_w self.stride_h = stride_h self.stride_w = stride_w self.pad_h = pad_h self.pad_w = pad_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.groups = groups + self.c_in_g = c_in_g if c_in_g is not None else c_in // groups + self.c_out_g = c_out_g if c_out_g is not None else c_out // groups self.dtype = dtype self.has_bias = has_bias + self.out_h = (h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 + self.out_w = (w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 + self.m = n * self.groups * self.out_h * self.out_w + self.k_total = self.c_in_g * kernel_h * kernel_w + self._validate_group_shape() - self.kernel = _conv2d_1x1_kernel( + self.kernel = _conv2d_group_kernel( n, c_in, h, w, c_out, + kernel_h, + kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, + dilation_w, has_bias, self.dtype_str, + groups, + self.c_in_g, + self.c_out_g, ) self.init_config(config, tune) + def _validate_group_shape(self) -> None: + if self.groups <= 1: + raise ValueError("GroupConv2dKernel requires groups > 1") + if self.c_in % self.groups != 0 or self.c_out % self.groups != 0: + raise ValueError( + f"GroupConv2dKernel requires c_in and c_out divisible by groups; " + f"got c_in={self.c_in}, c_out={self.c_out}, groups={self.groups}" + ) + @property def default_config(self) -> dict: sm_version = get_sm_version() - if sm_version in {80}: + if sm_version in {90}: return { "block_m": 64, "block_n": 64, "block_k": 64, - "num_stages": 1, - "threads": 128, - "enable_rasterization": True, - } - if sm_version in {90}: - return { - "block_m": 64, - "block_n": 128, - "block_k": 128, - "num_stages": 2, + "num_stages": 3, "threads": 128, - "enable_rasterization": True, + "enable_rasterization": False, } return { "block_m": 64, "block_n": 64, "block_k": 64, - "num_stages": 1, "threads": 128, + "num_stages": 2, "enable_rasterization": True, } @@ -1658,7 +1910,7 @@ def default_config(self) -> dict: def autotune_configs(self) -> list[dict]: shared_memory_limit_bytes = get_shared_memory_limit_bytes() configs = itertools.product( - [64, 128, 256], + [32, 64, 128], [64, 128, 256], [32, 64, 128], [2, 3], @@ -1689,9 +1941,152 @@ def forward( ) -> torch.Tensor: if bias is None: bias = torch.zeros(self.c_out, device=x.device, dtype=x.dtype) - # OIHW -> OC,IC since the 1x1 kernel consumes a dense [C_out, C_in] weight matrix. - weight_oc_ci = weight.view(self.c_out, self.c_in).contiguous() - return _conv2d_1x1_wrapped_kernel( + return _conv2d_group_wrapped_kernel( + self.n, + self.c_in, + self.h, + self.w, + self.c_out, + self.kernel_h, + self.kernel_w, + self.stride_h, + self.stride_w, + self.pad_h, + self.pad_w, + self.dilation_h, + self.dilation_w, + self.has_bias, + self.dtype_str, + self.groups, + self.c_in_g, + self.c_out_g, + self.config["block_m"], + self.config["block_n"], + self.config["block_k"], + self.config["num_stages"], + self.config["threads"], + self.config["enable_rasterization"], + x, + weight, + bias, + ) + + +class Conv2d1x1Kernel(Kernel): + supported_archs: list[int] = [80, 86, 89, 90] + + def __init__( + self, + n: int, + c_in: int, + h: int, + w: int, + c_out: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dtype: torch.dtype, + has_bias: bool = False, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + self.n = n + self.c_in = c_in + self.h = h + self.w = w + self.c_out = c_out + self.stride_h = stride_h + self.stride_w = stride_w + self.pad_h = pad_h + self.pad_w = pad_w + self.dtype = dtype + self.has_bias = has_bias + + self.kernel = _conv2d_1x1_kernel( + n, + c_in, + h, + w, + c_out, + stride_h, + stride_w, + pad_h, + pad_w, + has_bias, + self.dtype_str, + ) + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + sm_version = get_sm_version() + if sm_version in {80}: + return { + "block_m": 64, + "block_n": 64, + "block_k": 64, + "num_stages": 1, + "threads": 128, + "enable_rasterization": True, + } + if sm_version in {90}: + return { + "block_m": 64, + "block_n": 128, + "block_k": 128, + "num_stages": 2, + "threads": 128, + "enable_rasterization": True, + } + return { + "block_m": 64, + "block_n": 64, + "block_k": 64, + "num_stages": 1, + "threads": 128, + "enable_rasterization": True, + } + + @property + def autotune_configs(self) -> list[dict]: + shared_memory_limit_bytes = get_shared_memory_limit_bytes() + configs = itertools.product( + [64, 128, 256], + [64, 128, 256], + [32, 64, 128], + [2, 3], + [128, 256], + [True], + ) + valid_configs = [] + for block_m, block_n, block_k, num_stages, threads, enable_rasterization in configs: + shared_memory_bytes = conv_shared_memory_bytes( + block_m, block_n, block_k, num_stages, self.dtype) + if shared_memory_bytes > shared_memory_limit_bytes: + continue + valid_configs.append({ + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_stages": num_stages, + "threads": threads, + "enable_rasterization": enable_rasterization, + }) + return valid_configs + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if bias is None: + bias = torch.zeros(self.c_out, device=x.device, dtype=x.dtype) + # OIHW -> OC,IC since the 1x1 kernel consumes a dense [C_out, C_in] weight matrix. + weight_oc_ci = weight.view(self.c_out, self.c_in).contiguous() + return _conv2d_1x1_wrapped_kernel( self.n, self.c_in, self.h, @@ -1841,8 +2236,214 @@ def _conv3d_main( return _conv3d_func -@torch.library.custom_op("top::conv3d_wrapped_kernel", mutates_args=()) -def _conv3d_wrapped_kernel( +@functools.lru_cache(maxsize=64) +def _conv3d_group_kernel( + n: int, + c_in: int, + d_in: int, + h_in: int, + w_in: int, + c_out: int, + kernel_d: int, + kernel_h: int, + kernel_w: int, + stride_d: int, + stride_h: int, + stride_w: int, + pad_d: int, + pad_h: int, + pad_w: int, + dilation_d: int, + dilation_h: int, + dilation_w: int, + has_bias: bool, + dtype: str = "float16", + groups: int = 1, + c_in_g: int = 0, + c_out_g: int = 0, +): + accum_dtype = "float" + out_d = (d_in + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) // stride_d + 1 + out_h = (h_in + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 + out_w = (w_in + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 + out_dhw = out_d * out_h * out_w + c_in_g = c_in_g if c_in_g > 0 else c_in // groups + c_out_g = c_out_g if c_out_g > 0 else c_out // groups + k_total = kernel_d * kernel_h * kernel_w * c_in_g + + @tilelang.jit( + out_idx=[2], + compile_flags=["-O3", "-DENABLE_BF16"], + pass_configs={"tl.enable_async_copy": False}, + ) + def _conv3d_group_func( + block_m: int, + block_n: int, + block_k: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + ): + @T.prim_func + def _conv3d_group_main( + x: T.Tensor((n, c_in, d_in, h_in, w_in), dtype), # type: ignore + weight: T.Tensor((c_out, c_in_g, kernel_d, kernel_h, kernel_w), dtype), # type: ignore + out: T.Tensor((n, c_out, out_d, out_h, out_w), dtype), # type: ignore + bias: T.Tensor((c_out,), dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(out_dhw, block_n), + T.ceildiv(c_out_g, block_m), + n * groups, + threads=threads, + ) as (bx, by, bz): + weight_shared = T.alloc_shared((block_m, block_k), dtype) + data_shared = T.alloc_shared((block_k, block_n), dtype) + out_local = T.alloc_fragment((block_m, block_n), accum_dtype) + out_shared = T.alloc_shared((block_m, block_n), dtype) + + T.use_swizzle(10, enable=enable_rasterization) + T.clear(out_local) + + batch_id = bz // groups + group_id = bz % groups + + for k_iter in T.Pipelined(T.ceildiv(k_total, block_k), num_stages=num_stages): + for i, k in T.Parallel(block_m, block_k): + oc_g = by * block_m + i + oc = group_id * c_out_g + oc_g + k_idx = k_iter * block_k + k + ci_g = k_idx // (kernel_d * kernel_h * kernel_w) + kernel_idx = k_idx % (kernel_d * kernel_h * kernel_w) + kd = kernel_idx // (kernel_h * kernel_w) + kh = (kernel_idx // kernel_w) % kernel_h + kw = kernel_idx % kernel_w + weight_shared[i, k] = T.if_then_else( + (oc_g < c_out_g) & (k_idx < k_total), + weight[oc, ci_g, kd, kh, kw], + T.cast(0.0, dtype), + ) + + for k, j in T.Parallel(block_k, block_n): + k_idx = k_iter * block_k + k + spatial_idx = bx * block_n + j + ci_g = k_idx // (kernel_d * kernel_h * kernel_w) + ci = group_id * c_in_g + ci_g + kernel_idx = k_idx % (kernel_d * kernel_h * kernel_w) + kd = kernel_idx // (kernel_h * kernel_w) + kh = (kernel_idx // kernel_w) % kernel_h + kw = kernel_idx % kernel_w + od = spatial_idx // (out_h * out_w) + oh = (spatial_idx // out_w) % out_h + ow = spatial_idx % out_w + id_ = od * stride_d + kd * dilation_d - pad_d + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + data_shared[k, j] = T.if_then_else( + (spatial_idx < out_dhw) + & (k_idx < k_total) + & (id_ >= 0) + & (ih >= 0) + & (iw >= 0) + & (id_ < d_in) + & (ih < h_in) + & (iw < w_in), + x[batch_id, ci, id_, ih, iw], + T.cast(0.0, dtype), + ) + + T.gemm(weight_shared, data_shared, out_local) + + for i, j in T.Parallel(block_m, block_n): + oc_g = by * block_m + i + oc = group_id * c_out_g + oc_g + spatial_idx = bx * block_n + j + if has_bias: + out_shared[i, j] = T.if_then_else( + (oc_g < c_out_g) & (spatial_idx < out_dhw), + T.cast(out_local[i, j] + T.cast(bias[oc], accum_dtype), dtype), + T.cast(0.0, dtype), + ) + else: + out_shared[i, j] = T.if_then_else( + (oc_g < c_out_g) & (spatial_idx < out_dhw), + T.cast(out_local[i, j], dtype), + T.cast(0.0, dtype), + ) + + for i, j in T.Parallel(block_m, block_n): + oc_g = by * block_m + i + oc = group_id * c_out_g + oc_g + spatial_idx = bx * block_n + j + od = spatial_idx // (out_h * out_w) + oh = (spatial_idx // out_w) % out_h + ow = spatial_idx % out_w + if oc_g < c_out_g and spatial_idx < out_dhw: + out[batch_id, oc, od, oh, ow] = out_shared[i, j] + + return _conv3d_group_main + + return _conv3d_group_func + + +@torch.library.custom_op("top::conv3d_wrapped_kernel", mutates_args=()) +def _conv3d_wrapped_kernel( + n: int, + c_in: int, + d_in: int, + h_in: int, + w_in: int, + c_out: int, + kernel_d: int, + kernel_h: int, + kernel_w: int, + stride_d: int, + stride_h: int, + stride_w: int, + pad_d: int, + pad_h: int, + pad_w: int, + dilation_d: int, + dilation_h: int, + dilation_w: int, + has_bias: bool, + dtype: str, + block_m: int, + block_n: int, + block_k: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return _conv3d_kernel( + n, + c_in, + d_in, + h_in, + w_in, + c_out, + kernel_d, + kernel_h, + kernel_w, + stride_d, + stride_h, + stride_w, + pad_d, + pad_h, + pad_w, + dilation_d, + dilation_h, + dilation_w, + has_bias, + dtype, + )(block_m, block_n, block_k, num_stages, threads, enable_rasterization)(x, weight, bias) + + +@torch.library.custom_op("top::conv3d_group_wrapped_kernel", mutates_args=()) +def _conv3d_group_wrapped_kernel( n: int, c_in: int, d_in: int, @@ -1863,6 +2464,9 @@ def _conv3d_wrapped_kernel( dilation_w: int, has_bias: bool, dtype: str, + groups: int, + c_in_g: int, + c_out_g: int, block_m: int, block_n: int, block_k: int, @@ -1873,7 +2477,7 @@ def _conv3d_wrapped_kernel( weight: torch.Tensor, bias: torch.Tensor, ) -> torch.Tensor: - return _conv3d_kernel( + return _conv3d_group_kernel( n, c_in, d_in, @@ -1894,6 +2498,9 @@ def _conv3d_wrapped_kernel( dilation_w, has_bias, dtype, + groups, + c_in_g, + c_out_g, )(block_m, block_n, block_k, num_stages, threads, enable_rasterization)(x, weight, bias) @@ -1937,6 +2544,50 @@ def _( ) +@_conv3d_group_wrapped_kernel.register_fake +def _( + n: int, + c_in: int, + d_in: int, + h_in: int, + w_in: int, + c_out: int, + kernel_d: int, + kernel_h: int, + kernel_w: int, + stride_d: int, + stride_h: int, + stride_w: int, + pad_d: int, + pad_h: int, + pad_w: int, + dilation_d: int, + dilation_h: int, + dilation_w: int, + has_bias: bool, + dtype: str, + groups: int, + c_in_g: int, + c_out_g: int, + block_m: int, + block_n: int, + block_k: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + *inputs: tuple[torch.Tensor, ...], +) -> torch.Tensor: + del groups, c_in_g, c_out_g + out_d = (d_in + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) // stride_d + 1 + out_h = (h_in + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 + out_w = (w_in + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 + return torch.empty( + (n, c_out, out_d, out_h, out_w), + dtype=inputs[0].dtype, + device=inputs[0].device, + ) + + class Conv3dKernel(Kernel): supported_archs: list[int] = [80, 86, 89, 90] @@ -2103,3 +2754,193 @@ def forward( weight, bias, ) + + +class GroupConv3dKernel(Kernel): + supported_archs: list[int] = [80, 86, 89, 90] + + def __init__( + self, + n: int, + c_in: int, + d_in: int, + h_in: int, + w_in: int, + c_out: int, + kernel_d: int, + kernel_h: int, + kernel_w: int, + stride_d: int, + stride_h: int, + stride_w: int, + pad_d: int, + pad_h: int, + pad_w: int, + dilation_d: int, + dilation_h: int, + dilation_w: int, + dtype: torch.dtype, + has_bias: bool = False, + groups: int = 1, + c_in_g: Optional[int] = None, + c_out_g: Optional[int] = None, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + self.n = n + self.c_in = c_in + self.d_in = d_in + self.h_in = h_in + self.w_in = w_in + self.c_out = c_out + self.kernel_d = kernel_d + self.kernel_h = kernel_h + self.kernel_w = kernel_w + self.stride_d = stride_d + self.stride_h = stride_h + self.stride_w = stride_w + self.pad_d = pad_d + self.pad_h = pad_h + self.pad_w = pad_w + self.dilation_d = dilation_d + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.groups = groups + self.c_in_g = c_in_g if c_in_g is not None else c_in // groups + self.c_out_g = c_out_g if c_out_g is not None else c_out // groups + self.dtype = dtype + self.has_bias = has_bias + self.out_d = (d_in + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) // stride_d + 1 + self.out_h = (h_in + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 + self.out_w = (w_in + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 + self.m = n * self.groups * self.out_d * self.out_h * self.out_w + self.k_total = self.c_in_g * kernel_d * kernel_h * kernel_w + self._validate_group_shape() + + self.kernel = _conv3d_group_kernel( + n, + c_in, + d_in, + h_in, + w_in, + c_out, + kernel_d, + kernel_h, + kernel_w, + stride_d, + stride_h, + stride_w, + pad_d, + pad_h, + pad_w, + dilation_d, + dilation_h, + dilation_w, + has_bias, + self.dtype_str, + groups, + self.c_in_g, + self.c_out_g, + ) + self.init_config(config, tune) + + def _validate_group_shape(self) -> None: + if self.groups <= 1: + raise ValueError("GroupConv3dKernel requires groups > 1") + if self.c_in % self.groups != 0 or self.c_out % self.groups != 0: + raise ValueError( + f"GroupConv3dKernel requires c_in and c_out divisible by groups; " + f"got c_in={self.c_in}, c_out={self.c_out}, groups={self.groups}" + ) + + @property + def default_config(self) -> dict: + sm_version = get_sm_version() + if sm_version in {90}: + return { + "block_m": 64, + "block_n": 64, + "block_k": 64, + "num_stages": 3, + "threads": 128, + "enable_rasterization": True, + } + return { + "block_m": 64, + "block_n": 64, + "block_k": 64, + "num_stages": 2, + "threads": 128, + "enable_rasterization": True, + } + + @property + def autotune_configs(self) -> list[dict]: + shared_memory_limit_bytes = get_shared_memory_limit_bytes() + configs = itertools.product( + [32, 64, 128], + [32, 64, 128], + [32, 64, 128], + [2, 3], + [128, 256], + [True], + ) + valid_configs = [] + for block_m, block_n, block_k, num_stages, threads, enable_rasterization in configs: + shared_memory_bytes = conv_shared_memory_bytes( + block_m, block_n, block_k, num_stages, self.dtype) + if shared_memory_bytes > shared_memory_limit_bytes: + continue + valid_configs.append({ + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_stages": num_stages, + "threads": threads, + "enable_rasterization": enable_rasterization, + }) + return valid_configs + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if bias is None: + bias = torch.zeros(self.c_out, device=x.device, dtype=x.dtype) + return _conv3d_group_wrapped_kernel( + self.n, + self.c_in, + self.d_in, + self.h_in, + self.w_in, + self.c_out, + self.kernel_d, + self.kernel_h, + self.kernel_w, + self.stride_d, + self.stride_h, + self.stride_w, + self.pad_d, + self.pad_h, + self.pad_w, + self.dilation_d, + self.dilation_h, + self.dilation_w, + self.has_bias, + self.dtype_str, + self.groups, + self.c_in_g, + self.c_out_g, + self.config["block_m"], + self.config["block_n"], + self.config["block_k"], + self.config["num_stages"], + self.config["threads"], + self.config["enable_rasterization"], + x, + weight, + bias, + ) diff --git a/tileops/ops/convolution.py b/tileops/ops/convolution.py index 6a64fbc3d..5dcdf99c7 100644 --- a/tileops/ops/convolution.py +++ b/tileops/ops/convolution.py @@ -9,6 +9,8 @@ Conv2dKernel, Conv3dKernel, GroupConv1dKernel, + GroupConv2dKernel, + GroupConv3dKernel, ) from tileops.kernels.kernel_base import Kernel @@ -501,13 +503,14 @@ def __init__( _validate_positive_int("w", w, "Conv2d") _validate_positive_int("c_out", c_out, "Conv2d") _validate_conv_groups("Conv2d", c_in, c_out, groups) - if groups != 1: - raise NotImplementedError("Conv2d currently supports groups=1 only") self.n = n self.c_in = c_in self.h = h self.w = w self.c_out = c_out + self.groups = groups + self.c_in_g = c_in // groups + self.c_out_g = c_out // groups self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) dilation_tuple = _conv_tuple(dilation, 2, "dilation", "Conv2d") @@ -523,7 +526,6 @@ def __init__( dilation=dilation_tuple, ) self.dilation = dilation_tuple - self.groups = groups self.has_bias = _has_bias self.dtype = dtype @@ -543,14 +545,15 @@ def __init__( tune=tune, ) if ( - self.kernel_size == (1, 1) + self.groups == 1 + and self.kernel_size == (1, 1) and self.stride == (1, 1) and self.padding == (0, 0) and self.dilation == (1, 1) and "conv2d_1x1_kernel" in self.kernel_map ): self.kernel = self.kernel_map["conv2d_1x1_kernel"](**kernel_kwargs) - elif "conv2d_kernel" in self.kernel_map: + elif self.groups == 1 and "conv2d_kernel" in self.kernel_map: self.kernel = self.kernel_map["conv2d_kernel"]( **kernel_kwargs, kernel_h=self.kernel_size[0], @@ -558,9 +561,21 @@ def __init__( dilation_h=self.dilation[0], dilation_w=self.dilation[1], ) + elif self.groups > 1 and "group_conv2d_kernel" in self.kernel_map: + self.kernel = self.kernel_map["group_conv2d_kernel"]( + **kernel_kwargs, + kernel_h=self.kernel_size[0], + kernel_w=self.kernel_size[1], + dilation_h=self.dilation[0], + dilation_w=self.dilation[1], + groups=self.groups, + c_in_g=self.c_in_g, + c_out_g=self.c_out_g, + ) else: raise NotImplementedError( - "Conv2dFwdOp requires 'conv2d_1x1_kernel' or 'conv2d_kernel' in kernel_map" + "Conv2dFwdOp requires 'conv2d_1x1_kernel', 'conv2d_kernel', " + "or 'group_conv2d_kernel' in kernel_map" ) @property @@ -568,6 +583,7 @@ def default_kernel_map(self) -> Dict[str, Kernel]: return { "conv2d_1x1_kernel": Conv2d1x1Kernel, "conv2d_kernel": Conv2dKernel, + "group_conv2d_kernel": GroupConv2dKernel, } def forward( @@ -580,7 +596,7 @@ def forward( "Conv2d", "weight", weight, - (self.c_out, self.c_in, self.kernel_size[0], self.kernel_size[1]), + (self.c_out, self.c_in_g, self.kernel_size[0], self.kernel_size[1]), ) return self.kernel(input, weight, None) @@ -642,7 +658,7 @@ def forward( "Conv2d", "weight", weight, - (self.c_out, self.c_in, self.kernel_size[0], self.kernel_size[1]), + (self.c_out, self.c_in_g, self.kernel_size[0], self.kernel_size[1]), ) _validate_tensor_shape("Conv2d", "bias", bias, (self.c_out,)) return self.kernel(input, weight, bias) @@ -679,14 +695,15 @@ def __init__( _validate_positive_int("w", w, "Conv3d") _validate_positive_int("c_out", c_out, "Conv3d") _validate_conv_groups("Conv3d", c_in, c_out, groups) - if groups != 1: - raise NotImplementedError("Conv3d currently supports groups=1 only") self.n = n self.c_in = c_in self.d = d self.h = h self.w = w self.c_out = c_out + self.groups = groups + self.c_in_g = c_in // groups + self.c_out_g = c_out // groups self.kernel_size = _triple(kernel_size) self.stride = _triple(stride) dilation_tuple = _conv_tuple(dilation, 3, "dilation", "Conv3d") @@ -702,14 +719,11 @@ def __init__( dilation=dilation_tuple, ) self.dilation = dilation_tuple - self.groups = groups self.has_bias = _has_bias self.dtype = dtype self.dispatch_kernel(kernel_map) - if "conv3d_kernel" not in self.kernel_map: - raise NotImplementedError("Conv3dFwdOp requires 'conv3d_kernel' in kernel_map") - self.kernel = self.kernel_map["conv3d_kernel"]( + kernel_kwargs = dict( n=n, c_in=c_in, d_in=d, @@ -732,10 +746,26 @@ def __init__( has_bias=_has_bias, tune=tune, ) + if self.groups == 1 and "conv3d_kernel" in self.kernel_map: + self.kernel = self.kernel_map["conv3d_kernel"](**kernel_kwargs) + elif self.groups > 1 and "group_conv3d_kernel" in self.kernel_map: + self.kernel = self.kernel_map["group_conv3d_kernel"]( + **kernel_kwargs, + groups=self.groups, + c_in_g=self.c_in_g, + c_out_g=self.c_out_g, + ) + else: + raise NotImplementedError( + "Conv3dFwdOp requires 'conv3d_kernel' or 'group_conv3d_kernel' in kernel_map" + ) @property def default_kernel_map(self) -> Dict[str, Kernel]: - return {"conv3d_kernel": Conv3dKernel} + return { + "conv3d_kernel": Conv3dKernel, + "group_conv3d_kernel": GroupConv3dKernel, + } def forward( self, @@ -754,7 +784,7 @@ def forward( weight, ( self.c_out, - self.c_in, + self.c_in_g, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], @@ -829,7 +859,7 @@ def forward( weight, ( self.c_out, - self.c_in, + self.c_in_g, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2],