Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for lpython decorator #1796

Merged
merged 5 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ inst/bin/*
*_lines.dat.txt
*__tmp__generated__.c
visualize*.html
lpython_decorator*/
a.c
a.h
a.py
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ RUN(NAME callback_01 LABELS cpython llvm)
# Intrinsic Functions
RUN(NAME intrinsics_01 LABELS cpython llvm) # any

COMPILE(NAME import_order_01 LABELS cpython llvm c) # any
# lpython decorator
RUN(NAME lpython_decorator_01 LABELS cpython)

# Jit
RUN(NAME test_lpython_decorator LABELS cpython)
COMPILE(NAME import_order_01 LABELS cpython llvm c) # any
4 changes: 2 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3589,8 +3589,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
is_inline = true;
} else if (name == "static") {
is_static = true;
} else if (name == "jit") {
throw SemanticError("`@lpython.jit` decorator must be "
} else if (name == "lpython") {
throw SemanticError("`@lpython` decorator must be "
"run from CPython, not compiled using LPython",
dec->base.loc);
} else {
Expand Down
36 changes: 22 additions & 14 deletions src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import platform
from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass
from goto import with_goto
from numpy import get_include
from distutils.sysconfig import get_python_inc

# TODO: this does not seem to restrict other imports
__slots__ = ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "c32", "c64", "CPtr",
Expand Down Expand Up @@ -572,11 +570,13 @@ def get_data_type(t):
source_code = getsource(function)
source_code = source_code[source_code.find('\n'):]

# TODO: Create a filename based on the function name
# filename = function.__name__ + ".py"
dir_name = "./lpython_decorator_" + self.fn_name
if not os.path.exists(dir_name):
os.mkdir(dir_name)
filename = dir_name + "/" + self.fn_name

# Open the file for writing
with open("a.py", "w") as file:
with open(filename + ".py", "w") as file:
# Write the Python source code to the file
file.write("@ccallable")
file.write(source_code)
Expand Down Expand Up @@ -682,7 +682,7 @@ def get_data_type(t):
#include <numpy/ndarrayobject.h>

// LPython generated C code
#include "a.h"
#include "{self.fn_name}.h"

// Define the Python module and method mappings
static PyObject* define_module(PyObject* self, PyObject* args) {{
Expand All @@ -700,13 +700,13 @@ def get_data_type(t):
// Define the module initialization function
static struct PyModuleDef module_def = {{
PyModuleDef_HEAD_INIT,
"lpython_jit_module",
"lpython_module_{self.fn_name}",
"Shared library to use LPython generated functions",
-1,
module_methods
}};

PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{
PyMODINIT_FUNC PyInit_lpython_module_{self.fn_name}(void) {{
PyObject* module;

// Create the module object
Expand All @@ -720,33 +720,41 @@ def get_data_type(t):
"""
# ----------------------------------------------------------------------
# Write the C source code to the file
with open("a.c", "w") as file:
with open(filename + ".c", "w") as file:
file.write(template)

# ----------------------------------------------------------------------
# Generate the Shared library
# TODO: Use LLVM instead of C backend
r = os.system("lpython --show-c --disable-main a.py > a.h")
r = os.system("lpython --show-c --disable-main "
+ filename + ".py > " + filename + ".h")
assert r == 0, "Failed to create C file"

gcc_flags = ""
if platform.system() == "Linux":
gcc_flags = " -shared -fPIC "
elif platform.system() == "Darwin":
gcc_flags = " -bundle -flat_namespace -undefined suppress "
else:
raise NotImplementedError("Platform not implemented")

from numpy import get_include
from distutils.sysconfig import get_python_inc, get_python_lib
python_path = "-I" + get_python_inc() + " "
numpy_path = "-I" + get_include()
numpy_path = "-I" + get_include() + " "
rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime "
rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \
+ get_rtlib_dir() + " -llpython_runtime "
python_lib = "-L" "$CONDA_PREFIX/lib/ -lpython3.10 -lm"
python_lib = "-L" + get_python_lib() + "/../.. -lpython3.10 -lm"

r = os.system("gcc -g" + gcc_flags + python_path + numpy_path +
" a.c -o lpython_jit_module.so " + rt_path_01 + rt_path_02 + python_lib)
filename + ".c -o lpython_module_" + self.fn_name + ".so " +
rt_path_01 + rt_path_02 + python_lib)
assert r == 0, "Failed to create the shared library"

def __call__(self, *args, **kwargs):
import sys; sys.path.append('.')
# import the symbol from the shared library
function = getattr(__import__("lpython_jit_module"), self.fn_name)
function = getattr(__import__("lpython_module_" + self.fn_name),
self.fn_name)
return function(*args, **kwargs)