Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jpy.byte_buffer() function #112

Merged

Conversation

jmao-denver
Copy link
Contributor

@jmao-denver jmao-denver commented Sep 22, 2023

Fixes #111
Dependent on by deephaven/deephaven-core#4936

  1. byte_buffer() utility function to create a Java direct ByteBuffer wrapper that shares the underlying buffer of a Python buffer object (that implements the buffer protocol).
  2. when calling a Java method that takes a ByteBuffer argument, a Java ByteBuffer wrapper can be passed in.
  3. when calling a Java method with the last argument being a variadic one of Java ByteBuffer type, a sequence of Java ByteBuffer wrappers made from calling byte_buffer() util function can be accepted.

Note, Java methods that receive ByteBuffer arguments are considered to be borrowing these buffers, not owning them, and they are only guaranteed to be safe to access for the duration of the methods. So if these Python buffers are to be used in Java beyond these method calls, it falls to the users to keep the Python objects from being GCed after the Java method calls finish or make copies of the buffers before returning from the method calls.

import jpyutil
jpyutil.init_jvm()
import jpy

def check_jbb(jbb):
    print(jbb.toString())
    print(f"isReadOnly: {jbb.isReadOnly()}")
    print("before scanning...")
    print(f"remaining: {jbb.remaining()}")
    print(f"position: {jbb.position()}")
    for i in range(jbb.remaining()):
        print(i, jbb.get())
    print("after scanning...")
    print(jbb.toString())
    print(f"remaining: {jbb.remaining()}")
    print(f"position: {jbb.position()}")


ba = b'abc'
jbb = jpy.byte_buffer(ba)
check_jbb(jbb)

import pyarrow as pa

data = [
    pa.array([1, 2, 3, 4]),
    pa.array(['foo', 'bar', 'baz', None]),
    pa.array([True, None, False, True])
]


batch = pa.record_batch(data, names=['f0', 'f1', 'f2'])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, batch.schema) as writer:
   for i in range(5):
      writer.write_batch(batch)

buf = sink.getvalue()
jbb = jpy.byte_buffer(buf)
check_jbb(jbb)

buf = batch.schema.serialize()
jbb = jpy.byte_buffer(buf)
check_jbb(jbb)

@jmao-denver jmao-denver added the enhancement New feature or request label Sep 22, 2023
@jmao-denver jmao-denver added this to the Sep 2023 milestone Sep 22, 2023
@jmao-denver jmao-denver self-assigned this Sep 22, 2023
src/main/c/jpy_jbyte_buffer.h Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_module.c Outdated Show resolved Hide resolved
src/main/c/jpy_module.c Outdated Show resolved Hide resolved
src/main/c/jpy_module.c Show resolved Hide resolved
src/main/c/jpy_module.c Outdated Show resolved Hide resolved
Copy link
Contributor

@niloc132 niloc132 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notes from call:

  • I like that we have a helper for this, it could be handy for explicit manipulation of the java type from python
  • Please consider tweaking the python->java call code so that if the matching method has a ByteBuffer argument, and the provided argument is apparently a python buffer, we automatically wrap in this way.
  • Likewise, when Java calls python with a direct bytebuffer, consider automatically providing that to python as a buffer rather than assuming that the python being called already knows how to handle it. Non-direct buffers probably just have to be passed as-is.
  • Handling readonly should be as simple as calling .asReadOnlyBuffer() on the java ByteBuffer instance - this shouldn't copy, just provide a readonly view for Java, still leaving the instance as a direct bytebuffer.

Copy link
Member

@devinrsmith devinrsmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fun PR. I generally like the ideas.

There's one big gotcha: this assumes that python will keep the reference alive as long as java needs it, and there is no way for java to let python know if it still needs the reference in the current approach.

For example:

someJavaMethod(jpy.byte_buffer(some_buffer_obj))

will get a buffer, create a py/java ByteBuffer, pass it to someJavaMethod, and upon return, delete the python object, release the buffer, and tell jni to delete (decrement) the ByteBuffer; but java might still have a reference to the ByteBuffer, which now points to potentially invalid memory.

We may be okay with this limitation, as long as we are very clear. I don't know if there is an easy workaround. My initial thoughts go to a jpy-owned java class that more closely models the buffer protocol; where the java object owns a PyObject (memoryview probably) to ensure it stays alive for as long as java has a reference to it.

Looking deeper though, there may actually be a better solution:

    // Invoked to construct a direct ByteBuffer referring to the block of
    // memory. A given arbitrary object may also be attached to the buffer.
    //
    DirectByteBuffer(long addr, int cap, Object ob) {
        super(-1, 0, cap, cap);
        address = addr;
        cleaner = null;
        att = ob;
    }


    // Invoked only by JNI: NewDirectByteBuffer(void*, long)
    //
    private DirectByteBuffer(long addr, int cap) {
        super(-1, 0, cap, cap);
        address = addr;
        cleaner = null;
        att = null;
    }
    // An object attached to this buffer. If this buffer is a view of another
    // buffer then we use this field to keep a reference to that buffer to
    // ensure that its memory isn't freed before we are done with it.
    private final Object att;

