Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lake/dsl/dsl_examples/memtile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lake.dsl.lake_imports import *
from frail.ast import *

# example of DSL (makes current mem tile with 2 agg,
# wide SRAM, 2 tb
Expand Down
119 changes: 80 additions & 39 deletions lake/dsl/hw_top_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lake.dsl.memory import mem_inst
from lake.modules.for_loop import ForLoop
from lake.modules.addr_gen import AddrGen
from lake.modules.addressor import *
from lake.modules.spec.sched_gen import SchedGen
from lake.passes.passes import lift_config_reg
from lake.utils.util import safe_wire, trim_config_list
Expand All @@ -23,7 +24,8 @@ def __init__(self,
input_ports,
output_ports,
memories,
edges):
edges,
addressor_info):

super().__init__("LakeTop", debug=True)

Expand All @@ -41,6 +43,12 @@ def __init__(self,
self.memories = memories
self.edges = edges

# addressor_info
self.use_default_addr = addressor_info["use_default"]
if not self.use_default_addr:
self.addressor = addressor_info["addressor"]
self.addressor_name = addressor_info["name"]

# tile enable and clock
self.tile_en = self.input("tile_en", 1)
self.tile_en.add_attribute(ConfigRegAttr("Tile logic enable manifested as clock gate"))
Expand Down Expand Up @@ -184,15 +192,23 @@ def increment_cycle_count(self):
rst_n=self.rst_n,
step=self.valid)

newAG = AddrGen(iterator_support=input_dim,
config_width=max(1, clog2(input_stride))) # self.default_config_width)
self.add_child(f"input_port{input_port_index}_2{in_mem}_write_addr_gen",
newAG,
clk=self.gclk,
rst_n=self.rst_n,
step=self.valid,
mux_sel=forloop.ports.mux_sel_out,
restart=forloop.ports.restart)
if self.use_default_addr:
newAG = AddrGen(iterator_support=input_dim,
config_width=max(1, clog2(input_stride))) # self.default_config_width)
self.add_child(f"input_port{input_port_index}_2{in_mem}_write_addr_gen",
newAG,
clk=self.gclk,
rst_n=self.rst_n,
step=self.valid,
mux_sel=forloop.ports.mux_sel_out,
restart=forloop.ports.restart)

else:
newAG = AddressorWrapper(self.addressor_name)
self.add_child(f"input_port{input_port_index}_2{in_mem}_write_addr_gen",
newAG,
clk=self.gclk,
step=self.valid)

if self.memories[in_mem]["num_read_write_ports"] == 0:
safe_wire(self, self.mem_insts[in_mem].ports.write_addr[0], newAG.ports.addr_out)
Expand Down Expand Up @@ -243,15 +259,23 @@ def increment_cycle_count(self):
rst_n=self.rst_n,
step=self.valid)

newAG = AddrGen(iterator_support=output_dim,
config_width=max(1, clog2(output_stride))) # self.default_config_width)
self.add_child(f"{out_mem}2output_port{output_port_index}_read_addr_gen",
newAG,
clk=self.gclk,
rst_n=self.rst_n,
step=self.valid,
mux_sel=forloop.ports.mux_sel_out,
restart=forloop.ports.restart)
if self.use_default_addr:
newAG = AddrGen(iterator_support=output_dim,
config_width=max(1, clog2(output_stride))) # self.default_config_width)
self.add_child(f"{out_mem}2output_port{output_port_index}_read_addr_gen",
newAG,
clk=self.gclk,
rst_n=self.rst_n,
step=self.valid,
mux_sel=forloop.ports.mux_sel_out,
restart=forloop.ports.restart)

else:
newAG = AddressorWrapper(self.addressor_name)
self.add_child(f"{out_mem}2output_port{output_port_index}_read_addr_gen",
newAG,
clk=self.gclk,
step=self.valid)

if self.memories[out_mem]["num_read_write_ports"] == 0:
safe_wire(self, self.mem_insts[out_mem].ports.read_addr[0], newAG.ports.addr_out)
Expand Down Expand Up @@ -296,15 +320,22 @@ def increment_cycle_count(self):
step=self.valid)

# create input addressor
readAG = AddrGen(iterator_support=edge["dim"],
config_width=self.default_config_width)
self.add_child(f"{edge_name}_read_addr_gen",
readAG,
clk=self.gclk,
rst_n=self.rst_n,
step=self.valid,
mux_sel=forloop.ports.mux_sel_out,
restart=forloop.ports.restart)
if self.use_default_addr:
readAG = AddrGen(iterator_support=edge["dim"],
config_width=self.default_config_width)
self.add_child(f"{edge_name}_read_addr_gen",
readAG,
clk=self.gclk,
rst_n=self.rst_n,
step=self.valid,
mux_sel=forloop.ports.mux_sel_out,
restart=forloop.ports.restart)
else:
readAG = AddressorWrapper(self.addressor_name)
self.add_child(f"{edge_name}_read_addr_gen",
readAG,
clk=self.gclk,
step=self.valid)

