Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarks/ops/bench_fused_add_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@

class FusedAddRMSNormBenchmark(BenchmarkBase[FusedAddRMSNormTest]):

_roofline_cache: Optional[tuple[float, float]] = None

def __init__(self, test, op):
super().__init__(test)
self._op = op
self._roofline_cache: Optional[tuple[float, float]] = None

def _get_roofline(self) -> tuple[float, float]:
if self._roofline_cache is None:
Expand Down
30 changes: 18 additions & 12 deletions benchmarks/ops/bench_fused_moe_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,21 @@ def ref_program(self, *args):

class MoEExpertsBenchmark(BenchmarkBase[MoEExpertsTest]):

def __init__(self, test, op):
super().__init__(test)
self._op = op
self._roofline_cache: Optional[tuple[float, float]] = None

def _get_roofline(self) -> tuple[float, float]:
if self._roofline_cache is None:
self._roofline_cache = self._op.eval_roofline()
return self._roofline_cache

def calculate_flops(self) -> Optional[float]:
t = self.workload
return t.num_tokens * t.top_k * 6 * t.ffn_size * t.hidden_size
return self._get_roofline()[0]

def calculate_memory(self) -> Optional[float]:
t = self.workload
elem = 2 # bfloat16
weights = t.num_experts * 3 * t.ffn_size * t.hidden_size * elem
tokens = 2 * t.num_tokens * t.hidden_size * elem
return weights + tokens
return self._get_roofline()[1]


# ---------------------------------------------------------------------------
Expand All @@ -123,9 +128,10 @@ def _manifest_params():
for w in load_workloads(_OP_NAME):
label = w.get("label", "unlabeled")
for dtype_str in w["dtypes"]:
dtype = getattr(torch, dtype_str)
params.append(pytest.param(
w["num_tokens"], w["num_experts"], w["top_k"],
w["hidden_size"], w["ffn_size"],
w["hidden_size"], w["ffn_size"], dtype,
id=f"{label}-{dtype_str}",
))
return params
Expand All @@ -137,16 +143,15 @@ def _manifest_params():


@pytest.mark.parametrize(
"num_tokens, num_experts, top_k, hidden_size, ffn_size",
"num_tokens, num_experts, top_k, hidden_size, ffn_size, dtype",
_manifest_params(),
)
def test_moe_experts_nopad_bench(
num_tokens: int, num_experts: int, top_k: int, hidden_size: int, ffn_size: int,
num_tokens: int, num_experts: int, top_k: int, hidden_size: int,
ffn_size: int, dtype: torch.dtype,
) -> None:
dtype = torch.bfloat16
test = MoEExpertsTest(num_tokens, num_experts, top_k, hidden_size, ffn_size, dtype)
hidden, w1, w2, topk_weights, topk_ids = test.gen_inputs()
bm = MoEExpertsBenchmark(test)

