Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Compressed Sparse Column Matrix #1377

Merged
merged 26 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
023f0a7
modified factory method to include compressed column type
Mystic-Slice Feb 3, 2024
48a5bc9
Created a base class for both DCSR_matrix and DCSC_matrix
Mystic-Slice Feb 5, 2024
00fd19f
Updated docstrings and type annotations. The class, methods and dense…
Mystic-Slice Feb 19, 2024
82a38d9
refactoring changes
Mystic-Slice Feb 20, 2024
f070d23
Arithmetric operations implemented
Mystic-Slice Feb 21, 2024
b362628
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
Mystic-Slice Feb 21, 2024
f8cd97d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
fa2bdd4
Merge branch 'main' into sparse_csc
mrfh92 Mar 7, 2024
e7c4184
Merge branch 'main' into sparse_csc
ClaudiaComito Apr 22, 2024
c513bd3
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
Mystic-Slice Apr 22, 2024
6d727af
PyTorch CSC tensors do not support arithmetic ops yet
Mystic-Slice Apr 22, 2024
3684130
Merge branch 'sparse_csc' of https://github.com/helmholtz-analytics/h…
Mystic-Slice Apr 22, 2024
f6db1bb
tests for csc matrix - manipulations
Mystic-Slice Apr 22, 2024
203c274
tests for DCSC_matrix class methods
Mystic-Slice Apr 23, 2024
c2ba977
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
9f6b509
tests for sparse_csc factory method
Mystic-Slice Apr 23, 2024
86e6b05
Merge branch 'sparse_csc' of https://github.com/helmholtz-analytics/h…
Mystic-Slice Apr 23, 2024
5abbab6
Merge branch 'main' into sparse_csc
ClaudiaComito May 8, 2024
bde047a
Merge branch 'main' into sparse_csc
ClaudiaComito May 27, 2024
9fec353
Merge branch 'main' into sparse_csc
ClaudiaComito Jun 3, 2024
b913228
added name to CITATION.cff
Mystic-Slice Jun 5, 2024
959557c
Merge branch 'sparse_csc' of https://github.com/helmholtz-analytics/h…
Mystic-Slice Jun 5, 2024
12b7dbc
Merge branch 'main' into sparse_csc
ClaudiaComito Jun 6, 2024
680f3ee
fix: fixed dtype conversion bug in astype method
Mystic-Slice Jun 7, 2024
28ff0fb
Merge branch 'sparse_csc' of https://github.com/helmholtz-analytics/h…
Mystic-Slice Jun 7, 2024
90dc1d6
skip type conversion test for DCSC_matrix if torch < 2.0
Mystic-Slice Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ preferred-citation:
given-names: Achim
- family-names: Streit
given-names: Achim
- family-names: Vaithinathan Aravindan
given-names: Ashwath
year: 2020
collection-title: 2020 IEEE International Conference on Big Data (IEEE Big Data 2020)
collection-doi: 10.1109/BigData50022.2020.9378050
Expand Down
2 changes: 1 addition & 1 deletion heat/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""add sparse heat function to the ht.sparse namespace"""

from .arithmetics import *
from .dcsr_matrix import *
from .dcsx_matrix import *
from .factories import *
from ._operations import *
from .manipulations import *
116 changes: 78 additions & 38 deletions heat/sparse/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
import torch
import numpy as np

from heat.sparse.dcsr_matrix import DCSR_matrix
from heat.sparse.dcsx_matrix import DCSC_matrix, DCSR_matrix, __DCSX_matrix

from . import factories
from ..core.communication import MPI
from ..core.dndarray import DNDarray
from ..core import types

from typing import Callable, Optional, Dict

__all__ = []


def __binary_op_csr(
def __binary_op_csx(
operation: Callable,
t1: DCSR_matrix,
t2: DCSR_matrix,
out: Optional[DCSR_matrix] = None,
t1: __DCSX_matrix,
t2: __DCSX_matrix,
out: Optional[__DCSX_matrix] = None,
orientation: str = "row",
fn_kwargs: Optional[Dict] = {},
) -> DCSR_matrix:
) -> __DCSX_matrix:
"""
Generic wrapper for element-wise binary operations of two operands.
Takes the operation function and the two operands involved in the operation as arguments.
Expand All @@ -31,37 +31,60 @@ def __binary_op_csr(
operation : PyTorch function
The operation to be performed. Function that performs operation elements-wise on the involved tensors,
e.g. add values from other to self
t1: DCSR_matrix
t1: __DCSX_matrix or scalar
The first operand involved in the operation.
t2: DCSR_matrix
t2: __DCSX_matrix or scalar
The second operand involved in the operation.
out: DCSR_matrix, optional
out: __DCSX_matrix, optional
Output buffer in which the result is placed. If not provided, a freshly allocated matrix is returned.
orientation: str, optional
The orientation of the operation. Options: 'row' or 'col'
Default: 'row'
fn_kwargs: Dict, optional
keyword arguments used for the given operation
Default: {} (empty dictionary)

Returns
-------
result: ht.sparse.DCSR_matrix
A DCSR_matrix containing the results of element-wise operation.
result: ht.sparse.__DCSX_matrix
A __DCSX_matrix containing the results of element-wise operation.

Raises
------
ValueError
If the orientation is invalid
ValueError
If the input types are not supported
ValueError
If the input shapes are not compatible
ValueError
If the output buffer shape is not compatible with the result
"""
if not np.isscalar(t1) and not isinstance(t1, DCSR_matrix):
if orientation not in ["row", "col"]:
raise ValueError(f"Invalid orientation: '{orientation}'. Options: 'row' or 'col'")

if not np.isscalar(t1) and not isinstance(t1, __DCSX_matrix):
raise TypeError(
f"Only Dcsr_matrices and numeric scalars are supported, but input was {type(t1)}"
)
if not np.isscalar(t2) and not isinstance(t2, DCSR_matrix):
if not np.isscalar(t2) and not isinstance(t2, __DCSX_matrix):
raise TypeError(
f"Only Dcsr_matrices and numeric scalars are supported, but input was {type(t2)}"
)

if not isinstance(t1, DCSR_matrix) and not isinstance(t2, DCSR_matrix):
if not isinstance(t1, __DCSX_matrix) and not isinstance(t2, __DCSX_matrix):
raise TypeError(
f"Operator only to be used with Dcsr_matrices, but input types were {type(t1)} and {type(t2)}"
)

promoted_type = types.result_type(t1, t2).torch_type()

torch_constructor = torch.sparse_csr_tensor if orientation == "row" else torch.sparse_csc_tensor
factory_method = (
factories.sparse_csr_matrix if orientation == "row" else factories.sparse_csc_matrix
)
split_axis = 0 if orientation == "row" else 1

# If one of the inputs is a scalar
# just perform the operation on the data tensor
# and create a new sparse matrix
Expand All @@ -74,15 +97,15 @@ def __binary_op_csr(
scalar = t1

res_values = operation(matrix.larray.values().to(promoted_type), scalar, **fn_kwargs)
res_torch_sparse_csr = torch.sparse_csr_tensor(
res_torch_sparse_csx = torch_constructor(
matrix.lindptr,
matrix.lindices,
res_values,
size=matrix.lshape,
device=matrix.device.torch_device,
)
return factories.sparse_csr_matrix(
res_torch_sparse_csr, is_split=matrix.split, comm=matrix.comm, device=matrix.device
return factory_method(
res_torch_sparse_csx, is_split=matrix.split, comm=matrix.comm, device=matrix.device
)

if t1.shape != t2.shape:
Expand All @@ -93,10 +116,10 @@ def __binary_op_csr(

if t1.split is not None or t2.split is not None:
if t1.split is None:
t1 = factories.sparse_csr_matrix(t1.larray, split=0)
t1 = factory_method(t1.larray, split=split_axis)

if t2.split is None:
t2 = factories.sparse_csr_matrix(t2.larray, split=0)
t2 = factory_method(t2.larray, split=split_axis)

output_split = t1.split
output_device = t1.device
Expand All @@ -113,10 +136,10 @@ def __binary_op_csr(

if out.split != output_split:
if out.split is None:
out = factories.sparse_csr_matrix(out.larray, split=0)
out = factory_method(out.larray, split=split_axis)
else:
out = factories.sparse_csr_matrix(
torch.sparse_csr_tensor(
out = factory_method(
torch_constructor(
torch.tensor(out.indptr, dtype=torch.int64),
torch.tensor(out.indices, dtype=torch.int64),
torch.tensor(out.data),
Expand Down Expand Up @@ -146,21 +169,38 @@ def __binary_op_csr(
output_type = types.canonical_heat_type(result.dtype)

if out is None:
return DCSR_matrix(
array=torch.sparse_csr_tensor(
result.crow_indices().to(torch.int64),
result.col_indices().to(torch.int64),
result.values(),
size=output_lshape,
),
gnnz=output_gnnz,
gshape=output_shape,
dtype=output_type,
split=output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
if orientation == "row":
return DCSR_matrix(
array=torch_constructor(
result.crow_indices().to(torch.int64),
result.col_indices().to(torch.int64),
result.values(),
size=output_lshape,
),
gnnz=output_gnnz,
gshape=output_shape,
dtype=output_type,
split=output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
else:
return DCSC_matrix(
array=torch_constructor(
result.ccol_indices().to(torch.int64),
result.row_indices().to(torch.int64),
result.values(),
size=output_lshape,
),
gnnz=output_gnnz,
gshape=output_shape,
dtype=output_type,
split=output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)

out.larray.copy_(result)
out.gnnz = output_gnnz
Expand Down
24 changes: 15 additions & 9 deletions heat/sparse/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .dcsr_matrix import DCSR_matrix
from .dcsx_matrix import DCSC_matrix, DCSR_matrix

from . import _operations

Expand All @@ -14,7 +14,7 @@
]


def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
def add(t1: DCSR_matrix, t2: DCSR_matrix, orientation: str = "row") -> DCSR_matrix:
"""
Element-wise addition of values from two operands, commutative.
Takes the first and second operand (scalar or :class:`~heat.sparse.DCSR_matrix`) whose elements are to be added
Expand All @@ -26,6 +26,9 @@ def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
The first operand involved in the addition
t2: DCSR_matrix
The second operand involved in the addition
orientation: str, optional
The orientation of the operation. Options: 'row' or 'col'
Default: 'row'

Examples
--------
Expand All @@ -43,16 +46,16 @@ def add(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
DNDarray([[2., 0., 4.],
[0., 0., 6.]], dtype=ht.float32, device=cpu:0, split=0)
"""
return _operations.__binary_op_csr(torch.add, t1, t2)
return _operations.__binary_op_csx(torch.add, t1, t2, orientation=orientation)


DCSR_matrix.__add__ = lambda self, other: add(self, other)
DCSR_matrix.__add__ = lambda self, other: add(self, other, orientation="row")
DCSR_matrix.__add__.__doc__ = add.__doc__
DCSR_matrix.__radd__ = lambda self, other: add(self, other)
DCSR_matrix.__radd__ = lambda self, other: add(self, other, orientation="row")
DCSR_matrix.__radd__.__doc__ = add.__doc__


def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
def mul(t1: DCSR_matrix, t2: DCSR_matrix, orientation: str = "row") -> DCSR_matrix:
"""
Element-wise multiplication (NOT matrix multiplication) of values from two operands, commutative.
Takes the first and second operand (scalar or :class:`~heat.sparse.DCSR_matrix`) whose elements are to be
Expand All @@ -64,6 +67,9 @@ def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
The first operand involved in the multiplication
t2: DCSR_matrix
The second operand involved in the multiplication
orientation: str, optional
The orientation of the operation. Options: 'row' or 'col'
Default: 'row'

Examples
--------
Expand All @@ -81,10 +87,10 @@ def mul(t1: DCSR_matrix, t2: DCSR_matrix) -> DCSR_matrix:
DNDarray([[1., 0., 4.],
[0., 0., 9.]], dtype=ht.float32, device=cpu:0, split=0)
"""
return _operations.__binary_op_csr(torch.mul, t1, t2)
return _operations.__binary_op_csx(torch.mul, t1, t2, orientation=orientation)


DCSR_matrix.__mul__ = lambda self, other: mul(self, other)
DCSR_matrix.__mul__ = lambda self, other: mul(self, other, orientation="row")
DCSR_matrix.__mul__.__doc__ = mul.__doc__
DCSR_matrix.__rmul__ = lambda self, other: mul(self, other)
DCSR_matrix.__rmul__ = lambda self, other: mul(self, other, orientation="row")
DCSR_matrix.__rmul__.__doc__ = mul.__doc__
Loading
Loading