diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu index d6b9b2f70710..a1fc7d15aec9 100644 --- a/csrc/adam/multi_tensor_adam.cu +++ b/csrc/adam/multi_tensor_adam.cu @@ -30,7 +30,7 @@ typedef enum : int { using MATH_T = float; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, @@ -48,13 +48,13 @@ struct AdamFunctor { // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; // potentially use to pass in list of scalar // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc]; g += chunk_idx * chunk_size; @@ -71,7 +71,8 @@ struct AdamFunctor { n -= chunk_idx * chunk_size; // see note in multi_tensor_scale_kernel.cu - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size, bias_correction2 = 1 - std::pow(beta2, step); } + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { break; } + } + // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), - 0, - "adam", - multi_tensor_apply<4>(BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor(), - beta1, - beta2, - bias_correction1, - bias_correction2, - epsilon, - lr, - (adamMode_t)mode, - weight_decay);) + if (requires_64bit_indexing) { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, + (int64_t)chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } else { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh index 12f41cb49c6b..342376c141be 100644 --- a/csrc/adam/multi_tensor_apply.cuh +++ b/csrc/adam/multi_tensor_apply.cuh @@ -35,7 +35,7 @@ struct TensorListMetadata { }; template -__global__ void multi_tensor_apply_kernel(int chunk_size, +__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable, @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size, } template -void multi_tensor_apply(int block_size, - int chunk_size, +void multi_tensor_apply(int64_t block_size, + int64_t chunk_size, const at::Tensor& noop_flag, const std::vector>& tensor_lists, T callable, @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size, tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); loc_tensor_info++; - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk;