Skip to content

Commit

Permalink
New Combinational blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
kir486680 committed Nov 26, 2023
1 parent 640334d commit 91acb68
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 83 deletions.
41 changes: 15 additions & 26 deletions src/block_add.v
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ module block_add(
// Internal signal to keep track of addition completion for each element
reg [`J*`K-1:0] addition_complete;


// Combinational logic for block_add_done
always @(*) begin
if (rst || start) begin
// Reset the block_add_done when reset or a new start is signaled
if (rst) begin
// Reset block_add_done only on reset
block_add_done = 1'b0;
end else if (&addition_complete) begin
// If all additions are complete, set block_add_done high
// Set block_add_done high when all additions are complete
block_add_done = 1'b1;
end else begin
// Keep block_add_done low otherwise
block_add_done = 1'b0;
end
end
Expand All @@ -38,36 +38,25 @@ generate
.b_in(multiplied_block[i * `K + j]),
.result(fadd_result)
);
// Use non-blocking assignment inside always block to assign the result to buffer_result
always @(posedge clk or posedge rst) begin

// Combinational logic for assigning results to buffer_result
always @(*) begin
if (rst) begin
buffer_result[((start_row + i) * `B_N) + start_col + j] <= 0;
addition_complete[i * `K + j] <= 0;
buffer_result[((start_row + i) * `B_N) + start_col + j] = 0;
addition_complete[i * `K + j] = 0;
end else if(start) begin
// Check if the result indices are within the matrix dimensions
if ((start_row + i) < `A_M && (start_col + j) < `B_N) begin
buffer_result[((start_row + i) * `B_N) + start_col + j] <= fadd_result;
addition_complete[i * `K + j] <= 1'b1;
buffer_result[((start_row + i) * `B_N) + start_col + j] = fadd_result;
addition_complete[i * `K + j] = 1'b1;
end else begin
addition_complete[i * `K + j] = 1'b0;
end
end else begin
// No action if not reset or start
end
end
end
end
endgenerate

// Sequential logic for handling addition_complete
always @(posedge clk or posedge rst) begin
if (rst) begin
addition_complete <= `J*`K'd0; // Reset the addition_complete flags
end else if (start) begin
// When a new addition starts, reset the addition_complete flags
addition_complete <= `J*`K'd0;
end else begin
// Logic to update addition_complete flags based on the completion of each fadd operation
// ...
end
end


endmodule
31 changes: 11 additions & 20 deletions src/block_get.v
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,34 @@ module block_get(
parameter J = `J; // Define the block rows
parameter K = `K; // Define the block columns

// Signal to count the number of elements read
reg [`J*`K-1:0] get_complete;

integer i, j;

// Combinational logic for block_get_done
// Combinational logic for block_get_done and block assignment
always @(*) begin
if (rst) begin
block_get_done = 1'b0;
end else if (&get_complete) begin
block_get_done = 1'b1;
end else begin
block_get_done = 1'b0;
end
end

// Sequential logic for reading the block and updating get_complete
always @(posedge clk or posedge rst) begin
if (rst) begin
for (i = 0; i < J*K; i = i + 1) begin
block[i] <= 0;
get_complete[i] <= 1'b0;
block[i] = 0;
get_complete[i] = 1'b0;
end
end else if(start) begin
end else if (start) begin
block_get_done = 1'b0;
for (i = 0; i < J; i = i + 1) begin
for (j = 0; j < K; j = j + 1) begin
if(start_row + i < (matrix_len / num_cols) && start_col + j < num_cols) begin
block[i*K + j] <= buffer[(start_row + i)*num_cols + (start_col + j)];
get_complete[i*K + j] <= 1'b1; // Mark element as read
if (start_row + i < (matrix_len / num_cols) && start_col + j < num_cols) begin
block[i*K + j] = buffer[(start_row + i)*num_cols + (start_col + j)];
get_complete[i*K + j] = 1'b1;
end else begin
get_complete[i*K + j] <= 1'b1; // We still need to count these
get_complete[i*K + j] = 1'b1;
end
end
end
block_get_done = &get_complete;
end else begin
for (i = 0; i < J*K; i = i + 1) begin
get_complete[i] <= 1'b0; // Reset the get_complete flags
get_complete[i] = 1'b0;
end
end
end
Expand Down
38 changes: 16 additions & 22 deletions src/matrix_mul.v
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ module matrix_mul(

// Define states
localparam IDLE = 3'd0,
GET_BLOCK_A = 3'd1,
GET_BLOCK_B = 3'd2,
MULTIPLY_BLOCK = 3'd3,
ACCUMULATE = 3'd4,
WRITE_BACK = 3'd5,
DONE = 3'd6;
GET_BLOCKS = 3'd1,
MULTIPLY_BLOCK = 3'd2,
ACCUMULATE = 3'd3,
WRITE_BACK = 3'd4,
DONE = 3'd5;

reg [2:0] state, next_state;
reg [9:0] i, l, r, result_row, result_col; // this is for the loop counters
Expand Down Expand Up @@ -46,7 +45,7 @@ block_get block_get_A(
.block_get_done(block_get_a_done)
);

localparam B_cols = 10'd5;
localparam B_cols = 10'd2;
localparam B_size = 10'd10;

block_get block_get_B(
Expand All @@ -67,8 +66,8 @@ block_add block_add(
.clk(clk),
.rst(rst),
.start(add_block),
.start_row(result_row),
.start_col(result_col),
.start_row(i),
.start_col(l),
.multiplied_block(block_b),
.buffer_temp(block_a),
.buffer_result(matrix_C),
Expand Down Expand Up @@ -111,21 +110,16 @@ always @(posedge clk or posedge rst) begin
get_block_b <= 0;
add_block <= 0;
operation_complete <= 0;
next_state = GET_BLOCK_A;
next_state = GET_BLOCKS;
end
end
GET_BLOCK_A: begin
GET_BLOCKS: begin
get_block_a <= 1;
if (block_get_a_done) begin
get_block_a <= 0;
next_state = GET_BLOCK_B;
end
end
GET_BLOCK_B: begin
get_block_b <= 1;
if (block_get_b_done) begin
if (block_get_a_done && block_get_b_done) begin
get_block_a <= 0;
get_block_b <= 0;
next_state = IDLE;
next_state = ACCUMULATE;
end
end
MULTIPLY_BLOCK: begin
Expand All @@ -151,13 +145,13 @@ always @(posedge clk or posedge rst) begin
if (i >= `A_M) begin
next_state = DONE;
end else begin
next_state = GET_BLOCK_A;
next_state = GET_BLOCKS;
end
end else begin
next_state = GET_BLOCK_A;
next_state = GET_BLOCKS;
end
end else begin
next_state = GET_BLOCK_A;
next_state = GET_BLOCKS;
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions tests/test_block_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ async def test_get_block(dut):

dut.rst.value = 0
await RisingEdge(dut.clk)
print_matrix(dut.buffer, 2, 5, "Initial Buffer")
# # Initialize the buffer with sequential values

# Initialize the buffer with sequential values
for i in range(buffer_size):
dut.buffer[i].value = BinaryValue(value=float_to_float16(i), n_bits=16)
print("Initialized data")
Expand Down
22 changes: 9 additions & 13 deletions tests/test_matrix_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,20 @@ async def test_get_block(dut):
print("Current state", dut.state.value)
print_matrix(dut.block_get_A.buffer, 2, 5, "Input to block_get_A")
print_matrix(dut.block_get_A.block, J, K, "Output of block_get_A")
print_matrix(dut.block_a, J, K, "Block A")
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
print_matrix(dut.block_a, J, K, "Block A") #it was not able to update yet here
print_matrix(dut.block_b, J, K, "Block B") #it was not able to update yet here
await RisingEdge(dut.clk)
print_matrix(dut.block_a, J, K, "Block A")
print_matrix(dut.block_b, J, K, "Block B")
print("Current state", dut.state.value)
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
print("Start of add", dut.block_add.start.value)
print_matrix(dut.block_add.buffer_temp, J, K, "Input to block_add")
print_matrix(dut.block_add.multiplied_block, J, K, "Input to block_add")
print(dut.block_add_done.value)
print_matrix(dut.block_add.buffer_result, 2,2, "Matrix C")
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
print_matrix(dut.block_get_A.block, J, K, "Output of block_get_A")
print_matrix(dut.block_b, J, K, "Block B")
await RisingEdge(dut.clk)
print("Current state", dut.state.value)
await RisingEdge(dut.clk)
print("Current state", dut.state.value)

0 comments on commit 91acb68

Please sign in to comment.