Skip to content

Commit

Permalink
Update with suggestions from @pearu.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Sep 24, 2024
1 parent 4bc9b4c commit c243a10
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 41 deletions.
18 changes: 5 additions & 13 deletions sparse/numba_backend/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,8 @@ def isnan(self):
# `GCXS` is a reshaped/transposed `CSR`, but it can't (usually)
# be expressed in the `binsparse` 0.1 language.
# We are missing index maps.
def __binsparse_descriptor__(self) -> dict:
return super().__binsparse_descriptor__()

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return super().__binsparse_dlpack__()
def __binsparse__(self) -> tuple[dict, list[np.ndarray]]:
return super().__binsparse__()

Check warning on line 851 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L851

Added line #L851 was not covered by tests


class _Compressed2d(GCXS):
Expand Down Expand Up @@ -892,13 +889,13 @@ def from_numpy(cls, x, fill_value=0, idx_dtype=None):
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype)

def __binsparse_descriptor__(self) -> dict:
def __binsparse__(self) -> tuple[dict, list[np.ndarray]]:
from sparse._version import __version__

Check warning on line 893 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L893

Added line #L893 was not covered by tests

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[float{self.data.dtype.itemsize // 2}]"
return {
descriptor = {

Check warning on line 898 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L895-L898

Added lines #L895 - L898 were not covered by tests
"binsparse": {
"version": "0.1",
"format": self.format.upper(),
Expand All @@ -913,12 +910,7 @@ def __binsparse_descriptor__(self) -> dict:
"original_source": f"`sparse`, version {__version__}",
}

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return {
"pointers_to_1": self.indices,
"indices_1": self.indptr,
"values": self.data,
}
return descriptor, [self.indices, self.indptr, self.data]

Check warning on line 913 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L913

Added line #L913 was not covered by tests


class CSR(_Compressed2d):
Expand Down
11 changes: 3 additions & 8 deletions sparse/numba_backend/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,13 +1537,13 @@ def isnan(self):
prune=True,
)

def __binsparse_descriptor__(self) -> dict:
def __binsparse__(self) -> tuple[dict, list[np.ndarray]]:
from sparse._version import __version__

Check warning on line 1541 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1541

Added line #L1541 was not covered by tests

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[float{self.data.dtype.itemsize // 2}]"
return {
descriptor = {

Check warning on line 1546 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1543-L1546

Added lines #L1543 - L1546 were not covered by tests
"binsparse": {
"version": "0.1",
"format": {
Expand All @@ -1568,12 +1568,7 @@ def __binsparse_descriptor__(self) -> dict:
"original_source": f"`sparse`, version {__version__}",
}

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return {
"pointers_to_1": np.array([0, self.nnz], dtype=np.uint8),
"indices_1": self.coords,
"values": self.data,
}
return descriptor, [np.array([0, self.nnz], dtype=np.uint8), self.coords, self.data]

Check warning on line 1571 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1571

Added line #L1571 was not covered by tests


def as_coo(x, shape=None, fill_value=None, idx_dtype=None):
Expand Down
7 changes: 2 additions & 5 deletions sparse/numba_backend/_dok.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,8 @@ def reshape(self, shape, order="C"):

return DOK.from_coo(self.to_coo().reshape(shape))

def __binsparse_descriptor__(self) -> dict:
raise RuntimeError("`DOK` doesn't support the `__binsparse_descriptor__` protocol.")

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
raise RuntimeError("`DOK` doesn't support the `__binsparse_dlpack__` protocol.")
def __binsparse__(self) -> tuple[dict, list[np.ndarray]]:
raise RuntimeError("`DOK` doesn't support the `__binsparse__` protocol.")

Check warning on line 552 in sparse/numba_backend/_dok.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_dok.py#L552

Added line #L552 was not covered by tests


def to_slice(k):
Expand Down
21 changes: 6 additions & 15 deletions sparse/numba_backend/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,27 +219,18 @@ def _str_impl(self, summary):
return summary

@abstractmethod
def __binsparse_descriptor__(self) -> dict:
"""Return a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor)
def __binsparse__(self) -> tuple[dict, list[np.ndarray]]:
"""Return a 2-tuple:
* First element is a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor)
of this array.
* Second element is a `list[np.ndarray]` of the constituent arrays.
Returns
-------
dict
Parsed `binsparse` descriptor.
"""
raise NotImplementedError

@abstractmethod
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
"""A `dict` containing the constituent arrays of this sparse array. The keys are compatible with the
[`binsparse`](https://graphblas.org/binsparse-specification/) scheme, and the values are [`__dlpack__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html)
compatible objects.
Returns
-------
dict[str, np.ndarray]
The constituent arrays.
list[np.ndarray]
The constituent arrays
"""
raise NotImplementedError

Expand Down

0 comments on commit c243a10

Please sign in to comment.