# assign read address to all from memories
if self.memories[edge["from_signal"][0]]["num_read_write_ports"] == 0:
Expand Down Expand Up @@ -351,13 +382,20 @@ def increment_cycle_count(self):
self.mem_insts[edge["from_signal"][0]].ports.data_out)

# create output addressor
writeAG = AddrGen(iterator_support=edge["dim"],
config_width=self.default_config_width)
# step, mux_sel, restart may need delayed signals (assigned later)
self.add_child(f"{edge_name}_write_addr_gen",
writeAG,
clk=self.gclk,
rst_n=self.rst_n)
if self.use_default_addr:
writeAG = AddrGen(iterator_support=edge["dim"],
config_width=self.default_config_width)
# step, mux_sel, restart may need delayed signals (assigned later)
self.add_child(f"{edge_name}_write_addr_gen",
writeAG,
clk=self.gclk,
rst_n=self.rst_n)
else:
writeAG = AddressorWrapper(self.addressor_name)
self.add_child(f"{edge_name}_write_addr_gen",
writeAG,
clk=self.gclk,
step=self.valid)

# set write addr for to memories
if self.memories[edge["to_signal"][0]]["num_read_write_ports"] == 0:
Expand Down Expand Up @@ -440,12 +478,14 @@ def get_delayed_write(self):
# assign delayed signals for write addressor if needed
if self.delay == 0:
self.wire(writeAG.ports.step, self.valid)
self.wire(writeAG.ports.mux_sel, self.forloop.ports.mux_sel_out)
self.wire(writeAG.ports.restart, self.forloop.ports.restart)
if self.use_default_addr:
self.wire(writeAG.ports.mux_sel, self.forloop.ports.mux_sel_out)
self.wire(writeAG.ports.restart, self.forloop.ports.restart)
else:
self.wire(writeAG.ports.step, self.delayed_writes[self.delay - 1])
self.wire(writeAG.ports.mux_sel, self.delayed_mux_sels[self.delay - 1])
self.wire(writeAG.ports.restart, self.delayed_restarts[self.delay - 1])
if self.use_default_addr:
self.wire(writeAG.ports.mux_sel, self.delayed_mux_sels[self.delay - 1])
self.wire(writeAG.ports.restart, self.delayed_restarts[self.delay - 1])

