Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions benchmarks/ops/bench_fused_moe_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,22 @@ def ref_program(self, *args):

class MoEExpertsBenchmark(BenchmarkBase[MoEExpertsTest]):

_roofline_cache: Optional[tuple[float, float]] = None

def __init__(self, test, op):
super().__init__(test)
self._op = op
Comment thread
lcy-seso marked this conversation as resolved.
Outdated

def _get_roofline(self) -> tuple[float, float]:
if self._roofline_cache is None:
self._roofline_cache = self._op.eval_roofline()
return self._roofline_cache

def calculate_flops(self) -> Optional[float]:
t = self.workload
return t.num_tokens * t.top_k * 6 * t.ffn_size * t.hidden_size
return self._get_roofline()[0]

def calculate_memory(self) -> Optional[float]:
t = self.workload
elem = 2 # bfloat16
weights = t.num_experts * 3 * t.ffn_size * t.hidden_size * elem
tokens = 2 * t.num_tokens * t.hidden_size * elem
return weights + tokens
return self._get_roofline()[1]


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -146,7 +152,6 @@ def test_moe_experts_nopad_bench(
dtype = torch.bfloat16
Comment thread
superAngGao marked this conversation as resolved.
Outdated
test = MoEExpertsTest(num_tokens, num_experts, top_k, hidden_size, ffn_size, dtype)
hidden, w1, w2, topk_weights, topk_ids = test.gen_inputs()
bm = MoEExpertsBenchmark(test)

kwargs = dict(
num_tokens=num_tokens, num_experts=num_experts, top_k=top_k,
Expand All @@ -158,6 +163,7 @@ def test_moe_experts_nopad_bench(

# -- TileOPs nopad (3WG persistent) --------------------------------------
nopad = FusedMoEExpertsNopadPersistent3WGFwdOp(**kwargs)
bm = MoEExpertsBenchmark(test, nopad)

def _nopad_fn(hidden, w1, w2, topk_weights, topk_ids):
nopad.forward(
Expand Down
3 changes: 2 additions & 1 deletion tileops/manifest/moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ MoeGroupedGemmNopadFwdOp:
op: tileops/ops/moe/routed_expert/moe_grouped_gemm_nopad.py
test: tests/ops/test_moe_grouped_gemm_nopad.py
bench: benchmarks/ops/bench_moe_grouped_gemm_nopad.py
bench_manifest_driven: true
kernel_map:
moe_grouped_gemm_kernel: MoeGroupedGemmNopadKernel

Expand Down Expand Up @@ -266,7 +267,7 @@ FusedMoEExpertsNopadPersistent3WGFwdOp:
op: tileops/ops/moe/routed_expert/fused_routed_expert.py
test: tests/ops/test_fused_moe_experts.py
bench: benchmarks/ops/bench_fused_moe_experts.py
bench_manifest_driven: false
bench_manifest_driven: true
kernel_map:
permute_nopad_kernel: MoePermuteNopadKernel
moe_grouped_gemm_kernel: GroupedGemmPersistent3WGKernel
Expand Down
12 changes: 12 additions & 0 deletions tileops/manifest/normalization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ RMSNormFwdOp:
op: tileops/ops/norm/rms_norm.py
test: tests/ops/test_rms_norm.py
bench: benchmarks/ops/bench_rms_norm.py
bench_manifest_driven: true

LayerNormFwdOp:
ref_api: "torch.nn.functional.layer_norm"
Expand Down Expand Up @@ -101,6 +102,7 @@ LayerNormFwdOp:
op: tileops/ops/norm/layer_norm.py
test: tests/ops/test_layer_norm.py
bench: benchmarks/ops/bench_layer_norm.py
bench_manifest_driven: true

AdaLayerNormFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -146,6 +148,7 @@ AdaLayerNormFwdOp:
op: tileops/ops/norm/ada_layer_norm.py
test: tests/ops/test_ada_layer_norm.py
bench: benchmarks/ops/bench_ada_layer_norm.py
bench_manifest_driven: true

AdaLayerNormZeroFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -193,6 +196,7 @@ AdaLayerNormZeroFwdOp:
op: tileops/ops/norm/ada_layer_norm_zero.py
test: tests/ops/test_ada_layer_norm_zero.py
bench: benchmarks/ops/bench_ada_layer_norm.py
bench_manifest_driven: true

FusedAddLayerNormFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -243,6 +247,7 @@ FusedAddLayerNormFwdOp:
op: tileops/ops/norm/fused_add_layer_norm.py
test: tests/ops/test_fused_add_layer_norm.py
bench: benchmarks/ops/bench_fused_add_layer_norm.py
bench_manifest_driven: true

FusedAddRMSNormFwdOp:
ref_api: "none"
Expand Down Expand Up @@ -294,6 +299,7 @@ FusedAddRMSNormFwdOp:
op: tileops/ops/norm/fused_add_rms_norm.py
test: tests/ops/test_fused_add_rms_norm.py
bench: benchmarks/ops/bench_fused_add_rms_norm.py
bench_manifest_driven: true
Comment thread
lcy-seso marked this conversation as resolved.
Outdated

# ---------------------------------------------------------------------------
# norm — spatial-norm ops (operate over spatial/channel dims, shape: [N, C, *spatial])
Expand Down Expand Up @@ -352,6 +358,7 @@ BatchNormFwdOp:
op: tileops/ops/norm/batch_norm.py
test: tests/ops/test_batch_norm.py
bench: benchmarks/ops/bench_batch_norm.py
bench_manifest_driven: true

BatchNormBwdOp:
ref_api: "torch.nn.functional.batch_norm"
Expand Down Expand Up @@ -406,6 +413,7 @@ BatchNormBwdOp:
op: tileops/ops/norm/batch_norm.py
test: tests/ops/test_batch_norm.py
bench: benchmarks/ops/bench_batch_norm.py
bench_manifest_driven: true

GroupNormFwdOp:
ref_api: "torch.nn.functional.group_norm"
Expand Down Expand Up @@ -453,6 +461,7 @@ GroupNormFwdOp:
op: tileops/ops/norm/group_norm.py
test: tests/ops/test_group_norm.py
bench: benchmarks/ops/bench_group_norm.py
bench_manifest_driven: true

GroupNormFwdOpNoAffine:
ref_api: "torch.nn.functional.group_norm"
Expand Down Expand Up @@ -504,6 +513,7 @@ GroupNormFwdOpNoAffine:
op: tileops/ops/norm/group_norm.py
test: tests/ops/test_group_norm.py
bench: benchmarks/ops/bench_group_norm.py
bench_manifest_driven: true

InstanceNormFwdOp:
ref_api: "torch.nn.functional.instance_norm"
Expand Down Expand Up @@ -551,6 +561,7 @@ InstanceNormFwdOp:
op: tileops/ops/norm/instance_norm.py
test: tests/ops/test_instance_norm.py
bench: benchmarks/ops/bench_instance_norm.py
bench_manifest_driven: true

InstanceNormFwdOpNoAffine:
ref_api: "torch.nn.functional.instance_norm"
Expand Down Expand Up @@ -605,3 +616,4 @@ InstanceNormFwdOpNoAffine:
op: tileops/ops/norm/instance_norm.py
test: tests/ops/test_instance_norm.py
bench: benchmarks/ops/bench_instance_norm.py
bench_manifest_driven: true
2 changes: 2 additions & 0 deletions tileops/manifest/reduction.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ ArgmaxFwdOp:
op: tileops/ops/reduction/argmax.py
test: tests/ops/test_argreduce.py
bench: benchmarks/ops/bench_argreduce.py
bench_manifest_driven: true
Comment thread
superAngGao marked this conversation as resolved.
Outdated

ArgminFwdOp:
ref_api: "torch.argmin"
Expand Down Expand Up @@ -628,6 +629,7 @@ ArgminFwdOp:
op: tileops/ops/reduction/argmin.py
test: tests/ops/test_argreduce.py
bench: benchmarks/ops/bench_argreduce.py
bench_manifest_driven: true
Comment thread
superAngGao marked this conversation as resolved.
Outdated

# ---------------------------------------------------------------------------
# reduction -- logical reductions (output dtype: bool)
Expand Down
Loading