Skip to content

Commit

Permalink
Fix issue #224 / Improve handling of unlabled cells
Browse files Browse the repository at this point in the history
  • Loading branch information
moinfar committed Apr 5, 2024
1 parent 51a0294 commit 2f057a9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 3 additions & 0 deletions scarches/models/scpoli/scpoli.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def add_new_cell_type(self, latent, cell_type_name, prototypes, classes_list=Non
self.cell_type_encoder = {
k: v for k, v in zip(self.cell_types, range(len(self.cell_types)))
}
if self.unknown_ct_names is not None:
for unknown_ct in self.unknown_ct_names:
self.cell_type_encoder[unknown_ct] = -1

# Add new celltype index to hierarchy index list of prototypes
classes_list = torch.cat(
Expand Down
3 changes: 1 addition & 2 deletions scarches/models/scpoli/scpoli_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.unknown_ct_names_ = unknown_ct_names

if labeled_indices is None:
self.labeled_indices_ = range(len(adata))
self.labeled_indices_ = np.argwhere(adata.obs[self.cell_type_keys_].isin(self.unknown_ct_names_).to_numpy().astype(int).min(axis=1) == 0).T[0]
else:
self.labeled_indices_ = labeled_indices

Expand Down Expand Up @@ -180,7 +180,6 @@ def __init__(
if unknown_ct in self.cell_types_:
del self.cell_types_[unknown_ct]


# store model parameters
if hidden_layer_sizes is None:
self.hidden_layer_sizes_ = [int(np.ceil(np.sqrt(adata.shape[1])))]
Expand Down

0 comments on commit 2f057a9

Please sign in to comment.