Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions fu/single/DivRTL.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,34 @@ def construct(s, CtrlPktType, num_inports, num_outports, vector_factor_power = 0
s.in0_idx //= s.in0[0:idx_nbits]
s.in1_idx //= s.in1[0:idx_nbits]

PayloadType = s.DataType.get_field_type('payload')
RemainderType = mk_bits(PayloadType.nbits + 1)

s.recv_all_val = Wire(1)
s.dividend = Wire(PayloadType)
s.divisor = Wire(PayloadType)
s.div_quotient = Wire(PayloadType)
s.div_remainder = Wire(PayloadType)

@update
def comb_div_rem():
quotient = PayloadType(0)
remainder = RemainderType(0)

if s.divisor != 0:
for i in range(PayloadType.nbits):
remainder = (remainder << 1) | \
zext(s.dividend[PayloadType.nbits - 1 - i], RemainderType.nbits)
if remainder >= zext(s.divisor, RemainderType.nbits):
remainder = remainder - zext(s.divisor, RemainderType.nbits)
quotient = quotient | \
(PayloadType(1) << (PayloadType.nbits - 1 - i))

s.div_quotient @= quotient
if s.divisor != 0:
s.div_remainder @= s.dividend % s.divisor

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, we already performed above calculation, then the s.div_remainder can just be:

s.div_remainder @= remainder

no?

else:
s.div_remainder @= 0

@update
def comb_logic():
Expand All @@ -49,6 +76,8 @@ def comb_logic():

s.recv_const.rdy @= 0
s.recv_opt.rdy @= 0
s.dividend @= 0
s.divisor @= 0

s.send_to_ctrl_mem.val @= 0
s.send_to_ctrl_mem.msg @= s.CgraPayloadType(0, 0, 0, 0, 0)
Expand All @@ -62,7 +91,9 @@ def comb_logic():

if s.recv_opt.val:
if s.recv_opt.msg.operation == OPT_DIV:
s.send_out[0].msg.payload @= s.recv_in[s.in0_idx].msg.payload // s.recv_in[s.in1_idx].msg.payload
s.dividend @= s.recv_in[s.in0_idx].msg.payload
s.divisor @= s.recv_in[s.in1_idx].msg.payload
s.send_out[0].msg.payload @= s.div_quotient
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.reached_vector_factor
Expand All @@ -73,7 +104,9 @@ def comb_logic():
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

elif s.recv_opt.msg.operation == OPT_DIV_CONST:
s.send_out[0].msg.payload @= s.recv_in[s.in0_idx].msg.payload // s.recv_const.msg.payload
s.dividend @= s.recv_in[s.in0_idx].msg.payload
s.divisor @= s.recv_const.msg.payload
s.send_out[0].msg.payload @= s.div_quotient
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_const.val
Expand All @@ -82,6 +115,36 @@ def comb_logic():
s.recv_const.rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

elif s.recv_opt.msg.operation == OPT_REM:
s.dividend @= s.recv_in[s.in0_idx].msg.payload
s.divisor @= s.recv_in[s.in1_idx].msg.payload
s.send_out[0].msg.payload @= s.div_remainder
s.send_out[0].msg.predicate @= \
s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= \
s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= \
s.recv_all_val & s.send_out[0].rdy
s.recv_in[s.in1_idx].rdy @= \
s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

elif s.recv_opt.msg.operation == OPT_REM_CONST:
s.dividend @= s.recv_in[s.in0_idx].msg.payload
s.divisor @= s.recv_const.msg.payload
s.send_out[0].msg.payload @= s.div_remainder
s.send_out[0].msg.predicate @= \
s.recv_in[s.in0_idx].msg.predicate & s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_const.val
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= \
s.recv_all_val & s.send_out[0].rdy
s.recv_const.rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

else:
for j in range(num_outports):
s.send_out[j].val @= b1(0)
Expand Down
57 changes: 57 additions & 0 deletions fu/single/test/DivRTL_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,60 @@ def test_div0(input_a, input_b):
src_opt, sink_out)
run_sim(th)

@pytest.mark.parametrize(
'operation, input_a, input_b, expected',
[
(OPT_DIV, 7, 0, 0),
(OPT_REM, 7, 3, 1),
(OPT_REM, 7, 0, 0),
]
)
def test_div_rem_edge_cases(operation, input_a, input_b, expected):
DataType = mk_data(32, 1)
num_inports = 4
num_outports = 2
ConfigType = mk_ctrl(num_inports, num_outports)
FuInType = mk_bits(clog2(num_inports + 1))
DataAddrType = mk_bits(3)
CtrlAddrType = mk_bits(3)
CgraPayloadType = mk_cgra_payload(DataType, DataAddrType, ConfigType,
CtrlAddrType)
IntraCgraPktType = mk_intra_cgra_pkt(1, 1, 1, CgraPayloadType)
ctrl = ConfigType(
operation, [FuInType(1), FuInType(3), FuInType(0), FuInType(0)])
th = TestHarness(
DivRTL, IntraCgraPktType, DataType, ConfigType,
num_inports, num_outports, 8,
[DataType(input_a, 1)], [DataType(input_b, 1)],
[DataType(0, 1)], [ctrl], [DataType(expected, 1)],
)
run_sim(th)

