forked from KULeuven-MICAS/stream
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain_swiglu.py
More file actions
135 lines (121 loc) · 5.7 KB
/
main_swiglu.py
File metadata and controls
135 lines (121 loc) · 5.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import logging as _logging
import os
import re
from stream.api import optimize_allocation_co
from stream.inputs.aie.mapping.make_swiglu_mapping import make_swiglu_mapping2
from stream.inputs.aie.workload.make_onnx_swiglu import make_swiglu_workload
_logging_level = _logging.INFO
_logging_format = "%(asctime)s - %(name)s.%(funcName)s +%(lineno)s - %(levelname)s - %(message)s"
def run_main_aie_codegen_swiglu( # noqa: PLR0913
seq_len,
embedding_dim,
hidden_dim,
in_dtype,
out_dtype,
trace_size,
rows,
cols,
npu,
last_gemm_down: bool = False,
): # noqa: N803, PLR0913
############################################INPUTS############################################
accelerator = os.path.join(os.path.dirname(__file__), "stream/inputs/aie/hardware/whole_array_strix.yaml")
# CREATE THE SWIGLU ONNX MODEL AND MAPPING
workload_path = make_swiglu_workload(
seq_len, embedding_dim, hidden_dim, in_dtype, out_dtype, last_gemm_down=last_gemm_down
)
mapping_path = make_swiglu_mapping2(seq_len, embedding_dim, hidden_dim, last_gemm_down)
##############################################################################################
################################PARSING###############################
hw_name = accelerator.split("/")[-1].split(".")[0]
wl_name = re.split(r"/|\.", workload_path)[-1]
if wl_name == "onnx":
wl_name = re.split(r"/|\.", workload_path)[-2]
mapping_name = f"{rows}_row_{cols}_col"
experiment_id = f"{hw_name}-{wl_name}-{mapping_name}"
######################################################################
################################LOGGING###############################
log_path = os.path.join(os.getcwd(), f"outputs/{experiment_id}/stream.log")
# Ensure the output directory exists
os.makedirs(os.path.dirname(log_path), exist_ok=True)
# Get root logger and remove any existing handlers
logger = _logging.getLogger()
logger.setLevel(_logging_level) # or use _logging_level if you define one
# Remove all existing handlers (e.g., ones added by Snakemake or libraries)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Create a file handler explicitly
file_handler = _logging.FileHandler(log_path)
file_handler.setFormatter(_logging.Formatter(_logging_format))
logger.addHandler(file_handler)
logger.info(
f"Running AIE code generation for Swiglu with "
f"seq_len={seq_len}, embedding_dim={embedding_dim}, hidden_dim={hidden_dim}"
)
######################################################################
################################PLOTS################################
# memory_fig_path = f"outputs/{experiment_id}/memory.png"
# json_path = f"outputs/{experiment_id}/scme.json"
#####################################################################
ctx = optimize_allocation_co(
hardware=accelerator,
workload=workload_path,
mapping=mapping_path,
experiment_id=experiment_id,
output_path="outputs",
skip_if_exists=False,
enable_codegen=True,
trace_size=trace_size,
nb_cols_to_use=cols,
npu=npu,
)
# #####################CostModelEvaluationLUT LOAD#############################
# cost_lut_path = f"outputs/{experiment_id}/cost_lut_post_co.pickle"
# cost_lut = CostModelEvaluationLUT(cost_lut_path)
# #############################################################################
# # Save json for perfetto visualization (Visualize at http://ui.perfetto.dev/)
# convert_scme_to_perfetto_json(scme, cost_lut, json_path=json_path)
# # Plotting memory usage of best SCME
# plot_memory_usage(scme, section_start_percent, percent_shown, fig_path=memory_fig_path)
module = ctx.get("module")
return module
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run AIE code generation for Gemm")
parser.add_argument("--seq_len", type=int, required=True, help="Sequence length (seq_len dimension of the input)")
parser.add_argument(
"--embedding_dim", type=int, required=True, help="Embedding dimension (embedding_dim dimension of the input)"
)
parser.add_argument(
"--hidden_dim", type=int, required=True, help="Hidden dimension (hidden_dim dimension of the output)"
)
parser.add_argument("--in_dtype", type=str, default="bf16", help="Input data type (default: bf16)")
parser.add_argument("--out_dtype", type=str, default="bf16", help="Output data type (default: bf16)")
parser.add_argument("--trace_size", type=int, default=1048576, help="Size of the trace buffer (default: 1048576)")
parser.add_argument("--rows", type=int, default=4, help="Number of AIE rows to use (has to be 4)")
parser.add_argument("--cols", type=int, default=1, help="Number of AIE columns to use (default: 1)")
parser.add_argument("--npu", type=str, default="npu2", help="NPU type to target (default: npu2)")
parser.add_argument(
"--no_last_gemm_down",
dest="last_gemm_down",
action="store_false",
default=True,
help="If set, the last gemm down projection is skipped",
)
args = parser.parse_args()
module = run_main_aie_codegen_swiglu(
args.seq_len,
args.embedding_dim,
args.hidden_dim,
args.in_dtype,
args.out_dtype,
args.trace_size,
args.rows,
args.cols,
args.npu,
last_gemm_down=args.last_gemm_down,
)
save_path = f"outputs/swiglu_module_{args.seq_len}_{args.embedding_dim}_{args.hidden_dim}.mlir"
with open(save_path, "w") as f:
f.write(str(module))
print(f"Saved generated module to {save_path}")