Skip to content

Commit

Permalink
New algorithm for symmetrizing LD matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
shz9 committed May 21, 2024
1 parent 2e8013c commit 987770b
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 51 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.1.3] - 2024-05-15
## [0.1.3] - 2024-05-21

### Changed

Expand All @@ -22,6 +22,10 @@ name (before it was returning true for inliers...).

- Added `get_peak_memory_usage` to `system_utils` to inspect peak memory usage of a process.
- Placeholder method to perform QC on `SumstatsTable` objects (needs to be implemented still).
- New algorithm for symmetrizing upper triangular and block diagonal LD matrices.
- Much faster and more memory efficient than using `scipy`.
- New `LDMatrix` class has efficient data loading in `.load_data` method.
- We still retain `load_rows` because it is useful for loading a subset of rows.

## [0.1.2] - 2024-04-24

Expand Down
115 changes: 70 additions & 45 deletions magenpy/LDMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,6 @@ def set_mask(self, mask):
if self.in_memory:
self.load(force_reload=True,
return_symmetric=self.is_symmetric,
fill_diag=self.is_symmetric,
dtype=self.dtype)

def reset_mask(self):
Expand All @@ -858,7 +857,6 @@ def reset_mask(self):
if self.in_memory:
self.load(force_reload=True,
return_symmetric=self.is_symmetric,
fill_diag=self.is_symmetric,
dtype=self.dtype)

def to_snp_table(self, col_subset=None):
Expand Down Expand Up @@ -1103,35 +1101,27 @@ def update_rows_inplace(self, new_csr, start_row=None, end_row=None):
else:
self._zg['matrix/data'][data_start:data_end] = new_csr.data.astype(self.stored_dtype)

def low_memory_load(self, dtype=None):
def load_data(self,
return_symmetric=False,
dtype=None,
return_as_csr=False):
"""
A utility method to load the LD matrix in low-memory mode.
The method will load the entries of the upper triangular portion of the matrix,
perform filtering based on the mask (if set), and return the filtered data
and index pointer (`indptr`) arrays.
This is useful for some application, such as the `low_memory` version of
the `viprs` method, because it avoids reconstructing the `indices` array for the CSR matrix,
which can potentially be a very long array of large integers.
A utility function to load and process the LD matrix data.
This function is particularly useful for filtering, symmetrizing, and dequantizing the LD matrix
after it's loaded into memory.
!!! note
The method, by construction, does not support loading the full symmetric matrix. If
that's the goal, use the `.load()` or `.load_rows()` methods.
!!! seealso "See Also"
* [load_rows][magenpy.LDMatrix.LDMatrix.load_rows]
* [load][magenpy.LDMatrix.LDMatrix.load]
TODO: Support loading subset of rows of the matrix, similar to `load_rows` (i.e. replace `load_rows`
with newer implementation).
:param return_symmetric: If True, return a full symmetric representation of the LD matrix.
:param dtype: The data type for the entries of the LD matrix.
:param return_as_csr: If True, return the data in the CSR format.
:return: A tuple of the data and index pointer arrays for the LD matrix.
:return: A tuple of the index pointer array, the data array, and the leftmost index array. If
`return_as_csr` is set to True, we return the data as a CSR matrix.
"""

# Determine the final data type for the LD matrix entries
# and whether we need to perform dequantization or not depending on
# the stored data type and the requested data type.

# -------------- Step 1: Preparing input data type --------------
if dtype is None:
dtype = self.stored_dtype
dequantize_data = False
Expand All @@ -1142,9 +1132,13 @@ def low_memory_load(self, dtype=None):
else:
dequantize_data = False

# -------------- Step 2: Fetch the indptr array --------------

# Get the index pointer array:
indptr = self._zg['matrix/indptr'][:]

# -------------- Step 3: Checking data masking --------------

# Filter the index pointer array based on the mask:
if self._mask is not None:

Expand All @@ -1154,21 +1148,61 @@ def low_memory_load(self, dtype=None):
else:
mask = self._mask

from .stats.ld.c_utils import filter_ut_csr_matrix_low_memory
from .stats.ld.c_utils import filter_ut_csr_matrix

data_mask, indptr = filter_ut_csr_matrix_low_memory(indptr, mask)
# .oindex and .vindex are slow and likely convert to integer indices in the background,
data_mask, indptr = filter_ut_csr_matrix(indptr, mask)
# .oindex and .vindex are slow and convert to integer indices in the background,
# which unnecessarily increases memory usage. Unfortunately, here we have to load the entire
# data and index it using the boolean array afterward.
# Something to be improved in the future...
# Something to be improved in the future. :)
data = self._zg['matrix/data'][:][data_mask]
del data_mask
else:
data = self._zg['matrix/data'][:]

# -------------- Step 4: Symmetrizing input matrix --------------

if return_symmetric:

from .stats.ld.c_utils import symmetrize_ut_csr_matrix

if np.issubdtype(self.stored_dtype, np.integer):
fill_val = np.iinfo(self.stored_dtype).max
else:
fill_val = 1.

data, indptr, leftmost_idx = symmetrize_ut_csr_matrix(indptr, data, fill_val)
else:
leftmost_idx = np.arange(1, indptr.shape[0], dtype=np.int32)

# -------------- Step 5: Dequantizing/type cast requested data --------------

if dequantize_data:
return dequantize(data, float_dtype=dtype), indptr
data = dequantize(data, float_dtype=dtype)
else:
data = data.astype(dtype, copy=False)

# ---------------------------------------------------------------------------
# Return the requested data:

if return_as_csr:

from .stats.ld.c_utils import expand_ranges

indices = expand_ranges(leftmost_idx,
(np.diff(indptr) + leftmost_idx).astype(np.int32),
data.shape[0])

return csr_matrix(
(
data,
indices,
indptr
),
dtype=dtype
)
else:
return data.astype(dtype), indptr
return data, indptr, leftmost_idx

def load_rows(self,
start_row=None,
Expand All @@ -1182,14 +1216,12 @@ def load_rows(self,
By specifying `start_row` and `end_row`, the user can process or inspect small
blocks of the LD matrix without loading the whole thing into memory.
TODO: Consider using `low_memory_load` internally to avoid reconstructing the `indices` array.
!!! note
This method does not perform any filtering on the stored data.
To access the LD matrix with filtering, use `.load()` or `low_memory_load`.
To access the LD matrix with filtering, use `.load()` or `load_data`.
!!! seealso "See Also"
* [low_memory_load][magenpy.LDMatrix.LDMatrix.low_memory_load]
* [load_data][magenpy.LDMatrix.LDMatrix.load_data]
* [load][magenpy.LDMatrix.LDMatrix.load]
:param start_row: The start row to load to memory
Expand Down Expand Up @@ -1323,20 +1355,18 @@ def load_rows(self,
def load(self,
force_reload=False,
return_symmetric=True,
fill_diag=True,
dtype=None):

"""
Load the LD matrix from on-disk storage in the form of Zarr arrays to memory,
in the form of sparse CSR matrices.
!!! seealso "See Also"
* [low_memory_load][magenpy.LDMatrix.LDMatrix.low_memory_load]
* [load_data][magenpy.LDMatrix.LDMatrix.load_data]
* [load_rows][magenpy.LDMatrix.LDMatrix.load_rows]
:param force_reload: If True, it will reload the data even if it is already in memory.
:param return_symmetric: If True, return a full symmetric representation of the LD matrix.
:param fill_diag: If True, fill the diagonal elements of the LD matrix with ones.
:param dtype: The data type for the entries of the LD matrix.
:return: The LD matrix as a sparse CSR matrix.
Expand Down Expand Up @@ -1364,14 +1394,9 @@ def load(self,
# If we are re-loading the matrix, make sure to release the current one:
self.release()

self._mat = self.load_rows(return_symmetric=return_symmetric,
fill_diag=fill_diag,
self._mat = self.load_data(return_symmetric=return_symmetric,
dtype=dtype)

# If a mask is set, apply it:
if self._mask is not None:
self._mat = self._mat[self._mask, :][:, self._mask]

# Update the flags:
self.in_memory = True
self.is_symmetric = return_symmetric
Expand Down Expand Up @@ -1454,7 +1479,7 @@ def __setstate__(self, state):
self.set_mask(mask)

if in_mem:
self.load(return_symmetric=is_symmetric, fill_diag=is_symmetric, dtype=dtype)
self.load(return_symmetric=is_symmetric, dtype=dtype)

def __len__(self):
return self.n_snps
Expand All @@ -1467,7 +1492,7 @@ def __iter__(self):
TODO: Add a flag to allow for chunked iterator, with limited memory footprint.
"""
self.index = 0
self.load(return_symmetric=self.is_symmetric, fill_diag=self.is_symmetric)
self.load(return_symmetric=self.is_symmetric)
return self

def __next__(self):
Expand Down
2 changes: 1 addition & 1 deletion magenpy/plot/ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def plot_ld_matrix(ldm: LDMatrix,
row_subset = np.arange(ldm.shape[0])

# TODO: Figure out a way to do this without loading the entire matrix:
ldm.load(return_symmetric=True, fill_diag=True, dtype='float32')
ldm.load(return_symmetric=True, dtype='float32')

mat = ldm.csr_matrix[row_subset, :][:, row_subset].toarray()

Expand Down
Loading

0 comments on commit 987770b

Please sign in to comment.