Skip to content

Commit

Permalink
Use nrt api to allocate meminfo object
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed May 14, 2024
1 parent 574daab commit af44736
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
22 changes: 16 additions & 6 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
bool value_is_float,
int64_t value,
const DPCTLSyclQueueRef qref);
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
PyObject *ndarrobj,
void *data,
npy_intp nitems,
npy_intp itemsize,
Expand All @@ -58,7 +59,8 @@ static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
int ndim,
PyArray_Descr *descr);

static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
PyObject *obj,
usmarystruct_t *arystruct);
static PyObject *
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
Expand Down Expand Up @@ -336,6 +338,11 @@ NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type)
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info)
{
MemInfoDtorInfo *mi_dtor_info = NULL;
// Warning: we are destructing sycl memory. MI destructor is called
// separately by numba.
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Call to "
"usmndarray_meminfo_dtor at %s, line %d\n",
__FILE__, __LINE__));

// Sanity-check to make sure the mi_dtor_info is an actual pointer.
if (!(mi_dtor_info = (MemInfoDtorInfo *)info)) {
Expand Down Expand Up @@ -416,7 +423,8 @@ static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner)
* of the dpnp.ndarray was allocated.
* @return {return} A new NRT_MemInfo object
*/
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
PyObject *ndarrobj,
void *data,
npy_intp nitems,
npy_intp itemsize,
Expand All @@ -428,7 +436,8 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
DPCTLSyclContextRef cref = NULL;

// Allocate a new NRT_MemInfo object
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
// By passing 0 we are just allocating MemInfo
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo "
"object at %s, line %d\n",
Expand Down Expand Up @@ -795,7 +804,8 @@ static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim)
* instance of a dpnp.ndarray
* @return {return} Error code representing success (0) or failure (-1).
*/
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
PyObject *obj,
usmarystruct_t *arystruct)
{
struct PyUSMArrayObject *arrayobj = NULL;
Expand Down Expand Up @@ -842,7 +852,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
}

if (!(arystruct->meminfo = NRT_MemInfo_new_from_usmndarray(
obj, data, nitems, itemsize, qref)))
nrt, obj, data, nitems, itemsize, qref)))
{
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "
Expand Down
6 changes: 4 additions & 2 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,15 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
"""
fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
)
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
fn = pyapi._get_function(fnty, "DPEXRT_sycl_usm_ndarray_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")
fn.args[2].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))

return self.error

Expand Down

0 comments on commit af44736

Please sign in to comment.