kwargs = dict(
num_tokens=num_tokens, num_experts=num_experts, top_k=top_k,
Expand All @@ -158,6 +163,7 @@ def test_moe_experts_nopad_bench(

# -- TileOPs nopad (3WG persistent) --------------------------------------
nopad = FusedMoEExpertsNopadPersistent3WGFwdOp(**kwargs)
bm = MoEExpertsBenchmark(test, nopad)

def _nopad_fn(hidden, w1, w2, topk_weights, topk_ids):
nopad.forward(
Expand Down
3 changes: 2 additions & 1 deletion tileops/manifest/moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ MoeGroupedGemmNopadFwdOp:
op: tileops/ops/moe/routed_expert/moe_grouped_gemm_nopad.py
test: tests/ops/test_moe_grouped_gemm_nopad.py
bench: benchmarks/ops/bench_moe_grouped_gemm_nopad.py
bench_manifest_driven: true
kernel_map:
moe_grouped_gemm_kernel: MoeGroupedGemmNopadKernel

Expand Down Expand Up @@ -266,7 +267,7 @@ FusedMoEExpertsNopadPersistent3WGFwdOp:
op: tileops/ops/moe/routed_expert/fused_routed_expert.py
test: tests/ops/test_fused_moe_experts.py
bench: benchmarks/ops/bench_fused_moe_experts.py
bench_manifest_driven: false
bench_manifest_driven: true
kernel_map:
permute_nopad_kernel: MoePermuteNopadKernel
moe_grouped_gemm_kernel: GroupedGemmPersistent3WGKernel
Expand Down
11 changes: 11 additions & 0 deletions tileops/manifest/normalization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ RMSNormFwdOp:
op: tileops/ops/norm/rms_norm.py
test: tests/ops/test_rms_norm.py
bench: benchmarks/ops/bench_rms_norm.py
bench_manifest_driven: true

LayerNormFwdOp:
ref_api: "torch.nn.functional.layer_norm"
Expand Down Expand Up @@ -101,6 +102,7 @@ LayerNormFwdOp:
op: tileops/ops/norm/layer_norm.py
test: tests/ops/test_layer_norm.py
bench: benchmarks/ops/bench_layer_norm.py
bench_manifest_driven: true

AdaLayerNormFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -146,6 +148,7 @@ AdaLayerNormFwdOp:
op: tileops/ops/norm/ada_layer_norm.py
test: tests/ops/test_ada_layer_norm.py
bench: benchmarks/ops/bench_ada_layer_norm.py
bench_manifest_driven: true

AdaLayerNormZeroFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -193,6 +196,7 @@ AdaLayerNormZeroFwdOp:
op: tileops/ops/norm/ada_layer_norm_zero.py
test: tests/ops/test_ada_layer_norm_zero.py
bench: benchmarks/ops/bench_ada_layer_norm.py
bench_manifest_driven: true

FusedAddLayerNormFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -243,6 +247,7 @@ FusedAddLayerNormFwdOp:
op: tileops/ops/norm/fused_add_layer_norm.py
test: tests/ops/test_fused_add_layer_norm.py
bench: benchmarks/ops/bench_fused_add_layer_norm.py
bench_manifest_driven: true

FusedAddRMSNormFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -352,6 +357,7 @@ BatchNormFwdOp:
op: tileops/ops/norm/batch_norm.py
test: tests/ops/test_batch_norm.py
bench: benchmarks/ops/bench_batch_norm.py
bench_manifest_driven: true

BatchNormBwdOp:
ref_api: "torch.nn.functional.batch_norm"
Expand Down Expand Up @@ -406,6 +412,7 @@ BatchNormBwdOp:
op: tileops/ops/norm/batch_norm.py
test: tests/ops/test_batch_norm.py
bench: benchmarks/ops/bench_batch_norm.py
bench_manifest_driven: true

GroupNormFwdOp:
ref_api: "torch.nn.functional.group_norm"
Expand Down Expand Up @@ -453,6 +460,7 @@ GroupNormFwdOp:
op: tileops/ops/norm/group_norm.py
test: tests/ops/test_group_norm.py
bench: benchmarks/ops/bench_group_norm.py
bench_manifest_driven: true

GroupNormFwdOpNoAffine:
ref_api: "torch.nn.functional.group_norm"
Expand Down Expand Up @@ -504,6 +512,7 @@ GroupNormFwdOpNoAffine:
op: tileops/ops/norm/group_norm.py
test: tests/ops/test_group_norm.py
bench: benchmarks/ops/bench_group_norm.py
bench_manifest_driven: true

InstanceNormFwdOp:
ref_api: "torch.nn.functional.instance_norm"
Expand Down Expand Up @@ -551,6 +560,7 @@ InstanceNormFwdOp:
op: tileops/ops/norm/instance_norm.py
test: tests/ops/test_instance_norm.py
bench: benchmarks/ops/bench_instance_norm.py
bench_manifest_driven: true

InstanceNormFwdOpNoAffine:
ref_api: "torch.nn.functional.instance_norm"
Expand Down Expand Up @@ -605,3 +615,4 @@ InstanceNormFwdOpNoAffine:
op: tileops/ops/norm/instance_norm.py
test: tests/ops/test_instance_norm.py
bench: benchmarks/ops/bench_instance_norm.py
bench_manifest_driven: true
Loading