Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 72 additions & 46 deletions fu/single/ExclusiveDivRTL.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pymtl3 import *
from ..basic.Fu import Fu
from ...lib.opt_type import *
from pymtl3.passes.backends.verilog import *

class ExclusiveDivRTL(Fu):

Expand All @@ -34,32 +33,34 @@ def construct(s, CtrlPktType, num_inports, num_outports, latency = 4, vector_fac
idx_nbits = clog2(num_inports)
s.in0_idx = Wire(idx_nbits)
s.in1_idx = Wire(idx_nbits)
LatencyType = mk_bits(clog2(latency) + 1)
s.cur_cycle = Wire(LatencyType)

s.in0_idx //= s.in0[0:idx_nbits]
s.in1_idx //= s.in1[0:idx_nbits]

s.recv_all_val = Wire(1)
s.do_div = Wire(1)
s.accept_input = Wire(1)
s.launch_msg = Wire(s.DataType)
s.pipe_valid = [Wire(1) for _ in range(latency)]
s.next_pipe_valid = [Wire(1) for _ in range(latency)]
s.pipe_msg = [Wire(s.DataType) for _ in range(latency)]
s.next_pipe_msg = [Wire(s.DataType) for _ in range(latency)]
s.stage_can_advance = [Wire(1) for _ in range(latency)]

@update_ff
def comb_ff():
if (s.cur_cycle == latency - 1) & s.send_out[0].rdy:
s.cur_cycle <<= 0
elif s.cur_cycle == latency - 1:
s.cur_cycle <<= s.cur_cycle
elif s.do_div:
s.cur_cycle <<= s.cur_cycle + 1
elif (s.recv_all_val & (s.recv_opt.msg.operation == OPT_DIV)):
s.cur_cycle <<= 1
else:
s.cur_cycle <<= 0
for i in range(latency):
if s.reset | s.clear:
s.pipe_valid[i] <<= 0
s.pipe_msg[i] <<= s.DataType()
else:
s.pipe_valid[i] <<= s.next_pipe_valid[i]
s.pipe_msg[i] <<= s.next_pipe_msg[i]

@update
def comb_logic():

s.recv_all_val @= 0
s.accept_input @= 0
# For pick input register
s.in0 @= 0
s.in1 @= 0
Expand All @@ -68,48 +69,79 @@ def comb_logic():
for i in range(num_outports):
s.send_out[i].val @= 0
s.send_out[i].msg @= s.DataType()
s.launch_msg @= s.DataType()

s.recv_const.rdy @= 0
s.recv_opt.rdy @= 0

s.send_to_ctrl_mem.val @= 0
s.send_to_ctrl_mem.msg @= s.CgraPayloadType(0, 0, 0, 0, 0)
s.recv_from_ctrl_mem.rdy @= 0
for i in range(latency):
s.next_pipe_valid[i] @= s.pipe_valid[i]
s.next_pipe_msg[i] @= s.pipe_msg[i]
s.stage_can_advance[i] @= 0

if s.recv_opt.val:
if s.recv_opt.msg.fu_in[0] != 0:
s.in0 @= zext(s.recv_opt.msg.fu_in[0] - 1, FuInType)
if s.recv_opt.msg.fu_in[1] != 0:
s.in1 @= zext(s.recv_opt.msg.fu_in[1] - 1, FuInType)

s.send_out[0].val @= s.pipe_valid[latency - 1]
s.send_out[0].msg @= s.pipe_msg[latency - 1]

s.stage_can_advance[latency - 1] @= (~s.pipe_valid[latency - 1]) | s.send_out[0].rdy
for i in range(latency - 2, -1, -1):
s.stage_can_advance[i] @= (~s.pipe_valid[i]) | s.stage_can_advance[i + 1]

for i in range(latency - 1, 0, -1):
if s.stage_can_advance[i]:
s.next_pipe_valid[i] @= s.pipe_valid[i - 1]
s.next_pipe_msg[i] @= s.pipe_msg[i - 1]

if s.stage_can_advance[0]:
s.next_pipe_valid[0] @= 0
s.next_pipe_msg[0] @= s.DataType()

if s.recv_opt.val:
if (s.recv_opt.msg.operation == OPT_DIV) | (s.recv_opt.msg.operation == OPT_REM):
if (s.recv_opt.msg.operation == OPT_DIV) | \
(s.recv_opt.msg.operation == OPT_REM) | \
(s.recv_opt.msg.operation == OPT_DIV_CONST):
s.div.dividend @= s.recv_in[s.in0_idx].msg.payload
s.div.divisor @= s.recv_in[s.in1_idx].msg.payload
if s.recv_opt.msg.operation == OPT_DIV:
s.send_out[0].msg.payload @= s.div.quotient
if s.recv_opt.msg.operation == OPT_DIV_CONST:
s.div.divisor @= s.recv_const.msg.payload
s.launch_msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_const.msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_const.val
s.recv_const.rdy @= s.recv_all_val & s.stage_can_advance[0]
else:
s.send_out[0].msg.payload @= s.div.remainder
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val
s.send_out[0].val @= (latency - 1 == s.cur_cycle)
s.recv_in[s.in0_idx].rdy @= s.send_out[0].val & s.send_out[0].rdy
s.recv_in[s.in1_idx].rdy @= s.send_out[0].val & s.send_out[0].rdy
s.do_div @= 1
s.recv_opt.rdy @= s.send_out[0].val & s.send_out[0].rdy
s.div.divisor @= s.recv_in[s.in1_idx].msg.payload
s.launch_msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val
s.recv_in[s.in1_idx].rdy @= s.recv_all_val & s.stage_can_advance[0]
if (s.recv_opt.msg.operation == OPT_DIV) | \
(s.recv_opt.msg.operation == OPT_DIV_CONST):
s.launch_msg.payload @= s.div.quotient
else:
s.launch_msg.payload @= s.div.remainder
s.accept_input @= s.recv_all_val & s.stage_can_advance[0]
s.recv_in[s.in0_idx].rdy @= s.accept_input
s.recv_opt.rdy @= s.accept_input
if s.accept_input:
s.next_pipe_valid[0] @= 1
s.next_pipe_msg[0] @= s.launch_msg
else:
for j in range(num_outports):
s.send_out[j].val @= b1(0)
s.recv_opt.rdy @= 0
s.recv_in[s.in0_idx].rdy @= 0
s.recv_in[s.in1_idx].rdy @= 0
s.do_div @= 0
else:
s.do_div @= 0

class Div( VerilogPlaceholder, Component ):
class Div( Component ):

# Constructor
def construct( s, WIDTH = 32, CYCLE = 8 ):
Expand All @@ -121,17 +153,11 @@ def construct( s, WIDTH = 32, CYCLE = 8 ):
s.quotient = OutPort ( WIDTH )
s.remainder = OutPort ( WIDTH )

# Configurations
from os import path
srcdir = path.dirname(__file__) + path.sep

s.set_metadata( VerilogPlaceholderPass.src_file, srcdir + 'division.v' )
s.set_metadata( VerilogPlaceholderPass.top_module, 'pipeline_division' )
s.set_metadata( VerilogPlaceholderPass.v_include, [ srcdir ] )
# s.set_metadata( VerilogPlaceholderPass.v_libs, [
# srcdir + 'division.v',
# ])
s.set_metadata( VerilogPlaceholderPass.has_clk, True )
s.set_metadata( VerilogPlaceholderPass.has_reset, True )

s.set_metadata( VerilogVerilatorImportPass.vl_Wno_list, ['WIDTH'] )
@update
def comb_div():
if s.divisor == 0:
s.quotient @= 0
s.remainder @= 0
else:
s.quotient @= s.dividend // s.divisor
s.remainder @= s.dividend % s.divisor
13 changes: 9 additions & 4 deletions fu/single/test/ExclusiveDivRTL_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def construct(s, FunctionUnit, IntraCgraPktType, DataType, ConfigType,
data_bitwidth,
num_inports, num_outports, data_mem_size,
src0_msgs, src1_msgs, src_const, ctrl_msgs,
sink_msgs):
sink_msgs, sink_initial_delay = 0):

