Skip to content

Commit

Permalink
Fix bugs in PyTorch codegen. (halide#7443)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yongqi-Zhuo authored and ardier committed Mar 3, 2024
1 parent 4d97824 commit 902f764
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/CodeGen_PyTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void CodeGen_PyTorch::compile(const Module &module) {
"Please add \"-user_context\" to the generator's target options.\n";
}
stream << "#include \"ATen/cuda/CUDAContext.h\"\n";
stream << "#include \"HalidePyTorchCudaHelpers.h\"\n";
}
stream << "#include \"HalideBuffer.h\"\n";
stream << "#include \"HalidePyTorchHelpers.h\"\n";
Expand All @@ -43,6 +44,11 @@ void CodeGen_PyTorch::compile(const Module &module) {
}

for (const auto &f : module.functions()) {
// Don't put non-external function declarations in headers.
// We need to be consistent with CodeGen_C::compile.
if (f.linkage == LinkageType::Internal) {
continue;
}
if (target.has_feature(Target::CUDA)) {
compile(f, true);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/HalidePyTorchHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "HalideBuffer.h"

// Forward declare the cuda_device_interface, for tensor wrapper.
const halide_device_interface_t *halide_cuda_device_interface();
extern "C" const halide_device_interface_t *halide_cuda_device_interface();

#define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
Expand Down

0 comments on commit 902f764

Please sign in to comment.