This comment on att is exactly what we want. If it's a pointer to our python obj repr of ByteBuffer (or if we use an auxilary memoryview python obj), then I think that solves the problem.

I don't know exactly how to get at it yet, but there must be a way. jdk.internal.misc.JavaNioAccess#newDirectByteBuffer

src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jbyte_buffer.c Outdated Show resolved Hide resolved
src/main/c/jpy_module.c Show resolved Hide resolved
src/main/c/jpy_module.c Show resolved Hide resolved
@devinrsmith
Copy link
Member

Another small wrinkle for the future; java 21 has long for capacity.

@jmao-denver jmao-denver changed the title Add jpy.byte_buffer() function Add jpy.byte_buffer() function and auto-convert Py buffer object to Java ByteBuffer in args Dec 11, 2023
@jmao-denver jmao-denver marked this pull request as ready for review December 14, 2023 16:18
Copy link
Member

@devinrsmith devinrsmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got strong reservations about the current impl. I think there are ways that we could cleanly fix it.

Also, I think there is a leak:

import jpyutil
jpyutil.init_jvm()

import ctypes
import gc
import jpy


def get_refcount(obj_id):
  return ctypes.c_long.from_address(obj_id).value


def garbage_collect() -> None:
    gc.collect()
    _j_system = jpy.get_type("java.lang.System")
    _j_system.gc()
    _j_system.gc()


x = b'these are some bytes'
x_id = id(x)

print(f"get_refcount(x_id)={get_refcount(x_id)}")

bb = jpy.byte_buffer(x)

# this works correctly
# bb = memoryview(x)

print(f"get_refcount(x_id)={get_refcount(x_id)}")

del bb
garbage_collect()

print(f"get_refcount(x_id)={get_refcount(x_id)}")

results in

get_refcount(x_id)=1
get_refcount(x_id)=2
get_refcount(x_id)=2

src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_module.c Show resolved Hide resolved
src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_jbyte_buffer.h Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jtype.c Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
@devinrsmith
Copy link
Member

Here's a script I've been using:

import jpy
from contextlib import contextmanager

_JByteBuffer = jpy.get_type('java.nio.ByteBuffer')
_JArrowToTableConverter = jpy.get_type('io.deephaven.extensions.barrage.util.ArrowToTableConverter')

@contextmanager
def jpy_flags(flags):
    orig_flags = jpy.diag.flags
    jpy.diag.flags = flags
    try:
        yield
    finally:
        jpy.diag.flags = orig_flags

def buffer_protocol():
    return jpy.byte_buffer(b'hello world')

def j_byte_buffer():
    return _JByteBuffer.allocate(42)

def j_buffer():
    # return type is java.nio.Buffer *not* ByteBuffer
    # impl is ByteBuffer.allocateDirect(size)
    return _JArrowToTableConverter.myBuffer(43)

def j_object():
    # return type is java.lang.Object 
    # impl is new Object()
    return _JArrowToTableConverter.newObject()

def print_info(name, x):
    print(f"{name}: type(x)={type(x)}" + "\n")
    print(f"{name}: x={x}" + "\n")

def create_print_del(name, fn):
    my_obj = fn()
    with jpy_flags(jpy.diag.F_OFF):
        print_info(name, my_obj)
    del my_obj

with jpy_flags(jpy.diag.F_MEM | jpy.diag.F_TYPE):
    create_print_del('j_object', j_object)
    create_print_del('buffer_protocol', buffer_protocol)
    create_print_del('j_byte_buffer', j_byte_buffer)
    create_print_del('j_buffer', j_buffer)

I've added some public static methods to ArrowToTableConverter just because it was a place to put some java logic.

@devinrsmith
Copy link
Member

https://github.com/devinrsmith/jpy/tree/111-DirectByteBuffer-support is the branch where I've added some logging.

Copy link
Member

@devinrsmith devinrsmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resolves my memory safety / allocation / deallocation concerns (at least as far as the c layer itself is concerned - there are still memory safety concerns around how users must hold jpy.byte_buffer alive as long as the ByteBuffer is in java). I think the big take away is that this struct logic needs to be based on the type and not the actual object - it's a nuanced point that could have led to very-hard-to-debug crashes.

I want to make sure we get the naming correct; and also see if Jianfeng wants to move this to this logic to type-creation time (instead of object-creation time) which would help solidify the correct pattern. Neither are blockers for approval, but just want to poll / discuss.

src/main/c/jpy_jbyte_buffer.h Outdated Show resolved Hide resolved
src/main/c/jpy_jobj.c Show resolved Hide resolved
src/main/c/jpy_jobj.c Outdated Show resolved Hide resolved
@jmao-denver jmao-denver changed the title Add jpy.byte_buffer() function and auto-convert Py buffer object to Java ByteBuffer in args Add jpy.byte_buffer() function Dec 31, 2023
src/main/c/jpy_module.c Outdated Show resolved Hide resolved
@jmao-denver jmao-denver merged commit 99418df into jpy-consortium:master Jan 3, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add the ability to create Java DirectByteBuffer from Python objects that support the Buffer Protocol
6 participants