From 6d8be4d2a4f44fda49c8ef502b6fe4ce3e071eef Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Tue, 7 Apr 2026 14:44:53 +0800 Subject: [PATCH 1/9] [FEAT][POOL] add max_pool2d operator --- benchmarks/ops/bench_max_pool2d.py | 135 +++++++++++++ tests/ops/test_max_pool2d.py | 292 +++++++++++++++++++++++++++++ tileops/kernels/__init__.py | 3 +- tileops/kernels/pool/__init__.py | 3 +- tileops/kernels/pool/common.py | 22 ++- tileops/kernels/pool/max_pool2d.py | 259 +++++++++++++++++++++++++ tileops/ops/__init__.py | 2 + tileops/ops/max_pool2d.py | 100 ++++++++++ 8 files changed, 811 insertions(+), 5 deletions(-) create mode 100644 benchmarks/ops/bench_max_pool2d.py create mode 100644 tests/ops/test_max_pool2d.py create mode 100644 tileops/kernels/pool/max_pool2d.py create mode 100644 tileops/ops/max_pool2d.py diff --git a/benchmarks/ops/bench_max_pool2d.py b/benchmarks/ops/bench_max_pool2d.py new file mode 100644 index 000000000..ad6174200 --- /dev/null +++ b/benchmarks/ops/bench_max_pool2d.py @@ -0,0 +1,135 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from benchmarks.benchmark import BenchmarkBase, BenchmarkReport +from tileops.kernels.pool.common import pool_output_dim +from tileops.ops import MaxPool2dOp + + +class MaxPool2dBenchCase: + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = kernel_size + self.stride = kernel_size if stride is None else stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self) -> tuple[torch.Tensor]: + x = torch.randn(self.n, self.h_in, self.w_in, self.c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor: + return F.max_pool2d( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ceil_mode=self.ceil_mode, + ) + + +class MaxPool2dBenchmark(BenchmarkBase): + def calculate_flops(self) -> Optional[float]: + t = self.test + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] + + def calculate_memory(self) -> Optional[float]: + t = self.test + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize + + +_MAX_POOL2D_BASE_CASES = [ + (2, 64, 112, 112, (3, 3), (2, 2), (1, 1), (1, 1), False, "vision-3x3-s2"), + (2, 128, 56, 56, (5, 5), (2, 2), (2, 2), (1, 1), False, "vision-5x5-s2"), + (3, 96, 55, 57, (3, 3), (2, 2), (1, 1), (2, 1), True, "ceil-dilation-nonpow2"), +] + +_MAX_POOL2D_BENCH_PARAMS = [ + pytest.param(*case[:-1], dtype, True, id=f"{case[-1]}-{str(dtype).split('.')[-1]}") + for case in _MAX_POOL2D_BASE_CASES + for dtype in (torch.float16, torch.bfloat16) +] + + +@pytest.mark.parametrize( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, ceil_mode, dtype, tune", + _MAX_POOL2D_BENCH_PARAMS, +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_bench( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dBenchCase( + n, + c_in, + h_in, + w_in, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + dtype, + ) + bm = MaxPool2dBenchmark(test) + inputs = test.gen_inputs() + (x,) = inputs + x_nchw = x.permute(0, 3, 1, 2).contiguous() + + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + result = bm.profile(op, *inputs) + BenchmarkReport.record(op, locals(), result, tag="tileops") + + result_bl = bm.profile(test.ref_program, x_nchw) + BenchmarkReport.record(op, locals(), result_bl, tag="torch") + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_max_pool2d.py b/tests/ops/test_max_pool2d.py new file mode 100644 index 000000000..150ed9028 --- /dev/null +++ b/tests/ops/test_max_pool2d.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from tests.test_base import FixtureBase, TestBase +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.ops import MaxPool2dOp + + +class _DummyKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return x, torch.zeros_like(x, dtype=torch.int64) + + +class MaxPool2dFixture(FixtureBase): + PARAMS = [ + ( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, return_indices, ceil_mode, dtype, tune", + [ + pytest.param( + 2, 64, 56, 56, (3, 3), None, (1, 1), (1, 1), False, False, torch.float16, False, + marks=[pytest.mark.smoke, pytest.mark.packaging], + id="smoke-3x3-default-stride-fp16", + ), + pytest.param( + 1, 96, 29, 31, (3, 5), (2, 2), (1, 2), (1, 1), False, True, torch.float16, False, + marks=pytest.mark.full, + id="full-ceil-nonpow2-fp16", + ), + pytest.param( + 1, 80, 28, 30, (3, 3), (2, 2), (1, 1), (2, 1), False, False, torch.bfloat16, False, + marks=pytest.mark.full, + id="full-dilation-bf16", + ), + pytest.param( + 1, 32, 16, 18, (2, 3), (2, 2), (0, 1), (1, 1), True, False, torch.float16, False, + marks=pytest.mark.full, + id="full-return-indices-fp16", + ), + ], + ), + ] + + +class MaxPool2dTest(TestBase): + def __init__( + self, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self, n: int, c_in: int, h_in: int, w_in: int) -> tuple[torch.Tensor]: + x = torch.randn(n, h_in, w_in, c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + out = F.max_pool2d( + x.permute(0, 3, 1, 2).contiguous(), + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=self.return_indices, + ceil_mode=self.ceil_mode, + ) + if self.return_indices: + values, indices = out + return ( + values.permute(0, 2, 3, 1).contiguous(), + indices.permute(0, 2, 3, 1).contiguous(), + ) + return out.permute(0, 2, 3, 1).contiguous() + + +@MaxPool2dFixture +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dTest( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + ) + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + atol = 1e-3 if dtype == torch.float16 else 1.6e-2 + rtol = 1e-3 if dtype == torch.float16 else 1.6e-2 + test.check(op, *test.gen_inputs(n, c_in, h_in, w_in), atol=atol, rtol=rtol) + + +@pytest.mark.smoke +def test_max_pool2d_dispatches_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=32, + h_in=28, + w_in=28, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + assert isinstance(op.kernel, MaxPool2dKernel) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + assert indices.shape == x.shape + + +@pytest.mark.smoke +def test_max_pool2d_rejects_non_positive_dilation() -> None: + with pytest.raises(ValueError, match="dilation must be greater than zero"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dilation=(1, 0), + ) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_invalid_padding_for_effective_kernel() -> None: + with pytest.raises(ValueError, match="padding must be at most half"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + padding=(3, 1), + dilation=(2, 1), + ) + + +@pytest.mark.smoke +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"dilation": True}, "dilation must be an int or a tuple of 2 ints"), + ({"dilation": (1, True)}, "dilation must contain only ints"), + ({"kernel_size": True}, "kernel_size must be an int or a tuple of 2 ints"), + ({"stride": True}, "stride must be an int or a tuple of 2 ints"), + ({"padding": True}, "padding must be an int or a tuple of 2 ints"), + ], +) +def test_max_pool2d_rejects_invalid_param_types(kwargs: dict[str, object], match: str) -> None: + base_kwargs = { + "n": 1, + "c_in": 8, + "h_in": 16, + "w_in": 16, + "kernel_size": (3, 3), + } + base_kwargs.update(kwargs) + with pytest.raises((TypeError, ValueError), match=match): + MaxPool2dOp(**base_kwargs) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_unsupported_dtype(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + with pytest.raises(ValueError, match="only supports dtypes"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dtype=torch.float32, + ) + + +@pytest.mark.smoke +def test_max_pool2d_forward_rejects_non_cuda_input(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 8, 8, 4) + with pytest.raises(ValueError, match="CUDA"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_rejects_nchw_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="NHWC"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_warns_on_ambiguous_nhwc_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=8, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 8, 8, 8, device="cuda", dtype=torch.float16) + with pytest.warns(UserWarning, match="ambiguous NHWC shape"): + out = op(x) + assert out is x + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tileops/kernels/__init__.py b/tileops/kernels/__init__.py index 1c1e8645a..18d38bb9f 100644 --- a/tileops/kernels/__init__.py +++ b/tileops/kernels/__init__.py @@ -58,7 +58,7 @@ LayerNormKernel, RmsNormKernel, ) -from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel +from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel, MaxPool2dKernel from .rope import ( RopeLlama31Kernel, RopeLongRopeKernel, @@ -72,6 +72,7 @@ "AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", + "MaxPool2dKernel", "BatchNormBwdKernel", "BatchNormFwdInferKernel", "BatchNormFwdTrainKernel", diff --git a/tileops/kernels/pool/__init__.py b/tileops/kernels/pool/__init__.py index bf026df3d..98fbb94fa 100644 --- a/tileops/kernels/pool/__init__.py +++ b/tileops/kernels/pool/__init__.py @@ -1,5 +1,6 @@ from .avg_pool1d import AvgPool1dKernel from .avg_pool2d import AvgPool2dKernel from .avg_pool3d import AvgPool3dKernel +from .max_pool2d import MaxPool2dKernel -__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel"] +__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", "MaxPool2dKernel"] diff --git a/tileops/kernels/pool/common.py b/tileops/kernels/pool/common.py index 9a4bf20d9..fef3073c4 100644 --- a/tileops/kernels/pool/common.py +++ b/tileops/kernels/pool/common.py @@ -40,10 +40,13 @@ def validate_pool_params( kernel_size: tuple[int, ...], stride: tuple[int, ...], padding: tuple[int, ...], + dilation: tuple[int, ...] | None = None, divisor_override: int | None = None, ) -> None: if len(kernel_size) != ndim or len(stride) != ndim or len(padding) != ndim: raise ValueError("kernel_size, stride, and padding must match pooling dimensionality") + if dilation is not None and len(dilation) != ndim: + raise ValueError("dilation must match pooling dimensionality") for name, values in ( ("kernel_size", kernel_size), @@ -62,7 +65,18 @@ def validate_pool_params( if any(v < 0 for v in padding): raise ValueError("padding must be non-negative") - for pad, kernel in zip(padding, kernel_size, strict=True): + if dilation is not None: + if not all(isinstance(v, int) and not isinstance(v, bool) for v in dilation): + raise TypeError("dilation must contain only ints") + if any(v <= 0 for v in dilation): + raise ValueError("dilation must be greater than zero") + effective_kernel = tuple( + (kernel - 1) * step + 1 for kernel, step in zip(kernel_size, dilation, strict=True) + ) + else: + effective_kernel = kernel_size + + for pad, kernel in zip(padding, effective_kernel, strict=True): if pad > kernel // 2: raise ValueError("padding must be at most half of the effective kernel size") @@ -103,11 +117,13 @@ def pool_output_dim( stride: int, padding: int, ceil_mode: bool, + dilation: int = 1, ) -> int: + effective_kernel = (kernel_size - 1) * dilation + 1 if ceil_mode: - out = (input_size + 2 * padding - kernel_size + stride - 1) // stride + 1 + out = (input_size + 2 * padding - effective_kernel + stride - 1) // stride + 1 else: - out = (input_size + 2 * padding - kernel_size) // stride + 1 + out = (input_size + 2 * padding - effective_kernel) // stride + 1 if ceil_mode and out > 0 and (out - 1) * stride >= input_size + padding: out -= 1 diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py new file mode 100644 index 000000000..cdc2d424f --- /dev/null +++ b/tileops/kernels/pool/max_pool2d.py @@ -0,0 +1,259 @@ +import functools +import itertools +from typing import Optional + +import tilelang +import tilelang.language as T +import torch + +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool.common import pool_output_dim + +__all__ = ["MaxPool2dKernel"] + +_SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) + + +@functools.lru_cache(maxsize=64) +def _max_pool2d_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str = "float16", +): + accum_dtype = "float32" + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + @tilelang.jit(out_idx=[1, 2], compile_flags=["-O3", "-DENABLE_BF16"]) + def _max_pool2d_func(block_m: int, block_c: int, threads: int): + @T.prim_func + def _max_pool2d_main( + x: T.Tensor((n, h_in, w_in, c_in), dtype), # type: ignore + out: T.Tensor((n, out_h, out_w, c_in), dtype), # type: ignore + out_indices: T.Tensor((n, out_h, out_w, c_in), "int64"), # type: ignore + ): + with T.Kernel( + T.ceildiv(c_in, block_c), + T.ceildiv(n * out_h * out_w, block_m), + threads=threads, + ) as (bx, by): + out_flat = T.Tensor((n * out_h * out_w, c_in), dtype, out.data) + indices_flat = T.Tensor((n * out_h * out_w, c_in), "int64", out_indices.data) + + for i, j in T.Parallel(block_m, block_c): + m_idx = by * block_m + i + c_idx = bx * block_c + j + if m_idx < n * out_h * out_w and c_idx < c_in: + batch = m_idx // (out_h * out_w) + out_idx = m_idx % (out_h * out_w) + oh = out_idx // out_w + ow = out_idx % out_w + max_val = T.alloc_var(T.float32) + max_index = T.alloc_var(T.int64) + max_val = -T.infinity(accum_dtype) + max_index = T.int64(0) + + for kh in T.serial(kernel_h): + for kw in T.serial(kernel_w): + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + if ih >= 0 and ih < h_in and iw >= 0 and iw < w_in: + candidate = T.cast(x[batch, ih, iw, c_idx], accum_dtype) + candidate_index = T.cast(ih * w_in + iw, "int64") + should_update = candidate > max_val + max_val = T.if_then_else(should_update, candidate, max_val) + max_index = T.if_then_else(should_update, candidate_index, max_index) + + out_flat[m_idx, c_idx] = T.cast(max_val, dtype) + indices_flat[m_idx, c_idx] = max_index + + return _max_pool2d_main + + return _max_pool2d_func + + +@torch.library.custom_op("top::max_pool2d_wrapped_kernel", mutates_args=()) +def _max_pool2d_wrapped_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool2d_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + dtype, + )(block_m, block_c, threads)(x) + + +@_max_pool2d_wrapped_kernel.register_fake +def _( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + _ = (dtype, block_m, block_c, threads) + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + return ( + torch.empty((n, out_h, out_w, c_in), dtype=x.dtype, device=x.device), + torch.empty((n, out_h, out_w, c_in), dtype=torch.int64, device=x.device), + ) + + +class MaxPool2dKernel(Kernel): + supported_archs: list[int] = [80, 86, 89, 90] + SUPPORTED_DTYPES = _SUPPORTED_DTYPES + + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + if self.SUPPORTED_DTYPES is not None and dtype not in self.SUPPORTED_DTYPES: + supported = ", ".join(str(dt) for dt in self.SUPPORTED_DTYPES) + raise ValueError( + f"{self.__class__.__name__} only supports dtypes [{supported}], got {dtype}" + ) + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_h = kernel_h + self.kernel_w = kernel_w + self.stride_h = stride_h + self.stride_w = stride_w + self.pad_h = pad_h + self.pad_w = pad_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.ceil_mode = ceil_mode + self.dtype = dtype + self.out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + self.out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + self.kernel = _max_pool2d_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + self.dtype_str, + ) + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_m": 128, + "block_c": 64, + "threads": 128, + } + + @property + def autotune_configs(self) -> list[dict]: + configs = itertools.product([64, 128, 256], [32, 64, 128], [128, 256]) + return [ + { + "block_m": block_m, + "block_c": block_c, + "threads": threads, + } + for block_m, block_c, threads in configs + ] + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool2d_wrapped_kernel( + self.n, + self.c_in, + self.h_in, + self.w_in, + self.kernel_h, + self.kernel_w, + self.stride_h, + self.stride_w, + self.pad_h, + self.pad_w, + self.dilation_h, + self.dilation_w, + self.ceil_mode, + self.dtype_str, + self.config["block_m"], + self.config["block_c"], + self.config["threads"], + x, + ) diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index f7c8a6d50..684b2f425 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -1,6 +1,7 @@ from .avg_pool1d import AvgPool1dOp from .avg_pool2d import AvgPool2dOp from .avg_pool3d import AvgPool3dOp +from .max_pool2d import MaxPool2dOp from .conv1d import Conv1dOp from .conv2d import Conv2dOp from .conv3d import Conv3dOp @@ -95,6 +96,7 @@ "AvgPool1dOp", "AvgPool2dOp", "AvgPool3dOp", + "MaxPool2dOp", "AdaLayerNormOp", "AdaLayerNormZeroOp", "BatchNormBwdOp", diff --git a/tileops/ops/max_pool2d.py b/tileops/ops/max_pool2d.py new file mode 100644 index 000000000..609be2204 --- /dev/null +++ b/tileops/ops/max_pool2d.py @@ -0,0 +1,100 @@ +from typing import Dict, Optional, Tuple + +import torch + +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.kernels.pool.common import ( + _normalize_pool_dims, + validate_channels_last_input, + validate_pool_params, +) + +from .op import Op + +__all__ = ["MaxPool2dOp"] + + +class MaxPool2dOp(Op): + """Max pooling over channels-last `NHWC` inputs.""" + + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: int | Tuple[int, int], + stride: Optional[int | Tuple[int, int]] = None, + padding: int | Tuple[int, int] = 0, + dilation: int | Tuple[int, int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, + dtype: torch.dtype = torch.float16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = _normalize_pool_dims("kernel_size", kernel_size, 2) + self.stride = ( + self.kernel_size + if stride is None + else _normalize_pool_dims("stride", stride, 2) + ) + self.padding = _normalize_pool_dims("padding", padding, 2) + self.dilation = _normalize_pool_dims("dilation", dilation, 2) + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + validate_pool_params( + ndim=2, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + + self.dispatch_kernel(kernel_map) + if "max_pool2d_kernel" not in self.kernel_map: + raise NotImplementedError("MaxPool2dOp requires 'max_pool2d_kernel' in kernel_map") + self.kernel = self.kernel_map["max_pool2d_kernel"]( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_h=self.kernel_size[0], + kernel_w=self.kernel_size[1], + stride_h=self.stride[0], + stride_w=self.stride[1], + pad_h=self.padding[0], + pad_w=self.padding[1], + dilation_h=self.dilation[0], + dilation_w=self.dilation[1], + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"max_pool2d_kernel": MaxPool2dKernel} + + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if not x.is_cuda: + raise ValueError("Input must be a CUDA tensor") + if x.dtype != self.dtype: + raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}") + validate_channels_last_input( + op_name=type(self).__name__, + x_shape=tuple(x.shape), + expected_shape=(self.n, self.h_in, self.w_in, self.c_in), + layout="NHWC", + ambiguous_layout_shape=(self.n, self.c_in, self.h_in, self.w_in), + ) + values, indices = self.kernel(x) + if self.return_indices: + return values, indices + return values From 7a8647dcea9b8f7bf184aeab674d5828b0152546 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Tue, 7 Apr 2026 14:44:53 +0800 Subject: [PATCH 2/9] [FEAT][POOL] add max_pool2d operator --- benchmarks/ops/bench_max_pool2d.py | 135 +++++++++++++ tests/ops/test_max_pool2d.py | 292 +++++++++++++++++++++++++++++ tileops/kernels/__init__.py | 3 +- tileops/kernels/pool/__init__.py | 3 +- tileops/kernels/pool/common.py | 22 ++- tileops/kernels/pool/max_pool2d.py | 259 +++++++++++++++++++++++++ tileops/ops/__init__.py | 2 + tileops/ops/max_pool2d.py | 100 ++++++++++ 8 files changed, 811 insertions(+), 5 deletions(-) create mode 100644 benchmarks/ops/bench_max_pool2d.py create mode 100644 tests/ops/test_max_pool2d.py create mode 100644 tileops/kernels/pool/max_pool2d.py create mode 100644 tileops/ops/max_pool2d.py diff --git a/benchmarks/ops/bench_max_pool2d.py b/benchmarks/ops/bench_max_pool2d.py new file mode 100644 index 000000000..ad6174200 --- /dev/null +++ b/benchmarks/ops/bench_max_pool2d.py @@ -0,0 +1,135 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from benchmarks.benchmark import BenchmarkBase, BenchmarkReport +from tileops.kernels.pool.common import pool_output_dim +from tileops.ops import MaxPool2dOp + + +class MaxPool2dBenchCase: + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = kernel_size + self.stride = kernel_size if stride is None else stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self) -> tuple[torch.Tensor]: + x = torch.randn(self.n, self.h_in, self.w_in, self.c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor: + return F.max_pool2d( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ceil_mode=self.ceil_mode, + ) + + +class MaxPool2dBenchmark(BenchmarkBase): + def calculate_flops(self) -> Optional[float]: + t = self.test + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] + + def calculate_memory(self) -> Optional[float]: + t = self.test + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize + + +_MAX_POOL2D_BASE_CASES = [ + (2, 64, 112, 112, (3, 3), (2, 2), (1, 1), (1, 1), False, "vision-3x3-s2"), + (2, 128, 56, 56, (5, 5), (2, 2), (2, 2), (1, 1), False, "vision-5x5-s2"), + (3, 96, 55, 57, (3, 3), (2, 2), (1, 1), (2, 1), True, "ceil-dilation-nonpow2"), +] + +_MAX_POOL2D_BENCH_PARAMS = [ + pytest.param(*case[:-1], dtype, True, id=f"{case[-1]}-{str(dtype).split('.')[-1]}") + for case in _MAX_POOL2D_BASE_CASES + for dtype in (torch.float16, torch.bfloat16) +] + + +@pytest.mark.parametrize( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, ceil_mode, dtype, tune", + _MAX_POOL2D_BENCH_PARAMS, +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_bench( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dBenchCase( + n, + c_in, + h_in, + w_in, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + dtype, + ) + bm = MaxPool2dBenchmark(test) + inputs = test.gen_inputs() + (x,) = inputs + x_nchw = x.permute(0, 3, 1, 2).contiguous() + + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + result = bm.profile(op, *inputs) + BenchmarkReport.record(op, locals(), result, tag="tileops") + + result_bl = bm.profile(test.ref_program, x_nchw) + BenchmarkReport.record(op, locals(), result_bl, tag="torch") + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_max_pool2d.py b/tests/ops/test_max_pool2d.py new file mode 100644 index 000000000..150ed9028 --- /dev/null +++ b/tests/ops/test_max_pool2d.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from tests.test_base import FixtureBase, TestBase +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.ops import MaxPool2dOp + + +class _DummyKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return x, torch.zeros_like(x, dtype=torch.int64) + + +class MaxPool2dFixture(FixtureBase): + PARAMS = [ + ( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, return_indices, ceil_mode, dtype, tune", + [ + pytest.param( + 2, 64, 56, 56, (3, 3), None, (1, 1), (1, 1), False, False, torch.float16, False, + marks=[pytest.mark.smoke, pytest.mark.packaging], + id="smoke-3x3-default-stride-fp16", + ), + pytest.param( + 1, 96, 29, 31, (3, 5), (2, 2), (1, 2), (1, 1), False, True, torch.float16, False, + marks=pytest.mark.full, + id="full-ceil-nonpow2-fp16", + ), + pytest.param( + 1, 80, 28, 30, (3, 3), (2, 2), (1, 1), (2, 1), False, False, torch.bfloat16, False, + marks=pytest.mark.full, + id="full-dilation-bf16", + ), + pytest.param( + 1, 32, 16, 18, (2, 3), (2, 2), (0, 1), (1, 1), True, False, torch.float16, False, + marks=pytest.mark.full, + id="full-return-indices-fp16", + ), + ], + ), + ] + + +class MaxPool2dTest(TestBase): + def __init__( + self, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self, n: int, c_in: int, h_in: int, w_in: int) -> tuple[torch.Tensor]: + x = torch.randn(n, h_in, w_in, c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + out = F.max_pool2d( + x.permute(0, 3, 1, 2).contiguous(), + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=self.return_indices, + ceil_mode=self.ceil_mode, + ) + if self.return_indices: + values, indices = out + return ( + values.permute(0, 2, 3, 1).contiguous(), + indices.permute(0, 2, 3, 1).contiguous(), + ) + return out.permute(0, 2, 3, 1).contiguous() + + +@MaxPool2dFixture +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dTest( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + ) + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + atol = 1e-3 if dtype == torch.float16 else 1.6e-2 + rtol = 1e-3 if dtype == torch.float16 else 1.6e-2 + test.check(op, *test.gen_inputs(n, c_in, h_in, w_in), atol=atol, rtol=rtol) + + +@pytest.mark.smoke +def test_max_pool2d_dispatches_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=32, + h_in=28, + w_in=28, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + assert isinstance(op.kernel, MaxPool2dKernel) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + assert indices.shape == x.shape + + +@pytest.mark.smoke +def test_max_pool2d_rejects_non_positive_dilation() -> None: + with pytest.raises(ValueError, match="dilation must be greater than zero"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dilation=(1, 0), + ) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_invalid_padding_for_effective_kernel() -> None: + with pytest.raises(ValueError, match="padding must be at most half"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + padding=(3, 1), + dilation=(2, 1), + ) + + +@pytest.mark.smoke +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"dilation": True}, "dilation must be an int or a tuple of 2 ints"), + ({"dilation": (1, True)}, "dilation must contain only ints"), + ({"kernel_size": True}, "kernel_size must be an int or a tuple of 2 ints"), + ({"stride": True}, "stride must be an int or a tuple of 2 ints"), + ({"padding": True}, "padding must be an int or a tuple of 2 ints"), + ], +) +def test_max_pool2d_rejects_invalid_param_types(kwargs: dict[str, object], match: str) -> None: + base_kwargs = { + "n": 1, + "c_in": 8, + "h_in": 16, + "w_in": 16, + "kernel_size": (3, 3), + } + base_kwargs.update(kwargs) + with pytest.raises((TypeError, ValueError), match=match): + MaxPool2dOp(**base_kwargs) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_unsupported_dtype(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + with pytest.raises(ValueError, match="only supports dtypes"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dtype=torch.float32, + ) + + +@pytest.mark.smoke +def test_max_pool2d_forward_rejects_non_cuda_input(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 8, 8, 4) + with pytest.raises(ValueError, match="CUDA"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_rejects_nchw_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="NHWC"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_warns_on_ambiguous_nhwc_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=8, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyKernel}, + ) + x = torch.randn(1, 8, 8, 8, device="cuda", dtype=torch.float16) + with pytest.warns(UserWarning, match="ambiguous NHWC shape"): + out = op(x) + assert out is x + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tileops/kernels/__init__.py b/tileops/kernels/__init__.py index 1c1e8645a..18d38bb9f 100644 --- a/tileops/kernels/__init__.py +++ b/tileops/kernels/__init__.py @@ -58,7 +58,7 @@ LayerNormKernel, RmsNormKernel, ) -from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel +from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel, MaxPool2dKernel from .rope import ( RopeLlama31Kernel, RopeLongRopeKernel, @@ -72,6 +72,7 @@ "AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", + "MaxPool2dKernel", "BatchNormBwdKernel", "BatchNormFwdInferKernel", "BatchNormFwdTrainKernel", diff --git a/tileops/kernels/pool/__init__.py b/tileops/kernels/pool/__init__.py index bf026df3d..98fbb94fa 100644 --- a/tileops/kernels/pool/__init__.py +++ b/tileops/kernels/pool/__init__.py @@ -1,5 +1,6 @@ from .avg_pool1d import AvgPool1dKernel from .avg_pool2d import AvgPool2dKernel from .avg_pool3d import AvgPool3dKernel +from .max_pool2d import MaxPool2dKernel -__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel"] +__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", "MaxPool2dKernel"] diff --git a/tileops/kernels/pool/common.py b/tileops/kernels/pool/common.py index 9a4bf20d9..fef3073c4 100644 --- a/tileops/kernels/pool/common.py +++ b/tileops/kernels/pool/common.py @@ -40,10 +40,13 @@ def validate_pool_params( kernel_size: tuple[int, ...], stride: tuple[int, ...], padding: tuple[int, ...], + dilation: tuple[int, ...] | None = None, divisor_override: int | None = None, ) -> None: if len(kernel_size) != ndim or len(stride) != ndim or len(padding) != ndim: raise ValueError("kernel_size, stride, and padding must match pooling dimensionality") + if dilation is not None and len(dilation) != ndim: + raise ValueError("dilation must match pooling dimensionality") for name, values in ( ("kernel_size", kernel_size), @@ -62,7 +65,18 @@ def validate_pool_params( if any(v < 0 for v in padding): raise ValueError("padding must be non-negative") - for pad, kernel in zip(padding, kernel_size, strict=True): + if dilation is not None: + if not all(isinstance(v, int) and not isinstance(v, bool) for v in dilation): + raise TypeError("dilation must contain only ints") + if any(v <= 0 for v in dilation): + raise ValueError("dilation must be greater than zero") + effective_kernel = tuple( + (kernel - 1) * step + 1 for kernel, step in zip(kernel_size, dilation, strict=True) + ) + else: + effective_kernel = kernel_size + + for pad, kernel in zip(padding, effective_kernel, strict=True): if pad > kernel // 2: raise ValueError("padding must be at most half of the effective kernel size") @@ -103,11 +117,13 @@ def pool_output_dim( stride: int, padding: int, ceil_mode: bool, + dilation: int = 1, ) -> int: + effective_kernel = (kernel_size - 1) * dilation + 1 if ceil_mode: - out = (input_size + 2 * padding - kernel_size + stride - 1) // stride + 1 + out = (input_size + 2 * padding - effective_kernel + stride - 1) // stride + 1 else: - out = (input_size + 2 * padding - kernel_size) // stride + 1 + out = (input_size + 2 * padding - effective_kernel) // stride + 1 if ceil_mode and out > 0 and (out - 1) * stride >= input_size + padding: out -= 1 diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py new file mode 100644 index 000000000..cdc2d424f --- /dev/null +++ b/tileops/kernels/pool/max_pool2d.py @@ -0,0 +1,259 @@ +import functools +import itertools +from typing import Optional + +import tilelang +import tilelang.language as T +import torch + +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool.common import pool_output_dim + +__all__ = ["MaxPool2dKernel"] + +_SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) + + +@functools.lru_cache(maxsize=64) +def _max_pool2d_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str = "float16", +): + accum_dtype = "float32" + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + @tilelang.jit(out_idx=[1, 2], compile_flags=["-O3", "-DENABLE_BF16"]) + def _max_pool2d_func(block_m: int, block_c: int, threads: int): + @T.prim_func + def _max_pool2d_main( + x: T.Tensor((n, h_in, w_in, c_in), dtype), # type: ignore + out: T.Tensor((n, out_h, out_w, c_in), dtype), # type: ignore + out_indices: T.Tensor((n, out_h, out_w, c_in), "int64"), # type: ignore + ): + with T.Kernel( + T.ceildiv(c_in, block_c), + T.ceildiv(n * out_h * out_w, block_m), + threads=threads, + ) as (bx, by): + out_flat = T.Tensor((n * out_h * out_w, c_in), dtype, out.data) + indices_flat = T.Tensor((n * out_h * out_w, c_in), "int64", out_indices.data) + + for i, j in T.Parallel(block_m, block_c): + m_idx = by * block_m + i + c_idx = bx * block_c + j + if m_idx < n * out_h * out_w and c_idx < c_in: + batch = m_idx // (out_h * out_w) + out_idx = m_idx % (out_h * out_w) + oh = out_idx // out_w + ow = out_idx % out_w + max_val = T.alloc_var(T.float32) + max_index = T.alloc_var(T.int64) + max_val = -T.infinity(accum_dtype) + max_index = T.int64(0) + + for kh in T.serial(kernel_h): + for kw in T.serial(kernel_w): + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + if ih >= 0 and ih < h_in and iw >= 0 and iw < w_in: + candidate = T.cast(x[batch, ih, iw, c_idx], accum_dtype) + candidate_index = T.cast(ih * w_in + iw, "int64") + should_update = candidate > max_val + max_val = T.if_then_else(should_update, candidate, max_val) + max_index = T.if_then_else(should_update, candidate_index, max_index) + + out_flat[m_idx, c_idx] = T.cast(max_val, dtype) + indices_flat[m_idx, c_idx] = max_index + + return _max_pool2d_main + + return _max_pool2d_func + + +@torch.library.custom_op("top::max_pool2d_wrapped_kernel", mutates_args=()) +def _max_pool2d_wrapped_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool2d_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + dtype, + )(block_m, block_c, threads)(x) + + +@_max_pool2d_wrapped_kernel.register_fake +def _( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + _ = (dtype, block_m, block_c, threads) + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + return ( + torch.empty((n, out_h, out_w, c_in), dtype=x.dtype, device=x.device), + torch.empty((n, out_h, out_w, c_in), dtype=torch.int64, device=x.device), + ) + + +class MaxPool2dKernel(Kernel): + supported_archs: list[int] = [80, 86, 89, 90] + SUPPORTED_DTYPES = _SUPPORTED_DTYPES + + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + if self.SUPPORTED_DTYPES is not None and dtype not in self.SUPPORTED_DTYPES: + supported = ", ".join(str(dt) for dt in self.SUPPORTED_DTYPES) + raise ValueError( + f"{self.__class__.__name__} only supports dtypes [{supported}], got {dtype}" + ) + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_h = kernel_h + self.kernel_w = kernel_w + self.stride_h = stride_h + self.stride_w = stride_w + self.pad_h = pad_h + self.pad_w = pad_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.ceil_mode = ceil_mode + self.dtype = dtype + self.out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + self.out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + self.kernel = _max_pool2d_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + self.dtype_str, + ) + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_m": 128, + "block_c": 64, + "threads": 128, + } + + @property + def autotune_configs(self) -> list[dict]: + configs = itertools.product([64, 128, 256], [32, 64, 128], [128, 256]) + return [ + { + "block_m": block_m, + "block_c": block_c, + "threads": threads, + } + for block_m, block_c, threads in configs + ] + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool2d_wrapped_kernel( + self.n, + self.c_in, + self.h_in, + self.w_in, + self.kernel_h, + self.kernel_w, + self.stride_h, + self.stride_w, + self.pad_h, + self.pad_w, + self.dilation_h, + self.dilation_w, + self.ceil_mode, + self.dtype_str, + self.config["block_m"], + self.config["block_c"], + self.config["threads"], + x, + ) diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index f7c8a6d50..684b2f425 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -1,6 +1,7 @@ from .avg_pool1d import AvgPool1dOp from .avg_pool2d import AvgPool2dOp from .avg_pool3d import AvgPool3dOp +from .max_pool2d import MaxPool2dOp from .conv1d import Conv1dOp from .conv2d import Conv2dOp from .conv3d import Conv3dOp @@ -95,6 +96,7 @@ "AvgPool1dOp", "AvgPool2dOp", "AvgPool3dOp", + "MaxPool2dOp", "AdaLayerNormOp", "AdaLayerNormZeroOp", "BatchNormBwdOp", diff --git a/tileops/ops/max_pool2d.py b/tileops/ops/max_pool2d.py new file mode 100644 index 000000000..609be2204 --- /dev/null +++ b/tileops/ops/max_pool2d.py @@ -0,0 +1,100 @@ +from typing import Dict, Optional, Tuple + +import torch + +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.kernels.pool.common import ( + _normalize_pool_dims, + validate_channels_last_input, + validate_pool_params, +) + +from .op import Op + +__all__ = ["MaxPool2dOp"] + + +class MaxPool2dOp(Op): + """Max pooling over channels-last `NHWC` inputs.""" + + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: int | Tuple[int, int], + stride: Optional[int | Tuple[int, int]] = None, + padding: int | Tuple[int, int] = 0, + dilation: int | Tuple[int, int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, + dtype: torch.dtype = torch.float16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = _normalize_pool_dims("kernel_size", kernel_size, 2) + self.stride = ( + self.kernel_size + if stride is None + else _normalize_pool_dims("stride", stride, 2) + ) + self.padding = _normalize_pool_dims("padding", padding, 2) + self.dilation = _normalize_pool_dims("dilation", dilation, 2) + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + validate_pool_params( + ndim=2, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + + self.dispatch_kernel(kernel_map) + if "max_pool2d_kernel" not in self.kernel_map: + raise NotImplementedError("MaxPool2dOp requires 'max_pool2d_kernel' in kernel_map") + self.kernel = self.kernel_map["max_pool2d_kernel"]( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_h=self.kernel_size[0], + kernel_w=self.kernel_size[1], + stride_h=self.stride[0], + stride_w=self.stride[1], + pad_h=self.padding[0], + pad_w=self.padding[1], + dilation_h=self.dilation[0], + dilation_w=self.dilation[1], + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"max_pool2d_kernel": MaxPool2dKernel} + + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if not x.is_cuda: + raise ValueError("Input must be a CUDA tensor") + if x.dtype != self.dtype: + raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}") + validate_channels_last_input( + op_name=type(self).__name__, + x_shape=tuple(x.shape), + expected_shape=(self.n, self.h_in, self.w_in, self.c_in), + layout="NHWC", + ambiguous_layout_shape=(self.n, self.c_in, self.h_in, self.w_in), + ) + values, indices = self.kernel(x) + if self.return_indices: + return values, indices + return values From 808ee2ce67d895c5288d8347d3debc73aed88337 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Tue, 7 Apr 2026 21:27:58 +0800 Subject: [PATCH 3/9] [Fix][POOL] split max_pool2d index path --- benchmarks/ops/bench_max_pool2d.py | 5 +- tests/ops/test_max_pool2d.py | 93 ++++++++++++- tileops/kernels/pool/max_pool2d.py | 216 +++++++++++++++++++++++++---- tileops/ops/max_pool2d.py | 5 +- 4 files changed, 286 insertions(+), 33 deletions(-) diff --git a/benchmarks/ops/bench_max_pool2d.py b/benchmarks/ops/bench_max_pool2d.py index ad6174200..ee96ab710 100644 --- a/benchmarks/ops/bench_max_pool2d.py +++ b/benchmarks/ops/bench_max_pool2d.py @@ -51,13 +51,13 @@ def ref_program(self, x: torch.Tensor) -> torch.Tensor: class MaxPool2dBenchmark(BenchmarkBase): def calculate_flops(self) -> Optional[float]: - t = self.test + t = self.workload out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] def calculate_memory(self) -> Optional[float]: - t = self.test + t = self.workload out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize @@ -120,6 +120,7 @@ def test_max_pool2d_bench( stride=stride, padding=padding, dilation=dilation, + return_indices=False, ceil_mode=ceil_mode, dtype=dtype, tune=tune, diff --git a/tests/ops/test_max_pool2d.py b/tests/ops/test_max_pool2d.py index 150ed9028..73def78a5 100644 --- a/tests/ops/test_max_pool2d.py +++ b/tests/ops/test_max_pool2d.py @@ -10,7 +10,14 @@ from tileops.ops import MaxPool2dOp -class _DummyKernel(Kernel): +class _DummyValuesKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class _DummyValuesIndicesKernel(Kernel): supported_archs = [80] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -37,6 +44,16 @@ class MaxPool2dFixture(FixtureBase): marks=pytest.mark.full, id="full-dilation-bf16", ), + pytest.param( + 2, 64, 56, 56, (3, 3), (2, 2), (1, 1), (2, 2), False, False, torch.float16, False, + marks=pytest.mark.full, + id="full-dilated-maxpool-2x2-fp16", + ), + pytest.param( + 1, 48, 35, 35, (3, 3), (1, 1), (1, 1), (3, 3), False, False, torch.float16, False, + marks=pytest.mark.full, + id="full-dilated-maxpool-3x3-fp16", + ), pytest.param( 1, 32, 16, 18, (2, 3), (2, 2), (0, 1), (1, 1), True, False, torch.float16, False, marks=pytest.mark.full, @@ -160,7 +177,7 @@ def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPat kernel_size=(2, 2), stride=(2, 2), return_indices=True, - kernel_map={"max_pool2d_kernel": _DummyKernel}, + kernel_map={"max_pool2d_kernel": _DummyValuesIndicesKernel}, ) x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) values, indices = op(x) @@ -169,6 +186,72 @@ def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPat assert indices.shape == x.shape +@pytest.mark.smoke +def test_max_pool2d_default_path_uses_values_only_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("indices kernel should not be used when return_indices=False") + + def return_values(*args, **kwargs): + x = args[-1] + return x + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + fail_if_called, + ) + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + return_values, + ) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + out = op(x) + assert out is x + + +@pytest.mark.smoke +def test_max_pool2d_indices_path_uses_values_indices_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("values-only kernel should not be used when return_indices=True") + + def return_values_indices(*args, **kwargs): + x = args[-1] + return x, torch.zeros_like(x, dtype=torch.int64) + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + fail_if_called, + ) + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + return_values_indices, + ) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + + @pytest.mark.smoke def test_max_pool2d_rejects_non_positive_dilation() -> None: with pytest.raises(ValueError, match="dilation must be greater than zero"): @@ -244,7 +327,7 @@ def test_max_pool2d_forward_rejects_non_cuda_input(monkeypatch: pytest.MonkeyPat w_in=8, kernel_size=(2, 2), stride=(2, 2), - kernel_map={"max_pool2d_kernel": _DummyKernel}, + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, ) x = torch.randn(1, 8, 8, 4) with pytest.raises(ValueError, match="CUDA"): @@ -262,7 +345,7 @@ def test_max_pool2d_forward_rejects_nchw_shape(monkeypatch: pytest.MonkeyPatch) w_in=8, kernel_size=(2, 2), stride=(2, 2), - kernel_map={"max_pool2d_kernel": _DummyKernel}, + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, ) x = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) with pytest.raises(ValueError, match="NHWC"): @@ -280,7 +363,7 @@ def test_max_pool2d_forward_warns_on_ambiguous_nhwc_shape(monkeypatch: pytest.Mo w_in=8, kernel_size=(2, 2), stride=(2, 2), - kernel_map={"max_pool2d_kernel": _DummyKernel}, + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, ) x = torch.randn(1, 8, 8, 8, device="cuda", dtype=torch.float16) with pytest.warns(UserWarning, match="ambiguous NHWC shape"): diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py index cdc2d424f..addb8aaa7 100644 --- a/tileops/kernels/pool/max_pool2d.py +++ b/tileops/kernels/pool/max_pool2d.py @@ -15,7 +15,68 @@ @functools.lru_cache(maxsize=64) -def _max_pool2d_kernel( +def _max_pool2d_values_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str = "float16", +): + accum_dtype = "float32" + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + @tilelang.jit(out_idx=[1], compile_flags=["-O3", "-DENABLE_BF16"]) + def _max_pool2d_func(block_m: int, block_c: int, threads: int): + @T.prim_func + def _max_pool2d_main( + x: T.Tensor((n, h_in, w_in, c_in), dtype), # type: ignore + out: T.Tensor((n, out_h, out_w, c_in), dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(c_in, block_c), + T.ceildiv(n * out_h * out_w, block_m), + threads=threads, + ) as (bx, by): + out_flat = T.Tensor((n * out_h * out_w, c_in), dtype, out.data) + + for i, j in T.Parallel(block_m, block_c): + m_idx = by * block_m + i + c_idx = bx * block_c + j + if m_idx < n * out_h * out_w and c_idx < c_in: + batch = m_idx // (out_h * out_w) + out_idx = m_idx % (out_h * out_w) + oh = out_idx // out_w + ow = out_idx % out_w + max_val = T.alloc_var(T.float32) + max_val = -T.infinity(accum_dtype) + + for kh in T.serial(kernel_h): + for kw in T.serial(kernel_w): + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + if ih >= 0 and ih < h_in and iw >= 0 and iw < w_in: + candidate = T.cast(x[batch, ih, iw, c_idx], accum_dtype) + max_val = T.max(max_val, candidate) + + out_flat[m_idx, c_idx] = T.cast(max_val, dtype) + + return _max_pool2d_main + + return _max_pool2d_func + + +@functools.lru_cache(maxsize=64) +def _max_pool2d_values_indices_kernel( n: int, c_in: int, h_in: int, @@ -83,8 +144,74 @@ def _max_pool2d_main( return _max_pool2d_func -@torch.library.custom_op("top::max_pool2d_wrapped_kernel", mutates_args=()) -def _max_pool2d_wrapped_kernel( +@torch.library.custom_op("top::max_pool2d_values_wrapped_kernel", mutates_args=()) +def _max_pool2d_values_wrapped_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> torch.Tensor: + return _max_pool2d_values_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + dtype, + )(block_m, block_c, threads)(x) + + +@_max_pool2d_values_wrapped_kernel.register_fake +def _max_pool2d_values_wrapped_kernel_fake( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> torch.Tensor: + _ = (dtype, block_m, block_c, threads) + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + return torch.empty((n, out_h, out_w, c_in), dtype=x.dtype, device=x.device) + + +@torch.library.custom_op("top::max_pool2d_values_indices_wrapped_kernel", mutates_args=()) +def _max_pool2d_values_indices_wrapped_kernel( n: int, c_in: int, h_in: int, @@ -104,7 +231,7 @@ def _max_pool2d_wrapped_kernel( threads: int, x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - return _max_pool2d_kernel( + return _max_pool2d_values_indices_kernel( n, c_in, h_in, @@ -122,8 +249,8 @@ def _max_pool2d_wrapped_kernel( )(block_m, block_c, threads)(x) -@_max_pool2d_wrapped_kernel.register_fake -def _( +@_max_pool2d_values_indices_wrapped_kernel.register_fake +def _max_pool2d_values_indices_wrapped_kernel_fake( n: int, c_in: int, h_in: int, @@ -171,6 +298,7 @@ def __init__( dilation_h: int, dilation_w: int, ceil_mode: bool, + return_indices: bool, dtype: torch.dtype, config: Optional[dict] = None, tune: bool = False, @@ -194,26 +322,45 @@ def __init__( self.dilation_h = dilation_h self.dilation_w = dilation_w self.ceil_mode = ceil_mode + self.return_indices = return_indices self.dtype = dtype self.out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) self.out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) - self.kernel = _max_pool2d_kernel( - n, - c_in, - h_in, - w_in, - kernel_h, - kernel_w, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - ceil_mode, - self.dtype_str, - ) + if return_indices: + self.kernel = _max_pool2d_values_indices_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + self.dtype_str, + ) + else: + self.kernel = _max_pool2d_values_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + self.dtype_str, + ) self.init_config(config, tune) @property @@ -236,8 +383,29 @@ def autotune_configs(self) -> list[dict]: for block_m, block_c, threads in configs ] - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - return _max_pool2d_wrapped_kernel( + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.return_indices: + return _max_pool2d_values_indices_wrapped_kernel( + self.n, + self.c_in, + self.h_in, + self.w_in, + self.kernel_h, + self.kernel_w, + self.stride_h, + self.stride_w, + self.pad_h, + self.pad_w, + self.dilation_h, + self.dilation_w, + self.ceil_mode, + self.dtype_str, + self.config["block_m"], + self.config["block_c"], + self.config["threads"], + x, + ) + return _max_pool2d_values_wrapped_kernel( self.n, self.c_in, self.h_in, diff --git a/tileops/ops/max_pool2d.py b/tileops/ops/max_pool2d.py index 609be2204..8f1b2e577 100644 --- a/tileops/ops/max_pool2d.py +++ b/tileops/ops/max_pool2d.py @@ -74,6 +74,7 @@ def __init__( dilation_h=self.dilation[0], dilation_w=self.dilation[1], ceil_mode=ceil_mode, + return_indices=return_indices, dtype=dtype, tune=tune, ) @@ -94,7 +95,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.T layout="NHWC", ambiguous_layout_shape=(self.n, self.c_in, self.h_in, self.w_in), ) - values, indices = self.kernel(x) if self.return_indices: + values, indices = self.kernel(x) return values, indices - return values + return self.kernel(x) From 2a0aa195bfa3c694901d8bd60950c8edac75eb9d Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 8 Apr 2026 11:48:42 +0800 Subject: [PATCH 4/9] [Fix][POOL] widen max_pool2d indices math --- tileops/kernels/pool/max_pool2d.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py index addb8aaa7..ab7b74b45 100644 --- a/tileops/kernels/pool/max_pool2d.py +++ b/tileops/kernels/pool/max_pool2d.py @@ -131,7 +131,10 @@ def _max_pool2d_main( iw = ow * stride_w + kw * dilation_w - pad_w if ih >= 0 and ih < h_in and iw >= 0 and iw < w_in: candidate = T.cast(x[batch, ih, iw, c_idx], accum_dtype) - candidate_index = T.cast(ih * w_in + iw, "int64") + candidate_index = ( + T.cast(ih, "int64") * T.cast(w_in, "int64") + + T.cast(iw, "int64") + ) should_update = candidate > max_val max_val = T.if_then_else(should_update, candidate, max_val) max_index = T.if_then_else(should_update, candidate_index, max_index) From ae476b0f614d952eff1a70b4af1dbbb0e8b08f42 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 8 Apr 2026 11:58:51 +0800 Subject: [PATCH 5/9] [Chore][Lint] fix pre-commit import order --- tileops/ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index 684b2f425..a5069b29f 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -1,7 +1,6 @@ from .avg_pool1d import AvgPool1dOp from .avg_pool2d import AvgPool2dOp from .avg_pool3d import AvgPool3dOp -from .max_pool2d import MaxPool2dOp from .conv1d import Conv1dOp from .conv2d import Conv2dOp from .conv3d import Conv3dOp @@ -32,6 +31,7 @@ from .gqa_sliding_window_fwd import GqaSlidingWindowFwdOp from .gqa_sliding_window_varlen_fwd import GqaSlidingWindowVarlenFwdOp from .grouped_gemm import GroupedGemmOp +from .max_pool2d import MaxPool2dOp from .mha import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp from .mha_decode import MultiHeadAttentionDecodeWithKVCacheOp from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp From e529d7eda5538266d27e03ea3a1b05774422370f Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 8 Apr 2026 14:31:45 +0800 Subject: [PATCH 6/9] [Fix][POOL] align max_pool2d semantics --- tileops/kernels/pool/common.py | 9 ++------- tileops/kernels/pool/max_pool2d.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tileops/kernels/pool/common.py b/tileops/kernels/pool/common.py index fef3073c4..78add9acc 100644 --- a/tileops/kernels/pool/common.py +++ b/tileops/kernels/pool/common.py @@ -70,15 +70,10 @@ def validate_pool_params( raise TypeError("dilation must contain only ints") if any(v <= 0 for v in dilation): raise ValueError("dilation must be greater than zero") - effective_kernel = tuple( - (kernel - 1) * step + 1 for kernel, step in zip(kernel_size, dilation, strict=True) - ) - else: - effective_kernel = kernel_size - for pad, kernel in zip(padding, effective_kernel, strict=True): + for pad, kernel in zip(padding, kernel_size, strict=True): if pad > kernel // 2: - raise ValueError("padding must be at most half of the effective kernel size") + raise ValueError("padding must be at most half of the kernel size") if divisor_override is not None and (not isinstance(divisor_override, int) or isinstance(divisor_override, bool)): raise TypeError("divisor_override must be an int or None") diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py index ab7b74b45..a10d1c2ce 100644 --- a/tileops/kernels/pool/max_pool2d.py +++ b/tileops/kernels/pool/max_pool2d.py @@ -120,10 +120,19 @@ def _max_pool2d_main( out_idx = m_idx % (out_h * out_w) oh = out_idx // out_w ow = out_idx % out_w + window_h_start = oh * stride_h - pad_h + window_w_start = ow * stride_w - pad_w + first_kh = T.ceildiv(T.max(-window_h_start, 0), dilation_h) + first_kw = T.ceildiv(T.max(-window_w_start, 0), dilation_w) + first_ih = window_h_start + first_kh * dilation_h + first_iw = window_w_start + first_kw * dilation_w max_val = T.alloc_var(T.float32) max_index = T.alloc_var(T.int64) - max_val = -T.infinity(accum_dtype) - max_index = T.int64(0) + max_val = T.cast(x[batch, first_ih, first_iw, c_idx], accum_dtype) + max_index = ( + T.cast(first_ih, "int64") * T.cast(w_in, "int64") + + T.cast(first_iw, "int64") + ) for kh in T.serial(kernel_h): for kw in T.serial(kernel_w): @@ -135,9 +144,9 @@ def _max_pool2d_main( T.cast(ih, "int64") * T.cast(w_in, "int64") + T.cast(iw, "int64") ) - should_update = candidate > max_val - max_val = T.if_then_else(should_update, candidate, max_val) - max_index = T.if_then_else(should_update, candidate_index, max_index) + if candidate > max_val: + max_val = candidate + max_index = candidate_index out_flat[m_idx, c_idx] = T.cast(max_val, dtype) indices_flat[m_idx, c_idx] = max_index From 225aa4be679154b689a93cc78756691fabaa6ffb Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 8 Apr 2026 14:32:13 +0800 Subject: [PATCH 7/9] [Chore][POOL] split max_pool2d tests and benchmarks --- benchmarks/ops/bench_max_pool2d.py | 136 ----------- tests/ops/test_max_pool2d.py | 375 ----------------------------- 2 files changed, 511 deletions(-) delete mode 100644 benchmarks/ops/bench_max_pool2d.py delete mode 100644 tests/ops/test_max_pool2d.py diff --git a/benchmarks/ops/bench_max_pool2d.py b/benchmarks/ops/bench_max_pool2d.py deleted file mode 100644 index ee96ab710..000000000 --- a/benchmarks/ops/bench_max_pool2d.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import Optional, Tuple - -import pytest -import torch -import torch.nn.functional as F - -from benchmarks.benchmark import BenchmarkBase, BenchmarkReport -from tileops.kernels.pool.common import pool_output_dim -from tileops.ops import MaxPool2dOp - - -class MaxPool2dBenchCase: - def __init__( - self, - n: int, - c_in: int, - h_in: int, - w_in: int, - kernel_size: Tuple[int, int], - stride: Optional[Tuple[int, int]], - padding: Tuple[int, int], - dilation: Tuple[int, int], - ceil_mode: bool, - dtype: torch.dtype, - ) -> None: - self.n = n - self.c_in = c_in - self.h_in = h_in - self.w_in = w_in - self.kernel_size = kernel_size - self.stride = kernel_size if stride is None else stride - self.padding = padding - self.dilation = dilation - self.ceil_mode = ceil_mode - self.dtype = dtype - - def gen_inputs(self) -> tuple[torch.Tensor]: - x = torch.randn(self.n, self.h_in, self.w_in, self.c_in, device="cuda", dtype=self.dtype).contiguous() - return (x,) - - def ref_program(self, x: torch.Tensor) -> torch.Tensor: - return F.max_pool2d( - x, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - ceil_mode=self.ceil_mode, - ) - - -class MaxPool2dBenchmark(BenchmarkBase): - def calculate_flops(self) -> Optional[float]: - t = self.workload - out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) - out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) - return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] - - def calculate_memory(self) -> Optional[float]: - t = self.workload - out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) - out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) - return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize - - -_MAX_POOL2D_BASE_CASES = [ - (2, 64, 112, 112, (3, 3), (2, 2), (1, 1), (1, 1), False, "vision-3x3-s2"), - (2, 128, 56, 56, (5, 5), (2, 2), (2, 2), (1, 1), False, "vision-5x5-s2"), - (3, 96, 55, 57, (3, 3), (2, 2), (1, 1), (2, 1), True, "ceil-dilation-nonpow2"), -] - -_MAX_POOL2D_BENCH_PARAMS = [ - pytest.param(*case[:-1], dtype, True, id=f"{case[-1]}-{str(dtype).split('.')[-1]}") - for case in _MAX_POOL2D_BASE_CASES - for dtype in (torch.float16, torch.bfloat16) -] - - -@pytest.mark.parametrize( - "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, ceil_mode, dtype, tune", - _MAX_POOL2D_BENCH_PARAMS, -) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_max_pool2d_bench( - n: int, - c_in: int, - h_in: int, - w_in: int, - kernel_size: Tuple[int, int], - stride: Optional[Tuple[int, int]], - padding: Tuple[int, int], - dilation: Tuple[int, int], - ceil_mode: bool, - dtype: torch.dtype, - tune: bool, -) -> None: - test = MaxPool2dBenchCase( - n, - c_in, - h_in, - w_in, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - dtype, - ) - bm = MaxPool2dBenchmark(test) - inputs = test.gen_inputs() - (x,) = inputs - x_nchw = x.permute(0, 3, 1, 2).contiguous() - - op = MaxPool2dOp( - n=n, - c_in=c_in, - h_in=h_in, - w_in=w_in, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - return_indices=False, - ceil_mode=ceil_mode, - dtype=dtype, - tune=tune, - ) - result = bm.profile(op, *inputs) - BenchmarkReport.record(op, locals(), result, tag="tileops") - - result_bl = bm.profile(test.ref_program, x_nchw) - BenchmarkReport.record(op, locals(), result_bl, tag="torch") - - -if __name__ == "__main__": - pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_max_pool2d.py b/tests/ops/test_max_pool2d.py deleted file mode 100644 index 73def78a5..000000000 --- a/tests/ops/test_max_pool2d.py +++ /dev/null @@ -1,375 +0,0 @@ -from typing import Optional, Tuple - -import pytest -import torch -import torch.nn.functional as F - -from tests.test_base import FixtureBase, TestBase -from tileops.kernels.kernel import Kernel -from tileops.kernels.pool import MaxPool2dKernel -from tileops.ops import MaxPool2dOp - - -class _DummyValuesKernel(Kernel): - supported_archs = [80] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x - - -class _DummyValuesIndicesKernel(Kernel): - supported_archs = [80] - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - return x, torch.zeros_like(x, dtype=torch.int64) - - -class MaxPool2dFixture(FixtureBase): - PARAMS = [ - ( - "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, return_indices, ceil_mode, dtype, tune", - [ - pytest.param( - 2, 64, 56, 56, (3, 3), None, (1, 1), (1, 1), False, False, torch.float16, False, - marks=[pytest.mark.smoke, pytest.mark.packaging], - id="smoke-3x3-default-stride-fp16", - ), - pytest.param( - 1, 96, 29, 31, (3, 5), (2, 2), (1, 2), (1, 1), False, True, torch.float16, False, - marks=pytest.mark.full, - id="full-ceil-nonpow2-fp16", - ), - pytest.param( - 1, 80, 28, 30, (3, 3), (2, 2), (1, 1), (2, 1), False, False, torch.bfloat16, False, - marks=pytest.mark.full, - id="full-dilation-bf16", - ), - pytest.param( - 2, 64, 56, 56, (3, 3), (2, 2), (1, 1), (2, 2), False, False, torch.float16, False, - marks=pytest.mark.full, - id="full-dilated-maxpool-2x2-fp16", - ), - pytest.param( - 1, 48, 35, 35, (3, 3), (1, 1), (1, 1), (3, 3), False, False, torch.float16, False, - marks=pytest.mark.full, - id="full-dilated-maxpool-3x3-fp16", - ), - pytest.param( - 1, 32, 16, 18, (2, 3), (2, 2), (0, 1), (1, 1), True, False, torch.float16, False, - marks=pytest.mark.full, - id="full-return-indices-fp16", - ), - ], - ), - ] - - -class MaxPool2dTest(TestBase): - def __init__( - self, - kernel_size: Tuple[int, int], - stride: Optional[Tuple[int, int]], - padding: Tuple[int, int], - dilation: Tuple[int, int], - return_indices: bool, - ceil_mode: bool, - dtype: torch.dtype, - ) -> None: - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.return_indices = return_indices - self.ceil_mode = ceil_mode - self.dtype = dtype - - def gen_inputs(self, n: int, c_in: int, h_in: int, w_in: int) -> tuple[torch.Tensor]: - x = torch.randn(n, h_in, w_in, c_in, device="cuda", dtype=self.dtype).contiguous() - return (x,) - - def ref_program(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - out = F.max_pool2d( - x.permute(0, 3, 1, 2).contiguous(), - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - return_indices=self.return_indices, - ceil_mode=self.ceil_mode, - ) - if self.return_indices: - values, indices = out - return ( - values.permute(0, 2, 3, 1).contiguous(), - indices.permute(0, 2, 3, 1).contiguous(), - ) - return out.permute(0, 2, 3, 1).contiguous() - - -@MaxPool2dFixture -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_max_pool2d( - n: int, - c_in: int, - h_in: int, - w_in: int, - kernel_size: Tuple[int, int], - stride: Optional[Tuple[int, int]], - padding: Tuple[int, int], - dilation: Tuple[int, int], - return_indices: bool, - ceil_mode: bool, - dtype: torch.dtype, - tune: bool, -) -> None: - test = MaxPool2dTest( - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - return_indices=return_indices, - ceil_mode=ceil_mode, - dtype=dtype, - ) - op = MaxPool2dOp( - n=n, - c_in=c_in, - h_in=h_in, - w_in=w_in, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - return_indices=return_indices, - ceil_mode=ceil_mode, - dtype=dtype, - tune=tune, - ) - atol = 1e-3 if dtype == torch.float16 else 1.6e-2 - rtol = 1e-3 if dtype == torch.float16 else 1.6e-2 - test.check(op, *test.gen_inputs(n, c_in, h_in, w_in), atol=atol, rtol=rtol) - - -@pytest.mark.smoke -def test_max_pool2d_dispatches_kernel(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - op = MaxPool2dOp( - n=1, - c_in=32, - h_in=28, - w_in=28, - kernel_size=(3, 3), - stride=(2, 2), - padding=(1, 1), - ) - assert isinstance(op.kernel, MaxPool2dKernel) - - -@pytest.mark.smoke -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - op = MaxPool2dOp( - n=1, - c_in=4, - h_in=8, - w_in=8, - kernel_size=(2, 2), - stride=(2, 2), - return_indices=True, - kernel_map={"max_pool2d_kernel": _DummyValuesIndicesKernel}, - ) - x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) - values, indices = op(x) - assert values is x - assert indices.dtype == torch.int64 - assert indices.shape == x.shape - - -@pytest.mark.smoke -def test_max_pool2d_default_path_uses_values_only_kernel(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - - def fail_if_called(*args, **kwargs): - raise AssertionError("indices kernel should not be used when return_indices=False") - - def return_values(*args, **kwargs): - x = args[-1] - return x - - monkeypatch.setattr( - "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", - fail_if_called, - ) - monkeypatch.setattr( - "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", - return_values, - ) - op = MaxPool2dOp( - n=1, - c_in=4, - h_in=8, - w_in=8, - kernel_size=(2, 2), - stride=(2, 2), - ) - x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) - out = op(x) - assert out is x - - -@pytest.mark.smoke -def test_max_pool2d_indices_path_uses_values_indices_kernel(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - - def fail_if_called(*args, **kwargs): - raise AssertionError("values-only kernel should not be used when return_indices=True") - - def return_values_indices(*args, **kwargs): - x = args[-1] - return x, torch.zeros_like(x, dtype=torch.int64) - - monkeypatch.setattr( - "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", - fail_if_called, - ) - monkeypatch.setattr( - "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", - return_values_indices, - ) - op = MaxPool2dOp( - n=1, - c_in=4, - h_in=8, - w_in=8, - kernel_size=(2, 2), - stride=(2, 2), - return_indices=True, - ) - x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) - values, indices = op(x) - assert values is x - assert indices.dtype == torch.int64 - - -@pytest.mark.smoke -def test_max_pool2d_rejects_non_positive_dilation() -> None: - with pytest.raises(ValueError, match="dilation must be greater than zero"): - MaxPool2dOp( - n=1, - c_in=8, - h_in=16, - w_in=16, - kernel_size=(3, 3), - dilation=(1, 0), - ) - - -@pytest.mark.smoke -def test_max_pool2d_rejects_invalid_padding_for_effective_kernel() -> None: - with pytest.raises(ValueError, match="padding must be at most half"): - MaxPool2dOp( - n=1, - c_in=8, - h_in=16, - w_in=16, - kernel_size=(3, 3), - padding=(3, 1), - dilation=(2, 1), - ) - - -@pytest.mark.smoke -@pytest.mark.parametrize( - ("kwargs", "match"), - [ - ({"dilation": True}, "dilation must be an int or a tuple of 2 ints"), - ({"dilation": (1, True)}, "dilation must contain only ints"), - ({"kernel_size": True}, "kernel_size must be an int or a tuple of 2 ints"), - ({"stride": True}, "stride must be an int or a tuple of 2 ints"), - ({"padding": True}, "padding must be an int or a tuple of 2 ints"), - ], -) -def test_max_pool2d_rejects_invalid_param_types(kwargs: dict[str, object], match: str) -> None: - base_kwargs = { - "n": 1, - "c_in": 8, - "h_in": 16, - "w_in": 16, - "kernel_size": (3, 3), - } - base_kwargs.update(kwargs) - with pytest.raises((TypeError, ValueError), match=match): - MaxPool2dOp(**base_kwargs) - - -@pytest.mark.smoke -def test_max_pool2d_rejects_unsupported_dtype(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - with pytest.raises(ValueError, match="only supports dtypes"): - MaxPool2dOp( - n=1, - c_in=8, - h_in=16, - w_in=16, - kernel_size=(3, 3), - dtype=torch.float32, - ) - - -@pytest.mark.smoke -def test_max_pool2d_forward_rejects_non_cuda_input(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - op = MaxPool2dOp( - n=1, - c_in=4, - h_in=8, - w_in=8, - kernel_size=(2, 2), - stride=(2, 2), - kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, - ) - x = torch.randn(1, 8, 8, 4) - with pytest.raises(ValueError, match="CUDA"): - op(x) - - -@pytest.mark.smoke -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_max_pool2d_forward_rejects_nchw_shape(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - op = MaxPool2dOp( - n=1, - c_in=4, - h_in=8, - w_in=8, - kernel_size=(2, 2), - stride=(2, 2), - kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, - ) - x = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) - with pytest.raises(ValueError, match="NHWC"): - op(x) - - -@pytest.mark.smoke -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_max_pool2d_forward_warns_on_ambiguous_nhwc_shape(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) - op = MaxPool2dOp( - n=1, - c_in=8, - h_in=8, - w_in=8, - kernel_size=(2, 2), - stride=(2, 2), - kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, - ) - x = torch.randn(1, 8, 8, 8, device="cuda", dtype=torch.float16) - with pytest.warns(UserWarning, match="ambiguous NHWC shape"): - out = op(x) - assert out is x - - -if __name__ == "__main__": - pytest.main([__file__, "-vvs"]) From d654c32ef95b7aba0cf7936992db1fdfdd58cf76 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Thu, 9 Apr 2026 15:27:53 +0800 Subject: [PATCH 8/9] [Fix][POOL] handle empty max_pool2d windows --- tileops/kernels/pool/max_pool2d.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py index a10d1c2ce..f07448204 100644 --- a/tileops/kernels/pool/max_pool2d.py +++ b/tileops/kernels/pool/max_pool2d.py @@ -126,12 +126,26 @@ def _max_pool2d_main( first_kw = T.ceildiv(T.max(-window_w_start, 0), dilation_w) first_ih = window_h_start + first_kh * dilation_h first_iw = window_w_start + first_kw * dilation_w + has_valid = ( + first_kh < kernel_h + and first_kw < kernel_w + and first_ih >= 0 + and first_ih < h_in + and first_iw >= 0 + and first_iw < w_in + ) max_val = T.alloc_var(T.float32) max_index = T.alloc_var(T.int64) - max_val = T.cast(x[batch, first_ih, first_iw, c_idx], accum_dtype) - max_index = ( + max_val = T.if_then_else( + has_valid, + T.cast(x[batch, first_ih, first_iw, c_idx], accum_dtype), + -T.infinity(accum_dtype), + ) + max_index = T.if_then_else( + has_valid, T.cast(first_ih, "int64") * T.cast(w_in, "int64") - + T.cast(first_iw, "int64") + + T.cast(first_iw, "int64"), + T.cast(pad_h, "int64") * T.cast(w_in, "int64") + T.cast(pad_w, "int64"), ) for kh in T.serial(kernel_h): From 8a87232d8ff19d753d6dbd7a0377b94740cf611f Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 10 Apr 2026 14:58:57 +0800 Subject: [PATCH 9/9] [Fix][POOL] restore max_pool2d validation assets --- benchmarks/ops/bench_max_pool2d.py | 136 +++++++++ tests/ops/test_max_pool2d.py | 471 +++++++++++++++++++++++++++++ tileops/kernels/pool/max_pool2d.py | 14 + 3 files changed, 621 insertions(+) create mode 100644 benchmarks/ops/bench_max_pool2d.py create mode 100644 tests/ops/test_max_pool2d.py diff --git a/benchmarks/ops/bench_max_pool2d.py b/benchmarks/ops/bench_max_pool2d.py new file mode 100644 index 000000000..ee96ab710 --- /dev/null +++ b/benchmarks/ops/bench_max_pool2d.py @@ -0,0 +1,136 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from benchmarks.benchmark import BenchmarkBase, BenchmarkReport +from tileops.kernels.pool.common import pool_output_dim +from tileops.ops import MaxPool2dOp + + +class MaxPool2dBenchCase: + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = kernel_size + self.stride = kernel_size if stride is None else stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self) -> tuple[torch.Tensor]: + x = torch.randn(self.n, self.h_in, self.w_in, self.c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor: + return F.max_pool2d( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ceil_mode=self.ceil_mode, + ) + + +class MaxPool2dBenchmark(BenchmarkBase): + def calculate_flops(self) -> Optional[float]: + t = self.workload + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] + + def calculate_memory(self) -> Optional[float]: + t = self.workload + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize + + +_MAX_POOL2D_BASE_CASES = [ + (2, 64, 112, 112, (3, 3), (2, 2), (1, 1), (1, 1), False, "vision-3x3-s2"), + (2, 128, 56, 56, (5, 5), (2, 2), (2, 2), (1, 1), False, "vision-5x5-s2"), + (3, 96, 55, 57, (3, 3), (2, 2), (1, 1), (2, 1), True, "ceil-dilation-nonpow2"), +] + +_MAX_POOL2D_BENCH_PARAMS = [ + pytest.param(*case[:-1], dtype, True, id=f"{case[-1]}-{str(dtype).split('.')[-1]}") + for case in _MAX_POOL2D_BASE_CASES + for dtype in (torch.float16, torch.bfloat16) +] + + +@pytest.mark.parametrize( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, ceil_mode, dtype, tune", + _MAX_POOL2D_BENCH_PARAMS, +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_bench( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dBenchCase( + n, + c_in, + h_in, + w_in, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + dtype, + ) + bm = MaxPool2dBenchmark(test) + inputs = test.gen_inputs() + (x,) = inputs + x_nchw = x.permute(0, 3, 1, 2).contiguous() + + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=False, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + result = bm.profile(op, *inputs) + BenchmarkReport.record(op, locals(), result, tag="tileops") + + result_bl = bm.profile(test.ref_program, x_nchw) + BenchmarkReport.record(op, locals(), result_bl, tag="torch") + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_max_pool2d.py b/tests/ops/test_max_pool2d.py new file mode 100644 index 000000000..a3cf124b9 --- /dev/null +++ b/tests/ops/test_max_pool2d.py @@ -0,0 +1,471 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from tests.test_base import FixtureBase, TestBase +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.ops import MaxPool2dOp + + +class _DummyValuesKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class _DummyValuesIndicesKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return x, torch.zeros_like(x, dtype=torch.int64) + + +class MaxPool2dFixture(FixtureBase): + PARAMS = [ + ( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, return_indices, ceil_mode, dtype, tune", + [ + pytest.param( + 2, 64, 56, 56, (3, 3), None, (1, 1), (1, 1), False, False, torch.float16, False, + marks=[pytest.mark.smoke, pytest.mark.packaging], + id="smoke-3x3-default-stride-fp16", + ), + pytest.param( + 1, 96, 29, 31, (3, 5), (2, 2), (1, 2), (1, 1), False, True, torch.float16, False, + marks=pytest.mark.full, + id="full-ceil-nonpow2-fp16", + ), + pytest.param( + 1, 80, 28, 30, (3, 3), (2, 2), (1, 1), (2, 1), False, False, torch.bfloat16, False, + marks=pytest.mark.full, + id="full-dilation-bf16", + ), + pytest.param( + 2, 64, 56, 56, (3, 3), (2, 2), (1, 1), (2, 2), False, False, torch.float16, False, + marks=pytest.mark.full, + id="full-dilated-maxpool-2x2-fp16", + ), + pytest.param( + 1, 48, 35, 35, (3, 3), (1, 1), (1, 1), (3, 3), False, False, torch.float16, False, + marks=pytest.mark.full, + id="full-dilated-maxpool-3x3-fp16", + ), + pytest.param( + 1, 32, 16, 18, (2, 3), (2, 2), (0, 1), (1, 1), True, False, torch.float16, False, + marks=pytest.mark.full, + id="full-return-indices-fp16", + ), + ], + ), + ] + + +class MaxPool2dTest(TestBase): + def __init__( + self, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self, n: int, c_in: int, h_in: int, w_in: int) -> tuple[torch.Tensor]: + x = torch.randn(n, h_in, w_in, c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + out = F.max_pool2d( + x.permute(0, 3, 1, 2).contiguous(), + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=self.return_indices, + ceil_mode=self.ceil_mode, + ) + if self.return_indices: + values, indices = out + return ( + values.permute(0, 2, 3, 1).contiguous(), + indices.permute(0, 2, 3, 1).contiguous(), + ) + return out.permute(0, 2, 3, 1).contiguous() + + +@MaxPool2dFixture +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dTest( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + ) + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + atol = 1e-3 if dtype == torch.float16 else 1.6e-2 + rtol = 1e-3 if dtype == torch.float16 else 1.6e-2 + test.check(op, *test.gen_inputs(n, c_in, h_in, w_in), atol=atol, rtol=rtol) + + +@pytest.mark.smoke +def test_max_pool2d_dispatches_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=32, + h_in=28, + w_in=28, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + assert isinstance(op.kernel, MaxPool2dKernel) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + kernel_map={"max_pool2d_kernel": _DummyValuesIndicesKernel}, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + assert indices.shape == x.shape + + +@pytest.mark.smoke +def test_max_pool2d_default_path_uses_values_only_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("indices kernel should not be used when return_indices=False") + + def return_values(*args, **kwargs): + x = args[-1] + return x + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + fail_if_called, + ) + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + return_values, + ) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + out = op(x) + assert out is x + + +@pytest.mark.smoke +def test_max_pool2d_indices_path_uses_values_indices_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("values-only kernel should not be used when return_indices=True") + + def return_values_indices(*args, **kwargs): + x = args[-1] + return x, torch.zeros_like(x, dtype=torch.int64) + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + fail_if_called, + ) + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + return_values_indices, + ) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + + +@pytest.mark.smoke +def test_max_pool2d_rejects_non_positive_dilation() -> None: + with pytest.raises(ValueError, match="dilation must be greater than zero"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dilation=(1, 0), + ) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_invalid_padding_for_effective_kernel() -> None: + with pytest.raises(ValueError, match="padding must be at most half"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + padding=(3, 1), + dilation=(2, 1), + ) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_padding_pyTorch_rejects_when_dilated() -> None: + with pytest.raises(ValueError, match="padding must be at most half"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + padding=(2, 2), + dilation=(2, 2), + ) + + +@pytest.mark.smoke +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"dilation": True}, "dilation must be an int or a tuple of 2 ints"), + ({"dilation": (1, True)}, "dilation must contain only ints"), + ({"kernel_size": True}, "kernel_size must be an int or a tuple of 2 ints"), + ({"stride": True}, "stride must be an int or a tuple of 2 ints"), + ({"padding": True}, "padding must be an int or a tuple of 2 ints"), + ], +) +def test_max_pool2d_rejects_invalid_param_types(kwargs: dict[str, object], match: str) -> None: + base_kwargs = { + "n": 1, + "c_in": 8, + "h_in": 16, + "w_in": 16, + "kernel_size": (3, 3), + } + base_kwargs.update(kwargs) + with pytest.raises((TypeError, ValueError), match=match): + MaxPool2dOp(**base_kwargs) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_unsupported_dtype(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + with pytest.raises(ValueError, match="only supports dtypes"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dtype=torch.float32, + ) + + +@pytest.mark.smoke +def test_max_pool2d_forward_rejects_non_cuda_input(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, + ) + x = torch.randn(1, 8, 8, 4) + with pytest.raises(ValueError, match="CUDA"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_rejects_nchw_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, + ) + x = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="NHWC"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_warns_on_ambiguous_nhwc_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=8, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, + ) + x = torch.randn(1, 8, 8, 8, device="cuda", dtype=torch.float16) + with pytest.warns(UserWarning, match="ambiguous NHWC shape"): + out = op(x) + assert out is x + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_return_indices_handles_all_negative_infinity(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + x = torch.full((1, 3, 3, 1), float("-inf"), device="cuda", dtype=torch.float16) + op = MaxPool2dOp( + n=1, + c_in=1, + h_in=3, + w_in=3, + kernel_size=(2, 2), + stride=(1, 1), + return_indices=True, + ) + values, indices = op(x) + expected_values = torch.full((1, 2, 2, 1), float("-inf"), device="cuda", dtype=torch.float16) + expected_indices = torch.tensor([[[[0], [1]], [[3], [4]]]], device="cuda", dtype=torch.int64) + torch.testing.assert_close(values, expected_values) + torch.testing.assert_close(indices, expected_indices) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_returns_empty_tensor_for_zero_sized_output(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("kernel wrapper should not be called for zero-sized outputs") + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + fail_if_called, + ) + op = MaxPool2dOp( + n=1, + c_in=1, + h_in=1, + w_in=1, + kernel_size=(2, 2), + stride=(2, 2), + dilation=(2, 2), + return_indices=False, + ) + x = torch.randn(1, 1, 1, 1, device="cuda", dtype=torch.float16) + out = op(x) + assert out.shape == (1, 0, 0, 1) + assert out.numel() == 0 + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_return_indices_returns_empty_tensors_for_zero_sized_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("kernel wrapper should not be called for zero-sized outputs") + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + fail_if_called, + ) + op = MaxPool2dOp( + n=1, + c_in=1, + h_in=1, + w_in=1, + kernel_size=(2, 2), + stride=(2, 2), + dilation=(2, 2), + return_indices=True, + ) + x = torch.randn(1, 1, 1, 1, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values.shape == (1, 0, 0, 1) + assert values.numel() == 0 + assert indices.shape == (1, 0, 0, 1) + assert indices.dtype == torch.int64 + assert indices.numel() == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py index f07448204..86acbf3ad 100644 --- a/tileops/kernels/pool/max_pool2d.py +++ b/tileops/kernels/pool/max_pool2d.py @@ -410,6 +410,20 @@ def autotune_configs(self) -> list[dict]: ] def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.out_h == 0 or self.out_w == 0: + empty_values = torch.empty( + (self.n, self.out_h, self.out_w, self.c_in), + dtype=x.dtype, + device=x.device, + ) + if self.return_indices: + empty_indices = torch.empty( + (self.n, self.out_h, self.out_w, self.c_in), + dtype=torch.int64, + device=x.device, + ) + return empty_values, empty_indices + return empty_values if self.return_indices: return _max_pool2d_values_indices_wrapped_kernel( self.n,