Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learnable emb and some fixes #6

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions scooby/data/scdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,15 @@ def __init__(
self,
adatas: dict,
neighbors: scipy.sparse.csr_matrix,
embedding: pd.DataFrame,
ds: GenomeIntervalDataset,
clip_soft,
embedding: pd.DataFrame = None,
cell_sample_size: int = 32,
get_targets: bool = True,
random_cells: bool = True,
cells_to_run: Optional[np.ndarray] = None,
cell_weights: Optional[np.ndarray] = None,
learnable_cell_embs: bool = False,
normalize_atac: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -383,6 +384,7 @@ def __init__(
self.cells_to_run = cells_to_run
self.embedding = embedding
self.get_targets = get_targets
self.learnable_cell_embs = learnable_cell_embs
self.random_cells = random_cells
if not self.random_cells and not cells_to_run:
self.cells_to_run = np.zeros(1, dtype=np.int64)
Expand Down Expand Up @@ -520,8 +522,10 @@ def __getitem__(self, idx):
idx_gene = idx
seq_coord = self.genome_ds.df[idx_gene]
inputs, _, rc_augs = self.genome_ds[idx_gene]
embeddings = torch.from_numpy(np.vstack(self.embedding.iloc[idx_cells]["embedding"].values))

if not self.learnable_cell_embs:
embeddings = torch.from_numpy(np.vstack(self.embedding.iloc[idx_cells]["embedding"].values))
else:
embeddings = [0]
if self.get_targets:
chrom_size = self.chrom_sizes[seq_coord["column_1"].item()]
chrom_start = chrom_size["offset"]
Expand All @@ -537,8 +541,8 @@ def __getitem__(self, idx):
neighbors_to_load = self._get_neighbors_for_cell(cell_idx)
targets.append(self._load_pseudobulk(neighbors_to_load, genome_data))
targets = torch.vstack(targets)
return inputs, rc_augs, targets.permute(1, 0), embeddings
return inputs, rc_augs, embeddings
return inputs, rc_augs, targets.permute(1, 0), embeddings, idx_cells
return inputs, rc_augs, embeddings, idx_cells


class onTheFlyExonMultiomePseudobulkDataset(Dataset):
Expand Down
59 changes: 26 additions & 33 deletions scooby/modeling/scooby.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
batch_conv = torch.vmap(F.conv1d, chunk_size = 1024)

class Scooby(Borzoi):
def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, disable_cache = False, use_transform_borzoi_emb = False, cachesize = 2, **params):
def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, disable_cache = False, use_transform_borzoi_emb = False, cachesize = 2, num_learnable_cell_embs = None, **params):
"""
Scooby model for predicting single-cell genomic profiles from DNA sequence.

Expand All @@ -27,13 +27,14 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis
use_transform_borzoi_emb: Whether to use an additional transformation layer on Borzoi embeddings (default: False).
cachesize: Size of the sequence embedding cache (default: 2).
"""
super(Scooby, self).__init__(config)
super().__init__(config)
self.cell_emb_dim = cell_emb_dim
self.cachesize = cachesize
self.use_transform_borzoi_emb = use_transform_borzoi_emb
self.n_tracks = n_tracks
self.embedding_dim = embedding_dim
self.disable_cache = disable_cache
self.num_learnable_cell_embs = num_learnable_cell_embs
dropout_modules = [module for module in self.modules() if isinstance(module, torch.nn.Dropout)]
batchnorm_modules = [module for module in self.modules() if isinstance(module, torch.nn.BatchNorm1d)]
[module.eval() for module in dropout_modules] # disable dropout
Expand All @@ -59,38 +60,27 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis
nn.init.zeros_(self.transform_borzoi_emb[-2].weight)
nn.init.zeros_(self.transform_borzoi_emb[-2].bias)
nn.init.zeros_(self.cell_state_to_conv[-1].bias)
self.cell_state_to_conv[-1].is_hf_initialized = True
if self.num_learnable_cell_embs is not None:
self.embedding = nn.Embedding(num_learnable_cell_embs, cell_emb_dim)
self.sequences, self.last_embs = [], []
del self.human_head

