diff --git a/benchmarks/ops/bench_fused_add_rms_norm.py b/benchmarks/ops/bench_fused_add_rms_norm.py index 0bd50ecb3..9d4ccad89 100644 --- a/benchmarks/ops/bench_fused_add_rms_norm.py +++ b/benchmarks/ops/bench_fused_add_rms_norm.py @@ -13,11 +13,10 @@ class FusedAddRMSNormBenchmark(BenchmarkBase[FusedAddRMSNormTest]): - _roofline_cache: Optional[tuple[float, float]] = None - def __init__(self, test, op): super().__init__(test) self._op = op + self._roofline_cache: Optional[tuple[float, float]] = None def _get_roofline(self) -> tuple[float, float]: if self._roofline_cache is None: diff --git a/benchmarks/ops/bench_fused_moe_experts.py b/benchmarks/ops/bench_fused_moe_experts.py index 3c8cef4ca..338ac6b24 100644 --- a/benchmarks/ops/bench_fused_moe_experts.py +++ b/benchmarks/ops/bench_fused_moe_experts.py @@ -101,16 +101,21 @@ def ref_program(self, *args): class MoEExpertsBenchmark(BenchmarkBase[MoEExpertsTest]): + def __init__(self, test, op): + super().__init__(test) + self._op = op + self._roofline_cache: Optional[tuple[float, float]] = None + + def _get_roofline(self) -> tuple[float, float]: + if self._roofline_cache is None: + self._roofline_cache = self._op.eval_roofline() + return self._roofline_cache + def calculate_flops(self) -> Optional[float]: - t = self.workload - return t.num_tokens * t.top_k * 6 * t.ffn_size * t.hidden_size + return self._get_roofline()[0] def calculate_memory(self) -> Optional[float]: - t = self.workload - elem = 2 # bfloat16 - weights = t.num_experts * 3 * t.ffn_size * t.hidden_size * elem - tokens = 2 * t.num_tokens * t.hidden_size * elem - return weights + tokens + return self._get_roofline()[1] # --------------------------------------------------------------------------- @@ -123,9 +128,10 @@ def _manifest_params(): for w in load_workloads(_OP_NAME): label = w.get("label", "unlabeled") for dtype_str in w["dtypes"]: + dtype = getattr(torch, dtype_str) params.append(pytest.param( w["num_tokens"], w["num_experts"], w["top_k"], - w["hidden_size"], w["ffn_size"], + w["hidden_size"], w["ffn_size"], dtype, id=f"{label}-{dtype_str}", )) return params @@ -137,16 +143,15 @@ def _manifest_params(): @pytest.mark.parametrize( - "num_tokens, num_experts, top_k, hidden_size, ffn_size", + "num_tokens, num_experts, top_k, hidden_size, ffn_size, dtype", _manifest_params(), ) def test_moe_experts_nopad_bench( - num_tokens: int, num_experts: int, top_k: int, hidden_size: int, ffn_size: int, + num_tokens: int, num_experts: int, top_k: int, hidden_size: int, + ffn_size: int, dtype: torch.dtype, ) -> None: - dtype = torch.bfloat16 test = MoEExpertsTest(num_tokens, num_experts, top_k, hidden_size, ffn_size, dtype) hidden, w1, w2, topk_weights, topk_ids = test.gen_inputs() - bm = MoEExpertsBenchmark(test) kwargs = dict( num_tokens=num_tokens, num_experts=num_experts, top_k=top_k, @@ -158,6 +163,7 @@ def test_moe_experts_nopad_bench( # -- TileOPs nopad (3WG persistent) -------------------------------------- nopad = FusedMoEExpertsNopadPersistent3WGFwdOp(**kwargs) + bm = MoEExpertsBenchmark(test, nopad) def _nopad_fn(hidden, w1, w2, topk_weights, topk_ids): nopad.forward( diff --git a/tileops/manifest/moe.yaml b/tileops/manifest/moe.yaml index 03d9ea1e6..1334ac37e 100644 --- a/tileops/manifest/moe.yaml +++ b/tileops/manifest/moe.yaml @@ -198,6 +198,7 @@ MoeGroupedGemmNopadFwdOp: op: tileops/ops/moe/routed_expert/moe_grouped_gemm_nopad.py test: tests/ops/test_moe_grouped_gemm_nopad.py bench: benchmarks/ops/bench_moe_grouped_gemm_nopad.py + bench_manifest_driven: true kernel_map: moe_grouped_gemm_kernel: MoeGroupedGemmNopadKernel @@ -266,7 +267,7 @@ FusedMoEExpertsNopadPersistent3WGFwdOp: op: tileops/ops/moe/routed_expert/fused_routed_expert.py test: tests/ops/test_fused_moe_experts.py bench: benchmarks/ops/bench_fused_moe_experts.py - bench_manifest_driven: false + bench_manifest_driven: true kernel_map: permute_nopad_kernel: MoePermuteNopadKernel moe_grouped_gemm_kernel: GroupedGemmPersistent3WGKernel diff --git a/tileops/manifest/normalization.yaml b/tileops/manifest/normalization.yaml index e09ea6268..cc2b40a1f 100644 --- a/tileops/manifest/normalization.yaml +++ b/tileops/manifest/normalization.yaml @@ -51,6 +51,7 @@ RMSNormFwdOp: op: tileops/ops/norm/rms_norm.py test: tests/ops/test_rms_norm.py bench: benchmarks/ops/bench_rms_norm.py + bench_manifest_driven: true LayerNormFwdOp: ref_api: "torch.nn.functional.layer_norm" @@ -101,6 +102,7 @@ LayerNormFwdOp: op: tileops/ops/norm/layer_norm.py test: tests/ops/test_layer_norm.py bench: benchmarks/ops/bench_layer_norm.py + bench_manifest_driven: true AdaLayerNormFwdOp: ref_api: "none" @@ -146,6 +148,7 @@ AdaLayerNormFwdOp: op: tileops/ops/norm/ada_layer_norm.py test: tests/ops/test_ada_layer_norm.py bench: benchmarks/ops/bench_ada_layer_norm.py + bench_manifest_driven: true AdaLayerNormZeroFwdOp: ref_api: "none" @@ -193,6 +196,7 @@ AdaLayerNormZeroFwdOp: op: tileops/ops/norm/ada_layer_norm_zero.py test: tests/ops/test_ada_layer_norm_zero.py bench: benchmarks/ops/bench_ada_layer_norm.py + bench_manifest_driven: true FusedAddLayerNormFwdOp: ref_api: "none" @@ -243,6 +247,7 @@ FusedAddLayerNormFwdOp: op: tileops/ops/norm/fused_add_layer_norm.py test: tests/ops/test_fused_add_layer_norm.py bench: benchmarks/ops/bench_fused_add_layer_norm.py + bench_manifest_driven: true FusedAddRMSNormFwdOp: ref_api: "none" @@ -352,6 +357,7 @@ BatchNormFwdOp: op: tileops/ops/norm/batch_norm.py test: tests/ops/test_batch_norm.py bench: benchmarks/ops/bench_batch_norm.py + bench_manifest_driven: true BatchNormBwdOp: ref_api: "torch.nn.functional.batch_norm" @@ -406,6 +412,7 @@ BatchNormBwdOp: op: tileops/ops/norm/batch_norm.py test: tests/ops/test_batch_norm.py bench: benchmarks/ops/bench_batch_norm.py + bench_manifest_driven: true GroupNormFwdOp: ref_api: "torch.nn.functional.group_norm" @@ -453,6 +460,7 @@ GroupNormFwdOp: op: tileops/ops/norm/group_norm.py test: tests/ops/test_group_norm.py bench: benchmarks/ops/bench_group_norm.py + bench_manifest_driven: true GroupNormFwdOpNoAffine: ref_api: "torch.nn.functional.group_norm" @@ -504,6 +512,7 @@ GroupNormFwdOpNoAffine: op: tileops/ops/norm/group_norm.py test: tests/ops/test_group_norm.py bench: benchmarks/ops/bench_group_norm.py + bench_manifest_driven: true InstanceNormFwdOp: ref_api: "torch.nn.functional.instance_norm" @@ -551,6 +560,7 @@ InstanceNormFwdOp: op: tileops/ops/norm/instance_norm.py test: tests/ops/test_instance_norm.py bench: benchmarks/ops/bench_instance_norm.py + bench_manifest_driven: true InstanceNormFwdOpNoAffine: ref_api: "torch.nn.functional.instance_norm" @@ -605,3 +615,4 @@ InstanceNormFwdOpNoAffine: op: tileops/ops/norm/instance_norm.py test: tests/ops/test_instance_norm.py bench: benchmarks/ops/bench_instance_norm.py + bench_manifest_driven: true