diff --git a/fu/single/ExclusiveDivRTL.py b/fu/single/ExclusiveDivRTL.py index 1aca7b42..ecfcda23 100644 --- a/fu/single/ExclusiveDivRTL.py +++ b/fu/single/ExclusiveDivRTL.py @@ -11,7 +11,6 @@ from pymtl3 import * from ..basic.Fu import Fu from ...lib.opt_type import * -from pymtl3.passes.backends.verilog import * class ExclusiveDivRTL(Fu): @@ -34,32 +33,34 @@ def construct(s, CtrlPktType, num_inports, num_outports, latency = 4, vector_fac idx_nbits = clog2(num_inports) s.in0_idx = Wire(idx_nbits) s.in1_idx = Wire(idx_nbits) - LatencyType = mk_bits(clog2(latency) + 1) - s.cur_cycle = Wire(LatencyType) s.in0_idx //= s.in0[0:idx_nbits] s.in1_idx //= s.in1[0:idx_nbits] s.recv_all_val = Wire(1) - s.do_div = Wire(1) + s.accept_input = Wire(1) + s.launch_msg = Wire(s.DataType) + s.pipe_valid = [Wire(1) for _ in range(latency)] + s.next_pipe_valid = [Wire(1) for _ in range(latency)] + s.pipe_msg = [Wire(s.DataType) for _ in range(latency)] + s.next_pipe_msg = [Wire(s.DataType) for _ in range(latency)] + s.stage_can_advance = [Wire(1) for _ in range(latency)] @update_ff def comb_ff(): - if (s.cur_cycle == latency - 1) & s.send_out[0].rdy: - s.cur_cycle <<= 0 - elif s.cur_cycle == latency - 1: - s.cur_cycle <<= s.cur_cycle - elif s.do_div: - s.cur_cycle <<= s.cur_cycle + 1 - elif (s.recv_all_val & (s.recv_opt.msg.operation == OPT_DIV)): - s.cur_cycle <<= 1 - else: - s.cur_cycle <<= 0 + for i in range(latency): + if s.reset | s.clear: + s.pipe_valid[i] <<= 0 + s.pipe_msg[i] <<= s.DataType() + else: + s.pipe_valid[i] <<= s.next_pipe_valid[i] + s.pipe_msg[i] <<= s.next_pipe_msg[i] @update def comb_logic(): s.recv_all_val @= 0 + s.accept_input @= 0 # For pick input register s.in0 @= 0 s.in1 @= 0 @@ -68,6 +69,7 @@ def comb_logic(): for i in range(num_outports): s.send_out[i].val @= 0 s.send_out[i].msg @= s.DataType() + s.launch_msg @= s.DataType() s.recv_const.rdy @= 0 s.recv_opt.rdy @= 0 @@ -75,6 +77,10 @@ def comb_logic(): s.send_to_ctrl_mem.val @= 0 s.send_to_ctrl_mem.msg @= s.CgraPayloadType(0, 0, 0, 0, 0) s.recv_from_ctrl_mem.rdy @= 0 + for i in range(latency): + s.next_pipe_valid[i] @= s.pipe_valid[i] + s.next_pipe_msg[i] @= s.pipe_msg[i] + s.stage_can_advance[i] @= 0 if s.recv_opt.val: if s.recv_opt.msg.fu_in[0] != 0: @@ -82,34 +88,60 @@ def comb_logic(): if s.recv_opt.msg.fu_in[1] != 0: s.in1 @= zext(s.recv_opt.msg.fu_in[1] - 1, FuInType) + s.send_out[0].val @= s.pipe_valid[latency - 1] + s.send_out[0].msg @= s.pipe_msg[latency - 1] + + s.stage_can_advance[latency - 1] @= (~s.pipe_valid[latency - 1]) | s.send_out[0].rdy + for i in range(latency - 2, -1, -1): + s.stage_can_advance[i] @= (~s.pipe_valid[i]) | s.stage_can_advance[i + 1] + + for i in range(latency - 1, 0, -1): + if s.stage_can_advance[i]: + s.next_pipe_valid[i] @= s.pipe_valid[i - 1] + s.next_pipe_msg[i] @= s.pipe_msg[i - 1] + + if s.stage_can_advance[0]: + s.next_pipe_valid[0] @= 0 + s.next_pipe_msg[0] @= s.DataType() + if s.recv_opt.val: - if (s.recv_opt.msg.operation == OPT_DIV) | (s.recv_opt.msg.operation == OPT_REM): + if (s.recv_opt.msg.operation == OPT_DIV) | \ + (s.recv_opt.msg.operation == OPT_REM) | \ + (s.recv_opt.msg.operation == OPT_DIV_CONST): s.div.dividend @= s.recv_in[s.in0_idx].msg.payload - s.div.divisor @= s.recv_in[s.in1_idx].msg.payload - if s.recv_opt.msg.operation == OPT_DIV: - s.send_out[0].msg.payload @= s.div.quotient + if s.recv_opt.msg.operation == OPT_DIV_CONST: + s.div.divisor @= s.recv_const.msg.payload + s.launch_msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \ + s.recv_const.msg.predicate & \ + s.reached_vector_factor + s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_const.val + s.recv_const.rdy @= s.recv_all_val & s.stage_can_advance[0] else: - s.send_out[0].msg.payload @= s.div.remainder - s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \ - s.recv_in[s.in1_idx].msg.predicate & \ - s.reached_vector_factor - s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val - s.send_out[0].val @= (latency - 1 == s.cur_cycle) - s.recv_in[s.in0_idx].rdy @= s.send_out[0].val & s.send_out[0].rdy - s.recv_in[s.in1_idx].rdy @= s.send_out[0].val & s.send_out[0].rdy - s.do_div @= 1 - s.recv_opt.rdy @= s.send_out[0].val & s.send_out[0].rdy + s.div.divisor @= s.recv_in[s.in1_idx].msg.payload + s.launch_msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \ + s.recv_in[s.in1_idx].msg.predicate & \ + s.reached_vector_factor + s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val + s.recv_in[s.in1_idx].rdy @= s.recv_all_val & s.stage_can_advance[0] + if (s.recv_opt.msg.operation == OPT_DIV) | \ + (s.recv_opt.msg.operation == OPT_DIV_CONST): + s.launch_msg.payload @= s.div.quotient + else: + s.launch_msg.payload @= s.div.remainder + s.accept_input @= s.recv_all_val & s.stage_can_advance[0] + s.recv_in[s.in0_idx].rdy @= s.accept_input + s.recv_opt.rdy @= s.accept_input + if s.accept_input: + s.next_pipe_valid[0] @= 1 + s.next_pipe_msg[0] @= s.launch_msg else: for j in range(num_outports): s.send_out[j].val @= b1(0) s.recv_opt.rdy @= 0 s.recv_in[s.in0_idx].rdy @= 0 s.recv_in[s.in1_idx].rdy @= 0 - s.do_div @= 0 - else: - s.do_div @= 0 -class Div( VerilogPlaceholder, Component ): +class Div( Component ): # Constructor def construct( s, WIDTH = 32, CYCLE = 8 ): @@ -121,17 +153,11 @@ def construct( s, WIDTH = 32, CYCLE = 8 ): s.quotient = OutPort ( WIDTH ) s.remainder = OutPort ( WIDTH ) - # Configurations - from os import path - srcdir = path.dirname(__file__) + path.sep - - s.set_metadata( VerilogPlaceholderPass.src_file, srcdir + 'division.v' ) - s.set_metadata( VerilogPlaceholderPass.top_module, 'pipeline_division' ) - s.set_metadata( VerilogPlaceholderPass.v_include, [ srcdir ] ) - # s.set_metadata( VerilogPlaceholderPass.v_libs, [ - # srcdir + 'division.v', - # ]) - s.set_metadata( VerilogPlaceholderPass.has_clk, True ) - s.set_metadata( VerilogPlaceholderPass.has_reset, True ) - - s.set_metadata( VerilogVerilatorImportPass.vl_Wno_list, ['WIDTH'] ) + @update + def comb_div(): + if s.divisor == 0: + s.quotient @= 0 + s.remainder @= 0 + else: + s.quotient @= s.dividend // s.divisor + s.remainder @= s.dividend % s.divisor diff --git a/fu/single/test/ExclusiveDivRTL_test.py b/fu/single/test/ExclusiveDivRTL_test.py index eb74766a..5d44f4cd 100644 --- a/fu/single/test/ExclusiveDivRTL_test.py +++ b/fu/single/test/ExclusiveDivRTL_test.py @@ -32,13 +32,14 @@ def construct(s, FunctionUnit, IntraCgraPktType, DataType, ConfigType, data_bitwidth, num_inports, num_outports, data_mem_size, src0_msgs, src1_msgs, src_const, ctrl_msgs, - sink_msgs): + sink_msgs, sink_initial_delay = 0): s.src_in0 = TestSrcRTL(DataType, src0_msgs) s.src_in1 = TestSrcRTL(DataType, src1_msgs) s.src_in2 = TestSrcRTL(DataType, src1_msgs) s.src_opt = TestSrcRTL(ConfigType, ctrl_msgs) - s.sink_out = TestSinkRTL(DataType, sink_msgs) + s.sink_out = TestSinkRTL(DataType, sink_msgs, + initial_delay = sink_initial_delay) s.const_queue = ConstQueueRTL(DataType, src_const) s.dut = FunctionUnit(IntraCgraPktType, num_inports, num_outports) @@ -56,7 +57,11 @@ def done(s): def line_trace(s): return s.dut.line_trace() -def test_mul(): +# Concrete example: the second parametrized run stalls the output sink for +# 12 cycles. The completed 8/4 result must stay valid in the divider +# pipeline until the sink becomes ready instead of being overwritten or lost. +@pytest.mark.parametrize('sink_initial_delay', [0, 12]) +def test_div_pipeline_backpressure(sink_initial_delay): FU = ExclusiveDivRTL data_bitwidth = 32 DataType = mk_data(data_bitwidth, 1) @@ -84,5 +89,5 @@ def test_mul(): data_bitwidth, num_inports, num_outports, data_mem_size, src_in0, src_in1, src_const, src_opt, - sink_out) + sink_out, sink_initial_delay) run_sim(th)