def get_lora(self, lora_config, train):
"""
Applies Low-Rank Adaptation (LoRA) to the model.

This function integrates LoRA modules into specified layers of the model, enabling parameter-efficient
fine-tuning. If `train` is True, it sets the LoRA parameters and specific layers in the base model
to be trainable. Otherwise, it freezes all parameters.

Args:
lora_config (LoraConfig, optional): Configuration for LoRA. If None, uses a default configuration.
train (bool): Whether the model is being prepared for training.
"""
if lora_config is None:
lora_config = LoraConfig(
target_modules=r"(?!separable\d+).*conv_layer|.*to_q|.*to_v|transformer\.\d+\.1\.fn\.1|transformer\.\d+\.1\.fn\.4",
)
self = get_peft_model(self, lora_config) # get LoRA model
if train:
for params in self.base_model.cell_state_to_conv.parameters():
params.requires_grad = True
if self.use_transform_borzoi_emb:
for params in self.base_model.transform_borzoi_emb.parameters():
params.requires_grad = True
self.print_trainable_parameters()
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=1.0)
elif isinstance(module, (nn.Linear, nn.Conv1d)):
nn.init.xavier_normal_(module.weight)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()

else:
for params in self.parameters():
params.requires_grad = False


def forward_cell_embs_only(self, cell_emb):
"""
Expand Down Expand Up @@ -141,7 +131,6 @@ def forward_seq_to_emb(self, sequence):
x = self.final_joined_convs(x.permute(0, 2, 1))
if self.use_transform_borzoi_emb:
x = self.transform_borzoi_emb(x)
x = x.float()
if not self.training and not self.disable_cache:
if len(self.sequences) == self.cachesize:
self.sequences, self.last_embs = [], []
Expand Down Expand Up @@ -170,7 +159,6 @@ def forward_convs_on_emb(self, seq_emb, cell_emb_conv_weights, cell_emb_conv_bia
out = F.softplus(out)
return out.permute(0,2,1)


def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_conv_biases, bins_to_predict = None):
"""
Processes DNA sequence, applies cell-state-specific convolutions, and caches results.
Expand All @@ -184,24 +172,27 @@ def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_con
Returns:
Tensor: Predicted profiles.
"""
if self.sequences and not self.training and not self.disable_cache:

if self.sequences and not self.training and not self.disable_cache:
for i,s in enumerate(self.sequences):
if torch.equal(sequence,s):
cell_emb_conv_weights, cell_emb_conv_biases = cell_emb_conv_weights.to(self.last_embs[i].dtype), cell_emb_conv_biases.to(self.last_embs[i].dtype)
if bins_to_predict is not None: # unclear if this if is even needed or if self.last_embs[i][:,:,bins_to_predict] just also works when bins_to_predict is None
out = batch_conv(self.last_embs[i][:,:,bins_to_predict], cell_emb_conv_weights, cell_emb_conv_biases)
else:
out = batch_conv(self.last_embs[i], cell_emb_conv_weights, cell_emb_conv_biases)
out = F.softplus(out)
return out.permute(0,2,1)
x = self.forward_seq_to_emb(sequence)
cell_emb_conv_weights, cell_emb_conv_biases = cell_emb_conv_weights.to(x.dtype), cell_emb_conv_biases.to(x.dtype)
if bins_to_predict is not None:
out = batch_conv(x[:,:,bins_to_predict], cell_emb_conv_weights, cell_emb_conv_biases)
else:
out = batch_conv(x, cell_emb_conv_weights, cell_emb_conv_biases)
out = F.softplus(out)
return out.permute(0,2,1)

def forward(self, sequence, cell_emb):
def forward(self, sequence, cell_emb = None, cell_emb_idx = None):
"""
Forward pass of the scooby model.

