Skip to content

Commit

Permalink
Improvements and fixes to UsmNdArray type.
Browse files Browse the repository at this point in the history
  • Loading branch information
mingjie-intel authored and diptorupd committed Feb 5, 2023
1 parent 81cd967 commit 38083f2
Showing 1 changed file with 52 additions and 13 deletions.
65 changes: 52 additions & 13 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import dpctl
import dpctl.tensor
from numba.core.typeconv import Conversion
from numba.core.typeinfer import CallConstraint
from numba.core.types.npytypes import Array
from numba.np.numpy_support import from_dtype

from numba_dpex.utils import address_space

Expand All @@ -18,10 +20,10 @@ class USMNdArray(Array):

def __init__(
self,
dtype,
ndim,
layout,
usm_type="unknown",
layout="C",
dtype=None,
usm_type="device",
device="unknown",
queue=None,
readonly=False,
Expand All @@ -32,15 +34,51 @@ def __init__(
self.usm_type = usm_type
self.addrspace = addrspace

# Normalize the device filter string and get the fully qualified three
# tuple (backend:device_type:device_num) filter string from dpctl.
if device != "unknown":
_d = dpctl.SyclDevice(device)
self.device = _d.filter_string
if queue is not None and device is not None:
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object"
)
d1 = queue.sycl_device
d2 = dpctl.SyclDevice(device)
if d1 != d2:
raise TypeError(
"The queue keyword arg and the device keyword arg specify "
"different SYCL devices"
)
self.queue = queue
self.device = device
elif queue is None and device is not None:
if not isinstance(device, str):
raise TypeError(
"The device keyword arg should be a str object specifying "
"a SYCL filter selector"
)
self.queue = dpctl.SyclQueue(device)
self.device = device
elif queue is not None and device is None:
if not isinstance(queue, dpctl.SyclQueue):
raise TypeError(
"The queue keyword arg should be a dpctl.SyclQueue object"
)
self.device = self.queue.sycl_device.filter_string
self.queue = queue
else:
self.device = "unknown"
self.queue = dpctl.SyclQueue()
self.device = self.queue.sycl_device.filter_string

self.queue = queue
if not dtype:
dummy_tensor = dpctl.tensor.empty(
sh=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
self.dtype = from_dtype(_dtype)

if name is None:
type_name = "usm_ndarray"
Expand All @@ -50,20 +88,21 @@ def __init__(
type_name = "unaligned " + type_name
name_parts = (
type_name,
dtype,
self.dtype,
ndim,
layout,
self.addrspace,
usm_type,
self.device,
self.queue,
)
name = (
"%s(dtype=%s, ndim=%s, layout=%s, address_space=%s, "
"usm_type=%s, sycl_device=%s)" % name_parts
"usm_type=%s, device=%s, sycl_device=%s)" % name_parts
)

super().__init__(
dtype,
self.dtype,
ndim,
layout,
readonly=readonly,
Expand Down

0 comments on commit 38083f2

Please sign in to comment.