# create accessor for edge
newSG = SchedGen(iterator_support=edge["dim"],
Expand Down Expand Up @@ -482,6 +522,7 @@ def get_delayed_write(self):
flush_port = self.internal_generator.get_port("flush")

# bring config registers up to top level

lift_config_reg(self.internal_generator)

# formal subproblem annotations - uncomment to generate relevant files
Expand Down
39 changes: 37 additions & 2 deletions lake/dsl/top_lake.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import copy
import sys
import pathlib

from lake.dsl.memory import mem_inst, port_to_info
from lake.dsl.edge import edge_inst, get_full_edge_params
from lake.dsl.helper import *
from lake.dsl.hw_top_lake import TopLakeHW

from lake.utils.sram_macro import SRAMMacroInfo
from lake.passes.passes import change_sram_port_names
from lake.modules.cfg_reg_wrapper import CFGRegWrapper

from frail.verilog_printer import *


class Lake():
def __init__(self,
Expand All @@ -34,6 +38,10 @@ def __init__(self,
self.hw_edges = []
self.hardware_edges = []

self.addressor_info = {"use_default": True,
"addressor": None,
"name": "top"}

# mux info originally created for hardware, but not used
# keeping here in case logic is useful for the future
self.mux_count = 0
Expand Down Expand Up @@ -99,6 +107,16 @@ def add_input_output_edge(self, port, mem_name, dim=6, max_range=65536, max_stri
self.memories[mem_name][f"is_{io}"] = True
self.memories[mem_name][f"{io}_port"] = port

def set_addressor(self, addressor=None, name="top"):
self.addressor_info = {"use_default": True if addressor is None else False,
"addressor": addressor,
"name": name}

if not self.addressor_info["use_default"]:
lake_dir = pathlib.Path(__file__).parent.parent.absolute()
addr_file = f"{lake_dir}/modules/addressor.py"
self.print_verilog_helper(addr_file, "w+", False)

# called after all edges are added
def banking(self):
self.hw_memories = copy.deepcopy(self.memories)
Expand Down Expand Up @@ -263,11 +281,15 @@ def generate_hardware(self, wrap_cfg=True):
# print()
# print(self.hardware_edges)

# need this import to contain updated frail addressor
from lake.dsl.hw_top_lake import TopLakeHW

hw = TopLakeHW(self.word_width,
self.input_ports,
self.output_ports,
self.hw_memories,
self.hardware_edges)
self.hardware_edges,
self.addressor_info)

# Wrap it if we need to...
if wrap_cfg:
Expand Down Expand Up @@ -295,3 +317,16 @@ def construct_lake(self, filename="Lake_hw.sv", wrap_cfg=False):
optimize_if=False,
check_flip_flop_always_ff=False,
additional_passes={"change sram port names": sram_port_pass})

if not self.addressor_info["use_default"]:
self.print_verilog_helper(filename, "a", True)

def print_verilog_helper(self, filename, mode, get_verilog):
with open(filename, mode) as fi:
orig_stdout = sys.stdout
sys.stdout = fi
print_verilog(self.addressor_info["addressor"],
top_name=self.addressor_info["name"],
add_step=True,
get_verilog=get_verilog)
sys.stdout = orig_stdout
1 change: 1 addition & 0 deletions lake/passes/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self):
IRVisitor.__init__(self)

def visit(self, node):
# doesn't include external nodes...
if isinstance(node, _kratos.Generator):
ports_ = node.get_port_names()
for port_name in ports_:
Expand Down
17 changes: 13 additions & 4 deletions lake/utils/parse_clkwork_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def extract_controller(file_path):
return ctrl_info


def map_controller(controller, name):
def map_controller(controller, name, transform_data=True):
ctrl_dim = controller.dim
ctrl_ranges = controller.extent
ctrl_cyc_strides = controller.cyc_stride
Expand Down Expand Up @@ -148,15 +148,24 @@ def map_controller(controller, name):
(tform_extent, tform_cyc_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_cyc_strides, ctrl_dim)
tform_in_data_strides = None
if ctrl_in_data_strt is not None:
(tform_extent, tform_in_data_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_in_data_strides, ctrl_dim)
if transform_data:
(tform_extent, tform_in_data_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_in_data_strides, ctrl_dim)
else:
tform_extent, tform_in_data_strides = ctrl_ranges, ctrl_in_data_strides

tform_out_data_strides = None
if ctrl_out_data_strt is not None:
(tform_extent, tform_out_data_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_out_data_strides, ctrl_dim)
if transform_data:
(tform_extent, tform_out_data_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_out_data_strides, ctrl_dim)
else:
tform_extent, tform_out_data_strides = ctrl_ranges, ctrl_out_data_strides

tform_mux_data_strides = None
if ctrl_mux_data_strt is not None:
(tform_extent, tform_mux_data_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_mux_data_strides, ctrl_dim)
if transform_data:
(tform_extent, tform_mux_data_strides) = transform_strides_and_ranges(ctrl_ranges, ctrl_mux_data_strides, ctrl_dim)
else:
tform_extent, tform_mux_data_strides = ctrl_ranges, ctrl_mux_data_strides

# Basically give a starting margin for everything...
garnet_delay = 0
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"kratos",
"fault",
"magma-lang",
"pytest"
"pytest",
"frail"
]
)
17 changes: 13 additions & 4 deletions tests/test_dsl_memtile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import tempfile
import os
import shutil

from lake.passes.passes import lift_config_reg, change_sram_port_names
from lake.utils.sram_macro import SRAMMacroInfo
Expand All @@ -22,6 +23,7 @@ def base_lake_tester(config_path,
in_ports,
out_ports,
lt_dut,
use_default_addr=True,
stencil_valid=False):

configs = lt_dut.get_static_bitstream(config_path, in_file_name, out_file_name)
Expand All @@ -40,6 +42,7 @@ def base_lake_tester(config_path,
def gen_test_lake(config_path,
stream_path,
lt_dut,
tile,
in_file_name="input",
out_file_name="output",
in_ports=2,
Expand All @@ -51,7 +54,8 @@ def gen_test_lake(config_path,
out_file_name,
in_ports,
out_ports,
lt_dut)
lt_dut,
tile.addressor_info["use_default"])

tester.circuit.clk_en = 1
tester.circuit.clk_mem = 0
Expand Down Expand Up @@ -84,10 +88,14 @@ def gen_test_lake(config_path,
tester.step(2)

with tempfile.TemporaryDirectory() as tempdir:
tempdir = "hw"
# tempdir = "hw"
if not tile.addressor_info["use_default"]:
addr_verilog = tile.addressor_info["name"] + ".v"
tile.print_verilog_helper(addr_verilog, "w+", True)
shutil.copy(addr_verilog, tempdir)
tester.compile_and_run(target="verilator",
directory=tempdir,
flags=["-Wno-fatal", "--trace"])
flags=["-Wno-fatal"])


def test_conv_3_3():
Expand All @@ -99,7 +107,8 @@ def test_conv_3_3():

gen_test_lake(config_path=config_path,
stream_path=stream_path,
lt_dut=lt_dut)
lt_dut=lt_dut,
tile=tile)


if __name__ == "__main__":
Expand Down