@pytest.mark.parametrize(
"const_value, expected",
[
(3, 2),
(0, 0),
]
)
def test_rem_const(const_value, expected):
DataType = mk_data(32, 1)
num_inports = 4
num_outports = 2
ConfigType = mk_ctrl(num_inports, num_outports)
FuInType = mk_bits(clog2(num_inports + 1))
DataAddrType = mk_bits(3)
CtrlAddrType = mk_bits(3)
CgraPayloadType = mk_cgra_payload(DataType, DataAddrType, ConfigType,
CtrlAddrType)
IntraCgraPktType = mk_intra_cgra_pkt(1, 1, 1, CgraPayloadType)
ctrl = ConfigType(
OPT_REM_CONST,
[FuInType(1), FuInType(0), FuInType(0), FuInType(0)])
th = TestHarness(
DivRTL, IntraCgraPktType, DataType, ConfigType,
num_inports, num_outports, 8,
[DataType(8, 1)], [], [DataType(const_value, 1)],
[ctrl], [DataType(expected, 1)],
)
run_sim(th)
38 changes: 38 additions & 0 deletions fu/single/translate/DivRTL_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

from pathlib import Path

from pymtl3 import *
from pymtl3.passes.backends.verilog import VerilogTranslationPass
from pymtl3.passes.backends.verilog.errors import VerilogImportError
from pymtl3.stdlib.test_utils import config_model_with_cmdline_opts
from ..DivRTL import DivRTL
from ....lib.messages import *
from ....lib.opt_type import *

DataType = mk_data(32, 1)
num_inports = 4
num_outports = 2
CtrlType = mk_ctrl(num_inports, num_outports)
DataAddrType = mk_bits(3)
CtrlAddrType = mk_bits(3)
CgraPayloadType = mk_cgra_payload(DataType, DataAddrType, CtrlType, CtrlAddrType)
IntraCgraPktType = mk_intra_cgra_pkt(1, 1, 1, CgraPayloadType)

def test_translate_rem_operator(cmdline_opts):
dut = DivRTL(IntraCgraPktType, num_inports, num_outports)
dut.set_metadata(VerilogTranslationPass.explicit_module_name, 'DivRTL')
dut.set_metadata(VerilogTranslationPass.explicit_file_name, 'DivRTL__pickled.v')

try:
config_model_with_cmdline_opts(dut, cmdline_opts, duts=[])
except VerilogImportError as e:
# Translation already emitted Verilog before the optional Verilator import.
# On machines without Verilator, still inspect the generated RTL.
assert 'verilator: not found' in str(e)

verilog = Path('DivRTL__pickled.v').read_text()
assert 'div_remainder = dividend % divisor' in verilog
assert "if ( divisor != 32'd0 )" in verilog
payload_assigns = [line for line in verilog.splitlines()
if 'payload =' in line]
assert all('/' not in line for line in payload_assigns)
48 changes: 48 additions & 0 deletions lib/opt_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,46 @@
OPT_GEP_CONST = OpCodeType( 89 )
OPT_GEP_2D = OpCodeType( 90 )
OPT_GEP_2D_CONST = OpCodeType( 91 )
OPT_GRT_ONCE_CONST = OpCodeType( 92 )
OPT_GTE_CONST = OpCodeType( 93 )
OPT_LT_CONST = OpCodeType( 94 )
OPT_GT_CONST = OpCodeType( 95 )
OPT_REM_CONST = OpCodeType( 96 )
OPT_AND_CONST = OpCodeType( 97 )
OPT_OR_CONST = OpCodeType( 98 )
OPT_LLS_CONST = OpCodeType( 99 )

OPT_USES_CONST_LIST = (
OPT_CONST,
OPT_ADD_CONST,
OPT_SUB_CONST,
OPT_DIV_CONST,
OPT_EQ_CONST,
OPT_NE_CONST,
OPT_PHI_CONST,
OPT_LD_CONST,
OPT_STR_CONST,
OPT_MUL_CONST,
OPT_MUL_CONST_ADD,
OPT_ADD_CONST_LD,
OPT_INC_NE_CONST_NOT_GRT,
OPT_FADD_CONST,
OPT_FMUL_CONST,
OPT_VEC_ADD_CONST,
OPT_VEC_SUB_CONST,
OPT_VEC_ADD_CONST_COMBINED,
OPT_VEC_SUB_CONST_COMBINED,
OPT_GRT_ONCE_CONST,
OPT_GTE_CONST,
OPT_LT_CONST,
OPT_GT_CONST,
OPT_AND_CONST,
OPT_OR_CONST,
OPT_LLS_CONST,
OPT_REM_CONST,
OPT_GEP_CONST,
OPT_GEP_2D_CONST,
)

OPT_SYMBOL_DICT = {
OPT_START : "(start)",
Expand Down Expand Up @@ -201,6 +241,14 @@
OPT_REM_INCLUSIVE_START : "(%st)",
OPT_DIV_INCLUSIVE_END : "(/ed)",
OPT_REM_INCLUSIVE_END : "(%ed)",
OPT_GRT_ONCE_CONST : "(grant_once')",
OPT_GTE_CONST : "(?>=')",
OPT_LT_CONST : "(?<')",
OPT_GT_CONST : "(?>')",
OPT_AND_CONST : "(&')",
OPT_OR_CONST : "(|')",
OPT_LLS_CONST : "(<<')",
OPT_REM_CONST : "(%')",

OPT_LOOP_CONTROL : "(loop_ctrl)",
OPT_STREAM_LD : "(streaming_ld)",
Expand Down
Loading