-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
const std = @import("std"); | ||
const py = @import("pydust"); | ||
|
||
pub fn sum(args: *const struct { arr: py.PyObject }) !u64 { | ||
var out: py.PyBuffer = try py.PyBuffer.get(args.arr); | ||
defer out.decref(); | ||
|
||
const values = try out.asSliceView(u64); | ||
var s: u64 = 0; | ||
for (values) |v| s += v; | ||
return s; | ||
} | ||
|
||
pub fn reverse(args: *const struct { arr: py.PyObject }) !py.PyObject { | ||
var out: py.PyBuffer = try py.PyBuffer.get(args.arr); | ||
// don't decref out because we return it as result | ||
|
||
// we can just work with slice, but this tests getPtr | ||
const length: usize = @intCast(out.shape[0]); | ||
const iter: usize = @divFloor(length, 2); | ||
for (0..iter) |i| { | ||
var left = try out.getPtr(u64, &[_]isize{@intCast(i)}); | ||
var right = try out.getPtr(u64, &[_]isize{@intCast(length - i - 1)}); | ||
const tmp: u64 = left.*; | ||
left.* = right.*; | ||
right.* = tmp; | ||
} | ||
|
||
return out.obj; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
const std = @import("std"); | ||
const py = @import("../pydust.zig"); | ||
const ffi = py.ffi; | ||
const PyError = @import("../errors.zig").PyError; | ||
|
||
/// Wrapper for Python Py_buffer. | ||
/// See: https://docs.python.org/3/c-api/buffer.html | ||
pub const PyBuffer = extern struct { | ||
const Self = @This(); | ||
|
||
buf: ?[*]u8, | ||
obj: py.PyObject, | ||
// product(shape) * itemsize. | ||
// For contiguous arrays, this is the length of the underlying memory block. | ||
// For non-contiguous arrays, it is the length that the logical structure would | ||
// have if it were copied to a contiguous representation. | ||
len: isize, | ||
itemsize: isize, | ||
readonly: c_int, | ||
ndim: c_int, | ||
format: [*:0]u8, | ||
shape: [*:0]isize, | ||
strides: [*:0]isize, | ||
suboffsets: [*:0]isize, | ||
internal: ?*anyopaque, | ||
|
||
pub fn get(obj: py.PyObject) !Self { | ||
return getWithFlag(obj, ffi.PyBUF_FULL); | ||
} | ||
|
||
pub fn getro(obj: py.PyObject) !Self { | ||
return getWithFlag(obj, ffi.PyBUF_FULL_RO); | ||
} | ||
|
||
pub fn getWithFlag(obj: py.PyObject, flag: c_int) !Self { | ||
if (ffi.PyObject_CheckBuffer(obj.py) != 1) { | ||
// TODO(marko): This should be an error once we figure out how to do it | ||
@panic("not a buffer"); | ||
} | ||
var out: Self = undefined; | ||
if (ffi.PyObject_GetBuffer(obj.py, @ptrCast(&out), flag) != 0) { | ||
// TODO(marko): This should be an error once we figure out how to do it | ||
@panic("unable to get buffer"); | ||
} | ||
return out; | ||
} | ||
|
||
pub fn asSliceView(self: *const Self, comptime value_type: type) ![]value_type { | ||
if (ffi.PyBuffer_IsContiguous(@ptrCast(self), 'C') != 1) { | ||
// TODO(marko): This should be an error once we figure out how to do it | ||
@panic("only continuous buffers are supported for view - use getPtr instead"); | ||
} | ||
return @alignCast(std.mem.bytesAsSlice(value_type, self.buf.?[0..@intCast(self.len)])); | ||
} | ||
|
||
pub fn fromOwnedSlice(comptime value_type: type, values: []value_type) !Self { | ||
_ = values; | ||
// TODO(marko): We need to create an object using PyType_FromSpec and register buffer release | ||
@panic("not implemented"); | ||
} | ||
|
||
pub fn getPtr(self: *const Self, comptime value_type: type, item: [*]const isize) !*value_type { | ||
var ptr: *anyopaque = ffi.PyBuffer_GetPointer(@ptrCast(self), item) orelse return PyError.Propagate; | ||
return @ptrCast(@alignCast(ptr)); | ||
} | ||
|
||
pub fn incref(self: *Self) void { | ||
self.obj.incref(); | ||
} | ||
|
||
pub fn decref(self: *Self) void { | ||
// decrefs the underlying object | ||
ffi.PyBuffer_Release(@ptrCast(self)); | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from example import modules | ||
from array import array # implements buffer protocol | ||
|
||
|
||
def test_sum(): | ||
arr = array("L", [1, 2, 3, 4, 5]) # uint64 | ||
assert modules.sum(arr) == 15 | ||
|
||
|
||
def test_reverse(): | ||
arr = array("L", [1, 2, 3, 4, 5]) # uint64 | ||
assert arr == modules.reverse(arr) |