Skip to content

Commit

Permalink
Merge branch 'ademeure-less_idle_more_brrr-3'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 5, 2024
2 parents 6c179fa + ce333de commit 8168b78
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 124 deletions.
15 changes: 14 additions & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,25 @@ int main() {
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
cuda_num_SMs = deviceProp.multiProcessorCount;
printf("[System]\n");
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

cuda_num_SMs = deviceProp.multiProcessorCount;
cuda_threads_per_SM = deviceProp.maxThreadsPerMultiProcessor;
cuda_arch_major = deviceProp.major;
cuda_arch_minor = deviceProp.minor;

cudaCheck(cudaStreamCreate(&main_stream));
cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming);
cudaEventCreateWithFlags(&loss_event, cudaEventDisableTiming);
for (int i = 0; i < num_parallel_streams; i++) {
cudaCheck(cudaStreamCreate(&parallel_streams[i]));
cudaEventCreateWithFlags(&parallel_events[i], cudaEventDisableTiming);
}

// setup cuBLAS and cuBLASLt
cublasCheck(cublasCreate(&cublas_handle));
cublasCheck(cublasSetStream(cublas_handle, main_stream));
cublasCheck(cublasLtCreate(&cublaslt_handle));
// TF32 precision is equivalent to torch.set_float32_matmul_precision('high')
int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;
Expand Down
15 changes: 13 additions & 2 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,25 @@ int main(int argc, char *argv[]) {
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
printf("[System]\n");
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

cuda_num_SMs = deviceProp.multiProcessorCount;
cuda_threads_per_SM = deviceProp.maxThreadsPerMultiProcessor;
cuda_arch_major = deviceProp.major;
cuda_arch_minor = deviceProp.minor;
printf("[System]\n");
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

cudaCheck(cudaStreamCreate(&main_stream));
cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming);
cudaEventCreateWithFlags(&loss_event, cudaEventDisableTiming);
for (int i = 0; i < num_parallel_streams; i++) {
cudaCheck(cudaStreamCreate(&parallel_streams[i]));
cudaEventCreateWithFlags(&parallel_events[i], cudaEventDisableTiming);
}

// setup cuBLAS and cuBLASLt
cublasCheck(cublasCreate(&cublas_handle));
cublasCheck(cublasSetStream(cublas_handle, main_stream));
cublasCheck(cublasLtCreate(&cublaslt_handle));
// TF32 precision is equivalent to torch.set_float32_matmul_precision('high')
int enable_tf32 = cuda_arch_major >= 8 ? 1 : 0;
Expand Down
Loading

0 comments on commit 8168b78

Please sign in to comment.