Skip to content

Commit

Permalink
[feature] Add single attention tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gamzeisl committed Oct 24, 2024
1 parent 74eb7a6 commit 67465f1
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 30 deletions.
42 changes: 40 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,52 @@ run_sim:
F: 64
activation: gelu
no_stalls: 0
single_attention: 0
- S: 64
E: 64
P: 64
F: 64
activation: gelu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 1
single_attention: 0
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 1
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 0
single_attention: 1
script:
- make bender
- make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls
- make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention
- ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb

run_hwpe_sim:
Expand All @@ -87,31 +106,50 @@ run_hwpe_sim:
F: 64
activation: gelu
no_stalls: 0
single_attention: 0
- S: 64
E: 64
P: 64
F: 64
activation: gelu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 1
single_attention: 0
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 1
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 0
single_attention: 1
script:
- make bender
- make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls
- make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention
- ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F hwpe_tb
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ BENDER_TARGETS = -t rtl -t test
target ?= sim_ita_tb

no_stalls ?= 0
single_attention ?= 0
s ?= 64
e ?= 128
p ?= 192
Expand All @@ -33,7 +34,7 @@ else ifeq ($(activation), relu)
else
activation_int = 0
endif
vlog_defs += -DNO_STALLS=$(no_stalls) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)
vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)

ifeq ($(target), sim_ita_hwpe_tb)
BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test
Expand Down
77 changes: 72 additions & 5 deletions src/hwpe/tb/ita_hwpe_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ module ita_hwpe_tb;
parameter integer EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif;
parameter integer FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif;
parameter activation_e ACTIVATION = `ifdef ACTIVATION `ACTIVATION `else Identity `endif;
parameter integer SINGLE_ATTENTION = `ifdef SINGLE_ATTENTION `SINGLE_ATTENTION `else 0 `endif;

integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM;
integer N_ELEMENTS_PER_TILE;
Expand Down Expand Up @@ -358,11 +359,27 @@ endfunction
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 2: Step K
if (SINGLE_ATTENTION == 1) begin
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 3: Step V
if (SINGLE_ATTENTION == 1) begin
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

if (SINGLE_ATTENTION == 1) begin
// Reset the RQS values
ita_reg_eps_mult_val_compute(ita_reg_rqs_val);
end

for (int group = 0; group < N_TILES_SEQUENCE_DIM; group++) begin
BASE_PTR_INPUT[QK] = BASE_PTR[15] + group * N_TILES_INNER_DIM[QK] * N_ELEMENTS_PER_TILE;
Expand All @@ -384,12 +401,36 @@ endfunction
end

// 6: Step OW
if (SINGLE_ATTENTION == 1) begin
// Change order of P and E
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_PROJECTION_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8;
end
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 7: Step FF1
if (SINGLE_ATTENTION == 1) begin
// Change order of P and F
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 16;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16;
end
ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 8: Step FF2
if (SINGLE_ATTENTION == 1) begin
// Change order of E and F
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 24;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24;
end
ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// Wait for the last step to finish
Expand Down Expand Up @@ -454,8 +495,14 @@ endfunction
// Calculate input_ptr, weight_ptr0, weight_ptr1, bias_ptr, and output_ptr
ita_ptrs_compute(input_base_ptr, weight_base_ptr0, weight_base_ptr1, bias_base_ptr, output_base_ptr, step, tile, tile_x, tile_y, tile_inner, input_ptr, weight_ptr0, weight_ptr1, bias_ptr, output_ptr);

// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
if (SINGLE_ATTENTION == 1) begin
// Enable ita_reg_en
ita_reg_en = 1'b1;
end else begin
// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
end

// Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en
ctrl_val_compute(step, tile, ctrl_engine_val, ctrl_stream_val, weight_ptr_en, bias_ptr_en);

Expand Down Expand Up @@ -566,7 +613,13 @@ endfunction
ctrl_stream_val = 32'h0;
reg_weight_en = 1'b0;
reg_bias_en = 1'b0;
layer_type = Attention;

if (SINGLE_ATTENTION == 1) begin
layer_type = Linear;
end else begin
layer_type = Attention;
end

activation_function = Identity;

ctrl_engine_val = layer_type | activation_function << 2;
Expand Down Expand Up @@ -600,11 +653,17 @@ endfunction
reg_bias_en = 1'b1;
end
QK : begin
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = SingleAttention | Identity << 2;
end
ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias
reg_weight_en = 1'b1;
reg_bias_en = 1'b0;
end
AV : begin
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = SingleAttention | Identity << 2;
end
ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias
reg_weight_en = 1'b1;
reg_bias_en = 1'b0;
Expand All @@ -620,7 +679,11 @@ endfunction
reg_bias_en = 1'b1;
end
F1 : begin
ctrl_engine_val = Feedforward | ACTIVATION << 2;
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = Linear | ACTIVATION << 2;
end else begin
ctrl_engine_val = Feedforward | ACTIVATION << 2;
end
if (tile == 0) begin
ctrl_stream_val = {28'b0, 4'b0011}; // weight preload and weight nextload
end else begin
Expand All @@ -630,7 +693,11 @@ endfunction
reg_bias_en = 1'b1;
end
F2 : begin
ctrl_engine_val = Feedforward | Identity << 2;
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = Linear | Identity << 2;
end else begin
ctrl_engine_val = Feedforward | Identity << 2;
end
if (tile == (N_TILES_OUTER_X[F2]*N_TILES_OUTER_Y[F2]*N_TILES_INNER_DIM[F2])-1) begin
ctrl_stream_val = {28'b0, 4'b0000};
reg_weight_en = 1'b0;
Expand Down
Loading

0 comments on commit 67465f1

Please sign in to comment.