s.src_in0 = TestSrcRTL(DataType, src0_msgs)
s.src_in1 = TestSrcRTL(DataType, src1_msgs)
s.src_in2 = TestSrcRTL(DataType, src1_msgs)
s.src_opt = TestSrcRTL(ConfigType, ctrl_msgs)
s.sink_out = TestSinkRTL(DataType, sink_msgs)
s.sink_out = TestSinkRTL(DataType, sink_msgs,
initial_delay = sink_initial_delay)

s.const_queue = ConstQueueRTL(DataType, src_const)
s.dut = FunctionUnit(IntraCgraPktType, num_inports, num_outports)
Expand All @@ -56,7 +57,11 @@ def done(s):
def line_trace(s):
return s.dut.line_trace()

def test_mul():
# Concrete example: the second parametrized run stalls the output sink for
# 12 cycles. The completed 8/4 result must stay valid in the divider
# pipeline until the sink becomes ready instead of being overwritten or lost.
@pytest.mark.parametrize('sink_initial_delay', [0, 12])
def test_div_pipeline_backpressure(sink_initial_delay):
FU = ExclusiveDivRTL
data_bitwidth = 32
DataType = mk_data(data_bitwidth, 1)
Expand Down Expand Up @@ -84,5 +89,5 @@ def test_mul():
data_bitwidth,
num_inports, num_outports, data_mem_size,
src_in0, src_in1, src_const, src_opt,
sink_out)
sink_out, sink_initial_delay)
run_sim(th)