Expand All @@ -212,6 +203,8 @@ def forward(self, sequence, cell_emb):
Returns:
Tensor: Predicted profiles for each cell (batch_size, num_cells, seq_len, n_tracks).
"""
if self.num_learnable_cell_embs is not None:
cell_emb = self.embedding(cell_emb_idx)
cell_emb_conv_weights,cell_emb_conv_biases = self.forward_cell_embs_only(cell_emb)
out = self.forward_sequence_w_convs(sequence, cell_emb_conv_weights, cell_emb_conv_biases)
return out
43 changes: 39 additions & 4 deletions scooby/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from matplotlib import pyplot as plt
import anndata as ad
from anndata.experimental import read_elem, sparse_dataset
from peft import get_peft_model, LoraConfig


def poisson_multinomial_torch(
Expand Down Expand Up @@ -230,9 +231,9 @@ def evaluate(accelerator, csb, val_loader):
csb.eval()
output_list, target_list, pearsons_per_track = [], [], []

stop_idx = 2
stop_idx = 1

for i, [inputs, rc_augs, targets, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)):
for i, [inputs, rc_augs, targets, cell_emb, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)):
if i < (stop_idx):
continue
if i == (stop_idx + 1):
Expand All @@ -241,7 +242,7 @@ def evaluate(accelerator, csb, val_loader):
target_list.append(targets.to(device, non_blocking=True))
with torch.no_grad():
with torch.autocast("cuda"):
output_list.append(csb(inputs, cell_emb_idx).detach())
output_list.append(csb(inputs, cell_emb = cell_emb, cell_emb_idx = cell_emb_idx).detach())
break
targets = torch.vstack(target_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True)
outputs = torch.vstack(output_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True)
Expand Down Expand Up @@ -728,7 +729,7 @@ def add_weight_decay(model, lr, weight_decay=1e-5, skip_list=()):
continue
if len(param.shape) == 1 or name in skip_list:
no_decay.append(param)
elif "cell_state_to_conv" in name:
elif "cell_state_to_conv" in name or "embedding" in name:
high_lr.append(param)
#accelerator.print ("setting to highlr", name)
else:
Expand All @@ -737,6 +738,40 @@ def add_weight_decay(model, lr, weight_decay=1e-5, skip_list=()):



def get_lora(model, lora_config = None, train = False):
"""
Applies Low-Rank Adaptation (LoRA) to the model.

This function integrates LoRA modules into specified layers of the model, enabling parameter-efficient
fine-tuning. If `train` is True, it sets the LoRA parameters and specific layers in the base model
to be trainable. Otherwise, it freezes all parameters.

