-
Notifications
You must be signed in to change notification settings - Fork 1
/
warpscan.py
132 lines (106 loc) · 4.92 KB
/
warpscan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from numba import cuda
import numpy as np
def make_warp_scan(
STEPS_PER_THREAD, # sequential steps in the beginning of the algorithm
WARPS_PER_BLOCK
):
MAX_THREADS_PER_BLOCK = 1024
assert MAX_THREADS_PER_BLOCK == 1024
WARP_SIZE = 32
assert WARP_SIZE == 32
LOG_WARP_SIZE = 5 # log2(threads_per_warp)
block_dim = WARP_SIZE * WARPS_PER_BLOCK
assert block_dim <= MAX_THREADS_PER_BLOCK
N = block_dim * STEPS_PER_THREAD
@cuda.jit(fastmath=False, lineinfo=True)
def scan(gates, tokens, result):
def mempty():
return np.float32(1.), np.float32(0.)
def mappend(fl, xl, fr, xr):
return fl * fr, xl * fr + xr
thread_id = cuda.threadIdx.x
warp_id = thread_id // WARP_SIZE
lane_id = thread_id % WARP_SIZE
#
# Perform a little bit of sequential computation in thread registers.
#
thread_acc_f = cuda.local.array(STEPS_PER_THREAD, np.float32)
thread_acc_x = cuda.local.array(STEPS_PER_THREAD, np.float32)
acc_f, acc_x = mempty()
for i in range(0, STEPS_PER_THREAD):
g = np.float32(gates[thread_id * STEPS_PER_THREAD + i])
t = np.float32(tokens[thread_id * STEPS_PER_THREAD + i])
acc_f, acc_x = mappend(acc_f,
acc_x,
g,
t)
thread_acc_f[i], thread_acc_x[i] = acc_f, acc_x
cuda.syncthreads()
#
# Stitch threads in a warp using shuffling.
#
for e in range(0, LOG_WARP_SIZE):
delta = 1 << e
# At the same time:
# Send acc to the thread with id (lane_id + delta)
# Receive acc of the thread with id (lane_id - delta)
recv_f = cuda.shfl_up_sync(0xffffffff, acc_f, delta)
recv_x = cuda.shfl_up_sync(0xffffffff, acc_x, delta)
if lane_id >= delta:
for i in range(0, STEPS_PER_THREAD):
acc_f, acc_x = mappend(thread_acc_f[i], thread_acc_x[i], recv_f, recv_x)
thread_acc_f[i], thread_acc_x[i] = acc_f, acc_x
cuda.syncthreads()
#
# Stitch warps in a block using shared memory.
#
warp_last_f = cuda.shared.array(shape=WARPS_PER_BLOCK, dtype=np.float32)
warp_last_x = cuda.shared.array(shape=WARPS_PER_BLOCK, dtype=np.float32)
if lane_id == WARP_SIZE - 1:
warp_last_f[warp_id] = acc_f
warp_last_x[warp_id] = acc_x
cuda.syncthreads()
if WARPS_PER_BLOCK <= 4:
# if there are at most 4 warps per block, do a sequential scan over shared memory
if thread_id == 0:
warp_acc_f, warp_acc_x = mempty()
for w in range(0, WARPS_PER_BLOCK):
warp_acc_f, warp_acc_x = mappend(warp_acc_f,
warp_acc_x,
warp_last_f[w],
warp_last_x[w])
warp_last_f[w] = warp_acc_f
warp_last_x[w] = warp_acc_x
else:
# otherwise do a warp scan on warp 0
if warp_id == 0:
if lane_id < WARPS_PER_BLOCK:
warp_acc_f = warp_last_f[lane_id]
warp_acc_x = warp_last_x[lane_id]
else:
warp_acc_f, warp_acc_x = mempty()
for e in range(0, LOG_WARP_SIZE):
delta = 1 << e
recv_f = cuda.shfl_up_sync(0xffffffff, warp_acc_f, delta)
recv_x = cuda.shfl_up_sync(0xffffffff, warp_acc_x, delta)
if lane_id >= delta:
warp_acc_f, warp_acc_x = mappend(warp_acc_f, warp_acc_x, recv_f, recv_x)
if lane_id < WARPS_PER_BLOCK:
warp_last_f[lane_id] = warp_acc_f
warp_last_x[lane_id] = warp_acc_x
#print('warp', warp_id, 'lane', lane_id, 'warp_acc', warp_acc)
cuda.syncthreads()
# Add the last element of the previous warp to each element of the current warp.
if warp_id > 0:
warp_from_left_f = warp_last_f[warp_id - 1]
warp_from_left_x = warp_last_x[warp_id - 1]
else:
warp_from_left_f, warp_from_left_x = mempty()
for i in range(0, STEPS_PER_THREAD):
thread_acc_f[i], thread_acc_x[i] = mappend(thread_acc_f[i],
thread_acc_x[i],
warp_from_left_f,
warp_from_left_x)
result[thread_id * STEPS_PER_THREAD + i] = thread_acc_x[i]
#print('thread', thread_id, 'warp', warp_id, 'lane', lane_id, 'acc', i, thread_acc[i])
return scan, block_dim, N