Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

64bit indexing fused adam #5187

Merged
merged 11 commits into from
Apr 22, 2024
77 changes: 56 additions & 21 deletions csrc/adam/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ typedef enum : int {

using MATH_T = float;

template <typename T>
template <typename T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this chunk_size is still an int and not int64/index_t

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@garrett4wade - thoughts on this?

volatile int* noop_gmem,
Expand All @@ -48,13 +48,13 @@ struct AdamFunctor {
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
Expand All @@ -71,7 +71,8 @@ struct AdamFunctor {
n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
Expand Down Expand Up @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) { break; }
}

// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
if (requires_64bit_indexing) {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
(int64_t)chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int64_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
} else {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int32_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
}

AT_CUDA_CHECK(cudaGetLastError());
}
10 changes: 5 additions & 5 deletions csrc/adam/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct TensorListMetadata {
};

template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
__global__ void multi_tensor_apply_kernel(int64_t chunk_size,
volatile int* noop_flag,
T tl,
U callable,
Expand All @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size,
}

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
void multi_tensor_apply(int64_t block_size,
int64_t chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
Expand Down Expand Up @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size,
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;

int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;

for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
Expand Down
Loading