Args:
lora_config (LoraConfig, optional): Configuration for LoRA. If None, uses a default configuration.
train (bool): Whether the model is being prepared for training.
"""
if lora_config is None:
lora_config = LoraConfig(
target_modules=r"(?!separable\d+).*conv_layer|.*to_q|.*to_v|transformer\.\d+\.1\.fn\.1|transformer\.\d+\.1\.fn\.4",
)
model = get_peft_model(model, lora_config) # get LoRA model
if train:
for params in model.base_model.cell_state_to_conv.parameters():
params.requires_grad = True
if model.use_transform_borzoi_emb:
for params in model.base_model.transform_borzoi_emb.parameters():
params.requires_grad = True
if model.num_learnable_cell_embs is not None:
for params in model.base_model.embedding.parameters():
params.requires_grad = True
model.print_trainable_parameters()

else:
for params in model.parameters():
params.requires_grad = False
return model


import matplotlib as mpl
from matplotlib.text import TextPath
from matplotlib.patches import PathPatch, Rectangle
Expand Down
15 changes: 15 additions & 0 deletions scripts/train_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
downcast_bf16: 'yes'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
use_cpu: false
40 changes: 20 additions & 20 deletions scripts/train_multiome.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from enformer_pytorch.data import GenomeIntervalDataset

from scooby.modeling import Scooby
from scooby.utils.utils import poisson_multinomial_torch, evaluate, fix_rev_comp_multiome, read_backed, add_weight_decay
from scooby.utils.utils import poisson_multinomial_torch, evaluate, fix_rev_comp_multiome, read_backed, add_weight_decay, get_lora
from scooby.data import onTheFlyMultiomeDataset
import scanpy as sc
import h5py
Expand All @@ -39,8 +39,8 @@ def train(config):
cell_emb_dim = config["model"]["cell_emb_dim"]
num_tracks = config["model"]["num_tracks"]
batch_size = config["training"]["batch_size"]
lr = config["training"]["lr"]
wd = config["training"]["wd"]
lr = float(config["training"]["lr"])
wd = float(config["training"]["wd"])
clip_global_norm = config["training"]["clip_global_norm"]
warmup_steps = config["training"]["warmup_steps"] * local_world_size
num_epochs = config["training"]["num_epochs"] * local_world_size
Expand All @@ -56,14 +56,14 @@ def train(config):

# Load data
adatas = {
"rna_plus": read_backed(h5py.File(os.path.join(data_path, "scooby_training_data/snapatac_merged_plus.h5ad")), "fragment_single"),
"rna_minus": read_backed(h5py.File(os.path.join(data_path, "scooby_training_data/snapatac_merged_minus.h5ad")), "fragment_single"),
"atac": sc.read(os.path.join(data_path, "scooby_training_data/snapatac_merged_atac.h5ad")),
"rna_plus": read_backed(h5py.File(os.path.join(data_path, "snapatac_merged_fixed_plus.h5ad")), "fragment_single"),
"rna_minus": read_backed(h5py.File(os.path.join(data_path, "snapatac_merged_fixed_minus.h5ad")), "fragment_single"),
"atac": sc.read(os.path.join(data_path, "snapatac_merged_fixed_atac.h5ad")),
}

neighbors = scipy.sparse.load_npz(f"{data_path}scooby_training_data/no_neighbors.npz")
embedding = pd.read_parquet(f"{data_path}scooby_training_data/embedding_no_val_genes_new.pq")
cell_weights = np.load(f"{data_path}scooby_training_data/cell_weights_no_normoblast.npy")
neighbors = scipy.sparse.load_npz(f"/s/project/QNA/scborzoi/neurips_bone_marrow/borzoi_training_data_fixed/no_neighbors.npz")
# embedding = pd.read_parquet(f"{data_path}scooby_training_data/embedding_no_val_genes_new.pq")
# cell_weights = np.load(f"{data_path}scooby_training_data/cell_weights_no_normoblast.npy")

# Calculate training steps
num_steps = (45_000 * num_epochs) // (batch_size)
Expand All @@ -77,14 +77,15 @@ def train(config):
n_tracks=num_tracks,
return_center_bins_only=True,
disable_cache=True,
use_transform_borzoi_emb=True,
use_transform_borzoi_emb=False,
num_learnable_cell_embs = adatas['rna_plus'].shape[0]
)
scooby.get_lora(train=True)
parameters = add_weight_decay(scooby, lr = lr, weight_decay=wd)
scooby = get_lora(scooby, train=True)
parameters = add_weight_decay(scooby, lr = lr, weight_decay = wd)
optimizer = torch.optim.AdamW(parameters)

warmup_scheduler = LinearLR(optimizer, start_factor=0.0000001, total_iters=warmup_steps, verbose=False)
train_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00, total_iters=num_steps - warmup_steps, verbose=False)
warmup_scheduler = LinearLR(optimizer, start_factor=0.0001, total_iters=warmup_steps)
train_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00, total_iters=num_steps - warmup_steps)
scheduler = SequentialLR(optimizer, [warmup_scheduler, train_scheduler], [warmup_steps])

# Create datasets and dataloaders
Expand Down Expand Up @@ -119,24 +120,23 @@ def train(config):
otf_dataset = onTheFlyMultiomeDataset(
adatas=adatas,
neighbors=neighbors,
embedding=embedding,
ds=ds,
cell_sample_size=64,
cell_weights=None,
normalize_atac=True,
clip_soft=5,
learnable_cell_embs = True,
)
val_dataset = onTheFlyMultiomeDataset(
adatas=adatas,
neighbors=neighbors,
embedding=embedding,
ds=val_ds,
cell_sample_size=32,
cell_weights=None,
normalize_atac=True,
clip_soft=5,
learnable_cell_embs = True,
)

training_loader = DataLoader(otf_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

Expand All @@ -152,7 +152,7 @@ def train(config):

# Training loop
for epoch in range(40):
for i, [inputs, rc_augs, targets, cell_emb_idx] in tqdm.tqdm(enumerate(training_loader)):
for i, [inputs, rc_augs, targets, _, cell_emb_idx] in tqdm.tqdm(enumerate(training_loader)):
inputs = inputs.permute(0, 2, 1).to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
for rc_aug_idx in rc_augs.nonzero():
Expand All @@ -161,7 +161,7 @@ def train(config):
targets[rc_aug_idx] = fix_rev_comp_multiome(flipped_version)[0]
optimizer.zero_grad()
with torch.autocast("cuda"):
outputs = scooby(inputs, cell_emb_idx)
outputs = scooby(inputs, cell_emb_idx = cell_emb_idx)
loss = loss_fn(outputs, targets, total_weight=total_weight)
accelerator.log({"loss": loss})
accelerator.backward(loss)
Expand All @@ -183,4 +183,4 @@ def train(config):
config = yaml.safe_load(f)

# Train the model
train(config)
train(config)
Loading