diff --git a/neusomatic/python/call.py b/neusomatic/python/call.py index be1b40a..81d1deb 100755 --- a/neusomatic/python/call.py +++ b/neusomatic/python/call.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # call.py # Call variants using model trained by NeuSomatic network -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import traceback @@ -29,15 +29,19 @@ from defaults import VARTYPE_CLASSES, NUM_ENS_FEATURES, NUM_ST_FEATURES, MAT_DTYPES import torch._utils + try: torch._utils._rebuild_tensor_v2 except AttributeError: - def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): - tensor = torch._utils._rebuild_tensor( - storage, storage_offset, size, stride) + + def _rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks + ): + tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad tensor._backward_hooks = backward_hooks return tensor + torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 @@ -69,13 +73,16 @@ def call_variants(net, call_loader, out_dir, model_tag, run_i, matrix_dtype, use max_norm = 65535.0 else: logger.info( - "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES)) + "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES) + ) iii = 0 j = 0 for data in loader_: - (matrices, labels, var_pos_s, var_len_s, - non_transformed_matrices), (paths) = data + ( + (matrices, labels, var_pos_s, var_len_s, non_transformed_matrices), + (paths), + ) = data paths_ = copy.deepcopy(paths) del paths @@ -100,27 +107,47 @@ def call_variants(net, call_loader, out_dir, model_tag, run_i, matrix_dtype, use path = path_.split("/")[-1] preds[i] = [VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i]] if VARTYPE_CLASSES[predicted[i]] != "NONE": - final_preds[path] = [VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i], - list(map(lambda x: round(x, 4), F.softmax( - outputs1[i, :], 0).data.cpu().numpy())), - list(map(lambda x: round(x, 4), F.softmax( - outputs3[i, :], 0).data.cpu().numpy())), - list(map(lambda x: round(x, 4), - outputs1.data.cpu()[i].numpy())), - list(map(lambda x: round(x, 4), - outputs3.data.cpu()[i].numpy())), - np.array(non_transformed_matrices[i, :, :, 0:3]) / max_norm] + final_preds[path] = [ + VARTYPE_CLASSES[predicted[i]], + pos_pred[i], + len_pred[i], + list( + map( + lambda x: round(x, 4), + F.softmax(outputs1[i, :], 0).data.cpu().numpy(), + ) + ), + list( + map( + lambda x: round(x, 4), + F.softmax(outputs3[i, :], 0).data.cpu().numpy(), + ) + ), + list(map(lambda x: round(x, 4), outputs1.data.cpu()[i].numpy())), + list(map(lambda x: round(x, 4), outputs3.data.cpu()[i].numpy())), + np.array(non_transformed_matrices[i, :, :, 0:3]) / max_norm, + ] else: - none_preds[path] = [VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i], - list(map(lambda x: round(x, 4), F.softmax( - outputs1[i, :], 0).data.cpu().numpy())), - list(map(lambda x: round(x, 4), F.softmax( - outputs3[i, :], 0).data.cpu().numpy())), - list(map(lambda x: round(x, 4), - outputs1.data.cpu()[i].numpy())), - list(map(lambda x: round(x, 4), - outputs3.data.cpu()[i].numpy()))] - if (iii % 10 == 0): + none_preds[path] = [ + VARTYPE_CLASSES[predicted[i]], + pos_pred[i], + len_pred[i], + list( + map( + lambda x: round(x, 4), + F.softmax(outputs1[i, :], 0).data.cpu().numpy(), + ) + ), + list( + map( + lambda x: round(x, 4), + F.softmax(outputs3[i, :], 0).data.cpu().numpy(), + ) + ), + list(map(lambda x: round(x, 4), outputs1.data.cpu()[i].numpy())), + list(map(lambda x: round(x, 4), outputs3.data.cpu()[i].numpy())), + ] + if iii % 10 == 0: logger.info("Called {} candidates in this batch.".format(j)) logger.info("Called {} candidates in this batch.".format(j)) return final_preds, none_preds @@ -129,24 +156,28 @@ def call_variants(net, call_loader, out_dir, model_tag, run_i, matrix_dtype, use def pred_vcf_records_path(record): path, pred_all, chroms, ref_file = record thread_logger = logging.getLogger( - "{} ({})".format(pred_vcf_records_path.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + pred_vcf_records_path.__name__, multiprocessing.current_process().name + ) + ) try: fasta_file = pysam.FastaFile(ref_file) ACGT = "ACGT" I = pred_all[-1] vcf_record = [] Ih, Iw, _ = I.shape - zref_pos = np.where((np.argmax(I[:, :, 0], 0) == 0) & ( - sum(I[:, :, 0], 0) > 0))[0] - nzref_pos = np.where( - (np.argmax(I[:, :, 0], 0) > 0) & (sum(I[:, :, 0], 0) > 0))[0] + zref_pos = np.where((np.argmax(I[:, :, 0], 0) == 0) & (sum(I[:, :, 0], 0) > 0))[ + 0 + ] + nzref_pos = np.where((np.argmax(I[:, :, 0], 0) > 0) & (sum(I[:, :, 0], 0) > 0))[ + 0 + ] # zref_pos_0 = np.where((I[0, :, 0] > 0) & (sum(I[:, :, 0], 0) > 0))[0] # nzref_pos_0 = np.where((I[0, :, 0] == 0) & (sum(I[:, :, 0], 0) > 0))[0] # assert(len(set(zref_pos_0)^set(zref_pos))==0) # assert(len(set(nzref_pos_0)^set(nzref_pos))==0) - chrom, pos, ref, alt, _, center, _, _, _ = path.split( - ".") + chrom, pos, ref, alt, _, center, _, _, _ = path.split(".") ref, alt = ref.upper(), alt.upper() center = int(center) pos = int(pos) @@ -175,8 +206,7 @@ def pred_vcf_records_path(record): if center in nzref_pos: center_ = center elif len(nzref_pos) > 0: - center_ = nzref_pos[ - np.argmin(abs(nzref_pos - center_pred))] + center_ = nzref_pos[np.argmin(abs(nzref_pos - center_pred))] else: break elif type_pred == "INS" and center_ in nzref_pos: @@ -189,8 +219,7 @@ def pred_vcf_records_path(record): if center in zref_pos: center_ = center else: - center_ = zref_pos[ - np.argmin(abs(zref_pos - center_pred))] + center_ = zref_pos[np.argmin(abs(zref_pos - center_pred))] if abs(center_ - center) > too_far_center: pred[3][amx_prob] = 0 # thread_logger.warning("Too far center: path:{}, pred:{}".format(path, pred)) @@ -218,23 +247,36 @@ def pred_vcf_records_path(record): col_2_pos[0] = -1 nzref_pos = np.array([0] + list(nzref_pos)) if anchor[1] not in col_2_pos: - if I[0, anchor[1], 0] > 0 and vartype_candidate == "INS" and type_pred == "INS": + if ( + I[0, anchor[1], 0] > 0 + and vartype_candidate == "INS" + and type_pred == "INS" + ): ins_no_zref_pos = True else: # thread_logger.info(["NNN", path, pred]) return vcf_record if not ins_no_zref_pos: - b = (anchor[0] - col_2_pos[anchor[1]]) + b = anchor[0] - col_2_pos[anchor[1]] for i in nzref_pos: col_2_pos[i] += b pos_2_col = {v: k for k, v in col_2_pos.items()} - if type_pred == "SNP" and len(ref) - len(alt) > 1 and abs(center_pred - center) < center_dist_roundback: + if ( + type_pred == "SNP" + and len(ref) - len(alt) > 1 + and abs(center_pred - center) < center_dist_roundback + ): thread_logger.info(["TBC", path, nzref_pos]) if abs(center_pred - center) < too_far_center: if type_pred == "SNP": - if abs(center_pred - center) < center_dist_roundback and len_pred == 1 and len(ref) == 1 and len(alt) == 1: + if ( + abs(center_pred - center) < center_dist_roundback + and len_pred == 1 + and len(ref) == 1 + and len(alt) == 1 + ): pos_, ref_, alt_ = pos, ref.upper(), alt.upper() else: pos_ = col_2_pos[center_] @@ -249,7 +291,11 @@ def pred_vcf_records_path(record): II = I.copy() II[rb + 1, center__, 1] = 0 if max(II[1:, center__, 1]) == 0: - if abs(center_pred - center) < center_dist_roundback * 3 and len_pred == 1: + if ( + abs(center_pred - center) + < center_dist_roundback * 3 + and len_pred == 1 + ): pos_, ref_, alt_ = pos, ref.upper(), alt.upper() break else: @@ -279,7 +325,7 @@ def pred_vcf_records_path(record): len_pred_ = len_pred if len_pred == 3: len_pred = max(len(alt) - len(ref), len_pred) - if (sum(I[1:, i_, 1]) == 0): + if sum(I[1:, i_, 1]) == 0: # thread_logger.info(["PPP-2", path, pred]) return vcf_record if len_pred == len(alt) - len(ref) and pos_ == pos: @@ -294,7 +340,11 @@ def pred_vcf_records_path(record): break if (len(alt_) - len(ref_)) >= len_pred: break - if len_pred_ == 3 and (len(alt_) - len(ref_)) < len_pred and pos_ == pos: + if ( + len_pred_ == 3 + and (len(alt_) - len(ref_)) < len_pred + and pos_ == pos + ): pos_, ref_, alt_ = pos, ref.upper(), alt.upper() elif type_pred == "DEL": pos_ = col_2_pos[center_] - 1 @@ -313,8 +363,10 @@ def pred_vcf_records_path(record): if (len(ref_) - len(alt_)) < len_pred: pos_, ref_, alt_ = pos, ref.upper(), alt.upper() chrom_ = chroms[int(chrom)] - if fasta_file.fetch(chrom_, pos_ - 1, pos_ + - len(ref_) - 1).upper() != ref_.upper(): + if ( + fasta_file.fetch(chrom_, pos_ - 1, pos_ + len(ref_) - 1).upper() + != ref_.upper() + ): # print "AAAA" return vcf_record if ref_ == alt_: @@ -330,8 +382,7 @@ def pred_vcf_records_path(record): prob = pred[3][1] else: prob = pred[3][0] * (1 - pred[4][0]) - vcf_record = [path, [chrom_, pos_, - ref_, alt_, prob, [path, pred]]] + vcf_record = [path, [chrom_, pos_, ref_, alt_, prob, [path, pred]]] else: return vcf_record return vcf_record @@ -343,12 +394,10 @@ def pred_vcf_records_path(record): def pred_vcf_records(ref_file, final_preds, chroms, num_threads): logger = logging.getLogger(pred_vcf_records.__name__) - logger.info( - "Prepare VCF records for predicted somatic variants in this batch.") + logger.info("Prepare VCF records for predicted somatic variants in this batch.") map_args = [] for path in final_preds.keys(): - map_args.append([path, final_preds[path], - chroms, ref_file]) + map_args.append([path, final_preds[path], chroms, ref_file]) if num_threads == 1: all_vcf_records = [] @@ -357,8 +406,7 @@ def pred_vcf_records(ref_file, final_preds, chroms, num_threads): else: pool = multiprocessing.Pool(num_threads) try: - all_vcf_records = pool.map_async( - pred_vcf_records_path, map_args).get() + all_vcf_records = pool.map_async(pred_vcf_records_path, map_args).get() pool.close() except Exception as inst: logger.error(inst) @@ -377,13 +425,11 @@ def pred_vcf_records(ref_file, final_preds, chroms, num_threads): def pred_vcf_records_none(none_preds, chroms): logger = logging.getLogger(pred_vcf_records_none.__name__) - logger.info( - "Prepare VCF records for predicted non-somatic variants in this batch.") + logger.info("Prepare VCF records for predicted non-somatic variants in this batch.") all_vcf_records = {} for path in none_preds.keys(): pred = none_preds[path] - chrom, pos, ref, alt, _, _, _, _, _ = path.split( - ".") + chrom, pos, ref, alt, _, _, _, _, _ = path.split(".") if len(ref) == len(alt): prob = pred[3][3] * (pred[4][1]) elif len(ref) < len(alt): @@ -392,7 +438,13 @@ def pred_vcf_records_none(none_preds, chroms): prob = pred[3][0] * (1 - pred[4][0]) if prob > 0.005: all_vcf_records[path] = [ - chroms[int(chrom)], pos, ref, alt, prob, [path, pred]] + chroms[int(chrom)], + pos, + ref, + alt, + prob, + [path, pred], + ] return all_vcf_records.items() @@ -420,11 +472,23 @@ def write_vcf(vcf_records, output_vcf, chroms_order, pass_threshold, lowqual_thr filter_ = "PASS" elif prob >= lowqual_threshold: filter_ = "LowQual" - line = "\t".join([chrom_, str(pos_), ".", ref_, alt_, - "{:.4f}".format(np.round(prob2phred(prob), 4)), - filter_, "SCORE={:.4f}".format( - np.round(prob, 4)), - "GT", "0/1"]) + "\n" + line = ( + "\t".join( + [ + chrom_, + str(pos_), + ".", + ref_, + alt_, + "{:.4f}".format(np.round(prob2phred(prob), 4)), + filter_, + "SCORE={:.4f}".format(np.round(prob, 4)), + "GT", + "0/1", + ] + ) + + "\n" + ) curr_pos = "-".join([chrom_, str(pos_)]) emit = False if old_pos != curr_pos: @@ -468,37 +532,61 @@ def write_merged_vcf(output_vcfs, output_vcf, chroms_order): def single_thread_call(record): thread_logger = logging.getLogger( - "{} ({})".format(single_thread_call.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + single_thread_call.__name__, multiprocessing.current_process().name + ) + ) try: torch.set_num_threads(1) - net, candidate_files, max_load_candidates, data_transform, \ - coverage_thr, max_cov, normalize_channels, zero_ann_cols, batch_size, \ - out_dir, model_tag, ref_file, chroms, tmp_preds_dir, chroms_order, \ - pass_threshold, lowqual_threshold, matrix_dtype, i = record - - call_set = NeuSomaticDataset(roots=candidate_files, - max_load_candidates=max_load_candidates, - transform=data_transform, is_test=True, - num_threads=1, - coverage_thr=coverage_thr, - max_cov=max_cov, - normalize_channels=normalize_channels, - zero_ann_cols=zero_ann_cols, - matrix_dtype=matrix_dtype) - call_loader = torch.utils.data.DataLoader(call_set, - batch_size=batch_size, - shuffle=True, # pin_memory=True, - num_workers=0) + ( + net, + candidate_files, + max_load_candidates, + data_transform, + coverage_thr, + max_cov, + normalize_channels, + zero_ann_cols, + batch_size, + out_dir, + model_tag, + ref_file, + chroms, + tmp_preds_dir, + chroms_order, + pass_threshold, + lowqual_threshold, + matrix_dtype, + i, + ) = record + + call_set = NeuSomaticDataset( + roots=candidate_files, + max_load_candidates=max_load_candidates, + transform=data_transform, + is_test=True, + num_threads=1, + coverage_thr=coverage_thr, + max_cov=max_cov, + normalize_channels=normalize_channels, + zero_ann_cols=zero_ann_cols, + matrix_dtype=matrix_dtype, + ) + call_loader = torch.utils.data.DataLoader( + call_set, + batch_size=batch_size, + shuffle=True, # pin_memory=True, + num_workers=0, + ) logger.info("N_dataset: {}".format(len(call_set))) if len(call_set) == 0: - logger.warning( - "Skip {} with 0 candidates".format(candidate_file)) + logger.warning("Skip {} with 0 candidates".format(candidate_file)) return [], [] final_preds_, none_preds_ = call_variants( - net, call_loader, out_dir, model_tag, i, matrix_dtype, use_cuda) - all_vcf_records = pred_vcf_records( - ref_file, final_preds_, chroms, 1) + net, call_loader, out_dir, model_tag, i, matrix_dtype, use_cuda + ) + all_vcf_records = pred_vcf_records(ref_file, final_preds_, chroms, 1) all_vcf_records_none = pred_vcf_records_none(none_preds_, chroms) all_vcf_records = dict(all_vcf_records) @@ -508,13 +596,19 @@ def single_thread_call(record): vcf_records_none = get_vcf_records(all_vcf_records_none) output_vcf = "{}/pred_{}.vcf".format(tmp_preds_dir, i) - write_vcf(var_vcf_records, output_vcf, chroms_order, - pass_threshold, lowqual_threshold) + write_vcf( + var_vcf_records, output_vcf, chroms_order, pass_threshold, lowqual_threshold + ) logger.info("Prepare Non-Somatics VCF") output_vcf_none = "{}/none_{}.vcf".format(tmp_preds_dir, i) - write_vcf(vcf_records_none, output_vcf_none, - chroms_order, pass_threshold, lowqual_threshold) + write_vcf( + vcf_records_none, + output_vcf_none, + chroms_order, + pass_threshold, + lowqual_threshold, + ) return output_vcf, output_vcf_none except Exception as ex: @@ -523,11 +617,20 @@ def single_thread_call(record): return None -def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, - batch_size, max_load_candidates, pass_threshold, lowqual_threshold, - force_zero_ann_cols, - max_cov, - use_cuda): +def call_neusomatic( + candidates_tsv, + ref_file, + out_dir, + checkpoint, + num_threads, + batch_size, + max_load_candidates, + pass_threshold, + lowqual_threshold, + force_zero_ann_cols, + max_cov, + use_cuda, +): logger = logging.getLogger(call_neusomatic.__name__) logger.info("-----------------Call Somatic Mutations--------------------") @@ -544,8 +647,7 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) logger.info("Load pretrained model from checkpoint {}".format(checkpoint)) - pretrained_dict = torch.load( - checkpoint, map_location=lambda storage, loc: storage) + pretrained_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) pretrained_state_dict = pretrained_dict["state_dict"] model_tag = pretrained_dict["tag"] logger.info("tag: {}".format(model_tag)) @@ -573,12 +675,14 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, if force_zero_ann_cols: logger.info( - "Override zero_ann_cols from force_zero_ann_cols: {}".format(force_zero_ann_cols)) + "Override zero_ann_cols from force_zero_ann_cols: {}".format( + force_zero_ann_cols + ) + ) zero_ann_cols = force_zero_ann_cols if max_cov is not None: - logger.info( - "Set max_cov: {}".format(max_cov)) + logger.info("Set max_cov: {}".format(max_cov)) logger.info("coverage_thr: {}".format(coverage_thr)) logger.info("normalize_channels: {}".format(normalize_channels)) @@ -610,10 +714,12 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, break else: raise Exception( - "Wrong number of fields in {}: {}".format(tsv, len(x))) + "Wrong number of fields in {}: {}".format(tsv, len(x)) + ) - num_channels = expected_ens_fields + \ - NUM_ST_FEATURES if ensemble else NUM_ST_FEATURES + num_channels = ( + expected_ens_fields + NUM_ST_FEATURES if ensemble else NUM_ST_FEATURES + ) else: num_channels = 0 for tsv in candidates_tsv: @@ -640,16 +746,28 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, # 1. filter out unnecessary keys # pretrained_state_dict = { # k: v for k, v in pretrained_state_dict.items() if k in model_dict} - if "module." in list(pretrained_state_dict.keys())[0] and "module." not in list(model_dict.keys())[0]: - pretrained_state_dict = {k.split("module.")[1]: v for k, v in pretrained_state_dict.items( - ) if k.split("module.")[1] in model_dict} - elif "module." not in list(pretrained_state_dict.keys())[0] and "module." in list(model_dict.keys())[0]: + if ( + "module." in list(pretrained_state_dict.keys())[0] + and "module." not in list(model_dict.keys())[0] + ): pretrained_state_dict = { - ("module." + k): v for k, v in pretrained_state_dict.items() - if ("module." + k) in model_dict} + k.split("module.")[1]: v + for k, v in pretrained_state_dict.items() + if k.split("module.")[1] in model_dict + } + elif ( + "module." not in list(pretrained_state_dict.keys())[0] + and "module." in list(model_dict.keys())[0] + ): + pretrained_state_dict = { + ("module." + k): v + for k, v in pretrained_state_dict.items() + if ("module." + k) in model_dict + } else: - pretrained_state_dict = {k: v for k, - v in pretrained_state_dict.items() if k in model_dict} + pretrained_state_dict = { + k: v for k, v in pretrained_state_dict.items() if k in model_dict + } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_state_dict) @@ -662,7 +780,8 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, new_split_tsvs_dir = os.path.join(out_dir, "split_tsvs") if os.path.exists(new_split_tsvs_dir): logger.warning( - "Remove split candidates directory: {}".format(new_split_tsvs_dir)) + "Remove split candidates directory: {}".format(new_split_tsvs_dir) + ) shutil.rmtree(new_split_tsvs_dir) os.mkdir(new_split_tsvs_dir) Ls = [] @@ -673,30 +792,30 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, total_L += len(pickle.load(open(candidate_file + ".idx", "rb"))) logger.info("Total number of candidates: {}".format(total_L)) if not use_cuda: - max_load_candidates = min( - max_load_candidates, 3 * total_L // num_threads) + max_load_candidates = min(max_load_candidates, 3 * total_L // num_threads) for candidate_file in candidates_tsv: idx = pickle.load(open(candidate_file + ".idx", "rb")) if len(idx) > max_load_candidates / 2: - logger.info("Splitting {} of lenght {}".format( - candidate_file, len(idx))) + logger.info("Splitting {} of lenght {}".format(candidate_file, len(idx))) new_split_tsvs_dir_i = os.path.join( - new_split_tsvs_dir, "split_{}".format(split_i)) + new_split_tsvs_dir, "split_{}".format(split_i) + ) if os.path.exists(new_split_tsvs_dir_i): - logger.warning("Remove split candidates directory: {}".format( - new_split_tsvs_dir_i)) + logger.warning( + "Remove split candidates directory: {}".format(new_split_tsvs_dir_i) + ) shutil.rmtree(new_split_tsvs_dir_i) os.mkdir(new_split_tsvs_dir_i) - candidate_file_splits = merge_tsvs(input_tsvs=[candidate_file], - out=new_split_tsvs_dir_i, - candidates_per_tsv=max( - 1, max_load_candidates / 2), - max_num_tsvs=100000, - overwrite_merged_tsvs=True, - keep_none_types=True) + candidate_file_splits = merge_tsvs( + input_tsvs=[candidate_file], + out=new_split_tsvs_dir_i, + candidates_per_tsv=max(1, max_load_candidates / 2), + max_num_tsvs=100000, + overwrite_merged_tsvs=True, + keep_none_types=True, + ) for candidate_file_split in candidate_file_splits: - idx_split = pickle.load( - open(candidate_file_split + ".idx", "rb")) + idx_split = pickle.load(open(candidate_file_split + ".idx", "rb")) candidates_tsv_.append(candidate_file_split) Ls.append(len(idx_split) - 1) split_i += 1 @@ -710,41 +829,48 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, all_vcf_records_none = [] if use_cuda: run_i = -1 - for i, (candidate_file, L) in enumerate(sorted(zip(candidates_tsv_, Ls), key=lambda x: x[1])): + for i, (candidate_file, L) in enumerate( + sorted(zip(candidates_tsv_, Ls), key=lambda x: x[1]) + ): current_L += L candidate_files.append(candidate_file) if current_L > max_load_candidates / 10 or i == len(candidates_tsv_) - 1: - logger.info( - "Run for candidate files: {}".format(candidate_files)) - call_set = NeuSomaticDataset(roots=candidate_files, - max_load_candidates=max_load_candidates, - transform=data_transform, is_test=True, - num_threads=num_threads, - coverage_thr=coverage_thr, - max_cov=max_cov, - normalize_channels=normalize_channels, - zero_ann_cols=zero_ann_cols, - matrix_dtype=matrix_dtype) - call_loader = torch.utils.data.DataLoader(call_set, - batch_size=batch_size, - shuffle=True, pin_memory=True, - num_workers=num_threads) + logger.info("Run for candidate files: {}".format(candidate_files)) + call_set = NeuSomaticDataset( + roots=candidate_files, + max_load_candidates=max_load_candidates, + transform=data_transform, + is_test=True, + num_threads=num_threads, + coverage_thr=coverage_thr, + max_cov=max_cov, + normalize_channels=normalize_channels, + zero_ann_cols=zero_ann_cols, + matrix_dtype=matrix_dtype, + ) + call_loader = torch.utils.data.DataLoader( + call_set, + batch_size=batch_size, + shuffle=True, + pin_memory=True, + num_workers=num_threads, + ) current_L = 0 candidate_files = [] run_i += 1 logger.info("N_dataset: {}".format(len(call_set))) if len(call_set) == 0: - logger.warning( - "Skip {} with 0 candidates".format(candidate_file)) + logger.warning("Skip {} with 0 candidates".format(candidate_file)) continue final_preds_, none_preds_ = call_variants( - net, call_loader, out_dir, model_tag, run_i, matrix_dtype, use_cuda) - all_vcf_records.extend(pred_vcf_records( - ref_file, final_preds_, chroms, num_threads)) - all_vcf_records_none.extend( - pred_vcf_records_none(none_preds_, chroms)) + net, call_loader, out_dir, model_tag, run_i, matrix_dtype, use_cuda + ) + all_vcf_records.extend( + pred_vcf_records(ref_file, final_preds_, chroms, num_threads) + ) + all_vcf_records_none.extend(pred_vcf_records_none(none_preds_, chroms)) all_vcf_records = dict(all_vcf_records) all_vcf_records_none = dict(all_vcf_records_none) @@ -752,36 +878,60 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, logger.info("Prepare Output VCF") output_vcf = "{}/pred.vcf".format(out_dir) var_vcf_records = get_vcf_records(all_vcf_records) - write_vcf(var_vcf_records, output_vcf, chroms_order, - pass_threshold, lowqual_threshold) + write_vcf( + var_vcf_records, output_vcf, chroms_order, pass_threshold, lowqual_threshold + ) logger.info("Prepare Non-Somatics VCF") output_vcf_none = "{}/none.vcf".format(out_dir) vcf_records_none = get_vcf_records(all_vcf_records_none) - write_vcf(vcf_records_none, output_vcf_none, - chroms_order, pass_threshold, lowqual_threshold) + write_vcf( + vcf_records_none, + output_vcf_none, + chroms_order, + pass_threshold, + lowqual_threshold, + ) else: tmp_preds_dir = os.path.join(out_dir, "tmp_preds") if os.path.exists(tmp_preds_dir): - logger.warning( - "Remove tmp_preds directory: {}".format(tmp_preds_dir)) + logger.warning("Remove tmp_preds directory: {}".format(tmp_preds_dir)) shutil.rmtree(tmp_preds_dir) os.mkdir(tmp_preds_dir) map_args = [] j = 0 - for i, (candidate_file, L) in enumerate(sorted(zip(candidates_tsv_, Ls), key=lambda x: x[1])): + for i, (candidate_file, L) in enumerate( + sorted(zip(candidates_tsv_, Ls), key=lambda x: x[1]) + ): current_L += L candidate_files.append(candidate_file) if current_L > max_load_candidates / 10 or i == len(candidates_tsv_) - 1: - logger.info( - "Run for candidate files: {}".format(candidate_files)) - - map_args.append([net, candidate_files, max_load_candidates, data_transform, - coverage_thr, max_cov, normalize_channels, zero_ann_cols, batch_size, - out_dir, - model_tag, ref_file, chroms, tmp_preds_dir, chroms_order, - pass_threshold, lowqual_threshold, matrix_dtype, j]) + logger.info("Run for candidate files: {}".format(candidate_files)) + + map_args.append( + [ + net, + candidate_files, + max_load_candidates, + data_transform, + coverage_thr, + max_cov, + normalize_channels, + zero_ann_cols, + batch_size, + out_dir, + model_tag, + ref_file, + chroms, + tmp_preds_dir, + chroms_order, + pass_threshold, + lowqual_threshold, + matrix_dtype, + j, + ] + ) j += 1 current_L = 0 candidate_files = [] @@ -812,56 +962,72 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, write_merged_vcf(output_vcfs_none, output_vcf_none, chroms_order) if os.path.exists(tmp_preds_dir): - logger.warning( - "Remove tmp_preds directory: {}".format(tmp_preds_dir)) + logger.warning("Remove tmp_preds directory: {}".format(tmp_preds_dir)) shutil.rmtree(tmp_preds_dir) if os.path.exists(new_split_tsvs_dir): logger.warning( - "Remove split candidates directory: {}".format(new_split_tsvs_dir)) + "Remove split candidates directory: {}".format(new_split_tsvs_dir) + ) shutil.rmtree(new_split_tsvs_dir) logger.info("Calling is Done.") return output_vcf -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser( - description='simple call variants from bam') - parser.add_argument('--candidates_tsv', nargs="*", - help=' test candidate tsv files', required=True) - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--out', type=str, - help='output directory', required=True) - parser.add_argument('--checkpoint', type=str, - help='network model checkpoint path', required=True) - parser.add_argument('--ensemble', - help='Enable calling for ensemble mode', - action="store_true") - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) - parser.add_argument('--batch_size', type=int, - help='batch size', default=1000) - parser.add_argument('--max_load_candidates', type=int, - help='maximum candidates to load in memory', default=100000) - parser.add_argument('--pass_threshold', type=float, - help='SCORE for PASS (PASS for score => pass_threshold)', default=0.7) - parser.add_argument('--lowqual_threshold', type=float, - help='SCORE for LowQual (PASS for lowqual_threshold <= score < pass_threshold)', - default=0.4) - parser.add_argument('--force_zero_ann_cols', nargs="*", type=int, - help='force columns to be set to zero in the annotations. Higher priority than \ + parser = argparse.ArgumentParser(description="simple call variants from bam") + parser.add_argument( + "--candidates_tsv", nargs="*", help=" test candidate tsv files", required=True + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--out", type=str, help="output directory", required=True) + parser.add_argument( + "--checkpoint", type=str, help="network model checkpoint path", required=True + ) + parser.add_argument( + "--ensemble", help="Enable calling for ensemble mode", action="store_true" + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) + parser.add_argument("--batch_size", type=int, help="batch size", default=1000) + parser.add_argument( + "--max_load_candidates", + type=int, + help="maximum candidates to load in memory", + default=100000, + ) + parser.add_argument( + "--pass_threshold", + type=float, + help="SCORE for PASS (PASS for score => pass_threshold)", + default=0.7, + ) + parser.add_argument( + "--lowqual_threshold", + type=float, + help="SCORE for LowQual (PASS for lowqual_threshold <= score < pass_threshold)", + default=0.4, + ) + parser.add_argument( + "--force_zero_ann_cols", + nargs="*", + type=int, + help="force columns to be set to zero in the annotations. Higher priority than \ --zero_ann_cols and pretrained setting.\ - idx starts from 5th column in candidate.tsv file', - default=[]) - parser.add_argument('--max_cov', type=int, - help='maximum coverage threshold.', default=None) + idx starts from 5th column in candidate.tsv file", + default=[], + ) + parser.add_argument( + "--max_cov", type=int, help="maximum coverage threshold.", default=None + ) args = parser.parse_args() logger.info(args) @@ -870,13 +1036,20 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, logger.info("use_cuda: {}".format(use_cuda)) try: - output_vcf = call_neusomatic(args.candidates_tsv, args.reference, args.out, - args.checkpoint, - args.num_threads, args.batch_size, args.max_load_candidates, - args.pass_threshold, args.lowqual_threshold, - args.force_zero_ann_cols, - args.max_cov, - use_cuda) + output_vcf = call_neusomatic( + args.candidates_tsv, + args.reference, + args.out, + args.checkpoint, + args.num_threads, + args.batch_size, + args.max_load_candidates, + args.pass_threshold, + args.lowqual_threshold, + args.force_zero_ann_cols, + args.max_cov, + use_cuda, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") diff --git a/neusomatic/python/dataloader.py b/neusomatic/python/dataloader.py index 0fe5470..537433a 100755 --- a/neusomatic/python/dataloader.py +++ b/neusomatic/python/dataloader.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # dataloader.py # Data loader used by NeuSomatic network for datasets created by 'generate_dataset.py' -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import multiprocessing import pickle import zlib @@ -18,13 +18,12 @@ from utils import skip_empty from defaults import TYPE_CLASS_DICT, VARTYPE_CLASSES, MAT_DTYPES -FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -class matrix_transform(): - +class matrix_transform: def __init__(self, mean, std): self.mean = mean self.std = std @@ -38,12 +37,17 @@ def __call__(self, matrix): def extract_zlib(zlib_compressed_im, matrix_dtype): if matrix_dtype == "uint8": - return np.fromstring(zlib.decompress(zlib_compressed_im), dtype="uint8").reshape((5, 32, 23)) + return np.fromstring( + zlib.decompress(zlib_compressed_im), dtype="uint8" + ).reshape((5, 32, 23)) elif matrix_dtype == "uint16": - return np.fromstring(zlib.decompress(zlib_compressed_im), dtype="uint16").reshape((5, 32, 23)) + return np.fromstring( + zlib.decompress(zlib_compressed_im), dtype="uint16" + ).reshape((5, 32, 23)) else: logger.info( - "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES)) + "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES) + ) raise Exception @@ -69,13 +73,16 @@ def candidate_loader_tsv(tsv, open_tsv, idx, i, matrix_dtype): def extract_info_tsv(record): i_b, tsv, idx, L, max_load_candidates, nclasses_t, nclasses_l, matrix_dtype = record thread_logger = logging.getLogger( - "{} ({})".format(extract_info_tsv.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + extract_info_tsv.__name__, multiprocessing.current_process().name + ) + ) try: n_none = 0 with open(tsv, "r") as i_f: for line in skip_empty(i_f): tag = line.strip().split()[2] - n_none += (1 if "NONE" in tag else 0) + n_none += 1 if "NONE" in tag else 0 n_var = L - n_none max_load_candidates_var = min(n_var, max_load_candidates) @@ -107,9 +114,9 @@ def extract_info_tsv(record): count_class_t[TYPE_CLASS_DICT[vartype]] += 1 count_class_l[min(int(length), 3)] += 1 if ((cnt_var < max_load_candidates_var) and ("NONE" not in tag)) or ( - (cnt_none < max_load_candidates_none) and ("NONE" in tag)): - im = extract_zlib(base64.b64decode( - fields[3]), matrix_dtype) + (cnt_none < max_load_candidates_none) and ("NONE" in tag) + ): + im = extract_zlib(base64.b64decode(fields[3]), matrix_dtype) label = TYPE_CLASS_DICT[tag.split(".")[4]] if len(fields) > 4: anns = list(map(float, fields[4:])) @@ -123,8 +130,7 @@ def extract_info_tsv(record): else: data.append([]) assert i + 1 == L - thread_logger.info("Loaded {} candidates for {}".format( - len(matrices), tsv)) + thread_logger.info("Loaded {} candidates for {}".format(len(matrices), tsv)) return matrices, data, none_ids, var_ids, count_class_t, count_class_l except Exception as ex: thread_logger.error(traceback.format_exc()) @@ -133,16 +139,25 @@ def extract_info_tsv(record): class NeuSomaticDataset(torch.utils.data.Dataset): - - def __init__(self, roots, max_load_candidates, transform=None, - loader=candidate_loader_tsv, is_test=False, - num_threads=1, disable_ensemble=False, data_augmentation=False, - nclasses_t=4, nclasses_l=4, coverage_thr=100, - max_cov=None, - normalize_channels=False, - zero_ann_cols=[], - matrix_dtype="uint8", - max_opended_tsv=-1): + def __init__( + self, + roots, + max_load_candidates, + transform=None, + loader=candidate_loader_tsv, + is_test=False, + num_threads=1, + disable_ensemble=False, + data_augmentation=False, + nclasses_t=4, + nclasses_l=4, + coverage_thr=100, + max_cov=None, + normalize_channels=False, + zero_ann_cols=[], + matrix_dtype="uint8", + max_opended_tsv=-1, + ): soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) logger.info(resource.getrlimit(resource.RLIMIT_NOFILE)) @@ -186,8 +201,11 @@ def __init__(self, roots, max_load_candidates, transform=None, new_batch = [] for i_b, L in enumerate(self.Ls): new_batch.append([i_b, L]) - if sum(map(lambda x: x[1], new_batch)) > 200000 or i_b == len(self.Ls) - 1 \ - or len(new_batch) > num_threads: + if ( + sum(map(lambda x: x[1], new_batch)) > 200000 + or i_b == len(self.Ls) - 1 + or len(new_batch) > num_threads + ): batches.append(new_batch) new_batch = [] @@ -197,10 +215,21 @@ def __init__(self, roots, max_load_candidates, transform=None, Ls_ = [] for i_b, _ in batch: tsv = self.tsvs[i_b] - max_load_ = self.Ls[i_b] * max_load_candidates // \ - total_L if total_L > 0 else 0 - map_args.append([i_b, tsv, self.idxs[i_b], self.Ls[i_b], - max_load_, nclasses_t, nclasses_l, self.matrix_dtype]) + max_load_ = ( + self.Ls[i_b] * max_load_candidates // total_L if total_L > 0 else 0 + ) + map_args.append( + [ + i_b, + tsv, + self.idxs[i_b], + self.Ls[i_b], + max_load_, + nclasses_t, + nclasses_l, + self.matrix_dtype, + ] + ) Ls_.append(self.Ls[i_b]) logger.info("Len's of tsv files in this batch: {}".format(Ls_)) if len(map_args) == 1: @@ -213,8 +242,7 @@ def __init__(self, roots, max_load_candidates, transform=None, else: pool = multiprocessing.Pool(num_threads) try: - records_ = pool.map_async( - extract_info_tsv, map_args).get() + records_ = pool.map_async(extract_info_tsv, map_args).get() pool.close() except Exception as inst: pool.close() @@ -230,7 +258,14 @@ def __init__(self, roots, max_load_candidates, transform=None, j = 0 for records_ in records_done: - for matrices, data, none_ids, var_ids, count_class_t, count_class_l in records_: + for ( + matrices, + data, + none_ids, + var_ids, + count_class_t, + count_class_l, + ) in records_: self.matrices += matrices self.data += data self.none_ids += list(map(lambda x: x + j, none_ids)) @@ -266,18 +301,24 @@ def __getitem__(self, index): if len(self.data[index]) == 0: i_b, i = self.matrices[index] if multiprocessing.current_process()._identity: - path, matrix, anns, label = candidate_loader_tsv(self.tsvs[i_b], - self.open_tsvs[ - int(multiprocessing.current_process()._identity[0] - ) % self.num_threads][i_b], - self.idxs[i_b], i, self.matrix_dtype) + path, matrix, anns, label = candidate_loader_tsv( + self.tsvs[i_b], + self.open_tsvs[ + int(multiprocessing.current_process()._identity[0]) + % self.num_threads + ][i_b], + self.idxs[i_b], + i, + self.matrix_dtype, + ) else: - path, matrix, anns, label = candidate_loader_tsv(self.tsvs[i_b], - self.open_tsvs[ - 0][i_b], - self.idxs[ - i_b], i, - self.matrix_dtype) + path, matrix, anns, label = candidate_loader_tsv( + self.tsvs[i_b], + self.open_tsvs[0][i_b], + self.idxs[i_b], + i, + self.matrix_dtype, + ) else: path, matrix, anns, label = self.data[index] @@ -291,8 +332,7 @@ def __getitem__(self, index): anns = anns.tolist() tag = path.split("/")[-1] - _, _, _, _, vartype, center, length, tumor_cov, normal_cov = tag.split( - ".") + _, _, _, _, vartype, center, length, tumor_cov, normal_cov = tag.split(".") tumor_cov = int(tumor_cov) normal_cov = int(normal_cov) if self.max_cov is not None: @@ -307,13 +347,17 @@ def __getitem__(self, index): h, w, _ = matrix.shape far_center = False - if (((center - 2) * 2 / 3) >= (center - 2)) or (((w - center - 2) * 2 / 3) - >= (w - center - 2)): + if (((center - 2) * 2 / 3) >= (center - 2)) or ( + ((w - center - 2) * 2 / 3) >= (w - center - 2) + ): far_center = True # Data augmentaion by shifting left or right - if self.data_augmentation and (not self.is_test) and (random.rand() < self.da_shift_p - and (not far_center)): + if ( + self.data_augmentation + and (not self.is_test) + and (random.rand() < self.da_shift_p and (not far_center)) + ): h, w, c = matrix.shape r = random.rand() if r < 0.6: @@ -321,22 +365,25 @@ def __getitem__(self, index): else: x_left = 0 if r > 0.4: - x_right = random.randint( - (w - center - 2) * 2 // 3, w - center - 2) + x_right = random.randint((w - center - 2) * 2 // 3, w - center - 2) else: x_right = 0 if x_left > 0: - matrix[:, 0:w - x_left, :] = matrix[:, x_left:, :] + matrix[:, 0 : w - x_left, :] = matrix[:, x_left:, :] matrix[:, -x_left:, :] = -1 center -= x_left if x_right > 0: - matrix[:, x_right:, :] = matrix[:, 0:w - x_right, :] + matrix[:, x_right:, :] = matrix[:, 0 : w - x_right, :] matrix[:, 0:x_right, :] = -1 center += x_right # Data augmentaion by switch bases - if self.data_augmentation and (not self.is_test) and random.rand() < self.da_base_p \ - and (vartype != "NONE"): + if ( + self.data_augmentation + and (not self.is_test) + and random.rand() < self.da_base_p + and (vartype != "NONE") + ): [i, j] = random.permutation(range(1, 5))[0:2] a = matrix[i, :, :] matrix[i, :, :] = matrix[j, :, :] @@ -346,8 +393,12 @@ def __getitem__(self, index): try: nt_matrix = matrix.copy() nt_center = int(center) - if self.data_augmentation and (not self.is_test) and random.rand() < self.da_rev_p \ - and (vartype not in ["DEL"]): + if ( + self.data_augmentation + and (not self.is_test) + and random.rand() < self.da_rev_p + and (vartype not in ["DEL"]) + ): h, w, c = matrix.shape refbase = np.nonzero(matrix[:, center, 0])[0] if len(refbase) > 1: @@ -387,25 +438,31 @@ def __getitem__(self, index): h += 1 else: mx_1 = np.max(matrix[:, i, 1]) - if matrix[0, i, 1] < mx_1 and matrix[hp_base, i, 1] < mx_1: + if ( + matrix[0, i, 1] < mx_1 + and matrix[hp_base, i, 1] < mx_1 + ): e = center + 1 break e += 1 if h == 1: e = center + 1 if (e - b) > 1: - matrix[:, b:e, :] = matrix[:, e - 1:b - 1:-1, :].copy() + matrix[:, b:e, :] = matrix[:, e - 1 : b - 1 : -1, :].copy() center = e - 1 - (center - b) matrix = matrix[:, ::-1, :].copy() center = w - center - 1 except: - logger.warning( - "Failed random flip center={} tag={}".format(center, tag)) + logger.warning("Failed random flip center={} tag={}".format(center, tag)) matrix = nt_matrix center = nt_center # Data augmentaion by changing coverage - if self.data_augmentation and (not self.is_test) and random.rand() < self.da_cov_p: + if ( + self.data_augmentation + and (not self.is_test) + and random.rand() < self.da_cov_p + ): r_cov = (1 - self.da_cov_e) + (random.rand() * 2 * self.da_cov_e) tumor_cov *= r_cov normal_cov *= r_cov @@ -416,26 +473,30 @@ def __getitem__(self, index): max_norm = 65535.0 else: logger.info( - "Wrong matrix_dtype {}. Choices are {}".format(self.matrix_dtype, MAT_DTYPES)) + "Wrong matrix_dtype {}. Choices are {}".format( + self.matrix_dtype, MAT_DTYPES + ) + ) # add COV channel matrix_ = np.zeros((matrix.shape[0], matrix.shape[1], 26 + len(anns))) matrix_[:, :, 0:23] = matrix if self.normalize_channels: - matrix_[:, :, 3:23:2] *= (matrix_[:, :, 1:2] / max_norm) - matrix_[:, :, 4:23:2] *= (matrix_[:, :, 2:3] / max_norm) + matrix_[:, :, 3:23:2] *= matrix_[:, :, 1:2] / max_norm + matrix_[:, :, 4:23:2] *= matrix_[:, :, 2:3] / max_norm matrix = matrix_ matrix[:, center, 23] = np.max(matrix[:, :, 0]) - matrix[:, :, 24] = (min(tumor_cov, self.coverage_thr) / - float(self.coverage_thr)) * max_norm + matrix[:, :, 24] = ( + min(tumor_cov, self.coverage_thr) / float(self.coverage_thr) + ) * max_norm matrix[:, :, 25] = ( - min(normal_cov, self.coverage_thr) / float(self.coverage_thr)) * max_norm + min(normal_cov, self.coverage_thr) / float(self.coverage_thr) + ) * max_norm for i, a in enumerate(anns): matrix[:, :, 26 + i] = a * max_norm if self.is_test: - orig_matrix_ = np.zeros( - (orig_matrix.shape[0], orig_matrix.shape[1], 3)) + orig_matrix_ = np.zeros((orig_matrix.shape[0], orig_matrix.shape[1], 3)) orig_matrix_[:, :, 0:2] = orig_matrix[:, :, 0:2] orig_matrix_[:, orig_center, 2] = np.max(orig_matrix[:, :, 0]) orig_matrix = orig_matrix_ @@ -453,7 +514,10 @@ def __getitem__(self, index): var_pos = torch.Tensor(var_pos) varlen_label = min(length, 3) - return (matrix, label, var_pos, varlen_label, non_transformed_matrix), [path, label] + return ( + (matrix, label, var_pos, varlen_label, non_transformed_matrix), + [path, label], + ) def __len__(self): return len(self.matrices) diff --git a/neusomatic/python/defaults.py b/neusomatic/python/defaults.py index a959cae..a383b32 100644 --- a/neusomatic/python/defaults.py +++ b/neusomatic/python/defaults.py @@ -2,5 +2,5 @@ NUM_ST_FEATURES = 26 VCF_HEADER = "##fileformat=VCFv4.2" TYPE_CLASS_DICT = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3} -VARTYPE_CLASSES = ['DEL', 'INS', 'NONE', 'SNP'] +VARTYPE_CLASSES = ["DEL", "INS", "NONE", "SNP"] MAT_DTYPES = ["uint8", "uint16"] diff --git a/neusomatic/python/extend_features.py b/neusomatic/python/extend_features.py index 5d4b482..d51bc87 100755 --- a/neusomatic/python/extend_features.py +++ b/neusomatic/python/extend_features.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # extend_features.py # add extra features for standalone mode -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import traceback import logging @@ -20,9 +20,21 @@ def extract_features(candidate_record): - reference, tumor_bam, normal_bam, min_mapq, min_bq, dbsnp, no_seq_complexity, batch = candidate_record + ( + reference, + tumor_bam, + normal_bam, + min_mapq, + min_bq, + dbsnp, + no_seq_complexity, + batch, + ) = candidate_record thread_logger = logging.getLogger( - "{} ({})".format(extract_features.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + extract_features.__name__, multiprocessing.current_process().name + ) + ) try: tbam = pysam.AlignmentFile(tumor_bam) nbam = pysam.AlignmentFile(normal_bam) @@ -32,45 +44,69 @@ def extract_features(candidate_record): ext_features = [] for nei_cluster in batch: - n_cluster_reads = sequencing_features.ClusterReads( - nbam, nei_cluster) - t_cluster_reads = sequencing_features.ClusterReads( - tbam, nei_cluster) - for var_i, [chrom, pos, ref, alt, if_cosmic, num_cosmic_cases] in enumerate(nei_cluster): + n_cluster_reads = sequencing_features.ClusterReads(nbam, nei_cluster) + t_cluster_reads = sequencing_features.ClusterReads(tbam, nei_cluster) + for var_i, [chrom, pos, ref, alt, if_cosmic, num_cosmic_cases] in enumerate( + nei_cluster + ): var_id = "-".join([chrom, str(pos), ref, alt]) pos = int(pos) my_coordinate = [chrom, pos] nBamFeatures = n_cluster_reads.get_alignment_features( - var_i, ref, alt, min_mapq, min_bq) + var_i, ref, alt, min_mapq, min_bq + ) tBamFeatures = t_cluster_reads.get_alignment_features( - var_i, ref, alt, min_mapq, min_bq) + var_i, ref, alt, min_mapq, min_bq + ) - sor = sequencing_features.somaticOddRatio(nBamFeatures.nref, nBamFeatures.nalt, tBamFeatures.nref, - tBamFeatures.nalt) + sor = sequencing_features.somaticOddRatio( + nBamFeatures.nref, + nBamFeatures.nalt, + tBamFeatures.nref, + tBamFeatures.nalt, + ) try: - score_varscan2 = genome.p2phred(sequencing_features.fisher_exact_test( - ((tBamFeatures.nalt, nBamFeatures.nalt), - (tBamFeatures.nref, nBamFeatures.nref)), - alternative='greater')) + score_varscan2 = genome.p2phred( + sequencing_features.fisher_exact_test( + ( + (tBamFeatures.nalt, nBamFeatures.nalt), + (tBamFeatures.nref, nBamFeatures.nref), + ), + alternative="greater", + ) + ) except ValueError: - score_varscan2 = float('nan') + score_varscan2 = float("nan") - homopolymer_length, site_homopolymer_length = sequencing_features.from_genome_reference( - ref_fa, my_coordinate, ref, alt) + ( + homopolymer_length, + site_homopolymer_length, + ) = sequencing_features.from_genome_reference( + ref_fa, my_coordinate, ref, alt + ) indel_length = len(alt) - len(ref) if not no_seq_complexity: - seq_span_80bp = ref_fa.fetch(my_coordinate[0], max( - 0, my_coordinate[1] - 41), my_coordinate[1] + 40) - seq_left_80bp = ref_fa.fetch(my_coordinate[0], max( - 0, my_coordinate[1] - 81), my_coordinate[1]) - seq_right_80bp = ref_fa.fetch(my_coordinate[0], my_coordinate[ - 1], my_coordinate[1] + 81) + seq_span_80bp = ref_fa.fetch( + my_coordinate[0], + max(0, my_coordinate[1] - 41), + my_coordinate[1] + 40, + ) + seq_left_80bp = ref_fa.fetch( + my_coordinate[0], + max(0, my_coordinate[1] - 81), + my_coordinate[1], + ) + seq_right_80bp = ref_fa.fetch( + my_coordinate[0], my_coordinate[1], my_coordinate[1] + 81 + ) LC_spanning = sequencing_features.subLC(seq_span_80bp, 20) - LC_adjacent = min(sequencing_features.subLC( - seq_left_80bp, 20), sequencing_features.subLC(seq_right_80bp, 20)) + LC_adjacent = min( + sequencing_features.subLC(seq_left_80bp, 20), + sequencing_features.subLC(seq_right_80bp, 20), + ) LC_spanning_phred = genome.p2phred(1 - LC_spanning, 40) LC_adjacent_phred = genome.p2phred(1 - LC_adjacent, 40) @@ -80,12 +116,12 @@ def extract_features(candidate_record): region = "{}:{}-{}".format(chrom, pos, pos + 1) dbsnp_vars = {} for x in dbsnp_tb.fetch(region=region): - chrom_, pos_, _, ref_, alts_, _, _, info_ = x.strip().split("\t")[ - 0:8] + chrom_, pos_, _, ref_, alts_, _, _, info_ = x.strip().split( + "\t" + )[0:8] for alt_ in alts_.split(","): dbsnp_var_id = "-".join([chrom_, pos_, ref_, alt_]) - dbsnp_vars[ - dbsnp_var_id] = 1 if "COMMON=1" in info_ else 0 + dbsnp_vars[dbsnp_var_id] = 1 if "COMMON=1" in info_ else 0 if var_id in dbsnp_vars: if_dbsnp = 1 if_common = dbsnp_vars[var_id] @@ -106,32 +142,35 @@ def extract_features(candidate_record): Seq_Complexity_Adj = LC_adjacent_phred N_DP = nBamFeatures.dp - nBAM_REF_MQ = '%g' % nBamFeatures.ref_mq - nBAM_ALT_MQ = '%g' % nBamFeatures.alt_mq - nBAM_Z_Ranksums_MQ = '%g' % nBamFeatures.z_ranksums_mq - nBAM_REF_BQ = '%g' % nBamFeatures.ref_bq - nBAM_ALT_BQ = '%g' % nBamFeatures.alt_bq - nBAM_Z_Ranksums_BQ = '%g' % nBamFeatures.z_ranksums_bq - nBAM_REF_NM = '%g' % nBamFeatures.ref_NM - nBAM_ALT_NM = '%g' % nBamFeatures.alt_NM - nBAM_NM_Diff = '%g' % nBamFeatures.NM_Diff + nBAM_REF_MQ = "%g" % nBamFeatures.ref_mq + nBAM_ALT_MQ = "%g" % nBamFeatures.alt_mq + nBAM_Z_Ranksums_MQ = "%g" % nBamFeatures.z_ranksums_mq + nBAM_REF_BQ = "%g" % nBamFeatures.ref_bq + nBAM_ALT_BQ = "%g" % nBamFeatures.alt_bq + nBAM_Z_Ranksums_BQ = "%g" % nBamFeatures.z_ranksums_bq + nBAM_REF_NM = "%g" % nBamFeatures.ref_NM + nBAM_ALT_NM = "%g" % nBamFeatures.alt_NM + nBAM_NM_Diff = "%g" % nBamFeatures.NM_Diff nBAM_REF_Concordant = nBamFeatures.ref_concordant_reads nBAM_REF_Discordant = nBamFeatures.ref_discordant_reads nBAM_ALT_Concordant = nBamFeatures.alt_concordant_reads nBAM_ALT_Discordant = nBamFeatures.alt_discordant_reads nBAM_Concordance_FET = rescale( - nBamFeatures.concordance_fet, 'fraction', p_scale, 1001) + nBamFeatures.concordance_fet, "fraction", p_scale, 1001 + ) N_REF_FOR = nBamFeatures.ref_for N_REF_REV = nBamFeatures.ref_rev N_ALT_FOR = nBamFeatures.alt_for N_ALT_REV = nBamFeatures.alt_rev nBAM_StrandBias_FET = rescale( - nBamFeatures.strandbias_fet, 'fraction', p_scale, 1001) - nBAM_Z_Ranksums_EndPos = '%g' % nBamFeatures.z_ranksums_endpos + nBamFeatures.strandbias_fet, "fraction", p_scale, 1001 + ) + nBAM_Z_Ranksums_EndPos = "%g" % nBamFeatures.z_ranksums_endpos nBAM_REF_Clipped_Reads = nBamFeatures.ref_SC_reads nBAM_ALT_Clipped_Reads = nBamFeatures.alt_SC_reads nBAM_Clipping_FET = rescale( - nBamFeatures.clipping_fet, 'fraction', p_scale, 1001) + nBamFeatures.clipping_fet, "fraction", p_scale, 1001 + ) nBAM_MQ0 = nBamFeatures.MQ0 nBAM_Other_Reads = nBamFeatures.noise_read_count nBAM_Poor_Reads = nBamFeatures.poor_read_count @@ -144,34 +183,37 @@ def extract_features(candidate_record): SOR = sor MaxHomopolymer_Length = homopolymer_length SiteHomopolymer_Length = site_homopolymer_length - score_varscan2 = rescale(score_varscan2, 'phred', p_scale, 1001) + score_varscan2 = rescale(score_varscan2, "phred", p_scale, 1001) T_DP = tBamFeatures.dp - tBAM_REF_MQ = '%g' % tBamFeatures.ref_mq - tBAM_ALT_MQ = '%g' % tBamFeatures.alt_mq - tBAM_Z_Ranksums_MQ = '%g' % tBamFeatures.z_ranksums_mq - tBAM_REF_BQ = '%g' % tBamFeatures.ref_bq - tBAM_ALT_BQ = '%g' % tBamFeatures.alt_bq - tBAM_Z_Ranksums_BQ = '%g' % tBamFeatures.z_ranksums_bq - tBAM_REF_NM = '%g' % tBamFeatures.ref_NM - tBAM_ALT_NM = '%g' % tBamFeatures.alt_NM - tBAM_NM_Diff = '%g' % tBamFeatures.NM_Diff + tBAM_REF_MQ = "%g" % tBamFeatures.ref_mq + tBAM_ALT_MQ = "%g" % tBamFeatures.alt_mq + tBAM_Z_Ranksums_MQ = "%g" % tBamFeatures.z_ranksums_mq + tBAM_REF_BQ = "%g" % tBamFeatures.ref_bq + tBAM_ALT_BQ = "%g" % tBamFeatures.alt_bq + tBAM_Z_Ranksums_BQ = "%g" % tBamFeatures.z_ranksums_bq + tBAM_REF_NM = "%g" % tBamFeatures.ref_NM + tBAM_ALT_NM = "%g" % tBamFeatures.alt_NM + tBAM_NM_Diff = "%g" % tBamFeatures.NM_Diff tBAM_REF_Concordant = tBamFeatures.ref_concordant_reads tBAM_REF_Discordant = tBamFeatures.ref_discordant_reads tBAM_ALT_Concordant = tBamFeatures.alt_concordant_reads tBAM_ALT_Discordant = tBamFeatures.alt_discordant_reads tBAM_Concordance_FET = rescale( - tBamFeatures.concordance_fet, 'fraction', p_scale, 1001) + tBamFeatures.concordance_fet, "fraction", p_scale, 1001 + ) T_REF_FOR = tBamFeatures.ref_for T_REF_REV = tBamFeatures.ref_rev T_ALT_FOR = tBamFeatures.alt_for T_ALT_REV = tBamFeatures.alt_rev tBAM_StrandBias_FET = rescale( - tBamFeatures.strandbias_fet, 'fraction', p_scale, 1001) - tBAM_Z_Ranksums_EndPos = '%g' % tBamFeatures.z_ranksums_endpos + tBamFeatures.strandbias_fet, "fraction", p_scale, 1001 + ) + tBAM_Z_Ranksums_EndPos = "%g" % tBamFeatures.z_ranksums_endpos tBAM_REF_Clipped_Reads = tBamFeatures.ref_SC_reads tBAM_ALT_Clipped_Reads = tBamFeatures.alt_SC_reads tBAM_Clipping_FET = rescale( - tBamFeatures.clipping_fet, 'fraction', p_scale, 1001) + tBamFeatures.clipping_fet, "fraction", p_scale, 1001 + ) tBAM_MQ0 = tBamFeatures.MQ0 tBAM_Other_Reads = tBamFeatures.noise_read_count tBAM_Poor_Reads = tBamFeatures.poor_read_count @@ -183,24 +225,96 @@ def extract_features(candidate_record): tBAM_ALT_InDel_1bp = tBamFeatures.alt_indel_1bp InDel_Length = indel_length - features = [CHROM, POS, ".", REF, ALT, if_dbsnp, COMMON, if_COSMIC, COSMIC_CNT, - Consistent_Mates, Inconsistent_Mates] + features = [ + CHROM, + POS, + ".", + REF, + ALT, + if_dbsnp, + COMMON, + if_COSMIC, + COSMIC_CNT, + Consistent_Mates, + Inconsistent_Mates, + ] if not no_seq_complexity: features.extend([Seq_Complexity_Span, Seq_Complexity_Adj]) - features.extend([N_DP, nBAM_REF_MQ, nBAM_ALT_MQ, nBAM_Z_Ranksums_MQ, - nBAM_REF_BQ, nBAM_ALT_BQ, nBAM_Z_Ranksums_BQ, nBAM_REF_NM, nBAM_ALT_NM, nBAM_NM_Diff, - nBAM_REF_Concordant, nBAM_REF_Discordant, nBAM_ALT_Concordant, nBAM_ALT_Discordant, - nBAM_Concordance_FET, N_REF_FOR, N_REF_REV, N_ALT_FOR, N_ALT_REV, nBAM_StrandBias_FET, - nBAM_Z_Ranksums_EndPos, nBAM_REF_Clipped_Reads, nBAM_ALT_Clipped_Reads, nBAM_Clipping_FET, - nBAM_MQ0, nBAM_Other_Reads, nBAM_Poor_Reads, nBAM_REF_InDel_3bp, nBAM_REF_InDel_2bp, - nBAM_REF_InDel_1bp, nBAM_ALT_InDel_3bp, nBAM_ALT_InDel_2bp, nBAM_ALT_InDel_1bp, SOR, - MaxHomopolymer_Length, SiteHomopolymer_Length, score_varscan2, T_DP, tBAM_REF_MQ, tBAM_ALT_MQ, tBAM_Z_Ranksums_MQ, - tBAM_REF_BQ, tBAM_ALT_BQ, tBAM_Z_Ranksums_BQ, tBAM_REF_NM, tBAM_ALT_NM, tBAM_NM_Diff, - tBAM_REF_Concordant, tBAM_REF_Discordant, tBAM_ALT_Concordant, tBAM_ALT_Discordant, - tBAM_Concordance_FET, T_REF_FOR, T_REF_REV, T_ALT_FOR, T_ALT_REV, tBAM_StrandBias_FET, - tBAM_Z_Ranksums_EndPos, tBAM_REF_Clipped_Reads, tBAM_ALT_Clipped_Reads, tBAM_Clipping_FET, - tBAM_MQ0, tBAM_Other_Reads, tBAM_Poor_Reads, tBAM_REF_InDel_3bp, tBAM_REF_InDel_2bp, - tBAM_REF_InDel_1bp, tBAM_ALT_InDel_3bp, tBAM_ALT_InDel_2bp, tBAM_ALT_InDel_1bp, InDel_Length]) + features.extend( + [ + N_DP, + nBAM_REF_MQ, + nBAM_ALT_MQ, + nBAM_Z_Ranksums_MQ, + nBAM_REF_BQ, + nBAM_ALT_BQ, + nBAM_Z_Ranksums_BQ, + nBAM_REF_NM, + nBAM_ALT_NM, + nBAM_NM_Diff, + nBAM_REF_Concordant, + nBAM_REF_Discordant, + nBAM_ALT_Concordant, + nBAM_ALT_Discordant, + nBAM_Concordance_FET, + N_REF_FOR, + N_REF_REV, + N_ALT_FOR, + N_ALT_REV, + nBAM_StrandBias_FET, + nBAM_Z_Ranksums_EndPos, + nBAM_REF_Clipped_Reads, + nBAM_ALT_Clipped_Reads, + nBAM_Clipping_FET, + nBAM_MQ0, + nBAM_Other_Reads, + nBAM_Poor_Reads, + nBAM_REF_InDel_3bp, + nBAM_REF_InDel_2bp, + nBAM_REF_InDel_1bp, + nBAM_ALT_InDel_3bp, + nBAM_ALT_InDel_2bp, + nBAM_ALT_InDel_1bp, + SOR, + MaxHomopolymer_Length, + SiteHomopolymer_Length, + score_varscan2, + T_DP, + tBAM_REF_MQ, + tBAM_ALT_MQ, + tBAM_Z_Ranksums_MQ, + tBAM_REF_BQ, + tBAM_ALT_BQ, + tBAM_Z_Ranksums_BQ, + tBAM_REF_NM, + tBAM_ALT_NM, + tBAM_NM_Diff, + tBAM_REF_Concordant, + tBAM_REF_Discordant, + tBAM_ALT_Concordant, + tBAM_ALT_Discordant, + tBAM_Concordance_FET, + T_REF_FOR, + T_REF_REV, + T_ALT_FOR, + T_ALT_REV, + tBAM_StrandBias_FET, + tBAM_Z_Ranksums_EndPos, + tBAM_REF_Clipped_Reads, + tBAM_ALT_Clipped_Reads, + tBAM_Clipping_FET, + tBAM_MQ0, + tBAM_Other_Reads, + tBAM_Poor_Reads, + tBAM_REF_InDel_3bp, + tBAM_REF_InDel_2bp, + tBAM_REF_InDel_1bp, + tBAM_ALT_InDel_3bp, + tBAM_ALT_InDel_2bp, + tBAM_ALT_InDel_1bp, + InDel_Length, + ] + ) ext_features.append(features) return ext_features @@ -211,22 +325,29 @@ def extract_features(candidate_record): return None -def extend_features(candidates_vcf, - exclude_variants, - add_variants, - output_tsv, - reference, tumor_bam, normal_bam, - min_mapq, min_bq, - dbsnp, cosmic, - no_seq_complexity, - window_extend, - max_cluster_size, - num_threads): +def extend_features( + candidates_vcf, + exclude_variants, + add_variants, + output_tsv, + reference, + tumor_bam, + normal_bam, + min_mapq, + min_bq, + dbsnp, + cosmic, + no_seq_complexity, + window_extend, + max_cluster_size, + num_threads, +): logger = logging.getLogger(extend_features.__name__) logger.info( - "----------------------Extend Standalone Features------------------------") + "----------------------Extend Standalone Features------------------------" + ) if not os.path.exists(tumor_bam): logger.error("Aborting!") @@ -236,27 +357,28 @@ def extend_features(candidates_vcf, raise Exception("No normal BAM file {}".format(normal_bam)) if not os.path.exists(tumor_bam + ".bai"): logger.error("Aborting!") - raise Exception( - "No tumor .bai index file {}".format(tumor_bam + ".bai")) + raise Exception("No tumor .bai index file {}".format(tumor_bam + ".bai")) if not os.path.exists(normal_bam + ".bai"): logger.error("Aborting!") - raise Exception( - "No normal .bai index file {}".format(normal_bam + ".bai")) + raise Exception("No normal .bai index file {}".format(normal_bam + ".bai")) if dbsnp: if not os.path.exists(dbsnp): logger.error("Aborting!") - raise Exception( - "No dbSNP file {}".format(dbsnp)) + raise Exception("No dbSNP file {}".format(dbsnp)) if dbsnp[-6:] != "vcf.gz": logger.error("Aborting!") raise Exception( - "The dbSNP file should be a tabix indexed file with .vcf.gz format") + "The dbSNP file should be a tabix indexed file with .vcf.gz format" + ) if not os.path.exists(dbsnp + ".tbi"): logger.error("Aborting!") raise Exception( - "The dbSNP file should be a tabix indexed file with .vcf.gz format. No {}.tbi file exists.".format(dbsnp)) + "The dbSNP file should be a tabix indexed file with .vcf.gz format. No {}.tbi file exists.".format( + dbsnp + ) + ) chrom_order = get_chromosomes_order(reference) if cosmic: @@ -265,8 +387,11 @@ def extend_features(candidates_vcf, for line in skip_empty(i_f): x = line.strip().split("\t") chrom, pos, _, ref, alts, _, _, info = x[0:8] - num_cases = info.split("CNT=")[1].split( - ";")[0] if "CNT=" in info else float('nan') + num_cases = ( + info.split("CNT=")[1].split(";")[0] + if "CNT=" in info + else float("nan") + ) for alt in alts.split(","): var_id = "-".join([chrom, pos, ref, alt]) cosmic_vars[var_id] = num_cases @@ -304,29 +429,30 @@ def extend_features(candidates_vcf, if add_variants: if var_id in add_vars: add_vars = add_vars - set([var_id]) - num_cosmic_cases = float('nan') + num_cosmic_cases = float("nan") if_cosmic = 0 if cosmic and var_id in cosmic_vars: if_cosmic = 1 num_cosmic_cases = cosmic_vars[var_id] all_variants.append( - [chrom, int(pos), ref, alt, if_cosmic, num_cosmic_cases]) + [chrom, int(pos), ref, alt, if_cosmic, num_cosmic_cases] + ) if add_variants and len(add_vars) > 0: for var_id in add_vars - set(exclude_vars): v = var_id.split("-") pos, ref, alt = v[-3:] chrom = "-".join(v[:-3]) - num_cosmic_cases = float('nan') + num_cosmic_cases = float("nan") if_cosmic = 0 if cosmic and var_id in cosmic_vars: if_cosmic = 1 num_cosmic_cases = cosmic_vars[var_id] all_variants.append( - [chrom, int(pos), ref, alt, if_cosmic, num_cosmic_cases]) + [chrom, int(pos), ref, alt, if_cosmic, num_cosmic_cases] + ) - all_variants = sorted(all_variants, key=lambda x: [ - chrom_order[x[0]], x[1]]) + all_variants = sorted(all_variants, key=lambda x: [chrom_order[x[0]], x[1]]) n_variants = len(all_variants) logger.info("Number of variants: {}".format(n_variants)) split_len = (n_variants + num_threads - 1) // num_threads @@ -336,50 +462,136 @@ def extend_features(candidates_vcf, batch = [] n_batch = 0 curr_pos = None - for i, [chrom, pos, ref, alt, if_cosmic, num_cosmic_cases] in enumerate(all_variants): + for i, [chrom, pos, ref, alt, if_cosmic, num_cosmic_cases] in enumerate( + all_variants + ): if curr_pos is None: curr_pos = [chrom, pos] nei_cluster = [[chrom, pos, ref, alt, if_cosmic, num_cosmic_cases]] else: - if chrom == curr_pos[0] and abs(curr_pos[1] - pos) < window_extend and len(nei_cluster) < max_cluster_size: - nei_cluster.append( - [chrom, pos, ref, alt, if_cosmic, num_cosmic_cases]) + if ( + chrom == curr_pos[0] + and abs(curr_pos[1] - pos) < window_extend + and len(nei_cluster) < max_cluster_size + ): + nei_cluster.append([chrom, pos, ref, alt, if_cosmic, num_cosmic_cases]) else: batch.append(nei_cluster) n_batch += len(nei_cluster) curr_pos = [chrom, pos] - nei_cluster = [ - [chrom, pos, ref, alt, if_cosmic, num_cosmic_cases]] + nei_cluster = [[chrom, pos, ref, alt, if_cosmic, num_cosmic_cases]] if n_batch >= split_len or i == n_variants - 1: if i == n_variants - 1: batch.append(nei_cluster) curr_pos = None nei_cluster = [] if batch: - map_args.append((reference, tumor_bam, normal_bam, - min_mapq, min_bq, dbsnp, no_seq_complexity, batch)) + map_args.append( + ( + reference, + tumor_bam, + normal_bam, + min_mapq, + min_bq, + dbsnp, + no_seq_complexity, + batch, + ) + ) batch = [] - assert(n_variants == sum([len(y) for x in map_args for y in x[-1]])) + assert n_variants == sum([len(y) for x in map_args for y in x[-1]]) logger.info("Number of batches: {}".format(len(map_args))) - header = ["CHROM", "POS", "ID", "REF", "ALT", "if_dbsnp", "COMMON", "if_COSMIC", "COSMIC_CNT", - "Consistent_Mates", "Inconsistent_Mates"] + header = [ + "CHROM", + "POS", + "ID", + "REF", + "ALT", + "if_dbsnp", + "COMMON", + "if_COSMIC", + "COSMIC_CNT", + "Consistent_Mates", + "Inconsistent_Mates", + ] if not no_seq_complexity: header.extend(["Seq_Complexity_Span", "Seq_Complexity_Adj"]) - header.extend(["N_DP", "nBAM_REF_MQ", "nBAM_ALT_MQ", "nBAM_Z_Ranksums_MQ", - "nBAM_REF_BQ", "nBAM_ALT_BQ", "nBAM_Z_Ranksums_BQ", "nBAM_REF_NM", "nBAM_ALT_NM", "nBAM_NM_Diff", - "nBAM_REF_Concordant", "nBAM_REF_Discordant", "nBAM_ALT_Concordant", "nBAM_ALT_Discordant", - "nBAM_Concordance_FET", "N_REF_FOR", "N_REF_REV", "N_ALT_FOR", "N_ALT_REV", "nBAM_StrandBias_FET", - "nBAM_Z_Ranksums_EndPos", "nBAM_REF_Clipped_Reads", "nBAM_ALT_Clipped_Reads", "nBAM_Clipping_FET", - "nBAM_MQ0", "nBAM_Other_Reads", "nBAM_Poor_Reads", "nBAM_REF_InDel_3bp", "nBAM_REF_InDel_2bp", - "nBAM_REF_InDel_1bp", "nBAM_ALT_InDel_3bp", "nBAM_ALT_InDel_2bp", "nBAM_ALT_InDel_1bp", "SOR", - "MaxHomopolymer_Length", "SiteHomopolymer_Length", "VarScan2_Score", "T_DP", "tBAM_REF_MQ", "tBAM_ALT_MQ", "tBAM_Z_Ranksums_MQ", - "tBAM_REF_BQ", "tBAM_ALT_BQ", "tBAM_Z_Ranksums_BQ", "tBAM_REF_NM", "tBAM_ALT_NM", "tBAM_NM_Diff", - "tBAM_REF_Concordant", "tBAM_REF_Discordant", "tBAM_ALT_Concordant", "tBAM_ALT_Discordant", - "tBAM_Concordance_FET", "T_REF_FOR", "T_REF_REV", "T_ALT_FOR", "T_ALT_REV", "tBAM_StrandBias_FET", - "tBAM_Z_Ranksums_EndPos", "tBAM_REF_Clipped_Reads", "tBAM_ALT_Clipped_Reads", "tBAM_Clipping_FET", - "tBAM_MQ0", "tBAM_Other_Reads", "tBAM_Poor_Reads", "tBAM_REF_InDel_3bp", "tBAM_REF_InDel_2bp", - "tBAM_REF_InDel_1bp", "tBAM_ALT_InDel_3bp", "tBAM_ALT_InDel_2bp", "tBAM_ALT_InDel_1bp", "InDel_Length"]) + header.extend( + [ + "N_DP", + "nBAM_REF_MQ", + "nBAM_ALT_MQ", + "nBAM_Z_Ranksums_MQ", + "nBAM_REF_BQ", + "nBAM_ALT_BQ", + "nBAM_Z_Ranksums_BQ", + "nBAM_REF_NM", + "nBAM_ALT_NM", + "nBAM_NM_Diff", + "nBAM_REF_Concordant", + "nBAM_REF_Discordant", + "nBAM_ALT_Concordant", + "nBAM_ALT_Discordant", + "nBAM_Concordance_FET", + "N_REF_FOR", + "N_REF_REV", + "N_ALT_FOR", + "N_ALT_REV", + "nBAM_StrandBias_FET", + "nBAM_Z_Ranksums_EndPos", + "nBAM_REF_Clipped_Reads", + "nBAM_ALT_Clipped_Reads", + "nBAM_Clipping_FET", + "nBAM_MQ0", + "nBAM_Other_Reads", + "nBAM_Poor_Reads", + "nBAM_REF_InDel_3bp", + "nBAM_REF_InDel_2bp", + "nBAM_REF_InDel_1bp", + "nBAM_ALT_InDel_3bp", + "nBAM_ALT_InDel_2bp", + "nBAM_ALT_InDel_1bp", + "SOR", + "MaxHomopolymer_Length", + "SiteHomopolymer_Length", + "VarScan2_Score", + "T_DP", + "tBAM_REF_MQ", + "tBAM_ALT_MQ", + "tBAM_Z_Ranksums_MQ", + "tBAM_REF_BQ", + "tBAM_ALT_BQ", + "tBAM_Z_Ranksums_BQ", + "tBAM_REF_NM", + "tBAM_ALT_NM", + "tBAM_NM_Diff", + "tBAM_REF_Concordant", + "tBAM_REF_Discordant", + "tBAM_ALT_Concordant", + "tBAM_ALT_Discordant", + "tBAM_Concordance_FET", + "T_REF_FOR", + "T_REF_REV", + "T_ALT_FOR", + "T_ALT_REV", + "tBAM_StrandBias_FET", + "tBAM_Z_Ranksums_EndPos", + "tBAM_REF_Clipped_Reads", + "tBAM_ALT_Clipped_Reads", + "tBAM_Clipping_FET", + "tBAM_MQ0", + "tBAM_Other_Reads", + "tBAM_Poor_Reads", + "tBAM_REF_InDel_3bp", + "tBAM_REF_InDel_2bp", + "tBAM_REF_InDel_1bp", + "tBAM_ALT_InDel_3bp", + "tBAM_ALT_InDel_2bp", + "tBAM_ALT_InDel_1bp", + "InDel_Length", + ] + ) try: ext_features = pool.map_async(extract_features, map_args).get() @@ -389,7 +601,8 @@ def extend_features(candidates_vcf, for features in ext_features: for w in features: o_f.write( - "\t".join(map(lambda x: str(x).replace("nan", "0"), w)) + "\n") + "\t".join(map(lambda x: str(x).replace("nan", "0"), w)) + "\n" + ) except Exception as inst: logger.error(inst) pool.close() @@ -400,67 +613,93 @@ def extend_features(candidates_vcf, return ext_features -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser( - description='extract extra features for standalone mode') - parser.add_argument('--candidates_vcf', type=str, help='candidates vcf', - required=True) - parser.add_argument('--exclude_variants', type=str, help='variants to exclude', - default=None) - parser.add_argument('--add_variants', type=str, help='variants to add if not exist in vcf. (Lower priority than --exclude_variants)', - default=None) - parser.add_argument('--output_tsv', type=str, help='output features tsv', - required=True) - parser.add_argument('--reference', type=str, help='reference fasta filename', - required=True) - parser.add_argument('--tumor_bam', type=str, - help='tumor bam', required=True) - parser.add_argument('--normal_bam', type=str, - help='normal bam', required=True) - parser.add_argument('--min_mapq', type=int, - help='minimum mapping quality', default=1) - parser.add_argument('--min_bq', type=float, - help='minimum base quality', default=5) - parser.add_argument('--dbsnp', type=str, - help='dbSNP vcf (to annotate candidate variants)', default=None) - parser.add_argument('--cosmic', type=str, - help='COSMIC vcf (to annotate candidate variants)', default=None) - parser.add_argument('--no_seq_complexity', - help='Dont compute linguistic sequence complexity features', - action="store_true") - parser.add_argument('--window_extend', type=int, - help='window size for extending input features (should be in the order of readlength)', - default=1000) - parser.add_argument('--max_cluster_size', type=int, - help='max cluster size for extending input features (should be in the order of readlength)', - default=300) - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) + description="extract extra features for standalone mode" + ) + parser.add_argument( + "--candidates_vcf", type=str, help="candidates vcf", required=True + ) + parser.add_argument( + "--exclude_variants", type=str, help="variants to exclude", default=None + ) + parser.add_argument( + "--add_variants", + type=str, + help="variants to add if not exist in vcf. (Lower priority than --exclude_variants)", + default=None, + ) + parser.add_argument( + "--output_tsv", type=str, help="output features tsv", required=True + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--tumor_bam", type=str, help="tumor bam", required=True) + parser.add_argument("--normal_bam", type=str, help="normal bam", required=True) + parser.add_argument( + "--min_mapq", type=int, help="minimum mapping quality", default=1 + ) + parser.add_argument("--min_bq", type=float, help="minimum base quality", default=5) + parser.add_argument( + "--dbsnp", + type=str, + help="dbSNP vcf (to annotate candidate variants)", + default=None, + ) + parser.add_argument( + "--cosmic", + type=str, + help="COSMIC vcf (to annotate candidate variants)", + default=None, + ) + parser.add_argument( + "--no_seq_complexity", + help="Dont compute linguistic sequence complexity features", + action="store_true", + ) + parser.add_argument( + "--window_extend", + type=int, + help="window size for extending input features (should be in the order of readlength)", + default=1000, + ) + parser.add_argument( + "--max_cluster_size", + type=int, + help="max cluster size for extending input features (should be in the order of readlength)", + default=300, + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) args = parser.parse_args() logger.info(args) try: - output = extend_features(args.candidates_vcf, - args.exclude_variants, - args.add_variants, - args.output_tsv, - args.reference, args.tumor_bam, args.normal_bam, - args.min_mapq, args.min_bq, - args.dbsnp, args.cosmic, - args.no_seq_complexity, - args.window_extend, - args.max_cluster_size, - args.num_threads, - ) + output = extend_features( + args.candidates_vcf, + args.exclude_variants, + args.add_variants, + args.output_tsv, + args.reference, + args.tumor_bam, + args.normal_bam, + args.min_mapq, + args.min_bq, + args.dbsnp, + args.cosmic, + args.no_seq_complexity, + args.window_extend, + args.max_cluster_size, + args.num_threads, + ) if output is None: raise Exception("extend_features failed!") except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "extend_features.py failure on arguments: {}".format(args)) + logger.error("extend_features.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/extract_postprocess_targets.py b/neusomatic/python/extract_postprocess_targets.py index bdbe40d..2c5b8f8 100755 --- a/neusomatic/python/extract_postprocess_targets.py +++ b/neusomatic/python/extract_postprocess_targets.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # extract_postprocess_targets.py # Extract variants that need postprocessing (like larege INDELs) # from the output predictions of 'call.py'. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import traceback import logging @@ -19,16 +19,15 @@ def check_rep(ref_seq, left_right, w): if len(ref_seq) < 2 * w: return False if left_right == "left": - return ref_seq[0:w] == ref_seq[w:2 * w] + return ref_seq[0:w] == ref_seq[w : 2 * w] elif left_right == "right": - return ref_seq[-w:] == ref_seq[-2 * w:-w] + return ref_seq[-w:] == ref_seq[-2 * w : -w] else: logger.error("Wrong left/right value: {}".format(left_right)) raise Exception -def extend_region_repeat(chrom, start, end, ref_fasta, - chrom_length, pad): +def extend_region_repeat(chrom, start, end, ref_fasta, chrom_length, pad): logger = logging.getLogger(extend_region_repeat.__name__) new_start = start new_end = end @@ -36,8 +35,7 @@ def extend_region_repeat(chrom, start, end, ref_fasta, while True: changed = False new_start = max(new_start - pad - w, 1) - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() while True: cnt_s = 0 for rep_len in [1, 2, 3, 4]: @@ -45,8 +43,7 @@ def extend_region_repeat(chrom, start, end, ref_fasta, continue while check_rep(ref_seq, "left", rep_len) and new_start > rep_len: new_start -= rep_len - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += rep_len changed = True if cnt_s > 0: @@ -58,17 +55,18 @@ def extend_region_repeat(chrom, start, end, ref_fasta, while True: changed = False new_end = min(new_end + pad + w, chrom_length - 2) - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() while True: cnt_e = 0 for rep_len in [1, 2, 3, 4]: if cnt_e > 0: continue - while check_rep(ref_seq, "right", rep_len) and new_end < chrom_length - rep_len - 1: + while ( + check_rep(ref_seq, "right", rep_len) + and new_end < chrom_length - rep_len - 1 + ): new_end += rep_len - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += rep_len changed = True if cnt_e > 0: @@ -80,7 +78,9 @@ def extend_region_repeat(chrom, start, end, ref_fasta, return new_start, new_end -def extract_postprocess_targets(reference, input_vcf, min_len, max_dist, extend_repeats, pad): +def extract_postprocess_targets( + reference, input_vcf, min_len, max_dist, extend_repeats, pad +): logger = logging.getLogger(extract_postprocess_targets.__name__) logger.info("--------------Extract Postprocessing Targets---------------") @@ -107,16 +107,53 @@ def extract_postprocess_targets(reference, input_vcf, min_len, max_dist, extend_ if not record_set: record_set.append(record) continue - chrom_, pos_, ref_, alt_ = push_left_var( - ref_fasta, chrom, pos, ref, alt) - if len(list(filter(lambda x: (chrom == x[0] and - (min(abs(x[1] + len(x[2]) - (pos + len(ref))), - abs(x[1] - pos), - abs(min(x[1] + len(x[2]), pos + len(ref)) - max(x[1], pos))) <= max_dist)), record_set))) > 0 or len( - list(filter(lambda x: (chrom_ == x[0] and - (min(abs(x[1] + len(x[2]) - (pos_ + len(ref_))), - abs(x[1] - pos_), - abs(min(x[1] + len(x[2]), pos_ + len(ref_)) - max(x[1], pos_))) <= max_dist)), record_set))) > 0: + chrom_, pos_, ref_, alt_ = push_left_var(ref_fasta, chrom, pos, ref, alt) + if ( + len( + list( + filter( + lambda x: ( + chrom == x[0] + and ( + min( + abs(x[1] + len(x[2]) - (pos + len(ref))), + abs(x[1] - pos), + abs( + min(x[1] + len(x[2]), pos + len(ref)) + - max(x[1], pos) + ), + ) + <= max_dist + ) + ), + record_set, + ) + ) + ) + > 0 + or len( + list( + filter( + lambda x: ( + chrom_ == x[0] + and ( + min( + abs(x[1] + len(x[2]) - (pos_ + len(ref_))), + abs(x[1] - pos_), + abs( + min(x[1] + len(x[2]), pos_ + len(ref_)) + - max(x[1], pos_) + ), + ) + <= max_dist + ) + ), + record_set, + ) + ) + ) + > 0 + ): record_set.append(record) continue @@ -138,16 +175,21 @@ def extract_postprocess_targets(reference, input_vcf, min_len, max_dist, extend_ if len(varid_pos[vid]) > 1: multi_allelic = True - if list(filter(lambda x: len(x[2]) != len(x[3]), record_set)) or multi_allelic: + if ( + list(filter(lambda x: len(x[2]) != len(x[3]), record_set)) + or multi_allelic + ): for x in record_set: fields = x[-1].strip().split() # fields[2] = str(ii) if ii not in redo_vars: redo_vars[ii] = [] redo_vars[ii].append(fields) - redo_regions[ii] = [record_set[0][0], max(0, - min(map(lambda x:x[1], record_set)) - pad), - max(map(lambda x:x[1] + len(x[2]), record_set)) + pad] + redo_regions[ii] = [ + record_set[0][0], + max(0, min(map(lambda x: x[1], record_set)) - pad), + max(map(lambda x: x[1] + len(x[2]), record_set)) + pad, + ] else: for x in record_set: o_f.write(x[-1]) @@ -160,30 +202,32 @@ def extract_postprocess_targets(reference, input_vcf, min_len, max_dist, extend_ redo_vars[ii] = [] redo_vars[ii].append(fields) chrom_, pos_, ref_, alt_ = record_set[0][0:4] - redo_regions[ii] = [chrom_, max( - 0, pos_ - pad), pos_ + len(ref_) + pad] + redo_regions[ii] = [ + chrom_, + max(0, pos_ - pad), + pos_ + len(ref_) + pad, + ] else: o_f.write(record_set[0][-1]) - if extend_repeats: - chrom_lengths = dict( - zip(ref_fasta.references, ref_fasta.lengths)) + chrom_lengths = dict(zip(ref_fasta.references, ref_fasta.lengths)) tmp_ = get_tmp_file() with open(tmp_, "w") as o_f: for ii in redo_regions: chrom, st, en = redo_regions[ii] st, en = extend_region_repeat( - chrom, st, en, ref_fasta, chrom_lengths[chrom], 0) + chrom, st, en, ref_fasta, chrom_lengths[chrom], 0 + ) o_f.write("\t".join(list(map(str, [chrom, st, en, ii]))) + "\n") - tmp_=bedtools_sort(tmp_,run_logger=logger) - tmp_=bedtools_merge(tmp_,args="-c 4 -o collapse", run_logger=logger) + tmp_ = bedtools_sort(tmp_, run_logger=logger) + tmp_ = bedtools_merge(tmp_, args="-c 4 -o collapse", run_logger=logger) else: tmp_ = get_tmp_file() with open(tmp_, "w") as o_f: for ii in redo_regions: chrom, st, en = redo_regions[ii] - o_f.write("\t".join(list(map(str, [chrom, st, en, ii]))) + "\n") + o_f.write("\t".join(list(map(str, [chrom, st, en, ii]))) + "\n") j = 0 with open(tmp_) as i_f, open(redo_vcf, "w") as r_f, open(redo_bed, "w") as r_b: r_f.write("{}\n".format(VCF_HEADER)) @@ -198,35 +242,46 @@ def extract_postprocess_targets(reference, input_vcf, min_len, max_dist, extend_ j += 1 -if __name__ == '__main__': +if __name__ == "__main__": - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser( - description='infer genotype by ao and ro counts') - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--input_vcf', type=str, - help='input vcf', required=True) - parser.add_argument('--min_len', type=int, - help='minimum INDEL len to resolve', default=4) - parser.add_argument('--max_dist', type=int, - help='max distance to neighboring variant', default=5) - parser.add_argument('--extend_repeats', - help='extend resolve regions to repeat boundaries', - action='store_true') + parser = argparse.ArgumentParser(description="infer genotype by ao and ro counts") + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--input_vcf", type=str, help="input vcf", required=True) + parser.add_argument( + "--min_len", type=int, help="minimum INDEL len to resolve", default=4 + ) + parser.add_argument( + "--max_dist", type=int, help="max distance to neighboring variant", default=5 + ) + parser.add_argument( + "--extend_repeats", + help="extend resolve regions to repeat boundaries", + action="store_true", + ) parser.add_argument( - '--pad', type=int, help='padding to bed region for extracting reads', default=10) + "--pad", type=int, help="padding to bed region for extracting reads", default=10 + ) args = parser.parse_args() logger.info(args) try: extract_postprocess_targets( - args.reference, args.input_vcf, args.min_len, args.max_dist, args.extend_repeats, args.pad) + args.reference, + args.input_vcf, + args.min_len, + args.max_dist, + args.extend_repeats, + args.pad, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") logger.error( - "extract_postprocess_targets.py failure on arguments: {}".format(args)) + "extract_postprocess_targets.py failure on arguments: {}".format(args) + ) raise e diff --git a/neusomatic/python/filter_candidates.py b/neusomatic/python/filter_candidates.py index aa7a063..d8bc5d0 100755 --- a/neusomatic/python/filter_candidates.py +++ b/neusomatic/python/filter_candidates.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # filter_candidates.py # filter raw candidates extracted by 'scan_alignments.py' using min_af and other cut-offs -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import argparse import traceback @@ -12,32 +12,58 @@ import pysam import numpy as np -from utils import safe_read_info_dict, run_bedtools_cmd, vcf_2_bed, write_tsv_file, bedtools_sort, get_tmp_file, skip_empty +from utils import ( + safe_read_info_dict, + run_bedtools_cmd, + vcf_2_bed, + write_tsv_file, + bedtools_sort, + get_tmp_file, + skip_empty, +) from defaults import VCF_HEADER def filter_candidates(candidate_record): - candidates_vcf, filtered_candidates_vcf, reference, min_dp, max_dp, good_ao, \ - min_ao, snp_min_af, snp_min_bq, snp_min_ao, ins_min_af, del_min_af, \ - del_merge_min_af, ins_merge_min_af, merge_r = candidate_record + ( + candidates_vcf, + filtered_candidates_vcf, + reference, + min_dp, + max_dp, + good_ao, + min_ao, + snp_min_af, + snp_min_bq, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + ) = candidate_record thread_logger = logging.getLogger( - "{} ({})".format(filter_candidates.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + filter_candidates.__name__, multiprocessing.current_process().name + ) + ) try: thread_logger.info( - "---------------------Filter Candidates---------------------") + "---------------------Filter Candidates---------------------" + ) records = {} with open(candidates_vcf) as v_f: for line in skip_empty(v_f): if len(line.strip().split()) != 10: - raise RuntimeError( - "Bad VCF line (<10 fields): {}".format(line)) + raise RuntimeError("Bad VCF line (<10 fields): {}".format(line)) chrom, pos, _, ref, alt, _, _, info_, _, info = line.strip().split() pos = int(pos) loc = "{}.{}".format(chrom, pos) dp, ro, ao = list(map(int, info.split(":")[1:4])) - info_dict = dict(map(lambda x: x.split( - "="), filter(None, info_.split(";")))) + info_dict = dict( + map(lambda x: x.split("="), filter(None, info_.split(";"))) + ) mq_ = safe_read_info_dict(info_dict, "MQ", int, -100) bq_ = safe_read_info_dict(info_dict, "BQ", int, -100) nm_ = safe_read_info_dict(info_dict, "NM", int, -100) @@ -56,38 +82,85 @@ def filter_candidates(candidate_record): if loc not in records: records[loc] = [] - if ref == "N" or "\t".join(line.split()[0:5]) \ - not in map(lambda x: "\t".join(x[-1].split()[0:5]), records[loc]): - records[loc].append([chrom, pos, ref, alt, dp, ro, ao, mq_, bq_, st_, ls_, rs_, - nm_, as_, xs_, pr_, cl_, line]) - elif "\t".join(line.split()[0:5]) \ - in map(lambda x: "\t".join(x[-1].split()[0:5]), records[loc]): + if ref == "N" or "\t".join(line.split()[0:5]) not in map( + lambda x: "\t".join(x[-1].split()[0:5]), records[loc] + ): + records[loc].append( + [ + chrom, + pos, + ref, + alt, + dp, + ro, + ao, + mq_, + bq_, + st_, + ls_, + rs_, + nm_, + as_, + xs_, + pr_, + cl_, + line, + ] + ) + elif "\t".join(line.split()[0:5]) in map( + lambda x: "\t".join(x[-1].split()[0:5]), records[loc] + ): for i, x in enumerate(records[loc]): - if "\t".join(line.split()[0:5]) == "\t".join(x[-1].split()[0:5]) \ - and ao / float(ro + 0.0001) > x[6] / float(x[5] + 0.0001): - records[loc][i] = [chrom, pos, ref, alt, dp, ro, ao, mq_, bq_, st_, ls_, - rs_, nm_, as_, xs_, pr_, cl_, line] + if "\t".join(line.split()[0:5]) == "\t".join( + x[-1].split()[0:5] + ) and ao / float(ro + 0.0001) > x[6] / float(x[5] + 0.0001): + records[loc][i] = [ + chrom, + pos, + ref, + alt, + dp, + ro, + ao, + mq_, + bq_, + st_, + ls_, + rs_, + nm_, + as_, + xs_, + pr_, + cl_, + line, + ] break fasta_file = pysam.Fastafile(reference) good_records = [] dels = [] - for loc, rs in sorted(records.items(), key=lambda x: x[1][0:2]) + \ - [["", [["", 0, "", "", 0, 0, 0, ""]]]]: + for loc, rs in sorted(records.items(), key=lambda x: x[1][0:2]) + [ + ["", [["", 0, "", "", 0, 0, 0, ""]]] + ]: ins = list(filter(lambda x: x[2] == "N", rs)) if len(ins) > 1: # emit ins afs = list(map(lambda x: x[6] / float(x[5] + x[6]), ins)) max_af = max(afs) - ins = list(filter(lambda x: x[6] / float(x[5] + - x[6]) >= (max_af * merge_r), ins)) + ins = list( + filter( + lambda x: x[6] / float(x[5] + x[6]) >= (max_af * merge_r), ins + ) + ) chrom, pos, ref = ins[0][0:3] dp = max(map(lambda x: x[4], ins)) ro = max(map(lambda x: x[5], ins)) ao = max(map(lambda x: x[6], ins)) mq_ = max(map(lambda x: x[7], ins)) bq_ = max(map(lambda x: x[8], ins)) - st_ = "{},{}".format(max(map(lambda x: int(x[9].split(",")[0]), ins)), - max(map(lambda x: int(x[9].split(",")[1]), ins))) + st_ = "{},{}".format( + max(map(lambda x: int(x[9].split(",")[0]), ins)), + max(map(lambda x: int(x[9].split(",")[1]), ins)), + ) ls_ = max(map(lambda x: x[10], ins)) rs_ = max(map(lambda x: x[11], ins)) nm_ = max(map(lambda x: x[12], ins)) @@ -97,8 +170,27 @@ def filter_candidates(candidate_record): cl_ = max(map(lambda x: x[16], ins)) alt = "".join(map(lambda x: x[3], ins)) if (max_af >= ins_merge_min_af) or (ao >= good_ao): - ins = [[chrom, pos, ref, alt, dp, ro, ao, mq_, bq_, st_, ls_, - rs_, nm_, as_, xs_, pr_, cl_]] + ins = [ + [ + chrom, + pos, + ref, + alt, + dp, + ro, + ao, + mq_, + bq_, + st_, + ls_, + rs_, + nm_, + as_, + xs_, + pr_, + cl_, + ] + ] else: ins = [] elif len(ins) == 1: @@ -109,7 +201,9 @@ def filter_candidates(candidate_record): else: ins = [ins[0][:-1]] good_records.extend(ins) - if dels and (ins or len(list(filter(lambda x: x[3] == "N" and x[2] != "N", rs))) == 0): + if dels and ( + ins or len(list(filter(lambda x: x[3] == "N" and x[2] != "N", rs))) == 0 + ): # emit del if len(dels) == 1: ro = dels[0][5] @@ -122,16 +216,19 @@ def filter_candidates(candidate_record): afs = list(map(lambda x: x[6] / float(x[5] + x[6]), dels)) max_af = max(afs) merge_r_thr = merge_r * max_af - dels = list(filter( - lambda x: x[6] / float(x[5] + x[6]) >= merge_r_thr, dels)) + dels = list( + filter(lambda x: x[6] / float(x[5] + x[6]) >= merge_r_thr, dels) + ) chrom, pos = dels[0][0:2] dp = max(map(lambda x: x[4], dels)) ro = max(map(lambda x: x[5], dels)) ao = max(map(lambda x: x[6], dels)) mq_ = max(map(lambda x: x[7], dels)) bq_ = max(map(lambda x: x[8], dels)) - st_ = "{},{}".format(max(map(lambda x: int(x[9].split(",")[0]), dels)), - max(map(lambda x: int(x[9].split(",")[1]), dels))) + st_ = "{},{}".format( + max(map(lambda x: int(x[9].split(",")[0]), dels)), + max(map(lambda x: int(x[9].split(",")[1]), dels)), + ) ls_ = max(map(lambda x: x[10], dels)) rs_ = max(map(lambda x: x[11], dels)) nm_ = max(map(lambda x: x[12], dels)) @@ -141,8 +238,27 @@ def filter_candidates(candidate_record): cl_ = max(map(lambda x: x[16], dels)) ref = "".join(map(lambda x: x[2], dels)) alt = "N" - good_records.append([chrom, pos, ref, alt, dp, ro, ao, mq_, bq_, st_, ls_, - rs_, nm_, as_, xs_, pr_, cl_]) + good_records.append( + [ + chrom, + pos, + ref, + alt, + dp, + ro, + ao, + mq_, + bq_, + st_, + ls_, + rs_, + nm_, + as_, + xs_, + pr_, + cl_, + ] + ) dels = [] if not loc: continue @@ -156,7 +272,9 @@ def filter_candidates(candidate_record): ro, ao = record[5:7] if record[2] != "N" and record[3] != "N" and record[2] != record[3]: bq = record[8] - if (ao / float(ro + ao) >= (snp_min_af) or ao >= snp_min_ao) and bq >= snp_min_bq: + if ( + ao / float(ro + ao) >= (snp_min_af) or ao >= snp_min_ao + ) and bq >= snp_min_bq: # emit SNP good_records.append(record[:-1]) elif record[2] != "N" and record[3] == "N": @@ -172,22 +290,28 @@ def filter_candidates(candidate_record): if ao / float(ro + ao) >= ((del_min_af)): good_records.extend(dels) else: - afs = list(map(lambda x: x[6] / - float(x[5] + x[6]), dels)) + afs = list( + map(lambda x: x[6] / float(x[5] + x[6]), dels) + ) max_af = max(afs) merge_r_thr = merge_r * max_af - dels = list(filter( - lambda x: x[6] / float(x[5] + x[6]) >= merge_r_thr, dels)) + dels = list( + filter( + lambda x: x[6] / float(x[5] + x[6]) + >= merge_r_thr, + dels, + ) + ) chrom, pos = dels[0][0:2] dp = max(map(lambda x: x[4], dels)) ro = max(map(lambda x: x[5], dels)) ao = max(map(lambda x: x[6], dels)) mq_ = max(map(lambda x: x[7], dels)) bq_ = max(map(lambda x: x[8], dels)) - st_ = "{},{}".format(max(map(lambda x: int(x[9].split(",")[0]), - dels)), - max(map(lambda x: int(x[9].split(",")[1]), - dels))) + st_ = "{},{}".format( + max(map(lambda x: int(x[9].split(",")[0]), dels)), + max(map(lambda x: int(x[9].split(",")[1]), dels)), + ) ls_ = max(map(lambda x: x[10], dels)) rs_ = max(map(lambda x: x[11], dels)) nm_ = max(map(lambda x: x[12], dels)) @@ -197,8 +321,27 @@ def filter_candidates(candidate_record): cl_ = max(map(lambda x: x[16], dels)) ref = "".join(map(lambda x: x[2], dels)) alt = "N" - good_records.append([chrom, pos, ref, alt, dp, ro, ao, mq_, bq_, - st_, ls_, rs_, nm_, as_, xs_, pr_, cl_]) + good_records.append( + [ + chrom, + pos, + ref, + alt, + dp, + ro, + ao, + mq_, + bq_, + st_, + ls_, + rs_, + nm_, + as_, + xs_, + pr_, + cl_, + ] + ) dels = [] # accumulate dels dels.append(record[:-1]) @@ -206,7 +349,25 @@ def filter_candidates(candidate_record): final_records = [] dels = [] for i, record in enumerate(good_records): - chrom, pos, ref, alt, dp, ro, ao, mq_, bq_, st_, ls_, rs_, nm_, as_, xs_, pr_, cl_ = record + ( + chrom, + pos, + ref, + alt, + dp, + ro, + ao, + mq_, + bq_, + st_, + ls_, + rs_, + nm_, + as_, + xs_, + pr_, + cl_, + ) = record ref = ref.upper() alt = alt.upper() info_str = "" @@ -234,33 +395,61 @@ def filter_candidates(candidate_record): af = np.round(ao / float(ao + ro), 4) info_str += ";AF={}".format(af) if ref != "N" and alt != "N": - line = "\t".join([chrom, str(pos), ".", ref, alt, "100", ".", - "DP={};RO={};AO={}".format( - dp, ro, ao) + info_str, - "GT:DP:RO:AO:AF", "0/1:{}:{}:{}:{}".format(dp, ro, ao, af)]) + line = "\t".join( + [ + chrom, + str(pos), + ".", + ref, + alt, + "100", + ".", + "DP={};RO={};AO={}".format(dp, ro, ao) + info_str, + "GT:DP:RO:AO:AF", + "0/1:{}:{}:{}:{}".format(dp, ro, ao, af), + ] + ) final_records.append([chrom, pos, ref, alt, line]) elif alt == "N": - ref = fasta_file.fetch( - chrom, pos - 2, pos + len(ref) - 1).upper() + ref = fasta_file.fetch(chrom, pos - 2, pos + len(ref) - 1).upper() alt = fasta_file.fetch(chrom, pos - 2, pos - 1).upper() - line = "\t".join([chrom, str(pos - 1), ".", ref, alt, "100", ".", - "DP={};RO={};AO={}".format( - dp, ro, ao) + info_str, - "GT:DP:RO:AO:AF", "0/1:{}:{}:{}:{}".format(dp, ro, ao, af)]) + line = "\t".join( + [ + chrom, + str(pos - 1), + ".", + ref, + alt, + "100", + ".", + "DP={};RO={};AO={}".format(dp, ro, ao) + info_str, + "GT:DP:RO:AO:AF", + "0/1:{}:{}:{}:{}".format(dp, ro, ao, af), + ] + ) final_records.append([chrom, pos - 1, ref, alt, line]) elif ref == "N": ref = fasta_file.fetch(chrom, pos - 2, pos - 1).upper() alt = ref + alt - line = "\t".join([chrom, str(pos - 1), ".", ref, alt, "100", ".", - "DP={};RO={};AO={}".format( - dp, ro, ao) + info_str, - "GT:DP:RO:AO:AF", "0/1:{}:{}:{}:{}".format(dp, ro, ao, af)]) + line = "\t".join( + [ + chrom, + str(pos - 1), + ".", + ref, + alt, + "100", + ".", + "DP={};RO={};AO={}".format(dp, ro, ao) + info_str, + "GT:DP:RO:AO:AF", + "0/1:{}:{}:{}:{}".format(dp, ro, ao, af), + ] + ) final_records.append([chrom, pos - 1, ref, alt, line]) final_records = sorted(final_records, key=lambda x: x[0:2]) with open(filtered_candidates_vcf, "w") as o_f: o_f.write("{}\n".format(VCF_HEADER)) - o_f.write( - "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") + o_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") for record in final_records: o_f.write(record[-1] + "\n") return filtered_candidates_vcf @@ -271,58 +460,94 @@ def filter_candidates(candidate_record): return None -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser(description='filter candidates vcf') - parser.add_argument('--candidates_vcf', type=str, help='raw candidates vcf', - required=True) - parser.add_argument('--filtered_candidates_vcf', type=str, help='filtered candidates vcf', - required=True) - parser.add_argument('--reference', type=str, help='reference fasta filename', - required=True) - parser.add_argument('--good_ao', type=float, help='good alternate count (ignores maf)', - default=10) - parser.add_argument('--min_ao', type=float, - help='min alternate count', default=1) - parser.add_argument('--snp_min_af', type=float, - help='SNP min allele freq', default=0.05) - parser.add_argument('--snp_min_bq', type=float, - help='SNP min base quality', default=10) - parser.add_argument('--snp_min_ao', type=float, - help='SNP min alternate count for low AF candidates', default=3) - parser.add_argument('--ins_min_af', type=float, - help='INS min allele freq', default=0.05) - parser.add_argument('--del_min_af', type=float, - help='DEL min allele freq', default=0.05) - parser.add_argument('--del_merge_min_af', type=float, - help='min allele freq for merging DELs', default=0) - parser.add_argument('--ins_merge_min_af', type=float, - help='min allele freq for merging INSs', default=0) - parser.add_argument('--merge_r', type=float, - help='merge af ratio to the max af for merging adjacent variants', - default=0.5) - parser.add_argument('--min_dp', type=float, help='min depth', default=5) - parser.add_argument('--max_dp', type=float, - help='max depth', default=100000) + parser = argparse.ArgumentParser(description="filter candidates vcf") + parser.add_argument( + "--candidates_vcf", type=str, help="raw candidates vcf", required=True + ) + parser.add_argument( + "--filtered_candidates_vcf", + type=str, + help="filtered candidates vcf", + required=True, + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument( + "--good_ao", type=float, help="good alternate count (ignores maf)", default=10 + ) + parser.add_argument("--min_ao", type=float, help="min alternate count", default=1) + parser.add_argument( + "--snp_min_af", type=float, help="SNP min allele freq", default=0.05 + ) + parser.add_argument( + "--snp_min_bq", type=float, help="SNP min base quality", default=10 + ) + parser.add_argument( + "--snp_min_ao", + type=float, + help="SNP min alternate count for low AF candidates", + default=3, + ) + parser.add_argument( + "--ins_min_af", type=float, help="INS min allele freq", default=0.05 + ) + parser.add_argument( + "--del_min_af", type=float, help="DEL min allele freq", default=0.05 + ) + parser.add_argument( + "--del_merge_min_af", + type=float, + help="min allele freq for merging DELs", + default=0, + ) + parser.add_argument( + "--ins_merge_min_af", + type=float, + help="min allele freq for merging INSs", + default=0, + ) + parser.add_argument( + "--merge_r", + type=float, + help="merge af ratio to the max af for merging adjacent variants", + default=0.5, + ) + parser.add_argument("--min_dp", type=float, help="min depth", default=5) + parser.add_argument("--max_dp", type=float, help="max depth", default=100000) args = parser.parse_args() logger.info(args) try: - output = filter_candidates((args.candidates_vcf, args.filtered_candidates_vcf, - args.reference, args.min_dp, args.max_dp, - args.good_ao, args.min_ao, - args.snp_min_af, args.snp_min_bq, args.snp_min_ao, - args.ins_min_af, args.del_min_af, - args.del_merge_min_af, args.ins_merge_min_af, args.merge_r)) + output = filter_candidates( + ( + args.candidates_vcf, + args.filtered_candidates_vcf, + args.reference, + args.min_dp, + args.max_dp, + args.good_ao, + args.min_ao, + args.snp_min_af, + args.snp_min_bq, + args.snp_min_ao, + args.ins_min_af, + args.del_min_af, + args.del_merge_min_af, + args.ins_merge_min_af, + args.merge_r, + ) + ) if output is None: raise Exception("filter_candidates failed!") except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "filter_candidates.py failure on arguments: {}".format(args)) + logger.error("filter_candidates.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/generate_dataset.py b/neusomatic/python/generate_dataset.py index 4a9e93d..170cb52 100755 --- a/neusomatic/python/generate_dataset.py +++ b/neusomatic/python/generate_dataset.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # generate_dataset.py # Use the input filtered candidates to prepare and extracted features to generate datasets to # be used by the NeuSomatic network. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import base64 import multiprocessing @@ -21,7 +21,18 @@ from random import shuffle from split_bed import split_region -from utils import concatenate_vcfs, get_chromosomes_order, run_bedtools_cmd, vcf_2_bed, bedtools_sort, bedtools_window, bedtools_intersect, bedtools_slop, get_tmp_file, skip_empty +from utils import ( + concatenate_vcfs, + get_chromosomes_order, + run_bedtools_cmd, + vcf_2_bed, + bedtools_sort, + bedtools_window, + bedtools_intersect, + bedtools_slop, + get_tmp_file, + skip_empty, +) from defaults import NUM_ENS_FEATURES, VCF_HEADER, MAT_DTYPES NUC_to_NUM_tabix = {"A": 1, "C": 2, "G": 3, "T": 4, "-": 0} @@ -29,6 +40,7 @@ import time + def get_type(ref, alt): logger = logging.getLogger(get_type.__name__) len_diff = len(ref) - len(alt.split(",")[0]) @@ -39,7 +51,10 @@ def get_type(ref, alt): else: return "SNP" -def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom_lengths): + +def get_variant_matrix_tabix( + ref_seq, count_info, record, matrix_base_pad, chrom_lengths +): logger = logging.getLogger(get_variant_matrix_tabix.__name__) chrom, pos, ref, alt = record[0:4] # logger.info("vvv-1") @@ -54,9 +69,9 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom # pos - matrix_base_pad, pos + matrix_base_pad, count_info)) # tabix_records = [] - t1=time.time() + t1 = time.time() - tabix_records=[] + tabix_records = [] for pos_ in sorted(count_info.keys()): # print([record[0:4],pos_]) tabix_records.extend(count_info[pos_]) @@ -75,35 +90,35 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom curr_pos = s_pos # logger.info(["rrr-11",time.time()-t1]) - t1=time.time() - tabix_records=list(tabix_records) + t1 = time.time() + tabix_records = list(tabix_records) # logger.info(["rrr-12",time.time()-t1]) # logger.info("vvv-30") - t1=time.time() - t2=time.time() - t0=time.time() + t1 = time.time() + t2 = time.time() + t0 = time.time() for rec in tabix_records: # print(rec) # logger.info(["rrr-13",time.time()-t1]) - t1=time.time() + t1 = time.time() pos_ = int(rec[1]) if pos_ > pos + matrix_base_pad: # logger.info(["rrr-19",time.time()-t0]) - t0=time.time() - t1=time.time() + t0 = time.time() + t1 = time.time() continue ref_base = rec[3] if ref_base.upper() not in "ACGT-": ref_base = "-" if pos_ in col_pos_map and ref_base != "-": # logger.info(["rrr-19",time.time()-t0]) - t0=time.time() - t1=time.time() + t0 = time.time() + t1 = time.time() continue if pos_ > (curr_pos): # refs = fasta_file.fetch( # chrom, curr_pos - 1, pos_ - 1).upper().replace("N", "-") - refs = ref_seq[curr_pos-s_pos:pos_-s_pos] + refs = ref_seq[curr_pos - s_pos : pos_ - s_pos] for i in range(curr_pos, pos_): ref_base_ = refs[i - curr_pos] @@ -124,11 +139,11 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom cnt += 1 curr_pos = pos_ # logger.info(["rrr-14",time.time()-t1]) - t1=time.time() + t1 = time.time() if pos_ == (curr_pos) and ref_base == "-" and pos_ not in col_pos_map: # ref_base_ = fasta_file.fetch( # chrom, pos_ - 1, pos_).upper().replace("N", "-") - ref_base_ = ref_seq[pos_-s_pos:pos_-s_pos+1] + ref_base_ = ref_seq[pos_ - s_pos : pos_ - s_pos + 1] if ref_base_.upper() not in "ACGT-": ref_base_ = "-" matrix_.append([0, 0, 0, 0, 0]) @@ -147,8 +162,8 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom curr_pos = pos_ + 1 # logger.info(["rrr-15",time.time()-t1]) # logger.info(["rrr-19",time.time()-t0]) - t0=time.time() - t1=time.time() + t0 = time.time() + t1 = time.time() matrix_.append(list(map(int, rec[4].split(":")))) bq_matrix_.append(list(map(int, rec[5].split(":")))) mq_matrix_.append(list(map(int, rec[6].split(":")))) @@ -163,7 +178,7 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom cnt += 1 curr_pos = pos_ + 1 # logger.info(["rrr-16",time.time()-t1]) - t1=time.time() + t1 = time.time() # logger.info("vvv-32") end_pos = min(pos + matrix_base_pad, chrom_lengths[chrom] - 2) @@ -171,7 +186,7 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom if curr_pos < pos + matrix_base_pad + 1: # refs = fasta_file.fetch( # chrom, curr_pos - 1, end_pos).upper().replace("N", "-") - refs = ref_seq[curr_pos-s_pos:end_pos-s_pos+1] + refs = ref_seq[curr_pos - s_pos : end_pos - s_pos + 1] for i in range(curr_pos, end_pos + 1): ref_base_ = refs[i - curr_pos] if ref_base_.upper() not in "ACGT-": @@ -193,8 +208,7 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom # logger.info(["rrr-18",time.time()-t2]) - t1=time.time() - + t1 = time.time() matrix_ = np.array(matrix_).transpose() bq_matrix_ = np.array(bq_matrix_).transpose() @@ -210,8 +224,17 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom # logger.info(["rrr-17",time.time()-t1]) - return matrix_, bq_matrix_, mq_matrix_, st_matrix_, lsc_matrix_, rsc_matrix_, tag_matrices_, ref_array, col_pos_map - + return ( + matrix_, + bq_matrix_, + mq_matrix_, + st_matrix_, + lsc_matrix_, + rsc_matrix_, + tag_matrices_, + ref_array, + col_pos_map, + ) # def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom_lengths): @@ -325,8 +348,8 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom # # logger.info(["rrr-19",time.time()-t0]) # t0=time.time() # t1=time.time() -# z_s[i_rec][1]=n_ -# n_ += 1 +# z_s[i_rec][1]=n_ +# n_ += 1 # # matrix_.append(list(map(int, rec[4].split(":")))) # # bq_matrix_.append(list(map(int, rec[5].split(":")))) # # mq_matrix_.append(list(map(int, rec[6].split(":")))) @@ -411,39 +434,53 @@ def get_variant_matrix_tabix(ref_seq, count_info, record, matrix_base_pad, chrom # return matrix_, bq_matrix_, mq_matrix_, st_matrix_, lsc_matrix_, rsc_matrix_, tag_matrices_, ref_array, col_pos_map -def align_tumor_normal_matrices(record, tumor_matrix_, bq_tumor_matrix_, mq_tumor_matrix_, st_tumor_matrix_, - lsc_tumor_matrix_, rsc_tumor_matrix_, - tag_tumor_matrices_, tumor_ref_array, tumor_col_pos_map, normal_matrix_, - bq_normal_matrix_, mq_normal_matrix_, st_normal_matrix_, - lsc_normal_matrix_, rsc_normal_matrix_, tag_normal_matrices_, - normal_ref_array, normal_col_pos_map): + +def align_tumor_normal_matrices( + record, + tumor_matrix_, + bq_tumor_matrix_, + mq_tumor_matrix_, + st_tumor_matrix_, + lsc_tumor_matrix_, + rsc_tumor_matrix_, + tag_tumor_matrices_, + tumor_ref_array, + tumor_col_pos_map, + normal_matrix_, + bq_normal_matrix_, + mq_normal_matrix_, + st_normal_matrix_, + lsc_normal_matrix_, + rsc_normal_matrix_, + tag_normal_matrices_, + normal_ref_array, + normal_col_pos_map, +): logger = logging.getLogger(align_tumor_normal_matrices.__name__) if not tumor_col_pos_map: logger.error("record: {}".format(record)) - raise(RuntimeError("tumor_col_pos_map is empty.")) + raise (RuntimeError("tumor_col_pos_map is empty.")) - tumor_col_pos_map[max(tumor_col_pos_map.keys()) + - 1] = tumor_matrix_.shape[1] - normal_col_pos_map[max(normal_col_pos_map.keys()) + - 1] = normal_matrix_.shape[1] + tumor_col_pos_map[max(tumor_col_pos_map.keys()) + 1] = tumor_matrix_.shape[1] + normal_col_pos_map[max(normal_col_pos_map.keys()) + 1] = normal_matrix_.shape[1] if set(tumor_col_pos_map.keys()) ^ set(normal_col_pos_map.keys()): logger.error("record: {}".format(record)) logger.error("normal_col_pos_map: {}".format(normal_col_pos_map)) logger.error("tumor_col_pos_map: {}".format(tumor_col_pos_map)) - raise(RuntimeError( - "tumor_col_pos_map and normal_col_pos_map have different keys.")) + raise ( + RuntimeError( + "tumor_col_pos_map and normal_col_pos_map have different keys." + ) + ) - pT = list(map(lambda x: tumor_col_pos_map[ - x], sorted(tumor_col_pos_map.keys()))) - pN = list(map(lambda x: normal_col_pos_map[ - x], sorted(normal_col_pos_map.keys()))) + pT = list(map(lambda x: tumor_col_pos_map[x], sorted(tumor_col_pos_map.keys()))) + pN = list(map(lambda x: normal_col_pos_map[x], sorted(normal_col_pos_map.keys()))) if pT[0] != pN[0]: logger.error("record: {}".format(record)) logger.error("pT, pN: {}, {}".format(pT, pN)) - raise(RuntimeError( - "pT[0] != pN[0]")) + raise (RuntimeError("pT[0] != pN[0]")) min_i = pT[0] cols_T = np.ones(tumor_matrix_.shape[1] + 1, int) @@ -459,8 +496,7 @@ def align_tumor_normal_matrices(record, tumor_matrix_, bq_tumor_matrix_, mq_tumo current_col_N += max(0, current_col_T - current_col_N) if current_col_T != current_col_N: logger.error("record: {}".format(record)) - raise(RuntimeError( - "current_col_T != current_col_N")) + raise (RuntimeError("current_col_T != current_col_N")) del tumor_col_pos_map[max(tumor_col_pos_map.keys())] del normal_col_pos_map[max(normal_col_pos_map.keys())] @@ -471,14 +507,14 @@ def align_tumor_normal_matrices(record, tumor_matrix_, bq_tumor_matrix_, mq_tumo new_lsc_tumor_matrix_ = np.zeros((5, current_col_T - 1)) new_rsc_tumor_matrix_ = np.zeros((5, current_col_T - 1)) new_tag_tumor_matrices_ = [ - np.zeros((5, current_col_T - 1)) for i in range(len(tag_tumor_matrices_))] + np.zeros((5, current_col_T - 1)) for i in range(len(tag_tumor_matrices_)) + ] new_tumor_matrix_[0, :] = max(tumor_matrix_.sum(0)) new_bq_tumor_matrix_[0, :] = max(bq_tumor_matrix_[0, :]) new_mq_tumor_matrix_[0, :] = max(mq_tumor_matrix_[0, :]) new_st_tumor_matrix_[0, :] = max(st_tumor_matrix_[0, :]) for iii in range(len(tag_tumor_matrices_)): - new_tag_tumor_matrices_[iii][0, :] = max( - tag_tumor_matrices_[iii][0, :]) + new_tag_tumor_matrices_[iii][0, :] = max(tag_tumor_matrices_[iii][0, :]) new_normal_matrix_ = np.zeros((5, current_col_N - 1)) new_bq_normal_matrix_ = np.zeros((5, current_col_T - 1)) @@ -487,14 +523,14 @@ def align_tumor_normal_matrices(record, tumor_matrix_, bq_tumor_matrix_, mq_tumo new_lsc_normal_matrix_ = np.zeros((5, current_col_T - 1)) new_rsc_normal_matrix_ = np.zeros((5, current_col_T - 1)) new_tag_normal_matrices_ = [ - np.zeros((5, current_col_T - 1)) for i in range(len(tag_normal_matrices_))] + np.zeros((5, current_col_T - 1)) for i in range(len(tag_normal_matrices_)) + ] new_normal_matrix_[0, :] = max(normal_matrix_.sum(0)) new_bq_normal_matrix_[0, :] = max(bq_normal_matrix_[0, :]) new_mq_normal_matrix_[0, :] = max(mq_normal_matrix_[0, :]) new_st_normal_matrix_[0, :] = max(st_normal_matrix_[0, :]) for iii in range(len(tag_normal_matrices_)): - new_tag_normal_matrices_[iii][0, :] = max( - tag_normal_matrices_[iii][0, :]) + new_tag_normal_matrices_[iii][0, :] = max(tag_normal_matrices_[iii][0, :]) map_T = (np.cumsum(cols_T) - 1)[:-1] map_N = (np.cumsum(cols_N) - 1)[:-1] @@ -520,82 +556,145 @@ def align_tumor_normal_matrices(record, tumor_matrix_, bq_tumor_matrix_, mq_tumo new_tumor_ref_array[map_T] = tumor_ref_array new_normal_ref_array[map_N] = normal_ref_array - new_tumor_col_pos_map = {k: map_T[v] - for k, v in tumor_col_pos_map.items()} - new_normal_col_pos_map = {k: map_N[v] - for k, v in normal_col_pos_map.items()} + new_tumor_col_pos_map = {k: map_T[v] for k, v in tumor_col_pos_map.items()} + new_normal_col_pos_map = {k: map_N[v] for k, v in normal_col_pos_map.items()} if sum(new_normal_ref_array - new_tumor_ref_array) != 0: logger.error("record: {}".format(record)) - logger.error("new_normal_ref_array, new_tumor_ref_array: {}, {}".format( - new_normal_ref_array, new_tumor_ref_array)) - logger.error("new_normal_ref_array - new_tumor_ref_array: {}".format( - new_normal_ref_array - new_tumor_ref_array)) + logger.error( + "new_normal_ref_array, new_tumor_ref_array: {}, {}".format( + new_normal_ref_array, new_tumor_ref_array + ) + ) + logger.error( + "new_normal_ref_array - new_tumor_ref_array: {}".format( + new_normal_ref_array - new_tumor_ref_array + ) + ) raise Exception for k in new_tumor_col_pos_map: - assert(new_tumor_col_pos_map[k] == new_normal_col_pos_map[k]) - return [new_tumor_matrix_, new_bq_tumor_matrix_, new_mq_tumor_matrix_, new_st_tumor_matrix_, - new_lsc_tumor_matrix_, new_rsc_tumor_matrix_, new_tag_tumor_matrices_, new_normal_matrix_, - new_bq_normal_matrix_, new_mq_normal_matrix_, new_st_normal_matrix_, - new_lsc_normal_matrix_, new_rsc_normal_matrix_, - new_tag_normal_matrices_, new_tumor_ref_array, - new_tumor_col_pos_map] - - -def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, record, rlen, rcenter, - matrix_base_pad, matrix_width, min_ev_frac_per_col, min_cov, chrom_lengths): + assert new_tumor_col_pos_map[k] == new_normal_col_pos_map[k] + return [ + new_tumor_matrix_, + new_bq_tumor_matrix_, + new_mq_tumor_matrix_, + new_st_tumor_matrix_, + new_lsc_tumor_matrix_, + new_rsc_tumor_matrix_, + new_tag_tumor_matrices_, + new_normal_matrix_, + new_bq_normal_matrix_, + new_mq_normal_matrix_, + new_st_normal_matrix_, + new_lsc_normal_matrix_, + new_rsc_normal_matrix_, + new_tag_normal_matrices_, + new_tumor_ref_array, + new_tumor_col_pos_map, + ] + + +def prepare_info_matrices_tabix( + ref_seq, + tumor_count_info, + normal_count_info, + record, + rlen, + rcenter, + matrix_base_pad, + matrix_width, + min_ev_frac_per_col, + min_cov, + chrom_lengths, +): logger = logging.getLogger(prepare_info_matrices_tabix.__name__) chrom, pos, ref, alt = record[0:4] - t1=time.time() + t1 = time.time() # logger.info("ttt-1") - tumor_matrix_, bq_tumor_matrix_, mq_tumor_matrix_, st_tumor_matrix_, lsc_tumor_matrix_, rsc_tumor_matrix_, tag_tumor_matrices_, tumor_ref_array, tumor_col_pos_map = get_variant_matrix_tabix( - ref_seq, tumor_count_info, record, matrix_base_pad, chrom_lengths) - normal_matrix_, bq_normal_matrix_, mq_normal_matrix_, st_normal_matrix_, lsc_normal_matrix_, rsc_normal_matrix_, tag_normal_matrices_, normal_ref_array, normal_col_pos_map = get_variant_matrix_tabix( - ref_seq, normal_count_info, record, matrix_base_pad, chrom_lengths) + ( + tumor_matrix_, + bq_tumor_matrix_, + mq_tumor_matrix_, + st_tumor_matrix_, + lsc_tumor_matrix_, + rsc_tumor_matrix_, + tag_tumor_matrices_, + tumor_ref_array, + tumor_col_pos_map, + ) = get_variant_matrix_tabix( + ref_seq, tumor_count_info, record, matrix_base_pad, chrom_lengths + ) + ( + normal_matrix_, + bq_normal_matrix_, + mq_normal_matrix_, + st_normal_matrix_, + lsc_normal_matrix_, + rsc_normal_matrix_, + tag_normal_matrices_, + normal_ref_array, + normal_col_pos_map, + ) = get_variant_matrix_tabix( + ref_seq, normal_count_info, record, matrix_base_pad, chrom_lengths + ) # logger.info(["rrr-8",time.time()-t1]) - t1=time.time() + t1 = time.time() if not tumor_col_pos_map: logger.warning("Skip {} for all N reference".format(record)) return None # logger.info("ttt-2") - bq_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = np.max(bq_tumor_matrix_) - bq_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = np.max(bq_normal_matrix_) - mq_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = np.max(mq_tumor_matrix_) - mq_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = np.max(mq_normal_matrix_) - st_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = np.max(st_tumor_matrix_) - st_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = np.max(st_normal_matrix_) - lsc_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = np.max(lsc_tumor_matrix_) - lsc_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = np.max(lsc_normal_matrix_) - rsc_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = np.max(rsc_tumor_matrix_) - rsc_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = np.max(rsc_normal_matrix_) + bq_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[0]] = np.max( + bq_tumor_matrix_ + ) + bq_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[0]] = np.max( + bq_normal_matrix_ + ) + mq_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[0]] = np.max( + mq_tumor_matrix_ + ) + mq_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[0]] = np.max( + mq_normal_matrix_ + ) + st_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[0]] = np.max( + st_tumor_matrix_ + ) + st_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[0]] = np.max( + st_normal_matrix_ + ) + lsc_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[0]] = np.max( + lsc_tumor_matrix_ + ) + lsc_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[0]] = np.max( + lsc_normal_matrix_ + ) + rsc_tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[0]] = np.max( + rsc_tumor_matrix_ + ) + rsc_normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[0]] = np.max( + rsc_normal_matrix_ + ) for iii in range(len(tag_tumor_matrices_)): - tag_tumor_matrices_[iii][0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = np.max(tag_tumor_matrices_[iii]) + tag_tumor_matrices_[iii][0, np.where(tumor_matrix_.sum(0) == 0)[0]] = np.max( + tag_tumor_matrices_[iii] + ) for iii in range(len(tag_normal_matrices_)): - tag_normal_matrices_[iii][0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = np.max(tag_normal_matrices_[iii]) + tag_normal_matrices_[iii][0, np.where(normal_matrix_.sum(0) == 0)[0]] = np.max( + tag_normal_matrices_[iii] + ) # logger.info("ttt-3") - tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[ - 0]] = max(np.sum(tumor_matrix_, 0)) - normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[ - 0]] = max(np.sum(normal_matrix_, 0)) + tumor_matrix_[0, np.where(tumor_matrix_.sum(0) == 0)[0]] = max( + np.sum(tumor_matrix_, 0) + ) + normal_matrix_[0, np.where(normal_matrix_.sum(0) == 0)[0]] = max( + np.sum(normal_matrix_, 0) + ) if max(np.sum(normal_matrix_, 0)) == 0: normal_matrix_[0, :] = np.max(np.sum(tumor_matrix_, 0)) bq_normal_matrix_[0, :] = np.max(bq_tumor_matrix_) @@ -608,17 +707,47 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re # logger.info("ttt-4") - tumor_matrix_, bq_tumor_matrix_, mq_tumor_matrix_, st_tumor_matrix_, lsc_tumor_matrix_, rsc_tumor_matrix_, \ - tag_tumor_matrices_, normal_matrix_, bq_normal_matrix_, mq_normal_matrix_, st_normal_matrix_, \ - lsc_normal_matrix_, rsc_normal_matrix_, tag_normal_matrices_, \ - ref_array, col_pos_map = align_tumor_normal_matrices( - record, tumor_matrix_, bq_tumor_matrix_, mq_tumor_matrix_, st_tumor_matrix_, lsc_tumor_matrix_, rsc_tumor_matrix_, - tag_tumor_matrices_, tumor_ref_array, tumor_col_pos_map, normal_matrix_, - bq_normal_matrix_, mq_normal_matrix_, st_normal_matrix_, lsc_normal_matrix_, rsc_normal_matrix_, tag_normal_matrices_, - normal_ref_array, normal_col_pos_map) + ( + tumor_matrix_, + bq_tumor_matrix_, + mq_tumor_matrix_, + st_tumor_matrix_, + lsc_tumor_matrix_, + rsc_tumor_matrix_, + tag_tumor_matrices_, + normal_matrix_, + bq_normal_matrix_, + mq_normal_matrix_, + st_normal_matrix_, + lsc_normal_matrix_, + rsc_normal_matrix_, + tag_normal_matrices_, + ref_array, + col_pos_map, + ) = align_tumor_normal_matrices( + record, + tumor_matrix_, + bq_tumor_matrix_, + mq_tumor_matrix_, + st_tumor_matrix_, + lsc_tumor_matrix_, + rsc_tumor_matrix_, + tag_tumor_matrices_, + tumor_ref_array, + tumor_col_pos_map, + normal_matrix_, + bq_normal_matrix_, + mq_normal_matrix_, + st_normal_matrix_, + lsc_normal_matrix_, + rsc_normal_matrix_, + tag_normal_matrices_, + normal_ref_array, + normal_col_pos_map, + ) # logger.info(["rrr-9",time.time()-t1]) - t1=time.time() + t1 = time.time() tw = int(matrix_width) count_column = sum(tumor_matrix_[1:], 0) @@ -626,8 +755,7 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re n_row = np.max(np.sum(tumor_matrix_, 0)) if n_row < min_cov: - logger.warning("Skip {} for low cov {}<{}".format( - record, int(n_row), min_cov)) + logger.warning("Skip {} for low cov {}<{}".format(record, int(n_row), min_cov)) return None cols_not_to_del = [] largest_block = [] @@ -635,35 +763,65 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re z_ref_array = np.where((ref_array == 0))[0] if z_ref_array.shape[0] > 0: max_count = np.max(count_column[z_ref_array]) - cols_not_to_del = np.where(np.logical_and(count_column >= max_count * 0.7, - (ref_array == 0)))[0] + cols_not_to_del = np.where( + np.logical_and(count_column >= max_count * 0.7, (ref_array == 0)) + )[0] a = [-1000] + sorted(np.where((ref_array == 0))[0]) + [1000] b = np.diff((np.diff(a) == 1).astype(int)) a = a[1:-1] - blocks = list(map(lambda x: [a[x[0]], a[x[1]]], zip( - np.where(b == 1)[0], np.where(b == -1)[0]))) + blocks = list( + map( + lambda x: [a[x[0]], a[x[1]]], + zip(np.where(b == 1)[0], np.where(b == -1)[0]), + ) + ) if blocks: - largest_block = sorted( - blocks, key=lambda x: x[1] - x[0] + 1)[-1] - if np.max(count_column[range(largest_block[0], largest_block[1] + 1)]) > max_count * 0.05: + largest_block = sorted(blocks, key=lambda x: x[1] - x[0] + 1)[-1] + if ( + np.max(count_column[range(largest_block[0], largest_block[1] + 1)]) + > max_count * 0.05 + ): if (largest_block[1] - largest_block[0] + 1) > 2: - cols_not_to_del = sorted(list(set(cols_not_to_del) | set( - range(largest_block[0], largest_block[1] + 1)))) + cols_not_to_del = sorted( + list( + set(cols_not_to_del) + | set(range(largest_block[0], largest_block[1] + 1)) + ) + ) else: largest_block = [] - cols_to_del = sorted(list(set(np.where(np.logical_and(count_column <= (min_ev_frac_per_col * n_row), (ref_array == 0)))[0] - ) - set(cols_not_to_del))) + cols_to_del = sorted( + list( + set( + np.where( + np.logical_and( + count_column <= (min_ev_frac_per_col * n_row), (ref_array == 0) + ) + )[0] + ) + - set(cols_not_to_del) + ) + ) if n_col - len(cols_to_del) > tw: - mn = min(count_column[np.argsort(count_column)][ - n_col - tw - 1], max(n_row // 5, min_ev_frac_per_col * n_row)) - new_cols_to_del = set(np.where(np.logical_and(count_column <= mn, (ref_array == 0)))[ - 0]) - set(cols_to_del) - set(cols_not_to_del) - if n_col - (len(cols_to_del) + len(new_cols_to_del)) < tw and len(new_cols_to_del) > 3 and len(alt) > len(ref): - new_cols_to_del = list(map( - lambda x: [count_column[x], x], new_cols_to_del)) + mn = min( + count_column[np.argsort(count_column)][n_col - tw - 1], + max(n_row // 5, min_ev_frac_per_col * n_row), + ) + new_cols_to_del = ( + set(np.where(np.logical_and(count_column <= mn, (ref_array == 0)))[0]) + - set(cols_to_del) + - set(cols_not_to_del) + ) + if ( + n_col - (len(cols_to_del) + len(new_cols_to_del)) < tw + and len(new_cols_to_del) > 3 + and len(alt) > len(ref) + ): + new_cols_to_del = list(map(lambda x: [count_column[x], x], new_cols_to_del)) new_cols_to_del = sorted(new_cols_to_del, key=lambda x: [x[0], x[1]])[ - 0:len(new_cols_to_del) - 4] + 0 : len(new_cols_to_del) - 4 + ] new_cols_to_del = list(map(lambda x: x[1], new_cols_to_del)) cols_to_del = sorted(set(list(new_cols_to_del) + list(cols_to_del))) cols_to_del = list(set(cols_to_del)) @@ -671,25 +829,42 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re for i in range(n_col // 10): if i not in cols_to_del and sum(rcenter) > -3: ref_b = ref_array[i] - if tumor_matrix_[ref_b, i] > .8 * n_row: + if tumor_matrix_[ref_b, i] > 0.8 * n_row: cols_to_del.append(i) if n_col - len(cols_to_del) <= tw: break ii = n_col - i - 1 if ii not in cols_to_del and sum(rcenter) < 3: ref_b = ref_array[ii] - if tumor_matrix_[ref_b, ii] > .8 * n_row: + if tumor_matrix_[ref_b, ii] > 0.8 * n_row: cols_to_del.append(ii) if n_col - len(cols_to_del) <= tw: break if n_col - len(cols_to_del) > tw and len(largest_block) > 0: - block_len = (largest_block[1] - largest_block[0] + 1) + block_len = largest_block[1] - largest_block[0] + 1 if block_len > 2: - cols_to_del = sorted(list(set(cols_to_del) | - set(range(largest_block[0], largest_block[0] + min(n_col - len(cols_to_del) - tw, block_len - 3))))) + cols_to_del = sorted( + list( + set(cols_to_del) + | set( + range( + largest_block[0], + largest_block[0] + + min(n_col - len(cols_to_del) - tw, block_len - 3), + ) + ) + ) + ) else: - cols_to_del = sorted(list(set(cols_to_del) | (set(np.where(count_column <= (min_ev_frac_per_col * n_row))[0]) & - set(range(largest_block[0], largest_block[1] + 1))))) + cols_to_del = sorted( + list( + set(cols_to_del) + | ( + set(np.where(count_column <= (min_ev_frac_per_col * n_row))[0]) + & set(range(largest_block[0], largest_block[1] + 1)) + ) + ) + ) cols_to_del.sort() for i, v in col_pos_map.items(): @@ -703,8 +878,7 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re rsc_tumor_matrix = np.delete(rsc_tumor_matrix_, cols_to_del, 1) tag_tumor_matrices = [] for iii in range(len(tag_tumor_matrices_)): - tag_tumor_matrices.append( - np.delete(tag_tumor_matrices_[iii], cols_to_del, 1)) + tag_tumor_matrices.append(np.delete(tag_tumor_matrices_[iii], cols_to_del, 1)) normal_matrix = np.delete(normal_matrix_, cols_to_del, 1) bq_normal_matrix = np.delete(bq_normal_matrix_, cols_to_del, 1) mq_normal_matrix = np.delete(mq_normal_matrix_, cols_to_del, 1) @@ -713,8 +887,7 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re rsc_normal_matrix = np.delete(rsc_normal_matrix_, cols_to_del, 1) tag_normal_matrices = [] for iii in range(len(tag_normal_matrices_)): - tag_normal_matrices.append( - np.delete(tag_normal_matrices_[iii], cols_to_del, 1)) + tag_normal_matrices.append(np.delete(tag_normal_matrices_[iii], cols_to_del, 1)) ref_array = np.delete(ref_array, cols_to_del, 0) ref_matrix = np.zeros(tumor_matrix.shape) @@ -722,97 +895,138 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re ncols = tumor_matrix.shape[1] if matrix_width >= ncols: - col_pos_map = {i: v + (matrix_width - ncols) // - 2 for i, v in col_pos_map.items()} + col_pos_map = { + i: v + (matrix_width - ncols) // 2 for i, v in col_pos_map.items() + } tumor_count_matrix = np.zeros((5, matrix_width)) - tumor_count_matrix[:, (matrix_width - ncols) // - 2:(matrix_width - ncols) // 2 + ncols] = tumor_matrix + tumor_count_matrix[ + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = tumor_matrix bq_tumor_count_matrix = np.zeros((5, matrix_width)) bq_tumor_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = bq_tumor_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = bq_tumor_matrix mq_tumor_count_matrix = np.zeros((5, matrix_width)) mq_tumor_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = mq_tumor_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = mq_tumor_matrix st_tumor_count_matrix = np.zeros((5, matrix_width)) st_tumor_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = st_tumor_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = st_tumor_matrix lsc_tumor_count_matrix = np.zeros((5, matrix_width)) lsc_tumor_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = lsc_tumor_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = lsc_tumor_matrix rsc_tumor_count_matrix = np.zeros((5, matrix_width)) rsc_tumor_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = rsc_tumor_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = rsc_tumor_matrix tag_tumor_count_matrices = [] for iii in range(len(tag_tumor_matrices)): tag_tumor_count_matrices.append(np.zeros((5, matrix_width))) tag_tumor_count_matrices[iii][ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = tag_tumor_matrices[iii] + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = tag_tumor_matrices[iii] normal_count_matrix = np.zeros((5, matrix_width)) normal_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = normal_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = normal_matrix bq_normal_count_matrix = np.zeros((5, matrix_width)) bq_normal_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = bq_normal_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = bq_normal_matrix mq_normal_count_matrix = np.zeros((5, matrix_width)) mq_normal_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = mq_normal_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = mq_normal_matrix st_normal_count_matrix = np.zeros((5, matrix_width)) st_normal_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = st_normal_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = st_normal_matrix lsc_normal_count_matrix = np.zeros((5, matrix_width)) lsc_normal_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = lsc_normal_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = lsc_normal_matrix rsc_normal_count_matrix = np.zeros((5, matrix_width)) rsc_normal_count_matrix[ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = rsc_normal_matrix + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = rsc_normal_matrix tag_normal_count_matrices = [] for iii in range(len(tag_normal_matrices)): tag_normal_count_matrices.append(np.zeros((5, matrix_width))) tag_normal_count_matrices[iii][ - :, (matrix_width - ncols) // 2:(matrix_width - ncols) // 2 + ncols] = tag_normal_matrices[iii] + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = tag_normal_matrices[iii] ref_count_matrix = np.zeros((5, matrix_width)) - ref_count_matrix[:, (matrix_width - ncols) // - 2:(matrix_width - ncols) // 2 + ncols] = ref_matrix + ref_count_matrix[ + :, (matrix_width - ncols) // 2 : (matrix_width - ncols) // 2 + ncols + ] = ref_matrix else: - col_pos_map = {i: int(round(v / float(ncols) * matrix_width)) - for i, v in col_pos_map.items()} - tumor_count_matrix = np.array(Image.fromarray( - tumor_matrix).resize((matrix_width, 5), 2)).astype(int) - bq_tumor_count_matrix = np.array(Image.fromarray( - bq_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) - mq_tumor_count_matrix = np.array(Image.fromarray( - mq_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) - st_tumor_count_matrix = np.array(Image.fromarray( - st_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) - lsc_tumor_count_matrix = np.array(Image.fromarray( - lsc_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) - rsc_tumor_count_matrix = np.array(Image.fromarray( - rsc_tumor_matrix).resize((matrix_width, 5), 2)).astype(int) + col_pos_map = { + i: int(round(v / float(ncols) * matrix_width)) + for i, v in col_pos_map.items() + } + tumor_count_matrix = np.array( + Image.fromarray(tumor_matrix).resize((matrix_width, 5), 2) + ).astype(int) + bq_tumor_count_matrix = np.array( + Image.fromarray(bq_tumor_matrix).resize((matrix_width, 5), 2) + ).astype(int) + mq_tumor_count_matrix = np.array( + Image.fromarray(mq_tumor_matrix).resize((matrix_width, 5), 2) + ).astype(int) + st_tumor_count_matrix = np.array( + Image.fromarray(st_tumor_matrix).resize((matrix_width, 5), 2) + ).astype(int) + lsc_tumor_count_matrix = np.array( + Image.fromarray(lsc_tumor_matrix).resize((matrix_width, 5), 2) + ).astype(int) + rsc_tumor_count_matrix = np.array( + Image.fromarray(rsc_tumor_matrix).resize((matrix_width, 5), 2) + ).astype(int) tag_tumor_count_matrices = [] for iii in range(len(tag_tumor_matrices)): tag_tumor_count_matrices.append( - np.array(Image.fromarray(tag_tumor_matrices[iii]).resize((matrix_width, 5), 2)).astype(int)) - - normal_count_matrix = np.array(Image.fromarray( - normal_matrix).resize((matrix_width, 5), 2)).astype(int) - bq_normal_count_matrix = np.array(Image.fromarray( - bq_normal_matrix).resize((matrix_width, 5), 2)).astype(int) - mq_normal_count_matrix = np.array(Image.fromarray( - mq_normal_matrix).resize((matrix_width, 5), 2)).astype(int) - st_normal_count_matrix = np.array(Image.fromarray( - st_normal_matrix).resize((matrix_width, 5), 2)).astype(int) - lsc_normal_count_matrix = np.array(Image.fromarray( - lsc_normal_matrix).resize((matrix_width, 5), 2)).astype(int) - rsc_normal_count_matrix = np.array(Image.fromarray( - rsc_normal_matrix).resize((matrix_width, 5), 2)).astype(int) + np.array( + Image.fromarray(tag_tumor_matrices[iii]).resize( + (matrix_width, 5), 2 + ) + ).astype(int) + ) + + normal_count_matrix = np.array( + Image.fromarray(normal_matrix).resize((matrix_width, 5), 2) + ).astype(int) + bq_normal_count_matrix = np.array( + Image.fromarray(bq_normal_matrix).resize((matrix_width, 5), 2) + ).astype(int) + mq_normal_count_matrix = np.array( + Image.fromarray(mq_normal_matrix).resize((matrix_width, 5), 2) + ).astype(int) + st_normal_count_matrix = np.array( + Image.fromarray(st_normal_matrix).resize((matrix_width, 5), 2) + ).astype(int) + lsc_normal_count_matrix = np.array( + Image.fromarray(lsc_normal_matrix).resize((matrix_width, 5), 2) + ).astype(int) + rsc_normal_count_matrix = np.array( + Image.fromarray(rsc_normal_matrix).resize((matrix_width, 5), 2) + ).astype(int) tag_normal_count_matrices = [] for iii in range(len(tag_normal_matrices)): tag_normal_count_matrices.append( - np.array(Image.fromarray(tag_normal_matrices[iii]).resize((matrix_width, 5), 2)).astype(int)) - ref_count_matrix = np.array(Image.fromarray( - ref_matrix).resize((matrix_width, 5), 2)).astype(int) + np.array( + Image.fromarray(tag_normal_matrices[iii]).resize( + (matrix_width, 5), 2 + ) + ).astype(int) + ) + ref_count_matrix = np.array( + Image.fromarray(ref_matrix).resize((matrix_width, 5), 2) + ).astype(int) if int(pos) + rcenter[0] not in col_pos_map: center = min(col_pos_map.values()) + rcenter[0] - 1 + rcenter[1] @@ -820,56 +1034,126 @@ def prepare_info_matrices_tabix(ref_seq, tumor_count_info, normal_count_info, re center = col_pos_map[int(pos) + rcenter[0]] + rcenter[1] if center > ref_count_matrix.shape[1] - 1: - center = min(max(0, min(col_pos_map.values()) + - rcenter[0] - 1 + rcenter[1]), ref_count_matrix.shape[1] - 1) + center = min( + max(0, min(col_pos_map.values()) + rcenter[0] - 1 + rcenter[1]), + ref_count_matrix.shape[1] - 1, + ) # logger.info("ttt-9") # logger.info(["rrr-10",time.time()-t1]) - return [tumor_matrix_, tumor_matrix, normal_matrix_, normal_matrix, ref_count_matrix, tumor_count_matrix, - bq_tumor_count_matrix, mq_tumor_count_matrix, st_tumor_count_matrix, lsc_tumor_count_matrix, rsc_tumor_count_matrix, - tag_tumor_count_matrices, normal_count_matrix, bq_normal_count_matrix, mq_normal_count_matrix, - st_normal_count_matrix, lsc_normal_count_matrix, rsc_normal_count_matrix, - tag_normal_count_matrices, center, rlen, col_pos_map] + return [ + tumor_matrix_, + tumor_matrix, + normal_matrix_, + normal_matrix, + ref_count_matrix, + tumor_count_matrix, + bq_tumor_count_matrix, + mq_tumor_count_matrix, + st_tumor_count_matrix, + lsc_tumor_count_matrix, + rsc_tumor_count_matrix, + tag_tumor_count_matrices, + normal_count_matrix, + bq_normal_count_matrix, + mq_normal_count_matrix, + st_normal_count_matrix, + lsc_normal_count_matrix, + rsc_normal_count_matrix, + tag_normal_count_matrices, + center, + rlen, + col_pos_map, + ] def prep_data_single_tabix(input_record): - ref_seq, tumor_count_info, normal_count_info, record, vartype, rlen, rcenter, ch_order, \ - matrix_base_pad, matrix_width, min_ev_frac_per_col, min_cov, ann, chrom_lengths, matrix_dtype, is_none = input_record + ( + ref_seq, + tumor_count_info, + normal_count_info, + record, + vartype, + rlen, + rcenter, + ch_order, + matrix_base_pad, + matrix_width, + min_ev_frac_per_col, + min_cov, + ann, + chrom_lengths, + matrix_dtype, + is_none, + ) = input_record thread_logger = logging.getLogger( - "{} ({})".format(prep_data_single_tabix.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + prep_data_single_tabix.__name__, multiprocessing.current_process().name + ) + ) try: chrom, pos, ref, alt = record[:4] pos = int(pos) # thread_logger.info("prep-1") - t1=time.time() - - matrices_info = prepare_info_matrices_tabix(ref_seq=ref_seq, - tumor_count_info=tumor_count_info, - normal_count_info=normal_count_info, record=record, rlen=rlen, rcenter=rcenter, - matrix_base_pad=matrix_base_pad, matrix_width=matrix_width, - min_ev_frac_per_col=min_ev_frac_per_col, - min_cov=min_cov, - chrom_lengths=chrom_lengths) + t1 = time.time() + + matrices_info = prepare_info_matrices_tabix( + ref_seq=ref_seq, + tumor_count_info=tumor_count_info, + normal_count_info=normal_count_info, + record=record, + rlen=rlen, + rcenter=rcenter, + matrix_base_pad=matrix_base_pad, + matrix_width=matrix_width, + min_ev_frac_per_col=min_ev_frac_per_col, + min_cov=min_cov, + chrom_lengths=chrom_lengths, + ) if matrices_info: - tumor_matrix_, tumor_matrix, normal_matrix_, normal_matrix, ref_count_matrix, tumor_count_matrix, \ - bq_tumor_count_matrix, mq_tumor_count_matrix, st_tumor_count_matrix, lsc_tumor_count_matrix, rsc_tumor_count_matrix, \ - tag_tumor_count_matrices, normal_count_matrix, bq_normal_count_matrix, mq_normal_count_matrix, st_normal_count_matrix, \ - lsc_normal_count_matrix, rsc_normal_count_matrix, tag_normal_count_matrices, center, rlen, col_pos_map = matrices_info + ( + tumor_matrix_, + tumor_matrix, + normal_matrix_, + normal_matrix, + ref_count_matrix, + tumor_count_matrix, + bq_tumor_count_matrix, + mq_tumor_count_matrix, + st_tumor_count_matrix, + lsc_tumor_count_matrix, + rsc_tumor_count_matrix, + tag_tumor_count_matrices, + normal_count_matrix, + bq_normal_count_matrix, + mq_normal_count_matrix, + st_normal_count_matrix, + lsc_normal_count_matrix, + rsc_normal_count_matrix, + tag_normal_count_matrices, + center, + rlen, + col_pos_map, + ) = matrices_info else: return [] - # thread_logger.info(["rrr-6",time.time()-t1]) - t1=time.time() + t1 = time.time() # thread_logger.info("prep-2") - candidate_mat = np.zeros((tumor_count_matrix.shape[0], tumor_count_matrix.shape[ - 1], 13 + (len(tag_tumor_count_matrices) * 2))) + candidate_mat = np.zeros( + ( + tumor_count_matrix.shape[0], + tumor_count_matrix.shape[1], + 13 + (len(tag_tumor_count_matrices) * 2), + ) + ) candidate_mat[:, :, 0] = ref_count_matrix candidate_mat[:, :, 1] = tumor_count_matrix candidate_mat[:, :, 2] = normal_count_matrix @@ -885,8 +1169,7 @@ def prep_data_single_tabix(input_record): candidate_mat[:, :, 12] = rsc_normal_count_matrix for iii in range(len(tag_tumor_count_matrices)): candidate_mat[:, :, 13 + (iii * 2)] = tag_tumor_count_matrices[iii] - candidate_mat[:, :, 13 + (iii * 2) + - 1] = tag_normal_count_matrices[iii] + candidate_mat[:, :, 13 + (iii * 2) + 1] = tag_normal_count_matrices[iii] tumor_cov = int(round(max(np.sum(tumor_count_matrix, 0)))) normal_cov = int(round(max(np.sum(normal_count_matrix, 0)))) @@ -896,61 +1179,101 @@ def prep_data_single_tabix(input_record): max_norm = 65535.0 else: logger.info( - "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES)) - - candidate_mat[:, :, 0] = candidate_mat[ - :, :, 0] / (max(np.max(ref_count_matrix), np.max(tumor_count_matrix)) + 0.00001) * max_norm - candidate_mat[:, :, 1] = candidate_mat[:, :, 1] / \ - (np.max(tumor_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 2] = candidate_mat[:, :, 2] / \ - (np.max(normal_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 3] = candidate_mat[:, :, 3] / \ - (max(np.max(bq_tumor_count_matrix), 41.0)) * max_norm - candidate_mat[:, :, 4] = candidate_mat[:, :, 4] / \ - (max(np.max(bq_normal_count_matrix), 41.0)) * max_norm - candidate_mat[:, :, 5] = candidate_mat[:, :, 5] / \ - (max(np.max(mq_tumor_count_matrix), 70.0)) * max_norm - candidate_mat[:, :, 6] = candidate_mat[:, :, 6] / \ - (max(np.max(mq_normal_count_matrix), 70.0)) * max_norm - candidate_mat[:, :, 7] = candidate_mat[:, :, 7] / \ - (np.max(tumor_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 8] = candidate_mat[:, :, 8] / \ - (np.max(normal_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 9] = candidate_mat[:, :, 9] / \ - (np.max(tumor_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 10] = candidate_mat[:, :, 10] / \ - (np.max(normal_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 11] = candidate_mat[:, :, 11] / \ - (np.max(tumor_count_matrix) + 0.00001) * max_norm - candidate_mat[:, :, 12] = candidate_mat[:, :, 12] / \ - (np.max(normal_count_matrix) + 0.00001) * max_norm + "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES) + ) + + candidate_mat[:, :, 0] = ( + candidate_mat[:, :, 0] + / (max(np.max(ref_count_matrix), np.max(tumor_count_matrix)) + 0.00001) + * max_norm + ) + candidate_mat[:, :, 1] = ( + candidate_mat[:, :, 1] / (np.max(tumor_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 2] = ( + candidate_mat[:, :, 2] / (np.max(normal_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 3] = ( + candidate_mat[:, :, 3] + / (max(np.max(bq_tumor_count_matrix), 41.0)) + * max_norm + ) + candidate_mat[:, :, 4] = ( + candidate_mat[:, :, 4] + / (max(np.max(bq_normal_count_matrix), 41.0)) + * max_norm + ) + candidate_mat[:, :, 5] = ( + candidate_mat[:, :, 5] + / (max(np.max(mq_tumor_count_matrix), 70.0)) + * max_norm + ) + candidate_mat[:, :, 6] = ( + candidate_mat[:, :, 6] + / (max(np.max(mq_normal_count_matrix), 70.0)) + * max_norm + ) + candidate_mat[:, :, 7] = ( + candidate_mat[:, :, 7] / (np.max(tumor_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 8] = ( + candidate_mat[:, :, 8] / (np.max(normal_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 9] = ( + candidate_mat[:, :, 9] / (np.max(tumor_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 10] = ( + candidate_mat[:, :, 10] / (np.max(normal_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 11] = ( + candidate_mat[:, :, 11] / (np.max(tumor_count_matrix) + 0.00001) * max_norm + ) + candidate_mat[:, :, 12] = ( + candidate_mat[:, :, 12] / (np.max(normal_count_matrix) + 0.00001) * max_norm + ) for iii in range(len(tag_tumor_count_matrices)): - candidate_mat[:, :, 13 + (iii * 2)] = candidate_mat[:, :, 13 + (iii * 2)] / ( - max(np.max(tag_tumor_count_matrices[iii]), 100.0)) * max_norm - candidate_mat[:, :, 13 + (iii * 2) + 1] = candidate_mat[:, :, 13 + ( - iii * 2) + 1] / (max(np.max(tag_normal_count_matrices[iii]), 100.0)) * max_norm + candidate_mat[:, :, 13 + (iii * 2)] = ( + candidate_mat[:, :, 13 + (iii * 2)] + / (max(np.max(tag_tumor_count_matrices[iii]), 100.0)) + * max_norm + ) + candidate_mat[:, :, 13 + (iii * 2) + 1] = ( + candidate_mat[:, :, 13 + (iii * 2) + 1] + / (max(np.max(tag_normal_count_matrices[iii]), 100.0)) + * max_norm + ) # thread_logger.info("prep-3") if matrix_dtype == "uint8": - candidate_mat = np.maximum(0, np.minimum( - candidate_mat, max_norm)).astype(np.uint8) + candidate_mat = np.maximum(0, np.minimum(candidate_mat, max_norm)).astype( + np.uint8 + ) elif matrix_dtype == "uint16": - candidate_mat = np.maximum(0, np.minimum( - candidate_mat, max_norm)).astype(np.uint16) + candidate_mat = np.maximum(0, np.minimum(candidate_mat, max_norm)).astype( + np.uint16 + ) else: logger.info( - "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES)) + "Wrong matrix_dtype {}. Choices are {}".format(matrix_dtype, MAT_DTYPES) + ) raise Exception - tag = "{}.{}.{}.{}.{}.{}.{}.{}.{}".format(ch_order, pos, ref[0:55], alt[ - 0:55], vartype, center, rlen, tumor_cov, normal_cov) + tag = "{}.{}.{}.{}.{}.{}.{}.{}.{}".format( + ch_order, + pos, + ref[0:55], + alt[0:55], + vartype, + center, + rlen, + tumor_cov, + normal_cov, + ) # thread_logger.info("prep-4") - candidate_mat = base64.b64encode( - zlib.compress(candidate_mat)).decode('ascii') + candidate_mat = base64.b64encode(zlib.compress(candidate_mat)).decode("ascii") # thread_logger.info("prep-5") # thread_logger.info(["rrr-7",time.time()-t1]) - return tag, candidate_mat, record[0:4], ann, is_none except Exception as ex: thread_logger.error(traceback.format_exc()) @@ -968,13 +1291,12 @@ def push_lr(fasta_file, record, left_right_both): new_pos = pos new_ref = ref new_alt = alt - while(new_pos > 1): - l_base = fasta_file.fetch( - (chrom), new_pos - 2, new_pos - 1).upper() + while new_pos > 1: + l_base = fasta_file.fetch((chrom), new_pos - 2, new_pos - 1).upper() new_ref = l_base + new_ref new_alt = l_base + new_alt new_pos -= 1 - while(len(new_alt) > 1 and len(new_ref) > 1): + while len(new_alt) > 1 and len(new_ref) > 1: if new_alt[-1] == new_ref[-1]: new_alt = new_alt[:-1] new_ref = new_ref[:-1] @@ -994,13 +1316,12 @@ def push_lr(fasta_file, record, left_right_both): new_ref = ref new_alt = alt max_pos = fasta_file.lengths[fasta_file.references.index(chrom)] - while(new_pos < max_pos): - r_base = fasta_file.fetch( - (chrom), new_pos - 1, new_pos).upper() + while new_pos < max_pos: + r_base = fasta_file.fetch((chrom), new_pos - 1, new_pos).upper() new_ref = new_ref + r_base new_alt = new_alt + r_base new_pos += 1 - while(len(new_alt) > 1 and len(new_ref) > 1): + while len(new_alt) > 1 and len(new_ref) > 1: if new_alt[0] == new_ref[0] and new_alt[1] == new_ref[1]: new_alt = new_alt[1:] new_ref = new_ref[1:] @@ -1012,16 +1333,22 @@ def push_lr(fasta_file, record, left_right_both): new_pos -= 1 break else: - eqs.append([chrom, new_pos - len(new_ref), - new_ref, new_alt] + record[4:]) - record = [chrom, new_pos - - len(new_ref), new_ref, new_alt] + record[4:] - eqs = list(map(lambda x: eqs[x], dict( - map(lambda x: ["_".join(map(str, x[1])), x[0]], enumerate(eqs))).values())) + eqs.append( + [chrom, new_pos - len(new_ref), new_ref, new_alt] + record[4:] + ) + record = [chrom, new_pos - len(new_ref), new_ref, new_alt] + record[4:] + eqs = list( + map( + lambda x: eqs[x], + dict( + map(lambda x: ["_".join(map(str, x[1])), x[0]], enumerate(eqs)) + ).values(), + ) + ) for eq in eqs: c, p, r, a = eq[0:4] - assert(fasta_file.fetch((c), p - 1, p - 1 + len(r)).upper() == r) + assert fasta_file.fetch((c), p - 1, p - 1 + len(r)).upper() == r return record, eqs @@ -1035,13 +1362,12 @@ def push_left(fasta_file, record): new_pos = pos new_ref = ref new_alt = alt - while(new_pos > 1): - l_base = fasta_file.fetch( - (chrom), new_pos - 2, new_pos - 1).upper() + while new_pos > 1: + l_base = fasta_file.fetch((chrom), new_pos - 2, new_pos - 1).upper() new_ref = l_base + new_ref new_alt = l_base + new_alt new_pos -= 1 - while(len(new_alt) > 1 and len(new_ref) > 1): + while len(new_alt) > 1 and len(new_ref) > 1: if new_alt[-1] == new_ref[-1]: new_alt = new_alt[:-1] new_ref = new_ref[:-1] @@ -1089,14 +1415,14 @@ def merge_records(fasta_file, records): return None b = pos_ - 1 + len(ref_) - while(len(alt2_) > 1 and len(ref2_) > 1): + while len(alt2_) > 1 and len(ref2_) > 1: if alt2_[0] == ref2_[0] and (alt2_[1] == ref2_[1] or len(alt2_) == len(ref2_)): alt2_ = alt2_[1:] ref2_ = ref2_[1:] pos_m += 1 else: break - while(len(alt2_) > 1 and len(ref2_) > 1): + while len(alt2_) > 1 and len(ref2_) > 1: if alt2_[-1] == ref2_[-1]: alt2_ = alt2_[:-1] ref2_ = ref2_[:-1] @@ -1114,16 +1440,18 @@ def is_part_of(record1, record2): return False vartype1 = get_type(ref1, alt1) vartype2 = get_type(ref2, alt2) - if (vartype1 == "SNP" and vartype2 == "DEL"): + if vartype1 == "SNP" and vartype2 == "DEL": if pos2 < pos1 < pos2 + len(ref2): return True - elif (vartype2 == "SNP" and vartype1 == "DEL"): + elif vartype2 == "SNP" and vartype1 == "DEL": if pos1 < pos2 < pos1 + len(ref1): return True elif vartype1 == vartype2: if pos1 == pos2: return True - elif vartype1 == "DEL" and set(range(pos1 + 1, pos1 + len(ref1))) & set(range(pos2 + 1, pos2 + len(ref2))): + elif vartype1 == "DEL" and set(range(pos1 + 1, pos1 + len(ref1))) & set( + range(pos2 + 1, pos2 + len(ref2)) + ): return True return False @@ -1132,7 +1460,7 @@ def find_i_center(ref, alt): logger = logging.getLogger(find_i_center.__name__) i_center = 0 if len(alt) != len(ref): - while(min(len(alt), len(ref)) > i_center and alt[i_center] == ref[i_center]): + while min(len(alt), len(ref)) > i_center and alt[i_center] == ref[i_center]: i_center += 1 return [0, i_center] if (len(ref) < len(alt)) else [i_center, 0] @@ -1140,12 +1468,15 @@ def find_i_center(ref, alt): def find_len(ref, alt): logger = logging.getLogger(find_len.__name__) i_ = 0 - while(min(len(alt), len(ref)) > i_ and alt[i_] == ref[i_]): + while min(len(alt), len(ref)) > i_ and alt[i_] == ref[i_]: i_ += 1 ref_ = ref[i_:] alt_ = alt[i_:] i_ = 0 - while (min(len(alt_), len(ref_)) > i_ and alt_[len(alt_) - i_ - 1] == ref_[len(ref_) - i_ - 1]): + while ( + min(len(alt_), len(ref_)) > i_ + and alt_[len(alt_) - i_ - 1] == ref_[len(ref_) - i_ - 1] + ): i_ += 1 if i_ > 0: ref_ = ref_[:-i_] @@ -1153,8 +1484,7 @@ def find_len(ref, alt): return max(len(ref_), len(alt_)) -def keep_in_region(input_file, region_bed, - output_fn): +def keep_in_region(input_file, region_bed, output_fn): logger = logging.getLogger(keep_in_region.__name__) i = 0 tmp_ = get_tmp_file() @@ -1162,18 +1492,16 @@ def keep_in_region(input_file, region_bed, for line in skip_empty(i_f): fields = line.strip().split() chrom, start, end = fields[0:3] - o_f.write( - "\t".join([chrom, start, str(int(start) + 1), str(i)]) + "\n") + o_f.write("\t".join([chrom, start, str(int(start) + 1), str(i)]) + "\n") i += 1 good_i = set([]) - tmp_ = bedtools_window( - tmp_, region_bed, args=" -w 1", run_logger=logger) + tmp_ = bedtools_window(tmp_, region_bed, args=" -w 1", run_logger=logger) with open(tmp_) as i_f: for line in skip_empty(i_f): fields = line.strip().split() chrom, start, end, i_, chrom_, start_, end_ = fields[0:7] - assert(chrom == chrom_) + assert chrom == chrom_ if int(start_) <= int(start) <= int(end_): good_i.add(int(i_)) i = 0 @@ -1189,61 +1517,132 @@ def keep_in_region(input_file, region_bed, def find_records(input_record): - work, split_region_file, truth_vcf_file, pred_vcf_file, ref_file, ensemble_bed, num_ens_features, strict_labeling, work_index = input_record + ( + work, + split_region_file, + truth_vcf_file, + pred_vcf_file, + ref_file, + ensemble_bed, + num_ens_features, + strict_labeling, + work_index, + ) = input_record thread_logger = logging.getLogger( - "{} ({})".format(find_records.__name__, multiprocessing.current_process().name)) + "{} ({})".format(find_records.__name__, multiprocessing.current_process().name) + ) try: - thread_logger.info( - "Start find_records for worker {}".format(work_index)) + thread_logger.info("Start find_records for worker {}".format(work_index)) split_bed = bedtools_slop( - split_region_file, ref_file + ".fai", args=" -b 5", run_logger=thread_logger) - split_truth_vcf_file = os.path.join( - work, "truth_{}.vcf".format(work_index)) - split_pred_vcf_file = os.path.join( - work, "pred_{}.vcf".format(work_index)) + split_region_file, ref_file + ".fai", args=" -b 5", run_logger=thread_logger + ) + split_truth_vcf_file = os.path.join(work, "truth_{}.vcf".format(work_index)) + split_pred_vcf_file = os.path.join(work, "pred_{}.vcf".format(work_index)) split_ensemble_bed_file = os.path.join( - work, "ensemble_{}.bed".format(work_index)) + work, "ensemble_{}.bed".format(work_index) + ) split_missed_ensemble_bed_file = os.path.join( - work, "missed_ensemble_{}.bed".format(work_index)) + work, "missed_ensemble_{}.bed".format(work_index) + ) split_pred_with_missed_file = os.path.join( - work, "pred_with_missed_{}.bed".format(work_index)) + work, "pred_with_missed_{}.bed".format(work_index) + ) split_in_ensemble_bed = os.path.join( - work, "in_ensemble_{}.bed".format(work_index)) + work, "in_ensemble_{}.bed".format(work_index) + ) bedtools_intersect( - truth_vcf_file, split_bed, args=" -u", output_fn=split_truth_vcf_file, run_logger=thread_logger) + truth_vcf_file, + split_bed, + args=" -u", + output_fn=split_truth_vcf_file, + run_logger=thread_logger, + ) tmp_ = get_tmp_file() bedtools_intersect( - pred_vcf_file, split_bed, args=" -u", output_fn=tmp_, run_logger=thread_logger) - keep_in_region(input_file=tmp_, region_bed=split_region_file, - output_fn=split_pred_vcf_file) + pred_vcf_file, + split_bed, + args=" -u", + output_fn=tmp_, + run_logger=thread_logger, + ) + keep_in_region( + input_file=tmp_, region_bed=split_region_file, output_fn=split_pred_vcf_file + ) if ensemble_bed: tmp_ = get_tmp_file() bedtools_intersect( - ensemble_bed, split_bed, args=" -u", output_fn=tmp_, run_logger=thread_logger) - keep_in_region(input_file=tmp_, region_bed=split_region_file, - output_fn=split_ensemble_bed_file) + ensemble_bed, + split_bed, + args=" -u", + output_fn=tmp_, + run_logger=thread_logger, + ) + keep_in_region( + input_file=tmp_, + region_bed=split_region_file, + output_fn=split_ensemble_bed_file, + ) tmp_ = bedtools_window( - split_ensemble_bed_file, split_pred_vcf_file, args=" -w 5 -v", run_logger=thread_logger) - - vcf_2_bed(tmp_, split_missed_ensemble_bed_file, add_fields=[".", - ".", ".", ".", "."]) + split_ensemble_bed_file, + split_pred_vcf_file, + args=" -w 5 -v", + run_logger=thread_logger, + ) + + vcf_2_bed( + tmp_, + split_missed_ensemble_bed_file, + add_fields=[".", ".", ".", ".", "."], + ) concatenate_vcfs( - [split_pred_vcf_file, split_missed_ensemble_bed_file], split_pred_with_missed_file) + [split_pred_vcf_file, split_missed_ensemble_bed_file], + split_pred_with_missed_file, + ) tmp_ = get_tmp_file() with open(split_pred_with_missed_file) as i_f, open(tmp_, "w") as o_f: for line in skip_empty(i_f): x = line.strip().split("\t") - o_f.write("\t".join( - list(map(str, [x[0], x[1], ".", x[3], x[4], ".", ".", ".", ".", "."]))) + "\n") - bedtools_sort(tmp_, output_fn=split_pred_with_missed_file, - run_logger=thread_logger) + o_f.write( + "\t".join( + list( + map( + str, + [ + x[0], + x[1], + ".", + x[3], + x[4], + ".", + ".", + ".", + ".", + ".", + ], + ) + ) + ) + + "\n" + ) + bedtools_sort( + tmp_, output_fn=split_pred_with_missed_file, run_logger=thread_logger + ) not_in_ensemble_bed = bedtools_window( - split_pred_with_missed_file, split_ensemble_bed_file, args=" -w 1 -v", run_logger=thread_logger) + split_pred_with_missed_file, + split_ensemble_bed_file, + args=" -w 1 -v", + run_logger=thread_logger, + ) in_ensemble_bed = bedtools_window( - split_pred_with_missed_file, split_ensemble_bed_file, output_fn=split_in_ensemble_bed, args=" -w 1", run_logger=thread_logger) + split_pred_with_missed_file, + split_ensemble_bed_file, + output_fn=split_in_ensemble_bed, + args=" -w 1", + run_logger=thread_logger, + ) records = [] i = 0 @@ -1253,8 +1652,12 @@ def find_records(input_record): with open(not_in_ensemble_bed) as ni_f: for line in skip_empty(ni_f): record = line.strip().split("\t") - chrom, pos, ref, alt = [str(record[0]), int( - record[1]), record[3], record[4]] + chrom, pos, ref, alt = [ + str(record[0]), + int(record[1]), + record[3], + record[4], + ] r_ = [] if len(ref) == len(alt) and len(ref) > 1: for ii in range(len(ref)): @@ -1275,8 +1678,12 @@ def find_records(input_record): for line in skip_empty(ni_f): record = line.strip().split("\t") if curren_pos_records: - if (record[0] == curren_pos_records[0][0] and record[1] == curren_pos_records[0][1] and - record[3] == curren_pos_records[0][3] and record[4] == curren_pos_records[0][4]): + if ( + record[0] == curren_pos_records[0][0] + and record[1] == curren_pos_records[0][1] + and record[3] == curren_pos_records[0][3] + and record[4] == curren_pos_records[0][4] + ): curren_pos_records.append(record) else: emit_flag = True @@ -1287,18 +1694,25 @@ def find_records(input_record): if curren_pos_records: rrs = [] for record_ in curren_pos_records: - chrom, pos, ref, alt = [str(record_[0]), int( - record_[1]), record_[3], record_[4]] - ens_chrom, ens_pos, ens_ref, ens_alt = [str(record_[10]), int( - record_[11]), record_[13], record_[14]] + chrom, pos, ref, alt = [ + str(record_[0]), + int(record_[1]), + record_[3], + record_[4], + ] + ens_chrom, ens_pos, ens_ref, ens_alt = [ + str(record_[10]), + int(record_[11]), + record_[13], + record_[14], + ] r_ = [] if len(ref) == len(alt) and len(ref) > 1: for ii in range(len(ref)): ref_ = ref[ii] alt_ = alt[ii] if ref_ != alt_: - r_.append( - [chrom, pos + ii, ref_, alt_]) + r_.append([chrom, pos + ii, ref_, alt_]) else: r_ = [[chrom, pos, ref, alt]] @@ -1308,27 +1722,41 @@ def find_records(input_record): if ref == ens_ref and alt == ens_alt: ann = record_[15:] var_match = True - elif (len(ref) > len(alt) and len(ens_ref) > len(ens_alt) and - (alt) == (ens_alt)): - if ((len(ref) > len(ens_ref) and ref[0:len(ens_ref)] == ens_ref) or ( - len(ens_ref) > len(ref) and ens_ref[0:len(ref)] == ref)): + elif ( + len(ref) > len(alt) + and len(ens_ref) > len(ens_alt) + and (alt) == (ens_alt) + ): + if ( + len(ref) > len(ens_ref) + and ref[0 : len(ens_ref)] == ens_ref + ) or ( + len(ens_ref) > len(ref) + and ens_ref[0 : len(ref)] == ref + ): ann = record_[15:] - elif (len(ref) < len(alt) and len(ens_ref) < len(ens_alt) and - (ref) == (ens_ref)): - if ((len(alt) > len(ens_alt) and alt[0:len(ens_alt)] == ens_alt) or ( - len(ens_alt) > len(alt) and ens_alt[0:len(alt)] == alt)): + elif ( + len(ref) < len(alt) + and len(ens_ref) < len(ens_alt) + and (ref) == (ens_ref) + ): + if ( + len(alt) > len(ens_alt) + and alt[0 : len(ens_alt)] == ens_alt + ) or ( + len(ens_alt) > len(alt) + and ens_alt[0 : len(alt)] == alt + ): ann = record_[15:] if ann: ann = list(map(float, ann)) rrs.append([r_, ann, var_match]) has_var_match = sum(map(lambda x: x[2], rrs)) if has_var_match: - rrs = list( - filter(lambda x: x[2], rrs))[0:1] + rrs = list(filter(lambda x: x[2], rrs))[0:1] max_ann = max(map(lambda x: sum(x[1]), rrs)) if max_ann > 0: - rrs = list( - filter(lambda x: sum(x[1]) > 0, rrs)) + rrs = list(filter(lambda x: sum(x[1]) > 0, rrs)) elif max_ann == 0: rrs = rrs[0:1] for r_, ann, _ in rrs: @@ -1341,18 +1769,25 @@ def find_records(input_record): if curren_pos_records: rrs = [] for record_ in curren_pos_records: - chrom, pos, ref, alt = [str(record_[0]), int( - record_[1]), record_[3], record_[4]] - ens_chrom, ens_pos, ens_ref, ens_alt = [str(record_[10]), int( - record_[11]), record_[13], record_[14]] + chrom, pos, ref, alt = [ + str(record_[0]), + int(record_[1]), + record_[3], + record_[4], + ] + ens_chrom, ens_pos, ens_ref, ens_alt = [ + str(record_[10]), + int(record_[11]), + record_[13], + record_[14], + ] r_ = [] if len(ref) == len(alt) and len(ref) > 1: for ii in range(len(ref)): ref_ = ref[ii] alt_ = alt[ii] if ref_ != alt_: - r_.append( - [chrom, pos + ii, ref_, alt_]) + r_.append([chrom, pos + ii, ref_, alt_]) else: r_ = [[chrom, pos, ref, alt]] @@ -1362,23 +1797,38 @@ def find_records(input_record): if ref == ens_ref and alt == ens_alt: ann = record_[15:] var_match = True - elif (len(ref) > len(alt) and len(ens_ref) > len(ens_alt) and - (alt) == (ens_alt)): - if ((len(ref) > len(ens_ref) and ref[0:len(ens_ref)] == ens_ref) or ( - len(ens_ref) > len(ref) and ens_ref[0:len(ref)] == ref)): + elif ( + len(ref) > len(alt) + and len(ens_ref) > len(ens_alt) + and (alt) == (ens_alt) + ): + if ( + len(ref) > len(ens_ref) + and ref[0 : len(ens_ref)] == ens_ref + ) or ( + len(ens_ref) > len(ref) + and ens_ref[0 : len(ref)] == ref + ): ann = record_[15:] - elif (len(ref) < len(alt) and len(ens_ref) < len(ens_alt) and - (ref) == (ens_ref)): - if ((len(alt) > len(ens_alt) and alt[0:len(ens_alt)] == ens_alt) or ( - len(ens_alt) > len(alt) and ens_alt[0:len(alt)] == alt)): + elif ( + len(ref) < len(alt) + and len(ens_ref) < len(ens_alt) + and (ref) == (ens_ref) + ): + if ( + len(alt) > len(ens_alt) + and alt[0 : len(ens_alt)] == ens_alt + ) or ( + len(ens_alt) > len(alt) + and ens_alt[0 : len(alt)] == alt + ): ann = record_[15:] if ann: ann = list(map(float, ann)) rrs.append([r_, ann, var_match]) has_var_match = sum(map(lambda x: x[2], rrs)) if has_var_match: - rrs = list( - filter(lambda x: x[2], rrs))[0:1] + rrs = list(filter(lambda x: x[2], rrs))[0:1] max_ann = max(map(lambda x: sum(x[1]), rrs)) if max_ann > 0: rrs = list(filter(lambda x: sum(x[1]) > 0, rrs)) @@ -1391,11 +1841,15 @@ def find_records(input_record): i += 1 else: - with open(split_pred_vcf_file, 'r') as vcf_reader: + with open(split_pred_vcf_file, "r") as vcf_reader: for line in skip_empty(vcf_reader): record = line.strip().split() - chrom, pos, ref, alt = [record[0], int( - record[1]), record[3], record[4]] + chrom, pos, ref, alt = [ + record[0], + int(record[1]), + record[3], + record[4], + ] r_ = [] if len(ref) == len(alt) and len(ref) > 1: for ii in range(len(ref)): @@ -1414,18 +1868,25 @@ def find_records(input_record): with open(records_bed, "w") as r_b: for x in records: r_b.write( - "\t".join(map(str, [x[0], x[1], x[1] + len(x[2]), x[2], x[3], x[4]])) + "\n") + "\t".join( + map(str, [x[0], x[1], x[1] + len(x[2]), x[2], x[3], x[4]]) + ) + + "\n" + ) truth_records = [] i = 0 - with open(split_truth_vcf_file, 'r') as vcf_reader: + with open(split_truth_vcf_file, "r") as vcf_reader: for line in skip_empty(vcf_reader): record = line.strip().split() pos = int(record[1]) - if len(record[3]) != len(record[4]) and min(len(record[3]), len(record[4])) > 0 and record[3][0] != record[4][0]: + if ( + len(record[3]) != len(record[4]) + and min(len(record[3]), len(record[4])) > 0 + and record[3][0] != record[4][0] + ): if pos > 1: - l_base = fasta_file.fetch( - record[0], pos - 2, pos - 1).upper() + l_base = fasta_file.fetch(record[0], pos - 2, pos - 1).upper() record[3] = l_base + record[3] record[4] = l_base + record[4] pos -= 1 @@ -1439,10 +1900,15 @@ def find_records(input_record): with open(truth_bed, "w") as t_b: for x in truth_records: t_b.write( - "\t".join(map(str, [x[0], x[1], x[1] + len(x[2]), x[2], x[3], x[4]])) + "\n") + "\t".join( + map(str, [x[0], x[1], x[1] + len(x[2]), x[2], x[3], x[4]]) + ) + + "\n" + ) none_records_0 = bedtools_window( - records_bed, truth_bed, args=" -w 5 -v", run_logger=thread_logger) + records_bed, truth_bed, args=" -w 5 -v", run_logger=thread_logger + ) none_records_ids = [] with open(none_records_0) as i_f: for line in skip_empty(i_f): @@ -1450,7 +1916,8 @@ def find_records(input_record): none_records_ids.append(int(x[5])) other_records = bedtools_window( - records_bed, truth_bed, args=" -w 5", run_logger=thread_logger) + records_bed, truth_bed, args=" -w 5", run_logger=thread_logger + ) map_pred_2_truth = {} map_truth_2_pred = {} @@ -1478,7 +1945,12 @@ def find_records(input_record): truth_record = truth_records[i] for j in js: record = records[j] - if record[0] == truth_record[0] and record[1] == truth_record[1] and record[2] == truth_record[2] and record[3] == truth_record[3]: + if ( + record[0] == truth_record[0] + and record[1] == truth_record[1] + and record[2] == truth_record[2] + and record[3] == truth_record[3] + ): assert int(record[4]) == j vartype = get_type(record[2], record[3]) if j not in good_records[vartype]: @@ -1490,8 +1962,9 @@ def find_records(input_record): perfect_t_idx.add(i) good_records_idx = [i for w in list(good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + remained_idx = sorted( + set(range(len(records))) - (set(good_records_idx) | set(none_records_ids)) + ) done_js = list(good_records_idx) for i, js in map_truth_2_pred.items(): truth_record = truth_records[i] @@ -1507,16 +1980,15 @@ def find_records(input_record): for n_merge_i in range(0, len(i_s) - idx_ii): if done: break - t_i = i_s[idx_ii:idx_ii + n_merge_i + 1] + t_i = i_s[idx_ii : idx_ii + n_merge_i + 1] t_ = [truth_records[iii] for iii in t_i] mt = merge_records(fasta_file, t_) if mt: mt2, eqs2 = push_lr(fasta_file, mt, 2) - eqs2 = list( - map(lambda x: "_".join(map(str, x[0:4])), eqs2)) + eqs2 = list(map(lambda x: "_".join(map(str, x[0:4])), eqs2)) for idx_jj, jj in enumerate(js): for n_merge_j in range(0, len(js) - idx_jj): - r_j = js[idx_jj:idx_jj + n_merge_j + 1] + r_j = js[idx_jj : idx_jj + n_merge_j + 1] if set(r_j) & set(done_js_): continue r_ = [records[jjj] for jjj in r_j] @@ -1527,11 +1999,9 @@ def find_records(input_record): if record_str in eqs2: for j in r_j: record = records[j] - vartype = get_type( - record[2], record[3]) + vartype = get_type(record[2], record[3]) pos, ref, alt = record[1:4] - record_center[ - j] = find_i_center(ref, alt) + record_center[j] = find_i_center(ref, alt) record_len[j] = find_len(ref, alt) good_records[vartype].append(j) vtype[j] = vartype @@ -1544,8 +2014,9 @@ def find_records(input_record): perfect_idx = [i for w in list(good_records.values()) for i in w] good_records_idx = [i for w in list(good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + remained_idx = sorted( + set(range(len(records))) - (set(good_records_idx) | set(none_records_ids)) + ) for j in remained_idx: record = records[j] pos, ref, alt = record[1:4] @@ -1573,7 +2044,8 @@ def find_records(input_record): break if not done: p_s = set([jj for i in i_s for jj in map_truth_2_pred[i]]) & set( - perfect_idx) + perfect_idx + ) for p in p_s: ref_p, alt_p = records[p][2:4] if not strict_labeling: @@ -1593,8 +2065,9 @@ def find_records(input_record): break good_records_idx = [i for w in list(good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + remained_idx = sorted( + set(range(len(records))) - (set(good_records_idx) | set(none_records_ids)) + ) for i, js in map_truth_2_pred.items(): truth_record = truth_records[i] if set(js) & set(good_records_idx): @@ -1605,7 +2078,9 @@ def find_records(input_record): if len(js) == 2 and vartype_t == "SNP": vartype0 = get_type(records[js[0]][2], records[js[0]][3]) vartype1 = get_type(records[js[1]][2], records[js[1]][3]) - if (vartype0 == "DEL" and vartype1 == "INS") or (vartype1 == "DEL" and vartype0 == "INS"): + if (vartype0 == "DEL" and vartype1 == "INS") or ( + vartype1 == "DEL" and vartype0 == "INS" + ): for j in js: record = records[j] pos, ref, alt = record[1:4] @@ -1616,8 +2091,9 @@ def find_records(input_record): vtype[j] = vartype good_records_idx = [i for w in list(good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + remained_idx = sorted( + set(range(len(records))) - (set(good_records_idx) | set(none_records_ids)) + ) for i, js in map_truth_2_pred.items(): truth_record = truth_records[i] if set(js) & set(good_records_idx): @@ -1630,15 +2106,20 @@ def find_records(input_record): vartype = get_type(record[2], record[3]) pos, ref, alt = record[1:4] rc = find_i_center(ref, alt) - if vartype_t == vartype and pos_t == pos and ((not strict_labeling) or vartype_t != "SNP"): + if ( + vartype_t == vartype + and pos_t == pos + and ((not strict_labeling) or vartype_t != "SNP") + ): good_records[vartype_t].append(j) vtype[j] = vartype_t record_len[j] = find_len(ref_t, alt_t) record_center[j] = rc good_records_idx = [i for w in list(good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + remained_idx = sorted( + set(range(len(records))) - (set(good_records_idx) | set(none_records_ids)) + ) if not strict_labeling: for i, js in map_truth_2_pred.items(): truth_record = truth_records[i] @@ -1656,16 +2137,19 @@ def find_records(input_record): pos, ref, alt = record[1:4] rc = find_i_center(ref, alt) if pos_t + rct[0] + rct[1] == pos + rc[0] + rc[1]: - if (vartype_t == "INS" and vartype == "SNP") or (vartype == "INS" and vartype_t == "SNP"): + if (vartype_t == "INS" and vartype == "SNP") or ( + vartype == "INS" and vartype_t == "SNP" + ): good_records[vartype_t].append(j) vtype[j] = vartype_t record_len[j] = find_len(ref_t, alt_t) record_center[j] = rc - good_records_idx = [i for w in list( - good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + good_records_idx = [i for w in list(good_records.values()) for i in w] + remained_idx = sorted( + set(range(len(records))) + - (set(good_records_idx) | set(none_records_ids)) + ) if not strict_labeling: for i, js in map_truth_2_pred.items(): @@ -1678,16 +2162,21 @@ def find_records(input_record): record = records[j] pos, ref, alt = record[1:4] vartype = get_type(record[2], record[3]) - if (vartype == vartype_t) and vartype_t != "SNP" and abs(pos - pos_t) < 2: + if ( + (vartype == vartype_t) + and vartype_t != "SNP" + and abs(pos - pos_t) < 2 + ): good_records[vartype_t].append(j) vtype[j] = vartype_t record_center[j] = find_i_center(ref, alt) record_len[j] = find_len(ref_t, alt_t) - good_records_idx = [i for w in list( - good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + good_records_idx = [i for w in list(good_records.values()) for i in w] + remained_idx = sorted( + set(range(len(records))) + - (set(good_records_idx) | set(none_records_ids)) + ) for i, js in map_truth_2_pred.items(): truth_record = truth_records[i] @@ -1695,8 +2184,9 @@ def find_records(input_record): continue good_records_idx = [i for w in list(good_records.values()) for i in w] - remained_idx = sorted(set(range(len(records))) - - (set(good_records_idx) | set(none_records_ids))) + remained_idx = sorted( + set(range(len(records))) - (set(good_records_idx) | set(none_records_ids)) + ) for j in remained_idx: none_records_ids.append(j) @@ -1715,17 +2205,30 @@ def find_records(input_record): none_records = list(map(lambda x: records[x], none_records_ids)) none_records = sorted(none_records, key=lambda x: [x[0], int(x[1])]) - return records_r, none_records, vtype, record_len, record_center, chroms_order, anns + return ( + records_r, + none_records, + vtype, + record_len, + record_center, + chroms_order, + anns, + ) except Exception as ex: thread_logger.error(traceback.format_exc()) thread_logger.error(ex) return None -def extract_ensemble(ensemble_tsvs, ensemble_bed, no_seq_complexity, enforce_header, - custom_header, - zero_vscore, - is_extend): +def extract_ensemble( + ensemble_tsvs, + ensemble_bed, + no_seq_complexity, + enforce_header, + custom_header, + zero_vscore, + is_extend, +): logger = logging.getLogger(extract_ensemble.__name__) ensemble_data = [] ensemble_pos = [] @@ -1733,40 +2236,143 @@ def extract_ensemble(ensemble_tsvs, ensemble_bed, no_seq_complexity, enforce_hea header_pos = [] order_header = [] COV = 50 - expected_features = ["if_MuTect", "if_VarScan2", "if_JointSNVMix2", - "if_SomaticSniper", "if_VarDict", "MuSE_Tier", "if_LoFreq", "if_Scalpel", "if_Strelka", - "if_TNscope", "Strelka_Score", "Strelka_QSS", "Strelka_TQSS", "VarScan2_Score", "SNVMix2_Score", - "Sniper_Score", "VarDict_Score", "if_dbsnp", "COMMON", "if_COSMIC", "COSMIC_CNT", - "Consistent_Mates", "Inconsistent_Mates"] + expected_features = [ + "if_MuTect", + "if_VarScan2", + "if_JointSNVMix2", + "if_SomaticSniper", + "if_VarDict", + "MuSE_Tier", + "if_LoFreq", + "if_Scalpel", + "if_Strelka", + "if_TNscope", + "Strelka_Score", + "Strelka_QSS", + "Strelka_TQSS", + "VarScan2_Score", + "SNVMix2_Score", + "Sniper_Score", + "VarDict_Score", + "if_dbsnp", + "COMMON", + "if_COSMIC", + "COSMIC_CNT", + "Consistent_Mates", + "Inconsistent_Mates", + ] if not no_seq_complexity: expected_features += ["Seq_Complexity_Span", "Seq_Complexity_Adj"] - expected_features += ["N_DP", "nBAM_REF_MQ", "nBAM_ALT_MQ", - "nBAM_Z_Ranksums_MQ", "nBAM_REF_BQ", "nBAM_ALT_BQ", "nBAM_Z_Ranksums_BQ", "nBAM_REF_NM", - "nBAM_ALT_NM", "nBAM_NM_Diff", "nBAM_REF_Concordant", "nBAM_REF_Discordant", - "nBAM_ALT_Concordant", "nBAM_ALT_Discordant", "nBAM_Concordance_FET", "N_REF_FOR", "N_REF_REV", - "N_ALT_FOR", "N_ALT_REV", "nBAM_StrandBias_FET", "nBAM_Z_Ranksums_EndPos", - "nBAM_REF_Clipped_Reads", "nBAM_ALT_Clipped_Reads", "nBAM_Clipping_FET", "nBAM_MQ0", - "nBAM_Other_Reads", "nBAM_Poor_Reads", "nBAM_REF_InDel_3bp", "nBAM_REF_InDel_2bp", - "nBAM_REF_InDel_1bp", "nBAM_ALT_InDel_3bp", "nBAM_ALT_InDel_2bp", "nBAM_ALT_InDel_1bp", - "M2_NLOD", "M2_TLOD", "M2_STR", "M2_ECNT", "SOR", "MSI", "MSILEN", "SHIFT3", - "MaxHomopolymer_Length", "SiteHomopolymer_Length", "T_DP", "tBAM_REF_MQ", "tBAM_ALT_MQ", - "tBAM_Z_Ranksums_MQ", "tBAM_REF_BQ", "tBAM_ALT_BQ", "tBAM_Z_Ranksums_BQ", "tBAM_REF_NM", - "tBAM_ALT_NM", "tBAM_NM_Diff", "tBAM_REF_Concordant", "tBAM_REF_Discordant", - "tBAM_ALT_Concordant", "tBAM_ALT_Discordant", "tBAM_Concordance_FET", "T_REF_FOR", - "T_REF_REV", "T_ALT_FOR", "T_ALT_REV", "tBAM_StrandBias_FET", "tBAM_Z_Ranksums_EndPos", - "tBAM_REF_Clipped_Reads", "tBAM_ALT_Clipped_Reads", "tBAM_Clipping_FET", "tBAM_MQ0", - "tBAM_Other_Reads", "tBAM_Poor_Reads", "tBAM_REF_InDel_3bp", "tBAM_REF_InDel_2bp", - "tBAM_REF_InDel_1bp", "tBAM_ALT_InDel_3bp", "tBAM_ALT_InDel_2bp", "tBAM_ALT_InDel_1bp", - "InDel_Length"] - callers_features = ["if_MuTect", "if_VarScan2", "if_JointSNVMix2", "if_SomaticSniper", "if_VarDict", "MuSE_Tier", - "if_LoFreq", "if_Scalpel", "if_Strelka", "if_TNscope", "Strelka_Score", "Strelka_QSS", - "Strelka_TQSS", "SNVMix2_Score", "Sniper_Score", "VarDict_Score", - "M2_NLOD", "M2_TLOD", "M2_STR", "M2_ECNT", "MSI", "MSILEN", "SHIFT3"] + expected_features += [ + "N_DP", + "nBAM_REF_MQ", + "nBAM_ALT_MQ", + "nBAM_Z_Ranksums_MQ", + "nBAM_REF_BQ", + "nBAM_ALT_BQ", + "nBAM_Z_Ranksums_BQ", + "nBAM_REF_NM", + "nBAM_ALT_NM", + "nBAM_NM_Diff", + "nBAM_REF_Concordant", + "nBAM_REF_Discordant", + "nBAM_ALT_Concordant", + "nBAM_ALT_Discordant", + "nBAM_Concordance_FET", + "N_REF_FOR", + "N_REF_REV", + "N_ALT_FOR", + "N_ALT_REV", + "nBAM_StrandBias_FET", + "nBAM_Z_Ranksums_EndPos", + "nBAM_REF_Clipped_Reads", + "nBAM_ALT_Clipped_Reads", + "nBAM_Clipping_FET", + "nBAM_MQ0", + "nBAM_Other_Reads", + "nBAM_Poor_Reads", + "nBAM_REF_InDel_3bp", + "nBAM_REF_InDel_2bp", + "nBAM_REF_InDel_1bp", + "nBAM_ALT_InDel_3bp", + "nBAM_ALT_InDel_2bp", + "nBAM_ALT_InDel_1bp", + "M2_NLOD", + "M2_TLOD", + "M2_STR", + "M2_ECNT", + "SOR", + "MSI", + "MSILEN", + "SHIFT3", + "MaxHomopolymer_Length", + "SiteHomopolymer_Length", + "T_DP", + "tBAM_REF_MQ", + "tBAM_ALT_MQ", + "tBAM_Z_Ranksums_MQ", + "tBAM_REF_BQ", + "tBAM_ALT_BQ", + "tBAM_Z_Ranksums_BQ", + "tBAM_REF_NM", + "tBAM_ALT_NM", + "tBAM_NM_Diff", + "tBAM_REF_Concordant", + "tBAM_REF_Discordant", + "tBAM_ALT_Concordant", + "tBAM_ALT_Discordant", + "tBAM_Concordance_FET", + "T_REF_FOR", + "T_REF_REV", + "T_ALT_FOR", + "T_ALT_REV", + "tBAM_StrandBias_FET", + "tBAM_Z_Ranksums_EndPos", + "tBAM_REF_Clipped_Reads", + "tBAM_ALT_Clipped_Reads", + "tBAM_Clipping_FET", + "tBAM_MQ0", + "tBAM_Other_Reads", + "tBAM_Poor_Reads", + "tBAM_REF_InDel_3bp", + "tBAM_REF_InDel_2bp", + "tBAM_REF_InDel_1bp", + "tBAM_ALT_InDel_3bp", + "tBAM_ALT_InDel_2bp", + "tBAM_ALT_InDel_1bp", + "InDel_Length", + ] + callers_features = [ + "if_MuTect", + "if_VarScan2", + "if_JointSNVMix2", + "if_SomaticSniper", + "if_VarDict", + "MuSE_Tier", + "if_LoFreq", + "if_Scalpel", + "if_Strelka", + "if_TNscope", + "Strelka_Score", + "Strelka_QSS", + "Strelka_TQSS", + "SNVMix2_Score", + "Sniper_Score", + "VarDict_Score", + "M2_NLOD", + "M2_TLOD", + "M2_STR", + "M2_ECNT", + "MSI", + "MSILEN", + "SHIFT3", + ] if is_extend and custom_header: expected_features = list( - filter(lambda x: x not in callers_features, expected_features)) + filter(lambda x: x not in callers_features, expected_features) + ) n_vars = 0 all_headers = set([]) for ensemble_tsv in ensemble_tsvs: @@ -1781,16 +2387,22 @@ def extract_ensemble(ensemble_tsvs, ensemble_bed, no_seq_complexity, enforce_hea else: if is_extend and not custom_header: header_ += callers_features - header_en = list(filter( - lambda x: x[1] in expected_features, enumerate(header_))) + header_en = list( + filter( + lambda x: x[1] in expected_features, enumerate(header_) + ) + ) header = list(map(lambda x: x[1], header_en)) if not enforce_header: expected_features = header if set(expected_features) - set(header): - logger.error("The following features are missing from ensemble file {}: {}".format( - ensemble_tsv, - list(set(expected_features) - set(header)))) + logger.error( + "The following features are missing from ensemble file {}: {}".format( + ensemble_tsv, + list(set(expected_features) - set(header)), + ) + ) raise Exception order_header = [] for f in expected_features: @@ -1802,100 +2414,267 @@ def extract_ensemble(ensemble_tsvs, ensemble_bed, no_seq_complexity, enforce_hea features = fields[5:] if is_extend and not custom_header: features += ["0"] * len(callers_features) - features = list(map(lambda x: float( - x.replace("False", "0").replace("True", "1")), features)) + features = list( + map( + lambda x: float(x.replace("False", "0").replace("True", "1")), + features, + ) + ) if custom_header and not is_extend: if min(features) < 0 or max(features) > 1: logger.info( - "In --ensemble_custom_header mode, feature values in ensemble.tsv should be normalized in [0,1]") + "In --ensemble_custom_header mode, feature values in ensemble.tsv should be normalized in [0,1]" + ) raise Exception ensemble_data.append(features) n_vars += 1 if len(set(all_headers)) != 1: - raise(RuntimeError("inconsistent headers in {}".format(ensemble_tsvs))) + raise (RuntimeError("inconsistent headers in {}".format(ensemble_tsvs))) if n_vars > 0: ensemble_data = np.array(ensemble_data)[:, order_header] header = np.array(header_)[order_header].tolist() if not custom_header or is_extend: - cov_features = list(map(lambda x: x[0], filter(lambda x: x[1] in [ - "Consistent_Mates", "Inconsistent_Mates", "N_DP", - "nBAM_REF_NM", "nBAM_ALT_NM", "nBAM_REF_Concordant", "nBAM_REF_Discordant", "nBAM_ALT_Concordant", "nBAM_ALT_Discordant", - "N_REF_FOR", "N_REF_REV", "N_ALT_FOR", "N_ALT_REV", "nBAM_REF_Clipped_Reads", "nBAM_ALT_Clipped_Reads", "nBAM_MQ0", "nBAM_Other_Reads", "nBAM_Poor_Reads", - "nBAM_REF_InDel_3bp", "nBAM_REF_InDel_2bp", "nBAM_REF_InDel_1bp", "nBAM_ALT_InDel_3bp", "nBAM_ALT_InDel_2bp", - "nBAM_ALT_InDel_1bp", - "T_DP", "tBAM_REF_NM", "tBAM_ALT_NM", "tBAM_REF_Concordant", "tBAM_REF_Discordant", "tBAM_ALT_Concordant", "tBAM_ALT_Discordant", - "T_REF_FOR", "T_REF_REV", "T_ALT_FOR", "T_ALT_REV", - "tBAM_REF_Clipped_Reads", "tBAM_ALT_Clipped_Reads", - "tBAM_MQ0", "tBAM_Other_Reads", "tBAM_Poor_Reads", "tBAM_REF_InDel_3bp", "tBAM_REF_InDel_2bp", - "tBAM_REF_InDel_1bp", "tBAM_ALT_InDel_3bp", "tBAM_ALT_InDel_2bp", "tBAM_ALT_InDel_1bp", - ], enumerate(header)))) - mq_features = list(map(lambda x: x[0], filter(lambda x: x[1] in [ - "nBAM_REF_MQ", "nBAM_ALT_MQ", "tBAM_REF_MQ", "tBAM_ALT_MQ"], enumerate(header)))) - bq_features = list(map(lambda x: x[0], filter(lambda x: x[1] in [ - "nBAM_REF_BQ", "nBAM_ALT_BQ", "tBAM_REF_BQ", "tBAM_ALT_BQ"], enumerate(header)))) - nm_diff_features = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["nBAM_NM_Diff", "tBAM_NM_Diff"], enumerate(header)))) - ranksum_features = list(map(lambda x: x[0], filter(lambda x: x[1] in ["nBAM_Z_Ranksums_MQ", "nBAM_Z_Ranksums_BQ", - "nBAM_Z_Ranksums_EndPos", "tBAM_Z_Ranksums_BQ", "tBAM_Z_Ranksums_MQ", "tBAM_Z_Ranksums_EndPos", ], enumerate(header)))) - zero_to_one_features = list(map(lambda x: x[0], filter(lambda x: x[1] in ["if_MuTect", "if_VarScan2", "if_SomaticSniper", "if_VarDict", - "MuSE_Tier", "if_Strelka"] + ["nBAM_Concordance_FET", "nBAM_StrandBias_FET", "nBAM_Clipping_FET", - "tBAM_Concordance_FET", "tBAM_StrandBias_FET", "tBAM_Clipping_FET"] + ["if_dbsnp", "COMMON"] + ["M2_STR"], enumerate(header)))) - stralka_scor = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["Strelka_Score"], enumerate(header)))) - stralka_qss = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["Strelka_QSS"], enumerate(header)))) - stralka_tqss = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["Strelka_TQSS"], enumerate(header)))) - varscan2_score = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["VarScan2_Score"], enumerate(header)))) - vardict_score = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["VarDict_Score"], enumerate(header)))) - m2_lod = list(map(lambda x: x[0], filter(lambda x: x[1] in [ - "M2_NLOD", "M2_TLOD"], enumerate(header)))) - sniper_score = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["Sniper_Score"], enumerate(header)))) - m2_ecent = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["M2_ECNT"], enumerate(header)))) - sor = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["SOR"], enumerate(header)))) - msi = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["MSI"], enumerate(header)))) - msilen = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["MSILEN"], enumerate(header)))) - shift3 = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["SHIFT3"], enumerate(header)))) - MaxHomopolymer_Length = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["MaxHomopolymer_Length"], enumerate(header)))) - SiteHomopolymer_Length = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["SiteHomopolymer_Length"], enumerate(header)))) - InDel_Length = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["InDel_Length"], enumerate(header)))) - Seq_Complexity_ = list(map(lambda x: x[0], filter( - lambda x: x[1] in ["Seq_Complexity_Span", "Seq_Complexity_Adj"], enumerate(header)))) - - min_max_features = [[cov_features, 0, 2 * COV], - [mq_features, 0, 70], - [bq_features, 0, 41], - [nm_diff_features, -2 * COV, 2 * COV], - [zero_to_one_features, 0, 1], - [ranksum_features, -30, 30], - [stralka_scor, 0, 40], - [stralka_qss, 0, 200], - [stralka_tqss, 0, 4], - [varscan2_score, 0, 60], - [vardict_score, 0, 120], - [m2_lod, 0, 100], - [sniper_score, 0, 120], - [m2_ecent, 0, 40], - [sor, 0, 100], - [msi, 0, 100], - [msilen, 0, 10], - [shift3, 0, 100], - [MaxHomopolymer_Length, 0, 50], - [SiteHomopolymer_Length, 0, 50], - [InDel_Length, -30, 30], - ] + cov_features = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] + in [ + "Consistent_Mates", + "Inconsistent_Mates", + "N_DP", + "nBAM_REF_NM", + "nBAM_ALT_NM", + "nBAM_REF_Concordant", + "nBAM_REF_Discordant", + "nBAM_ALT_Concordant", + "nBAM_ALT_Discordant", + "N_REF_FOR", + "N_REF_REV", + "N_ALT_FOR", + "N_ALT_REV", + "nBAM_REF_Clipped_Reads", + "nBAM_ALT_Clipped_Reads", + "nBAM_MQ0", + "nBAM_Other_Reads", + "nBAM_Poor_Reads", + "nBAM_REF_InDel_3bp", + "nBAM_REF_InDel_2bp", + "nBAM_REF_InDel_1bp", + "nBAM_ALT_InDel_3bp", + "nBAM_ALT_InDel_2bp", + "nBAM_ALT_InDel_1bp", + "T_DP", + "tBAM_REF_NM", + "tBAM_ALT_NM", + "tBAM_REF_Concordant", + "tBAM_REF_Discordant", + "tBAM_ALT_Concordant", + "tBAM_ALT_Discordant", + "T_REF_FOR", + "T_REF_REV", + "T_ALT_FOR", + "T_ALT_REV", + "tBAM_REF_Clipped_Reads", + "tBAM_ALT_Clipped_Reads", + "tBAM_MQ0", + "tBAM_Other_Reads", + "tBAM_Poor_Reads", + "tBAM_REF_InDel_3bp", + "tBAM_REF_InDel_2bp", + "tBAM_REF_InDel_1bp", + "tBAM_ALT_InDel_3bp", + "tBAM_ALT_InDel_2bp", + "tBAM_ALT_InDel_1bp", + ], + enumerate(header), + ), + ) + ) + mq_features = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] + in ["nBAM_REF_MQ", "nBAM_ALT_MQ", "tBAM_REF_MQ", "tBAM_ALT_MQ"], + enumerate(header), + ), + ) + ) + bq_features = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] + in ["nBAM_REF_BQ", "nBAM_ALT_BQ", "tBAM_REF_BQ", "tBAM_ALT_BQ"], + enumerate(header), + ), + ) + ) + nm_diff_features = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] in ["nBAM_NM_Diff", "tBAM_NM_Diff"], + enumerate(header), + ), + ) + ) + ranksum_features = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] + in [ + "nBAM_Z_Ranksums_MQ", + "nBAM_Z_Ranksums_BQ", + "nBAM_Z_Ranksums_EndPos", + "tBAM_Z_Ranksums_BQ", + "tBAM_Z_Ranksums_MQ", + "tBAM_Z_Ranksums_EndPos", + ], + enumerate(header), + ), + ) + ) + zero_to_one_features = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] + in [ + "if_MuTect", + "if_VarScan2", + "if_SomaticSniper", + "if_VarDict", + "MuSE_Tier", + "if_Strelka", + ] + + [ + "nBAM_Concordance_FET", + "nBAM_StrandBias_FET", + "nBAM_Clipping_FET", + "tBAM_Concordance_FET", + "tBAM_StrandBias_FET", + "tBAM_Clipping_FET", + ] + + ["if_dbsnp", "COMMON"] + + ["M2_STR"], + enumerate(header), + ), + ) + ) + stralka_scor = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["Strelka_Score"], enumerate(header)), + ) + ) + stralka_qss = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["Strelka_QSS"], enumerate(header)), + ) + ) + stralka_tqss = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["Strelka_TQSS"], enumerate(header)), + ) + ) + varscan2_score = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["VarScan2_Score"], enumerate(header)), + ) + ) + vardict_score = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["VarDict_Score"], enumerate(header)), + ) + ) + m2_lod = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["M2_NLOD", "M2_TLOD"], enumerate(header)), + ) + ) + sniper_score = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["Sniper_Score"], enumerate(header)), + ) + ) + m2_ecent = list( + map( + lambda x: x[0], filter(lambda x: x[1] in ["M2_ECNT"], enumerate(header)) + ) + ) + sor = list( + map(lambda x: x[0], filter(lambda x: x[1] in ["SOR"], enumerate(header))) + ) + msi = list( + map(lambda x: x[0], filter(lambda x: x[1] in ["MSI"], enumerate(header))) + ) + msilen = list( + map(lambda x: x[0], filter(lambda x: x[1] in ["MSILEN"], enumerate(header))) + ) + shift3 = list( + map(lambda x: x[0], filter(lambda x: x[1] in ["SHIFT3"], enumerate(header))) + ) + MaxHomopolymer_Length = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["MaxHomopolymer_Length"], enumerate(header)), + ) + ) + SiteHomopolymer_Length = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["SiteHomopolymer_Length"], enumerate(header)), + ) + ) + InDel_Length = list( + map( + lambda x: x[0], + filter(lambda x: x[1] in ["InDel_Length"], enumerate(header)), + ) + ) + Seq_Complexity_ = list( + map( + lambda x: x[0], + filter( + lambda x: x[1] in ["Seq_Complexity_Span", "Seq_Complexity_Adj"], + enumerate(header), + ), + ) + ) + + min_max_features = [ + [cov_features, 0, 2 * COV], + [mq_features, 0, 70], + [bq_features, 0, 41], + [nm_diff_features, -2 * COV, 2 * COV], + [zero_to_one_features, 0, 1], + [ranksum_features, -30, 30], + [stralka_scor, 0, 40], + [stralka_qss, 0, 200], + [stralka_tqss, 0, 4], + [varscan2_score, 0, 60], + [vardict_score, 0, 120], + [m2_lod, 0, 100], + [sniper_score, 0, 120], + [m2_ecent, 0, 40], + [sor, 0, 100], + [msi, 0, 100], + [msilen, 0, 10], + [shift3, 0, 100], + [MaxHomopolymer_Length, 0, 50], + [SiteHomopolymer_Length, 0, 50], + [InDel_Length, -30, 30], + ] if not no_seq_complexity: min_max_features.append([Seq_Complexity_, 0, 40]) @@ -1903,8 +2682,7 @@ def extract_ensemble(ensemble_tsvs, ensemble_bed, no_seq_complexity, enforce_hea ensemble_data[:, np.array(varscan2_score)] = 0 selected_features = sorted([i for f in min_max_features for i in f[0]]) - selected_features_tags = list( - map(lambda x: header[x], selected_features)) + selected_features_tags = list(map(lambda x: header[x], selected_features)) if n_vars > 0: for i_s, mn, mx in min_max_features: if i_s: @@ -1917,27 +2695,28 @@ def extract_ensemble(ensemble_tsvs, ensemble_bed, no_seq_complexity, enforce_hea else: ensemble_data = ensemble_data.tolist() selected_features_tags = header_ - with open(ensemble_bed, "w")as f_: - f_.write( - "#" + "\t".join(map(str, header_pos + selected_features_tags)) + "\n") + with open(ensemble_bed, "w") as f_: + f_.write("#" + "\t".join(map(str, header_pos + selected_features_tags)) + "\n") for i, s in enumerate(ensemble_data): f_.write("\t".join(map(str, ensemble_pos[i] + s)) + "\n") return ensemble_bed + def chunks(lst, n): for i in range(0, len(lst), n): - yield lst[i:i + n] - + yield lst[i : i + n] - -def parallel_generation(inputs): +def parallel_generation(inputs): map_args, matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed = inputs thread_logger = logging.getLogger( - "{} ({})".format(parallel_generation.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + parallel_generation.__name__, multiprocessing.current_process().name + ) + ) try: - chrom_pos={} + chrom_pos = {} for w in map_args: record = w[3] chrom = record[0] @@ -1947,61 +2726,90 @@ def parallel_generation(inputs): if chrom not in chrom_pos: chrom_pos[chrom] = [s_pos, e_pos] else: - chrom_pos[chrom] = [min(s_pos,chrom_pos[chrom][0]), max(e_pos,chrom_pos[chrom][1])] + chrom_pos[chrom] = [ + min(s_pos, chrom_pos[chrom][0]), + max(e_pos, chrom_pos[chrom][1]), + ] thread_logger.info(chrom_pos) # thread_logger.info("Gener-7") tb_tumor = pysam.TabixFile(tumor_count_bed, parser=pysam.asTuple()) tb_normal = pysam.TabixFile(normal_count_bed, parser=pysam.asTuple()) - tumor_tabix_records_dict={} - normal_tabix_records_dict={} + tumor_tabix_records_dict = {} + normal_tabix_records_dict = {} for chrom in chrom_pos: - t2=time.time() - tumor_tabix_records_dict[chrom]={} - normal_tabix_records_dict[chrom]={} + t2 = time.time() + tumor_tabix_records_dict[chrom] = {} + normal_tabix_records_dict[chrom] = {} try: - tumor_tabix_records = list(tb_tumor.fetch(chrom,chrom_pos[chrom][0]-1,chrom_pos[chrom][1])) + tumor_tabix_records = list( + tb_tumor.fetch(chrom, chrom_pos[chrom][0] - 1, chrom_pos[chrom][1]) + ) except: - thread_logger.warning("No count information at {} for {}:{}-{}".format(tumor_count_bed,chrom,chrom_pos[chrom][0]-1,chrom_pos[chrom][1])) + thread_logger.warning( + "No count information at {} for {}:{}-{}".format( + tumor_count_bed, + chrom, + chrom_pos[chrom][0] - 1, + chrom_pos[chrom][1], + ) + ) tumor_tabix_records = [] try: - normal_tabix_records = list(tb_normal.fetch(chrom,chrom_pos[chrom][0]-1,chrom_pos[chrom][1])) + normal_tabix_records = list( + tb_normal.fetch(chrom, chrom_pos[chrom][0] - 1, chrom_pos[chrom][1]) + ) except: - thread_logger.warning("No count information at {} for {}:{}-{}".format(normal_count_bed,chrom,chrom_pos[chrom][0]-1,chrom_pos[chrom][1])) + thread_logger.warning( + "No count information at {} for {}:{}-{}".format( + normal_count_bed, + chrom, + chrom_pos[chrom][0] - 1, + chrom_pos[chrom][1], + ) + ) normal_tabix_records = [] # thread_logger.info(["ffff-1",time.time()-t2]) - t2=time.time() + t2 = time.time() for x in tumor_tabix_records: pos = int(x[1]) - if pos not in tumor_tabix_records_dict[chrom]: - tumor_tabix_records_dict[chrom][pos]=[] + if pos not in tumor_tabix_records_dict[chrom]: + tumor_tabix_records_dict[chrom][pos] = [] tumor_tabix_records_dict[chrom][pos].append(list(x)) for x in normal_tabix_records: pos = int(x[1]) - if pos not in normal_tabix_records_dict[chrom]: - normal_tabix_records_dict[chrom][pos]=[] + if pos not in normal_tabix_records_dict[chrom]: + normal_tabix_records_dict[chrom][pos] = [] normal_tabix_records_dict[chrom][pos].append(list(x)) # thread_logger.info(["ffff-2",time.time()-t2]) - t2=time.time() + t2 = time.time() del tumor_tabix_records, normal_tabix_records # thread_logger.info(["ffff-3",time.time()-t2]) - thread_logger.info(["Gener-8",len(map_args)]) + thread_logger.info(["Gener-8", len(map_args)]) - records_done=[] + records_done = [] for w in map_args: record = w[3] chrom = record[0] pos = int(record[1]) s_pos = max(1, pos - matrix_base_pad) e_pos = min(pos + matrix_base_pad, chrom_lengths[chrom] - 2) - tumor_counts ={x_pos:tumor_tabix_records_dict[chrom][x_pos] for x_pos in range(s_pos,e_pos+1) if x_pos in tumor_tabix_records_dict[chrom]} - normal_counts ={x_pos:normal_tabix_records_dict[chrom][x_pos] for x_pos in range(s_pos,e_pos+1) if x_pos in normal_tabix_records_dict[chrom]} + tumor_counts = { + x_pos: tumor_tabix_records_dict[chrom][x_pos] + for x_pos in range(s_pos, e_pos + 1) + if x_pos in tumor_tabix_records_dict[chrom] + } + normal_counts = { + x_pos: normal_tabix_records_dict[chrom][x_pos] + for x_pos in range(s_pos, e_pos + 1) + if x_pos in normal_tabix_records_dict[chrom] + } w[1] = tumor_counts w[2] = normal_counts - o=prep_data_single_tabix(w) + o = prep_data_single_tabix(w) if o is None: aaaa records_done.append(o) @@ -2011,21 +2819,37 @@ def parallel_generation(inputs): thread_logger.error(ex) return None -def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_bed_file, tumor_count_bed, normal_count_bed, ref_file, - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_cov, num_threads, ensemble_tsv, - ensemble_bed, - ensemble_custom_header, - no_seq_complexity, enforce_header, - zero_vscore, - matrix_dtype, - strict_labeling, - tsv_batch_size): + +def generate_dataset( + work, + truth_vcf_file, + mode, + tumor_pred_vcf_file, + region_bed_file, + tumor_count_bed, + normal_count_bed, + ref_file, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_cov, + num_threads, + ensemble_tsv, + ensemble_bed, + ensemble_custom_header, + no_seq_complexity, + enforce_header, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, +): logger = logging.getLogger(generate_dataset.__name__) logger.info("---------------------Generate Dataset----------------------") logger.info(tumor_count_bed) - t1=time.time() + t1 = time.time() # logger.info("Gener-0") if not os.path.exists(work): @@ -2038,26 +2862,30 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be tempfile.tempdir = bed_tempdir if mode == "train" and not truth_vcf_file: - raise(RuntimeError("--truth_vcf is needed for 'train' mode")) + raise (RuntimeError("--truth_vcf is needed for 'train' mode")) if mode == "call": truth_vcf_file = os.path.join(work, "empty.vcf") with open(truth_vcf_file, "w") as o_f: o_f.write("{}\n".format(VCF_HEADER)) - o_f.write( - "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") + o_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") split_batch_size = 10000 if ensemble_tsv and not ensemble_bed: ensemble_bed = os.path.join(work, "ensemble.bed") - extract_ensemble(ensemble_tsvs=[ensemble_tsv], ensemble_bed=ensemble_bed, - no_seq_complexity=no_seq_complexity, enforce_header=enforce_header, - custom_header=ensemble_custom_header, - zero_vscore=zero_vscore, - is_extend=False) + extract_ensemble( + ensemble_tsvs=[ensemble_tsv], + ensemble_bed=ensemble_bed, + no_seq_complexity=no_seq_complexity, + enforce_header=enforce_header, + custom_header=ensemble_custom_header, + zero_vscore=zero_vscore, + is_extend=False, + ) tmp_ = bedtools_intersect( - tumor_pred_vcf_file, region_bed_file, args=" -u", run_logger=logger) + tumor_pred_vcf_file, region_bed_file, args=" -u", run_logger=logger + ) len_candids = 0 with open(tmp_) as i_f: for line in skip_empty(i_f): @@ -2065,16 +2893,16 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be if ensemble_bed: tmp_ = bedtools_intersect( - ensemble_bed, region_bed_file, args=" -u", run_logger=logger) + ensemble_bed, region_bed_file, args=" -u", run_logger=logger + ) with open(tmp_) as i_f: for line in i_f: len_candids += 1 logger.info("len_candids: {}".format(len_candids)) num_splits = max(len_candids // split_batch_size, num_threads) split_region_files = split_region( - work, region_bed_file, num_splits, - shuffle_intervals=False - ) + work, region_bed_file, num_splits, shuffle_intervals=False + ) fasta_file = pysam.Fastafile(ref_file) chrom_lengths = dict(zip(fasta_file.references, fasta_file.lengths)) @@ -2089,15 +2917,26 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be x = i_f.readline().strip().split() if x: num_ens_features = len(x) - 5 - + # logger.info(["rrr-1",time.time()-t1]) t1 = time.time() # logger.info("Gener-1") map_args = [] for i, split_region_file in enumerate(split_region_files): - map_args.append((work, split_region_file, truth_vcf_file, - tumor_pred_vcf_file, ref_file, ensemble_bed, num_ens_features, strict_labeling, i)) + map_args.append( + ( + work, + split_region_file, + truth_vcf_file, + tumor_pred_vcf_file, + ref_file, + ensemble_bed, + num_ens_features, + strict_labeling, + i, + ) + ) if num_threads == 1: records_data = [] for w in map_args: @@ -2116,14 +2955,21 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be if o is None: raise Exception("find_records failed!") - none_vcf = "{}/none.vcf".format(work) var_vcf = "{}/var.vcf".format(work) if not os.path.exists(work): os.mkdir("{}".format(work)) total_ims = 0 - for records_r, none_records, vtype, record_len, record_center, chroms_order, anns in records_data: + for ( + records_r, + none_records, + vtype, + record_len, + record_center, + chroms_order, + anns, + ) in records_data: total_ims += len(records_r) + len(none_records) # logger.info("Gener-2") @@ -2141,33 +2987,78 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be else: is_end = total_ims candidates_tsv_file = "{}/candidates_{}.tsv".format(work, is_) - logger.info("Write {}/{} split to {} for cnts ({}..{})/{}".format( - is_ + 1, candidates_split, candidates_tsv_file, is_current, - is_end, total_ims)) + logger.info( + "Write {}/{} split to {} for cnts ({}..{})/{}".format( + is_ + 1, + candidates_split, + candidates_tsv_file, + is_current, + is_end, + total_ims, + ) + ) map_args_records = [] - for records_r, none_records, vtype, record_len, record_center, chroms_order, anns in records_data: + for ( + records_r, + none_records, + vtype, + record_len, + record_center, + chroms_order, + anns, + ) in records_data: if len(records_r) + len(none_records) + cnt < is_current: - cnt += len(records_r)+len(none_records) + cnt += len(records_r) + len(none_records) else: - for is_none, records in [["False", records_r],["True",none_records]]: + for is_none, records in [ + ["False", records_r], + ["True", none_records], + ]: for record in records: cnt += 1 if is_current <= cnt < is_end: - vartype = vtype[int(record[-1])] if not is_none else "NONE" + vartype = ( + vtype[int(record[-1])] if not is_none else "NONE" + ) rlen = record_len[int(record[-1])] if not is_none else 0 rcenter = record_center[int(record[-1])] ch_order = chroms_order[record[0]] - ann = list(anns[int(record[-1])] - ) if ensemble_bed else [] - + ann = ( + list(anns[int(record[-1])]) if ensemble_bed else [] + ) + chrom, pos = record[0:2] pos = int(pos) s_pos = max(1, pos - matrix_base_pad) - e_pos = min(pos + matrix_base_pad, chrom_lengths[chrom] - 2) - ref_seq = fasta_file.fetch(chrom, s_pos-1, e_pos).upper().replace("N", "-") - - map_args_records.append([ref_seq, None, None, record, vartype, rlen, rcenter, ch_order, - matrix_base_pad, matrix_width, min_ev_frac_per_col, min_cov, ann, chrom_lengths, matrix_dtype, is_none]) + e_pos = min( + pos + matrix_base_pad, chrom_lengths[chrom] - 2 + ) + ref_seq = ( + fasta_file.fetch(chrom, s_pos - 1, e_pos) + .upper() + .replace("N", "-") + ) + + map_args_records.append( + [ + ref_seq, + None, + None, + record, + vartype, + rlen, + rcenter, + ch_order, + matrix_base_pad, + matrix_width, + min_ev_frac_per_col, + min_cov, + ann, + chrom_lengths, + matrix_dtype, + is_none, + ] + ) if cnt >= is_end: break if cnt >= is_end: @@ -2179,17 +3070,37 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be len_records = len(map_args_records) records_done = [] - if len_records>0: + if len_records > 0: # logger.info("Gener-9") if num_threads == 1: - records_done_ = [parallel_generation([map_args_records, matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed])] - else: + records_done_ = [ + parallel_generation( + [ + map_args_records, + matrix_base_pad, + chrom_lengths, + tumor_count_bed, + normal_count_bed, + ] + ) + ] + else: pool = multiprocessing.Pool(num_threads) try: - split_len=max(1,len_records//num_threads) + split_len = max(1, len_records // num_threads) records_done_ = pool.map_async( - parallel_generation, [[map_args_records[i_split:i_split+(split_len)],matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed] - for i_split in range(0, len_records, split_len)]).get() + parallel_generation, + [ + [ + map_args_records[i_split : i_split + (split_len)], + matrix_base_pad, + chrom_lengths, + tumor_count_bed, + normal_count_bed, + ] + for i_split in range(0, len_records, split_len) + ], + ).get() pool.close() except Exception as inst: logger.error(inst) @@ -2211,14 +3122,49 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be if x: tag, compressed_candidate_mat, record, ann, is_none = x if not is_none: - vv.write("\t".join([record[0], str(record[1]), ".", record[2], record[ - 3], ".", ".", "TAG={};".format(tag), ".", "."]) + "\n") + vv.write( + "\t".join( + [ + record[0], + str(record[1]), + ".", + record[2], + record[3], + ".", + ".", + "TAG={};".format(tag), + ".", + ".", + ] + ) + + "\n" + ) else: - nv.write("\t".join([record[0], str(record[1]), ".", record[2], record[ - 3], ".", ".", "TAG={};".format(tag), ".", "."]) + "\n") + nv.write( + "\t".join( + [ + record[0], + str(record[1]), + ".", + record[2], + record[3], + ".", + ".", + "TAG={};".format(tag), + ".", + ".", + ] + ) + + "\n" + ) tsv_idx.append(b_o.tell()) - b_o.write("\t".join([str(cnt_ims), "1", tag, compressed_candidate_mat] + list(map( - lambda x: str(np.round(x, 4)), ann))) + "\n") + b_o.write( + "\t".join( + [str(cnt_ims), "1", tag, compressed_candidate_mat] + + list(map(lambda x: str(np.round(x, 4)), ann)) + ) + + "\n" + ) cnt_ims += 1 tsv_idx.append(b_o.tell()) pickle.dump(tsv_idx, open(candidates_tsv_file + ".idx", "wb")) @@ -2239,62 +3185,111 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be logger.info("Generating dataset is Done.") -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + +if __name__ == "__main__": + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser( - description='Generate dataset for train/call candidate variants on CNN') - parser.add_argument('--mode', type=str, help='train/call mode', - choices=["train", "call"], required=True) - parser.add_argument('--truth_vcf', type=str, - help='truth vcf (required for train mode)', default=None) - parser.add_argument('--tumor_pred_vcf', type=str, - help='tumor candidate variants vcf file', required=True) - parser.add_argument('--region_bed', type=str, - help='region bed', required=True) - parser.add_argument('--tumor_count_bed', type=str, - help='tumor count bed.gz tabix file', required=True) - parser.add_argument('--normal_count_bed', type=str, - help='normal count bed.gz tabix file', required=True) - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--work', type=str, - help='work directory', required=True) - parser.add_argument('--tsv_batch_size', type=int, - help='output files batch size', default=50000) - parser.add_argument('--matrix_window_size', type=int, - help='target window width', default=32) - parser.add_argument('--matrix_base_pad', type=int, - help='number of bases to pad around the candidate variant', default=7) - parser.add_argument('--min_ev_frac_per_col', type=float, - help='minimum frac cov per column to keep columm', default=0.06) - parser.add_argument('--min_cov', type=int, help='minimum cov', default=5) - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) - parser.add_argument('--ensemble_tsv', type=str, - help='Ensemble annotation tsv file (only for short read)', default=None) - parser.add_argument('--ensemble_bed', type=str, - help='Ensemble annotation bed file (only for short read)', default=None) - parser.add_argument('--ensemble_custom_header', - help='Allow ensemble tsv to have custom header fields', - action="store_true") - parser.add_argument('--no_seq_complexity', - help='Dont compute linguistic sequence complexity features', - action="store_true") - parser.add_argument('--enforce_header', - help='Enforce header match for ensemble_tsv', - action="store_true") - parser.add_argument('--zero_vscore', - help='set VarScan2_Score to zero', - action="store_true") - parser.add_argument('--matrix_dtype', type=str, - help='matrix_dtype to be used to store matrix', default="uint8", - choices=MAT_DTYPES) - parser.add_argument('--strict_labeling', - help='strict labeling in train mode', - action="store_true") + description="Generate dataset for train/call candidate variants on CNN" + ) + parser.add_argument( + "--mode", + type=str, + help="train/call mode", + choices=["train", "call"], + required=True, + ) + parser.add_argument( + "--truth_vcf", + type=str, + help="truth vcf (required for train mode)", + default=None, + ) + parser.add_argument( + "--tumor_pred_vcf", + type=str, + help="tumor candidate variants vcf file", + required=True, + ) + parser.add_argument("--region_bed", type=str, help="region bed", required=True) + parser.add_argument( + "--tumor_count_bed", + type=str, + help="tumor count bed.gz tabix file", + required=True, + ) + parser.add_argument( + "--normal_count_bed", + type=str, + help="normal count bed.gz tabix file", + required=True, + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--work", type=str, help="work directory", required=True) + parser.add_argument( + "--tsv_batch_size", type=int, help="output files batch size", default=50000 + ) + parser.add_argument( + "--matrix_window_size", type=int, help="target window width", default=32 + ) + parser.add_argument( + "--matrix_base_pad", + type=int, + help="number of bases to pad around the candidate variant", + default=7, + ) + parser.add_argument( + "--min_ev_frac_per_col", + type=float, + help="minimum frac cov per column to keep columm", + default=0.06, + ) + parser.add_argument("--min_cov", type=int, help="minimum cov", default=5) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) + parser.add_argument( + "--ensemble_tsv", + type=str, + help="Ensemble annotation tsv file (only for short read)", + default=None, + ) + parser.add_argument( + "--ensemble_bed", + type=str, + help="Ensemble annotation bed file (only for short read)", + default=None, + ) + parser.add_argument( + "--ensemble_custom_header", + help="Allow ensemble tsv to have custom header fields", + action="store_true", + ) + parser.add_argument( + "--no_seq_complexity", + help="Dont compute linguistic sequence complexity features", + action="store_true", + ) + parser.add_argument( + "--enforce_header", + help="Enforce header match for ensemble_tsv", + action="store_true", + ) + parser.add_argument( + "--zero_vscore", help="set VarScan2_Score to zero", action="store_true" + ) + parser.add_argument( + "--matrix_dtype", + type=str, + help="matrix_dtype to be used to store matrix", + default="uint8", + choices=MAT_DTYPES, + ) + parser.add_argument( + "--strict_labeling", help="strict labeling in train mode", action="store_true" + ) args = parser.parse_args() logger.info(args) @@ -2321,18 +3316,32 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be matrix_dtype = args.matrix_dtype strict_labeling = args.strict_labeling try: - generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_bed_file, tumor_count_bed, normal_count_bed, ref_file, - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_cov, num_threads, ensemble_tsv, - ensemble_bed, - ensemble_custom_header, - no_seq_complexity, enforce_header, - zero_vscore, - matrix_dtype, - strict_labeling, - tsv_batch_size) + generate_dataset( + work, + truth_vcf_file, + mode, + tumor_pred_vcf_file, + region_bed_file, + tumor_count_bed, + normal_count_bed, + ref_file, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_cov, + num_threads, + ensemble_tsv, + ensemble_bed, + ensemble_custom_header, + no_seq_complexity, + enforce_header, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "generate_dataset.py failure on arguments: {}".format(args)) + logger.error("generate_dataset.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/genomic_file_handlers.py b/neusomatic/python/genomic_file_handlers.py index 6e45a3a..86df6fe 100644 --- a/neusomatic/python/genomic_file_handlers.py +++ b/neusomatic/python/genomic_file_handlers.py @@ -10,55 +10,110 @@ # The regular expression pattern for "chrXX 1234567" in both VarScan2 # Output and VCF files: pattern_major_chr_position = re.compile( - r'^(?:chr)?(?:[1-9]|1[0-9]|2[0-2]|[XY]|MT?)\t[0-9]+\b') + r"^(?:chr)?(?:[1-9]|1[0-9]|2[0-2]|[XY]|MT?)\t[0-9]+\b" +) # More lenient pattern: -pattern_chr_position = re.compile(r'[^\t]+\t[0-9]+\b') -pattern_chrom = re.compile(r'(?:chr)?([1-9]|1[0-9]|2[0-2]|[XY]|MT?)\W') +pattern_chr_position = re.compile(r"[^\t]+\t[0-9]+\b") +pattern_chrom = re.compile(r"(?:chr)?([1-9]|1[0-9]|2[0-2]|[XY]|MT?)\W") # Valid Phred+33 quality strings: valid_q = set() [valid_q.add(chr(33 + i)) for i in range(42)] -nan = float('nan') -inf = float('inf') - -AA_3to1 = {"Ala": "A", "Arg": "R", "Asn": "N", "Asp": "D", "Cys": "C", "Glu": "E", "Gln": "Q", "Gly": "G", "His": "H", "Ile": "I", - "Leu": "L", "Lys": "K", "Met": "M", "Phe": "F", "Pro": "P", "Ser": "S", "Thr": "T", "Trp": "W", "Tyr": "Y", "Val": "V"} -AA_1to3 = {"A": "Ala", "R": "Arg", "N": "Asn", "D": "Asp", "C": "Cys", "E": "Glu", "Q": "Gln", "G": "Gly", "H": "His", "I": "Ile", - "L": "Leu", "K": "Lys", "M": "Met", "F": "Phe", "P": "Pro", "S": "Ser", "T": "Thr", "W": "Trp", "Y": "Tyr", "V": "Val"} +nan = float("nan") +inf = float("inf") + +AA_3to1 = { + "Ala": "A", + "Arg": "R", + "Asn": "N", + "Asp": "D", + "Cys": "C", + "Glu": "E", + "Gln": "Q", + "Gly": "G", + "His": "H", + "Ile": "I", + "Leu": "L", + "Lys": "K", + "Met": "M", + "Phe": "F", + "Pro": "P", + "Ser": "S", + "Thr": "T", + "Trp": "W", + "Tyr": "Y", + "Val": "V", +} +AA_1to3 = { + "A": "Ala", + "R": "Arg", + "N": "Asn", + "D": "Asp", + "C": "Cys", + "E": "Glu", + "Q": "Gln", + "G": "Gly", + "H": "His", + "I": "Ile", + "L": "Leu", + "K": "Lys", + "M": "Met", + "F": "Phe", + "P": "Pro", + "S": "Ser", + "T": "Thr", + "W": "Trp", + "Y": "Tyr", + "V": "Val", +} ### ### ### ### ### MAJOR CLASSES ### ### ### ### ### class Vcf_line: - '''Each instance of this object is a line from the vcf file (no header).''' + """Each instance of this object is a line from the vcf file (no header).""" def __init__(self, vcf_line): - '''Argument is a line in pileup file.''' - self.vcf_line = vcf_line.rstrip('\n') + """Argument is a line in pileup file.""" + self.vcf_line = vcf_line.rstrip("\n") try: - self.chromosome, self.position, self.identifier, self.refbase, self.altbase, self.qual, self.filters, self.info, * \ - self.has_samples = vcf_line.rstrip('\n').split('\t') + ( + self.chromosome, + self.position, + self.identifier, + self.refbase, + self.altbase, + self.qual, + self.filters, + self.info, + *self.has_samples, + ) = vcf_line.rstrip("\n").split("\t") self.position = int(self.position) try: self.field, *self.samples = self.has_samples except ValueError: - self.field = self.samples = '' + self.field = self.samples = "" except ValueError: - self.chromosome = self.identifier = self.refbase = self.altbase = self.qual = self.filters = self.info = self.field = self.samples = '' + self.chromosome = ( + self.identifier + ) = ( + self.refbase + ) = ( + self.altbase + ) = self.qual = self.filters = self.info = self.field = self.samples = "" self.position = None def get_info_items(self): - return self.info.split(';') + return self.info.split(";") def get_info_value(self, variable): - key_item = re.search( - r'\b{}=([^;\s]+)([;\W]|$)'.format(variable), self.vcf_line) + key_item = re.search(r"\b{}=([^;\s]+)([;\W]|$)".format(variable), self.vcf_line) # The key has a value attached to it, e.g., VAR=1,2,3 if key_item: @@ -66,24 +121,23 @@ def get_info_value(self, variable): # Perhaps it's simply a flag without "=" else: - key_item = self.info.split(';') + key_item = self.info.split(";") return True if variable in key_item else False def get_sample_variable(self): - return self.field.split(':') + return self.field.split(":") - def get_sample_item(self, idx=0, out_type='d'): - '''d to output a dictionary. l to output a tuple of lists''' + def get_sample_item(self, idx=0, out_type="d"): + """d to output a dictionary. l to output a tuple of lists""" - if out_type.lower() == 'd': - return dict(zip(self.get_sample_variable(), self.samples[idx].split(':'))) - elif out_type.lower() == 'l': - return (self.get_sample_variable(), self.samples[idx].split(':')) + if out_type.lower() == "d": + return dict(zip(self.get_sample_variable(), self.samples[idx].split(":"))) + elif out_type.lower() == "l": + return (self.get_sample_variable(), self.samples[idx].split(":")) def get_sample_value(self, variable, idx=0): - var2value = dict(zip(self.field.split( - ':'), self.samples[idx].split(':'))) + var2value = dict(zip(self.field.split(":"), self.samples[idx].split(":"))) try: return var2value[variable] @@ -92,10 +146,10 @@ def get_sample_value(self, variable, idx=0): class pysam_header: - ''' + """ Extract BAM header using pysam. Only sample name (SM) so far. - ''' + """ def __init__(self, bam_file): @@ -103,12 +157,12 @@ def __init__(self, bam_file): self.bam_header = bam.header def SM(self): - '''Sample Name''' + """Sample Name""" sample_name = set() - for header_i in self.bam_header['RG']: - sample_name.add(header_i['SM']) + for header_i in self.bam_header["RG"]: + sample_name.add(header_i["SM"]) sample_name = tuple(sample_name) return sample_name @@ -119,35 +173,36 @@ def SM(self): ### ### ### ### ### FUNCTIONS OF CONVENIENCE ### ### ### ### ### + def skip_vcf_header(opened_file): line_i = opened_file.readline().rstrip() - while line_i.startswith('#'): + while line_i.startswith("#"): line_i = opened_file.readline().rstrip() return line_i def faiordict2contigorder(file_name, file_format): - '''Takes either a .fai or .dict file, and return a contig order dictionary, i.e., chrom_seq['chr1'] == 0''' + """Takes either a .fai or .dict file, and return a contig order dictionary, i.e., chrom_seq['chr1'] == 0""" - assert file_format in ('fai', 'dict') + assert file_format in ("fai", "dict") contig_sequence = [] with open(file_name) as gfile: for line_i in gfile: - if file_format == 'fai': - contig_match = re.match(r'([^\t]+)\t', line_i) + if file_format == "fai": + contig_match = re.match(r"([^\t]+)\t", line_i) - elif file_format == 'dict': - if line_i.startswith('@SQ'): - contig_match = re.match(r'@SQ\tSN:([^\t]+)\tLN:', line_i) + elif file_format == "dict": + if line_i.startswith("@SQ"): + contig_match = re.match(r"@SQ\tSN:([^\t]+)\tLN:", line_i) if contig_match: # some .fai files have space after the contig for descriptions. - contig_i = contig_match.groups()[0].split(' ')[0] + contig_i = contig_match.groups()[0].split(" ")[0] contig_sequence.append(contig_i) chrom_seq = {} @@ -160,8 +215,8 @@ def faiordict2contigorder(file_name, file_format): def open_textfile(file_name): # See if the input file is a .gz file: - if file_name.lower().endswith('.gz'): - return gzip.open(file_name, 'rt') + if file_name.lower().endswith(".gz"): + return gzip.open(file_name, "rt") else: return open(file_name) @@ -170,23 +225,23 @@ def open_textfile(file_name): def open_bam_file(file_name): try: - return AlignmentFile(file_name, 'rb') + return AlignmentFile(file_name, "rb") except ValueError: return open(file_name) def ascii2phred33(x): - '''Put in an ASCII string, return a Phred+33 score.''' + """Put in an ASCII string, return a Phred+33 score.""" return ord(x) - 33 def phred33toascii(x): - '''Put in a Phred33 score, return the character.''' + """Put in a Phred33 score, return the character.""" return chr(x + 33) def p2phred(p, max_phred=inf): - '''Convert p-value to Phred-scale quality score.''' + """Convert p-value to Phred-scale quality score.""" if p == 0: Q = max_phred @@ -209,27 +264,26 @@ def p2phred(p, max_phred=inf): def phred2p(phred): - '''Convert Phred-scale quality score to p-value.''' - return 10**(-phred / 10) + """Convert Phred-scale quality score to p-value.""" + return 10 ** (-phred / 10) def findall_index(mylist, tolookfor): - '''Find all instances in a list that matches exactly thestring.''' + """Find all instances in a list that matches exactly thestring.""" all_indices = [i for i, item in enumerate(mylist) if item == tolookfor] return all_indices def findall_index_regex(mylist, pattern): - '''Find all instances in a list that matches a regex pattern.''' - all_indices = [i for i, item in enumerate( - mylist) if re.search(pattern, item)] + """Find all instances in a list that matches a regex pattern.""" + all_indices = [i for i, item in enumerate(mylist) if re.search(pattern, item)] return all_indices def count_repeating_bases(sequence): - '''For a string, count the number of characters that appears in a row. + """For a string, count the number of characters that appears in a row. E.g., for string "ABBCCCDDDDAAAAAAA", the function returns 1, 2, 3, 4, 7, because there is 1 A, 2 B's, 3 C's, 4 D's, and then 7 A's. - ''' + """ counters = [] previous_base = None @@ -260,9 +314,9 @@ def numeric_id(chr_i, pos_i, contig_seq): # Define which chromosome coordinate is ahead for the following function: chrom_sequence = [str(i) for i in range(1, 23)] -chrom_sequence.append('X') -chrom_sequence.append('Y') -chrom_sequence.append('M') +chrom_sequence.append("X") +chrom_sequence.append("Y") +chrom_sequence.append("M") chrom_seq = {} for n, contig_i in enumerate(chrom_sequence): @@ -270,17 +324,17 @@ def numeric_id(chr_i, pos_i, contig_seq): def whoisbehind(coord_0, coord_1, chrom_sequence): - ''' + """ coord_0 and coord_1 are two strings or two lists, specifying the chromosome, a (typically) tab, and then the location. Return the index where the coordinate is behind. Return 10 if they are the same position. - ''' + """ end_of_0 = end_of_1 = False - if coord_0 == '' or coord_0 == ['', ''] or coord_0 == ('', '') or not coord_0: + if coord_0 == "" or coord_0 == ["", ""] or coord_0 == ("", "") or not coord_0: end_of_0 = True - if coord_1 == '' or coord_1 == ['', ''] or coord_1 == ('', '') or not coord_1: + if coord_1 == "" or coord_1 == ["", ""] or coord_1 == ("", "") or not coord_1: end_of_1 = True if end_of_0 and end_of_1: @@ -312,10 +366,10 @@ def whoisbehind(coord_0, coord_1, chrom_sequence): chrom1_position = chrom_sequence.index(chrom1) if chrom0_position < chrom1_position: - return 0 # 1st coordinate is ahead + return 0 # 1st coordinate is ahead elif chrom0_position > chrom1_position: - return 1 # 1st coordinate is ahead + return 1 # 1st coordinate is ahead # Must be in the same chromosome else: @@ -334,9 +388,9 @@ def whoisbehind(coord_0, coord_1, chrom_sequence): return 10 -def vcf_header_modifier(infile_handle, addons=[], getlost=' '): - '''addons = A list of INFO, FORMAT, ID, or Filter lines you want to add. - getlost = a regex expression for the ID of INFO/FORMAT/FILTER that you want to get rid of.''' +def vcf_header_modifier(infile_handle, addons=[], getlost=" "): + """addons = A list of INFO, FORMAT, ID, or Filter lines you want to add. + getlost = a regex expression for the ID of INFO/FORMAT/FILTER that you want to get rid of.""" line_i = infile_handle.readline().rstrip() @@ -347,17 +401,17 @@ def vcf_header_modifier(infile_handle, addons=[], getlost=' '): for additions in addons: vcfheader_info_format_filter.append(additions) - while line_i.startswith('##'): + while line_i.startswith("##"): - if re.match(r'##fileformat=', line_i): + if re.match(r"##fileformat=", line_i): vcffileformat = line_i - elif re.match(r'##(INFO|FORMAT|FILTER)', line_i): + elif re.match(r"##(INFO|FORMAT|FILTER)", line_i): - if not re.match(r'##(INFO|FORMAT|FILTER)= self.end_ra_pos: - self.realignments.append([int(region_start), int(region_end), int(start_idx), - int(end_idx), - int(del_start), int(del_end), int( - pos_start), int(pos_end), - new_cigar, int(excess_start), int(excess_end)]) + self.realignments.append( + [ + int(region_start), + int(region_end), + int(start_idx), + int(end_idx), + int(del_start), + int(del_end), + int(pos_start), + int(pos_end), + new_cigar, + int(excess_start), + int(excess_end), + ] + ) self.end_ra_pos = int(pos_end) def fix_record(self, record, ref_seq): logger = logging.getLogger(Realign_Read.fix_record.__name__) self.realignments = sorted(self.realignments, key=lambda x: x[0]) cigartuples = record.cigartuples - start_hc = cigartuples[0][1] if cigartuples[ - 0][0] == CIGAR_HARDCLIP else 0 + start_hc = cigartuples[0][1] if cigartuples[0][0] == CIGAR_HARDCLIP else 0 end_hc = cigartuples[-1][1] if cigartuples[-1][0] == CIGAR_HARDCLIP else 0 if start_hc: cigartuples = cigartuples[1:] @@ -120,14 +148,34 @@ def fix_record(self, record, ref_seq): try: assert self.realignments except: - logger.error("Realignments are empty for {} at {}:{}".format( - record, self.chrom, self.pos)) + logger.error( + "Realignments are empty for {} at {}:{}".format( + record, self.chrom, self.pos + ) + ) raise Exception bias = 0 - for region_start, region_end, start_idx, end_idx, del_start, del_end, \ - pos_start, pos_end, new_cigar, excess_start, excess_end in self.realignments: - c_array = np.array(list(map(lambda x: [x[0], x[1][1] if x[1][0] - != CIGAR_DEL else 0], enumerate(cigartuples)))) + for ( + region_start, + region_end, + start_idx, + end_idx, + del_start, + del_end, + pos_start, + pos_end, + new_cigar, + excess_start, + excess_end, + ) in self.realignments: + c_array = np.array( + list( + map( + lambda x: [x[0], x[1][1] if x[1][0] != CIGAR_DEL else 0], + enumerate(cigartuples), + ) + ) + ) c_map = np.repeat(c_array[:, 0], c_array[:, 1]) c_i = c_map[start_idx - bias] @@ -136,59 +184,90 @@ def fix_record(self, record, ref_seq): end_match = np.nonzero(c_map == c_e)[0][-1] if excess_start > 0: - if (c_i > 0 and cigartuples[c_i - 1][0] == CIGAR_DEL and - cigartuples[c_i - 1][1] >= (excess_start)): + if ( + c_i > 0 + and cigartuples[c_i - 1][0] == CIGAR_DEL + and cigartuples[c_i - 1][1] >= (excess_start) + ): if excess_start > del_start: del_start = excess_start else: return record if excess_end > 0: - if (c_e < (len(cigartuples) - 1) and cigartuples[c_e + 1][0] == CIGAR_DEL - and cigartuples[c_e + 1][1] >= (excess_end)): + if ( + c_e < (len(cigartuples) - 1) + and cigartuples[c_e + 1][0] == CIGAR_DEL + and cigartuples[c_e + 1][1] >= (excess_end) + ): if excess_end > del_end: del_end = excess_end - elif not(c_e == len(c_array) - 1 or cigartuples[c_e + 1][0] == CIGAR_SOFTCLIP): + elif not ( + c_e == len(c_array) - 1 or cigartuples[c_e + 1][0] == CIGAR_SOFTCLIP + ): return record left_cigartuple = cigartuples[:c_i] if del_start == 0: if begin_match < (start_idx - bias): left_cigartuple.append( - [cigartuples[c_i][0], start_idx - bias - begin_match]) + [cigartuples[c_i][0], start_idx - bias - begin_match] + ) else: try: assert cigartuples[c_i - 1][0] == CIGAR_DEL except: - logger.error("Expect DEL (c_i) in positon {} at cigartuples {}".format( - c_i - 1, cigartuples)) + logger.error( + "Expect DEL (c_i) in positon {} at cigartuples {}".format( + c_i - 1, cigartuples + ) + ) raise Exception if cigartuples[c_i - 1][1] > del_start: - left_cigartuple = left_cigartuple[ - :-1] + [[cigartuples[c_i - 1][0], cigartuples[c_i - 1][1] - del_start]] + left_cigartuple = left_cigartuple[:-1] + [ + [cigartuples[c_i - 1][0], cigartuples[c_i - 1][1] - del_start] + ] else: left_cigartuple = left_cigartuple[:-1] new_cigartuples_list.append(left_cigartuple) new_cigartuples_list.append(list(cigarstring_to_tuple(new_cigar))) - right_cigartuple = cigartuples[c_e + 1:] + right_cigartuple = cigartuples[c_e + 1 :] if del_end == 0: if end_match > (end_idx - bias): right_cigartuple = [ - [cigartuples[c_e][0], end_match - (end_idx - bias)]] + right_cigartuple + [cigartuples[c_e][0], end_match - (end_idx - bias)] + ] + right_cigartuple else: try: assert cigartuples[c_e + 1][0] == CIGAR_DEL except: - logger.info("Expect DEL (c_e) in positon {} at cigartuples {}, {}".format( - c_e + 1, cigartuples, self.realignments)) + logger.info( + "Expect DEL (c_e) in positon {} at cigartuples {}, {}".format( + c_e + 1, cigartuples, self.realignments + ) + ) logger.info(cigartuple_to_string(cigartuples)) - logger.info([region_start, region_end, start_idx, end_idx, del_start, del_end, - pos_start, pos_end, new_cigar, excess_start, excess_end]) + logger.info( + [ + region_start, + region_end, + start_idx, + end_idx, + del_start, + del_end, + pos_start, + pos_end, + new_cigar, + excess_start, + excess_end, + ] + ) raise Exception if cigartuples[c_e + 1][1] > del_end: - right_cigartuple = [[cigartuples[ - c_e + 1][0], cigartuples[c_e + 1][1] - del_end]] + right_cigartuple[1:] + right_cigartuple = [ + [cigartuples[c_e + 1][0], cigartuples[c_e + 1][1] - del_end] + ] + right_cigartuple[1:] else: right_cigartuple = right_cigartuple[1:] cigartuples = right_cigartuple @@ -199,32 +278,52 @@ def fix_record(self, record, ref_seq): new_cigartuples_list.append([[CIGAR_HARDCLIP, end_hc]]) new_cigartuples = functools.reduce( - lambda x, y: merge_cigartuples(x, y), new_cigartuples_list) - if len(new_cigartuples) > 2 and new_cigartuples[-1][0] == CIGAR_SOFTCLIP \ - and new_cigartuples[-2][0] == CIGAR_DEL: + lambda x, y: merge_cigartuples(x, y), new_cigartuples_list + ) + if ( + len(new_cigartuples) > 2 + and new_cigartuples[-1][0] == CIGAR_SOFTCLIP + and new_cigartuples[-2][0] == CIGAR_DEL + ): new_cigartuples = new_cigartuples[:-2] + [new_cigartuples[-1]] elif new_cigartuples[-1][0] == CIGAR_DEL: new_cigartuples = new_cigartuples[:-1] - if len(new_cigartuples) > 2 and new_cigartuples[0][0] == CIGAR_SOFTCLIP \ - and new_cigartuples[1][0] == CIGAR_DEL: + if ( + len(new_cigartuples) > 2 + and new_cigartuples[0][0] == CIGAR_SOFTCLIP + and new_cigartuples[1][0] == CIGAR_DEL + ): new_cigartuples = [new_cigartuples[0]] + new_cigartuples[2:] elif new_cigartuples[0][0] == CIGAR_DEL: new_cigartuples = new_cigartuples[1:] try: - assert(sum(map(lambda x: x[1] if x[0] != CIGAR_DEL else 0, new_cigartuples)) == - sum(map(lambda x: x[1] if x[0] != CIGAR_DEL else 0, record.cigartuples))) + assert sum( + map(lambda x: x[1] if x[0] != CIGAR_DEL else 0, new_cigartuples) + ) == sum( + map(lambda x: x[1] if x[0] != CIGAR_DEL else 0, record.cigartuples) + ) except: - logger.error("Old and new cigarstrings have different lengthes: {} vs {}".format( - sum(map(lambda x: x[1] if x[0] != - CIGAR_DEL else 0, new_cigartuples)), - sum(map(lambda x: x[1] if x[0] != CIGAR_DEL else 0, record.cigartuples)))) + logger.error( + "Old and new cigarstrings have different lengthes: {} vs {}".format( + sum( + map(lambda x: x[1] if x[0] != CIGAR_DEL else 0, new_cigartuples) + ), + sum( + map( + lambda x: x[1] if x[0] != CIGAR_DEL else 0, + record.cigartuples, + ) + ), + ) + ) raise Exception record.cigarstring = cigartuple_to_string(new_cigartuples) NM = find_NM(record, ref_seq) - record.tags = list(filter( - lambda x: x[0] != "NM", record.tags)) + [("NM", int(NM))] + record.tags = list(filter(lambda x: x[0] != "NM", record.tags)) + [ + ("NM", int(NM)) + ] return record @@ -240,12 +339,16 @@ def get_cigar_stat(cigartuple, keys=[]): def find_NM(record, ref_seq): logger = logging.getLogger(find_NM.__name__) - positions = np.array(list(map(lambda x: x if x else -1, - (record.get_reference_positions(full_length=True))))) - sc_start = (record.cigartuples[0][0] == - CIGAR_SOFTCLIP) * record.cigartuples[0][1] - sc_end = (record.cigartuples[-1][0] == - CIGAR_SOFTCLIP) * record.cigartuples[-1][1] + positions = np.array( + list( + map( + lambda x: x if x else -1, + (record.get_reference_positions(full_length=True)), + ) + ) + ) + sc_start = (record.cigartuples[0][0] == CIGAR_SOFTCLIP) * record.cigartuples[0][1] + sc_end = (record.cigartuples[-1][0] == CIGAR_SOFTCLIP) * record.cigartuples[-1][1] q_seq = record.seq q_seq = q_seq[sc_start:] positions = positions[sc_start:] @@ -257,21 +360,25 @@ def find_NM(record, ref_seq): mn, mx = min(non_ins_positions), max(non_ins_positions) refseq = ref_seq.get_seq(mn, mx + 1) ref_array = np.array(list(map(lambda x: NUC_to_NUM[x.upper()], list(refseq))))[ - non_ins_positions - mn] - q_array = np.array(list(map(lambda x: NUC_to_NUM[x.upper()], list(q_seq))))[ - non_ins] + non_ins_positions - mn + ] + q_array = np.array(list(map(lambda x: NUC_to_NUM[x.upper()], list(q_seq))))[non_ins] cigar_stat = get_cigar_stat(record.cigartuples, [CIGAR_DEL, CIGAR_INS]) assert ref_array.shape[0] == q_array.shape[0] - NM = sum(abs(ref_array - q_array) > 0) + \ - cigar_stat[CIGAR_DEL] + cigar_stat[CIGAR_INS] + NM = ( + sum(abs(ref_array - q_array) > 0) + + cigar_stat[CIGAR_DEL] + + cigar_stat[CIGAR_INS] + ) return NM def cigarstring_to_tuple(cigarstring): logger = logging.getLogger(cigarstring_to_tuple.__name__) - return tuple((_CIGAR_OP_DICT[op], - int(length)) for length, - op in _CIGAR_PATTERN.findall(cigarstring)) + return tuple( + (_CIGAR_OP_DICT[op], int(length)) + for length, op in _CIGAR_PATTERN.findall(cigarstring) + ) def cigartuple_to_string(cigartuples): @@ -279,23 +386,30 @@ def cigartuple_to_string(cigartuples): return "".join(map(lambda x: "%d%s" % (x[1], _CIGAR_OPS[x[0]]), cigartuples)) -def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i, ds, filter_duplicate): +def prepare_fasta( + work, region, input_bam, ref_fasta_file, include_ref, split_i, ds, filter_duplicate +): logger = logging.getLogger(prepare_fasta.__name__) in_fasta_file = os.path.join( - work, region.__str__() + "_split_{}".format(split_i) + "_0.fasta") - info_file = os.path.join(work, region.__str__() + - "_split_{}".format(split_i) + ".txt") + work, region.__str__() + "_split_{}".format(split_i) + "_0.fasta" + ) + info_file = os.path.join( + work, region.__str__() + "_split_{}".format(split_i) + ".txt" + ) with pysam.Fastafile(ref_fasta_file) as ref_fasta: with open(in_fasta_file, "w") as in_fasta: with open(info_file, "w") as info_txt: if include_ref: ref_seq = ref_fasta.fetch( - region.chrom, region.start, region.end + 1).upper() + region.chrom, region.start, region.end + 1 + ).upper() in_fasta.write(">0\n") in_fasta.write("%s\n" % ref_seq.upper()) cnt = 1 with pysam.Samfile(input_bam, "rb") as samfile: - for record in samfile.fetch(region.chrom, region.start, region.end + 1): + for record in samfile.fetch( + region.chrom, region.start, region.end + 1 + ): if record.is_unmapped: continue if filter_duplicate and record.is_duplicate: @@ -303,81 +417,127 @@ def prepare_fasta(work, region, input_bam, ref_fasta_file, include_ref, split_i, if record.is_supplementary and "SA" in dict(record.tags): sas = dict(record.tags)["SA"].split(";") sas = list(filter(None, sas)) - sas_cigs = list( - map(lambda x: x.split(",")[3], sas)) + sas_cigs = list(map(lambda x: x.split(",")[3], sas)) if record.cigarstring in sas_cigs: continue - positions = np.array(list(map(lambda x: x if x else -1, - (record.get_reference_positions( - full_length=True))))) + positions = np.array( + list( + map( + lambda x: x if x else -1, + (record.get_reference_positions(full_length=True)), + ) + ) + ) if not record.cigartuples: continue if np.random.rand() > ds: continue - sc_start = (record.cigartuples[0][0] == - CIGAR_SOFTCLIP) * record.cigartuples[0][1] - sc_end = (record.cigartuples[-1][0] == - CIGAR_SOFTCLIP) * record.cigartuples[-1][1] + sc_start = ( + record.cigartuples[0][0] == CIGAR_SOFTCLIP + ) * record.cigartuples[0][1] + sc_end = ( + record.cigartuples[-1][0] == CIGAR_SOFTCLIP + ) * record.cigartuples[-1][1] positions = positions[sc_start:] if sc_end > 0: positions = positions[:-sc_end] rstart = max(region.start, record.pos) rend = min(record.aend - 1, region.end) pos_start = positions[positions >= rstart][0] - pos_end = positions[ - (positions <= rend) & (positions >= 0)][-1] + pos_end = positions[(positions <= rend) & (positions >= 0)][-1] del_start = pos_start - rstart del_end = rend - pos_end - start_idx = np.nonzero(positions == pos_start)[ - 0][0] + sc_start - end_idx = np.nonzero(positions == pos_end)[ - 0][0] + sc_start - - if max(end_idx - start_idx, pos_end - pos_start) >= (region.span() * 0.75): - c_array = np.array(list(map(lambda x: [x[0], x[1][1] if x[1][0] - != CIGAR_DEL else 0], - enumerate(record.cigartuples)))) + start_idx = np.nonzero(positions == pos_start)[0][0] + sc_start + end_idx = np.nonzero(positions == pos_end)[0][0] + sc_start + + if max(end_idx - start_idx, pos_end - pos_start) >= ( + region.span() * 0.75 + ): + c_array = np.array( + list( + map( + lambda x: [ + x[0], + x[1][1] if x[1][0] != CIGAR_DEL else 0, + ], + enumerate(record.cigartuples), + ) + ) + ) c_map = np.repeat(c_array[:, 0], c_array[:, 1]) c_i = c_map[start_idx] c_e = c_map[end_idx] begin_match = np.nonzero(c_map == c_i)[0][0] end_match = np.nonzero(c_map == c_e)[0][-1] - my_cigartuples = record.cigartuples[c_i:c_e + 1] + my_cigartuples = record.cigartuples[c_i : c_e + 1] positions_ = positions[ - (start_idx - sc_start):(end_idx - sc_start)] + (start_idx - sc_start) : (end_idx - sc_start) + ] non_ins = np.nonzero(positions_ >= 0) - refseq = ref_fasta.fetch(region.chrom, positions_[non_ins][0], - positions_[non_ins][-1] + 1).upper() - q_seq = record.seq[start_idx:end_idx + 1] + refseq = ref_fasta.fetch( + region.chrom, + positions_[non_ins][0], + positions_[non_ins][-1] + 1, + ).upper() + q_seq = record.seq[start_idx : end_idx + 1] non_ins_positions = positions_[non_ins] - mn, mx = min(non_ins_positions), max( - non_ins_positions) - ref_array = np.array(list(map(lambda x: - NUC_to_NUM[ - x.upper()], - list(refseq))))[non_ins_positions - mn] - q_array = np.array(list(map(lambda x: NUC_to_NUM[x.upper()], list(q_seq))))[ - non_ins] + mn, mx = min(non_ins_positions), max(non_ins_positions) + ref_array = np.array( + list(map(lambda x: NUC_to_NUM[x.upper()], list(refseq))) + )[non_ins_positions - mn] + q_array = np.array( + list(map(lambda x: NUC_to_NUM[x.upper()], list(q_seq))) + )[non_ins] cigar_stat = get_cigar_stat( - my_cigartuples, [CIGAR_DEL, CIGAR_INS]) + my_cigartuples, [CIGAR_DEL, CIGAR_INS] + ) assert ref_array.shape[0] == q_array.shape[0] NM_SNP = sum(abs(ref_array - q_array) > 0) - NM_INDEL = cigar_stat[ - CIGAR_DEL] + cigar_stat[CIGAR_INS] + del_start + del_end + NM_INDEL = ( + cigar_stat[CIGAR_DEL] + + cigar_stat[CIGAR_INS] + + del_start + + del_end + ) in_fasta.write(">%s\n" % cnt) in_fasta.write( - "%s\n" % record.seq[start_idx:end_idx + 1].upper()) - info_txt.write("\t".join(map(str, [cnt, record.query_name, record.pos, - record.cigarstring, start_idx, - end_idx, - del_start, del_end, pos_start, - pos_end, NM_SNP, NM_INDEL])) + "\n") + "%s\n" % record.seq[start_idx : end_idx + 1].upper() + ) + info_txt.write( + "\t".join( + map( + str, + [ + cnt, + record.query_name, + record.pos, + record.cigarstring, + start_idx, + end_idx, + del_start, + del_end, + pos_start, + pos_end, + NM_SNP, + NM_INDEL, + ], + ) + ) + + "\n" + ) cnt += 1 return in_fasta_file, info_file -def split_bam_to_chunks(work, region, input_bam, chunk_size=200, - chunk_scale=1.5, do_split=False, filter_duplicate=False): +def split_bam_to_chunks( + work, + region, + input_bam, + chunk_size=200, + chunk_scale=1.5, + do_split=False, + filter_duplicate=False, +): logger = logging.getLogger(split_bam_to_chunks.__name__) records = [] with pysam.Samfile(input_bam, "rb") as samfile: @@ -394,14 +554,22 @@ def split_bam_to_chunks(work, region, input_bam, chunk_size=200, continue positions = np.array( - list(map(lambda x: x if x else -1, (record.get_reference_positions(full_length=True))))) + list( + map( + lambda x: x if x else -1, + (record.get_reference_positions(full_length=True)), + ) + ) + ) if not record.cigartuples: continue - sc_start = (record.cigartuples[0][0] == - CIGAR_SOFTCLIP) * record.cigartuples[0][1] - sc_end = (record.cigartuples[-1][0] == - CIGAR_SOFTCLIP) * record.cigartuples[-1][1] + sc_start = ( + record.cigartuples[0][0] == CIGAR_SOFTCLIP + ) * record.cigartuples[0][1] + sc_end = (record.cigartuples[-1][0] == CIGAR_SOFTCLIP) * record.cigartuples[ + -1 + ][1] positions = positions[sc_start:] if sc_end > 0: positions = positions[:-sc_end] @@ -411,7 +579,7 @@ def split_bam_to_chunks(work, region, input_bam, chunk_size=200, pos_end = positions[(positions <= rend) & (positions >= 0)][-1] start_idx = np.nonzero(positions == pos_start)[0][0] + sc_start end_idx = np.nonzero(positions == pos_end)[0][0] + sc_start - q_seq = record.seq[start_idx:end_idx + 1] + q_seq = record.seq[start_idx : end_idx + 1] records.append([record, len(q_seq)]) @@ -427,24 +595,26 @@ def split_bam_to_chunks(work, region, input_bam, chunk_size=200, lens = [] ds = [] n_split = (len(records) // new_chunk_size) + 1 - if 0 < (len(records) - ((n_split - 1) * new_chunk_size) + new_chunk_size) \ - < new_chunk_size * chunk_scale: + if ( + 0 + < (len(records) - ((n_split - 1) * new_chunk_size) + new_chunk_size) + < new_chunk_size * chunk_scale + ): n_split -= 1 for i in range(n_split): i_start = i * new_chunk_size - i_end = (i + 1) * \ - new_chunk_size if i < (n_split - 1) else len(records) + i_end = (i + 1) * new_chunk_size if i < (n_split - 1) else len(records) split_input_bam = os.path.join( - work, region.__str__() + "_split_{}.bam".format(i)) + work, region.__str__() + "_split_{}.bam".format(i) + ) with pysam.AlignmentFile(input_bam, "rb") as samfile: - with pysam.AlignmentFile(split_input_bam, "wb", - template=samfile) as out_samfile: + with pysam.AlignmentFile( + split_input_bam, "wb", template=samfile + ) as out_samfile: for record in records[i_start:i_end]: out_samfile.write(record) - pysam.sort("-o", "{}.sorted.bam".format(split_input_bam), - split_input_bam) - shutil.move("{}.sorted.bam".format( - split_input_bam), split_input_bam) + pysam.sort("-o", "{}.sorted.bam".format(split_input_bam), split_input_bam) + shutil.move("{}.sorted.bam".format(split_input_bam), split_input_bam) pysam.index(split_input_bam) bams.append(split_input_bam) @@ -460,8 +630,8 @@ def split_bam_to_chunks(work, region, input_bam, chunk_size=200, def read_info(info_file): logger = logging.getLogger(read_info.__name__) info = {} - with open(info_file, 'r') as csvfile: - csvreader = csv.reader(csvfile, delimiter='\t', quotechar='|') + with open(info_file, "r") as csvfile: + csvreader = csv.reader(csvfile, delimiter="\t", quotechar="|") for row in csvreader: info[int(row[0])] = Read_Info(row[1:]) return info @@ -474,8 +644,11 @@ def find_cigar(alignment): event_pos = np.append([-1], np.nonzero(np.diff(augmented_alignment))) event_len = np.diff(event_pos) if sum(event_len) != alignment.shape[0]: - logger.error("event_len is different from length of alignment: {} vs {}".format( - sum(event_len), alignment.shape[0])) + logger.error( + "event_len is different from length of alignment: {} vs {}".format( + sum(event_len), alignment.shape[0] + ) + ) raise Exception event_type = augmented_alignment[:-1][event_pos[1:]] @@ -492,20 +665,37 @@ def extract_new_cigars(region, info_file, out_fasta_file): return {}, {}, {} if set(map(int, records.keys())) ^ set(range(len(records))): - logger.error("sequences are missing in the alignment {}".format( - set(map(int, records.keys())) ^ set(range(len(records))))) + logger.error( + "sequences are missing in the alignment {}".format( + set(map(int, records.keys())) ^ set(range(len(records))) + ) + ) raise Exception - alignment = list(map(lambda x: x[1], sorted(map(lambda x: [int(x[0]), list(map(lambda x: 0 if x == "-" - else 1, x[1].seq))], - records.items()), - key=lambda x: x[0]))) + alignment = list( + map( + lambda x: x[1], + sorted( + map( + lambda x: [ + int(x[0]), + list(map(lambda x: 0 if x == "-" else 1, x[1].seq)), + ], + records.items(), + ), + key=lambda x: x[0], + ), + ) + ) ref_seq = np.array(alignment[0]) pos_ref = np.cumsum(alignment[0]) - 1 alignment = np.array(alignment[1:]) - ref_seq - alignment = (alignment == 0) * (1 - ref_seq) * (-1) + (alignment == 0) * ref_seq * \ - CIGAR_MATCH + (alignment == 1) * CIGAR_INS + \ - (alignment == -1) * CIGAR_DEL + alignment = ( + (alignment == 0) * (1 - ref_seq) * (-1) + + (alignment == 0) * ref_seq * CIGAR_MATCH + + (alignment == 1) * CIGAR_INS + + (alignment == -1) * CIGAR_DEL + ) N = alignment.shape[0] new_cigars = {} excess_start = {} @@ -514,13 +704,16 @@ def extract_new_cigars(region, info_file, out_fasta_file): core_alignment = alignment[i, :][alignment[i, :] >= 0] new_cigars[i + 1] = find_cigar(core_alignment) if CIGAR_MATCH in alignment[i, :]: - excess_start[i + 1] = (info[i + 1].pos_start - region.start) - \ - (pos_ref[np.where((alignment[i, :] == CIGAR_MATCH))[0][0]]) + excess_start[i + 1] = (info[i + 1].pos_start - region.start) - ( + pos_ref[np.where((alignment[i, :] == CIGAR_MATCH))[0][0]] + ) excess_end[i + 1] = (region.end - info[i + 1].pos_end) - ( - max(pos_ref) - (pos_ref[np.where((alignment[i, :] == CIGAR_MATCH))[0][-1]])) + max(pos_ref) + - (pos_ref[np.where((alignment[i, :] == CIGAR_MATCH))[0][-1]]) + ) else: - excess_start[i + 1] = (info[i + 1].pos_start - region.start) - excess_end[i + 1] = (region.end - info[i + 1].pos_end) + excess_start[i + 1] = info[i + 1].pos_start - region.start + excess_end[i + 1] = region.end - info[i + 1].pos_end return new_cigars, excess_start, excess_end @@ -531,11 +724,13 @@ def extract_consensus(region, out_fasta_file): if len(records) <= 1: return {}, {} try: - assert(not set(map(int, records.keys())) - ^ set(range(1, len(records) + 1))) + assert not set(map(int, records.keys())) ^ set(range(1, len(records) + 1)) except: - logger.error("sequences are missing in the alignment {}".format( - set(map(int, records.keys())) ^ set(range(1, len(records) + 1)))) + logger.error( + "sequences are missing in the alignment {}".format( + set(map(int, records.keys())) ^ set(range(1, len(records) + 1)) + ) + ) raise Exception n = len(records) align_len = len(records["1"].seq) @@ -547,6 +742,7 @@ def nuc_to_num_convert(nuc): if nuc.upper() not in NUC_to_NUM.keys(): nuc = "-" return NUC_to_NUM[nuc.upper()] + for i, record in records.items(): ii = int(i) - 1 msa[ii] = list(map(lambda x: nuc_to_num_convert(x), record.seq)) @@ -573,8 +769,11 @@ def get_final_msa(region, msa_0, consensus, out_fasta_file_1, out_fasta_file_fin if len(records) <= 1: return False if set(map(int, records.keys())) ^ set(range(2)): - logger.error("sequences are missing in the alignment {}".format( - set(map(int, records.keys())) ^ set(range(1, len(records) + 1)))) + logger.error( + "sequences are missing in the alignment {}".format( + set(map(int, records.keys())) ^ set(range(1, len(records) + 1)) + ) + ) raise Exception align_len = len(records["0"].seq) msa_1 = [[] for i in range(2)] @@ -582,8 +781,7 @@ def get_final_msa(region, msa_0, consensus, out_fasta_file_1, out_fasta_file_fin ii = int(i) msa_1[ii] = list(map(lambda x: nuc_to_num_convert(x), record.seq)) msa_1 = np.array(msa_1, dtype=int) - consensus_array = np.array( - list(map(lambda x: nuc_to_num_convert(x), consensus))) + consensus_array = np.array(list(map(lambda x: nuc_to_num_convert(x), consensus))) consensus_cumsum = np.cumsum(consensus_array > 0) new_cols = np.where(msa_1[1, :] == 0)[0] new_cols -= np.arange(len(new_cols)) @@ -594,15 +792,13 @@ def get_final_msa(region, msa_0, consensus, out_fasta_file_1, out_fasta_file_fin else: inser_new_cols.append(0) msa = np.insert(msa_0, inser_new_cols, 0, axis=1) - new_consensus_array = np.insert( - consensus_array, inser_new_cols, 100, axis=0) + new_consensus_array = np.insert(consensus_array, inser_new_cols, 100, axis=0) msa = np.insert(msa, 0, 0, axis=0) msa[0, np.where(new_consensus_array > 0)[0]] = msa_1[0] with open(out_fasta_file_final, "w") as out_fasta: for i, seq in enumerate(msa.tolist()): out_fasta.write(">%s\n" % i) - out_fasta.write("%s\n" % "".join( - map(lambda y: NUM_to_NUC[y], seq))) + out_fasta.write("%s\n" % "".join(map(lambda y: NUM_to_NUC[y], seq))) return True @@ -614,18 +810,32 @@ def get_entries(region, info_file, new_cigars, excess_start, excess_end): if len(info) != N: logger.error( "number of items in info is different from length of new cigars: {} vs {}".format( - len(info), N)) + len(info), N + ) + ) raise Exception entries = [] for i in range(1, N + 1): - entries.append([region.chrom, region.start, region.end, info[i].query_name, info[i].pos, - info[i].start_idx, info[i].end_idx, - info[i].cigarstring, info[ - i].del_start, info[i].del_end, - info[i].pos_start, info[ - i].pos_end, new_cigars[i], excess_start[i], - excess_end[i]]) + entries.append( + [ + region.chrom, + region.start, + region.end, + info[i].query_name, + info[i].pos, + info[i].start_idx, + info[i].end_idx, + info[i].cigarstring, + info[i].del_start, + info[i].del_end, + info[i].pos_start, + info[i].pos_end, + new_cigars[i], + excess_start[i], + excess_end[i], + ] + ) return entries @@ -636,7 +846,9 @@ def merge_cigartuples(tuple1, tuple2): if not tuple2: return tuple1 if tuple1[-1][0] == tuple2[0][0]: - return tuple1[:-1] + [[tuple1[-1][0], tuple1[-1][1] + tuple2[0][1]]] + tuple2[1:] + return ( + tuple1[:-1] + [[tuple1[-1][0], tuple1[-1][1] + tuple2[0][1]]] + tuple2[1:] + ) return tuple1 + tuple2 @@ -656,25 +868,47 @@ def find_realign_dict(realign_bed_file, chrom): for line in skip_empty(r_f): interval = line.strip().split("\t") chrom, start, end, query_name = interval[0:4] - pos, start_idx, end_idx, cigarstring, del_start, del_end, pos_start, pos_end, new_cigar, \ - excess_start, excess_end = interval[6:] + ( + pos, + start_idx, + end_idx, + cigarstring, + del_start, + del_end, + pos_start, + pos_end, + new_cigar, + excess_start, + excess_end, + ) = interval[6:] q_key = "{}_{}_{}".format(query_name, pos, cigarstring) if q_key not in realign_dict: - realign_dict[q_key] = Realign_Read( - query_name, chrom, pos, cigarstring) - realign_dict[q_key].add_realignment(start, end, start_idx, end_idx, del_start, - del_end, pos_start, pos_end, new_cigar, excess_start, - excess_end) + realign_dict[q_key] = Realign_Read(query_name, chrom, pos, cigarstring) + realign_dict[q_key].add_realignment( + start, + end, + start_idx, + end_idx, + del_start, + del_end, + pos_start, + pos_end, + new_cigar, + excess_start, + excess_end, + ) chrom_regions.add("{}-{}".format(start, end)) - chrom_regions = sorted( - map(lambda x: list(map(int, x.split("-"))), chrom_regions)) + chrom_regions = sorted(map(lambda x: list(map(int, x.split("-"))), chrom_regions)) return realign_dict, chrom_regions def correct_bam_chrom(input_record): work, input_bam, realign_bed_file, ref_fasta_file, chrom = input_record thread_logger = logging.getLogger( - "{} ({})".format(correct_bam_chrom.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + correct_bam_chrom.__name__, multiprocessing.current_process().name + ) + ) try: fasta_file = pysam.Fastafile(ref_fasta_file) ref_seq = fasta_seq(fasta_file) @@ -682,8 +916,7 @@ def correct_bam_chrom(input_record): out_sam = os.path.join(work, "output_{}.sam".format(chrom)) with pysam.AlignmentFile(input_bam, "rb") as samfile: with pysam.AlignmentFile(out_sam, "w", template=samfile) as out_samfile: - realign_dict, chrom_regions = find_realign_dict( - realign_bed_file, chrom) + realign_dict, chrom_regions = find_realign_dict(realign_bed_file, chrom) if chrom_regions: region_cnt = 0 next_region = chrom_regions[0] @@ -694,7 +927,11 @@ def correct_bam_chrom(input_record): for record in samfile.fetch(chrom): if record.is_unmapped: continue - if not done_regions and in_active_region and record.pos > next_region[1]: + if ( + not done_regions + and in_active_region + and record.pos > next_region[1] + ): if region_cnt == len(chrom_regions) - 1: done_regions = True else: @@ -704,12 +941,12 @@ def correct_bam_chrom(input_record): out_samfile.write(record) continue q_key = "{}_{}_{}".format( - record.query_name, record.pos, record.cigarstring) + record.query_name, record.pos, record.cigarstring + ) if q_key not in realign_dict: out_samfile.write(record) continue - fixed_record = realign_dict[ - q_key].fix_record(record, ref_seq) + fixed_record = realign_dict[q_key].fix_record(record, ref_seq) out_samfile.write(fixed_record) in_active_region = True return out_sam @@ -727,8 +964,7 @@ def correct_bam_all(work, input_bam, output_bam, ref_fasta_file, realign_bed_fil ref_seq = fasta_seq(fasta_file) for chrom in samfile.references: ref_seq.set_chrom(chrom) - realign_dict, chrom_regions = find_realign_dict( - realign_bed_file, chrom) + realign_dict, chrom_regions = find_realign_dict(realign_bed_file, chrom) if chrom_regions: region_cnt = 0 next_region = chrom_regions[0] @@ -739,7 +975,11 @@ def correct_bam_all(work, input_bam, output_bam, ref_fasta_file, realign_bed_fil for record in samfile.fetch(chrom): if record.is_unmapped: continue - if not done_regions and in_active_region and record.pos > next_region[1]: + if ( + not done_regions + and in_active_region + and record.pos > next_region[1] + ): if region_cnt == len(chrom_regions) - 1: done_regions = True else: @@ -749,12 +989,12 @@ def correct_bam_all(work, input_bam, output_bam, ref_fasta_file, realign_bed_fil out_samfile.write(record) continue q_key = "{}_{}_{}".format( - record.query_name, record.pos, record.cigarstring) + record.query_name, record.pos, record.cigarstring + ) if q_key not in realign_dict: out_samfile.write(record) continue - fixed_record = realign_dict[ - q_key].fix_record(record, ref_seq) + fixed_record = realign_dict[q_key].fix_record(record, ref_seq) out_samfile.write(fixed_record) in_active_region = True if os.path.exists(output_bam): @@ -779,8 +1019,9 @@ def concatenate_sam_files(files, output, bam_header): return output -def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_bed_file, - num_threads): +def parallel_correct_bam( + work, input_bam, output_bam, ref_fasta_file, realign_bed_file, num_threads +): logger = logging.getLogger(parallel_correct_bam.__name__) if num_threads > 1: pool = multiprocessing.Pool(num_threads) @@ -792,7 +1033,8 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be with pysam.AlignmentFile(input_bam, "rb") as samfile: for chrom in samfile.references: map_args.append( - (work, input_bam, realign_bed_file, ref_fasta_file, chrom)) + (work, input_bam, realign_bed_file, ref_fasta_file, chrom) + ) try: sams = pool.map_async(correct_bam_chrom, map_args).get() @@ -811,8 +1053,9 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be concatenate_sam_files(sams, output_sam, bam_header) if os.path.exists(output_sam): with pysam.AlignmentFile(output_sam, "r") as samfile: - with pysam.AlignmentFile(output_bam, "wb", - template=samfile) as out_samfile: + with pysam.AlignmentFile( + output_bam, "wb", template=samfile + ) as out_samfile: for record in samfile.fetch(): out_samfile.write(record) pysam.index(output_bam) @@ -820,19 +1063,30 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be for sam in [bam_header] + sams: os.remove(sam) else: - correct_bam_all(work, input_bam, output_bam, - ref_fasta_file, realign_bed_file) + correct_bam_all(work, input_bam, output_bam, ref_fasta_file, realign_bed_file) -def run_msa(in_fasta_file, match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, - msa_binary): +def run_msa( + in_fasta_file, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + msa_binary, +): logger = logging.getLogger(run_msa.__name__) if not os.path.exists(msa_binary): raise IOError("File not found: {}".format(msa_binary)) out_fasta_file = ".".join(in_fasta_file.split(".")[:-1]) + "_aligned.fasta" - cmd = "{} -A {} -B {} -O {} -E {} -i {} -o {}".format(msa_binary, match_score, mismatch_penalty, - gap_open_penalty, gap_ext_penalty, - in_fasta_file, out_fasta_file) + cmd = "{} -A {} -B {} -O {} -E {} -i {} -o {}".format( + msa_binary, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + in_fasta_file, + out_fasta_file, + ) if not os.path.exists(out_fasta_file): run_shell_command(cmd, run_logger=logger) return out_fasta_file @@ -851,9 +1105,9 @@ def do_realign(region, info_file, max_realign_dp, thr_realign=0.0135): c += 1 eps = 0.0001 if (c < max_realign_dp) and ( - (sum_nm_snp + sum_nm_indel - ) / float(c + eps) / float(region.span() + eps) - > thr_realign): + (sum_nm_snp + sum_nm_indel) / float(c + eps) / float(region.span() + eps) + > thr_realign + ): return True return False @@ -865,12 +1119,29 @@ def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf, simp logger = logging.getLogger(find_var.__name__) records = SeqIO.to_dict(SeqIO.parse(out_fasta_file, "fasta")) if set(map(int, records.keys())) ^ set(range(len(records))): - logger.error("sequences are missing in the alignment {}".format( - set(map(int, records.keys())) ^ set(range(len(records))))) + logger.error( + "sequences are missing in the alignment {}".format( + set(map(int, records.keys())) ^ set(range(len(records))) + ) + ) raise Exception - alignment = np.array(list(map(lambda x: x[1], sorted(map(lambda x: [int(x[0]), list(map( - lambda x: NUC_to_NUM[x.upper()], x[1].seq))], records.items()), - key=lambda x: x[0])))) + alignment = np.array( + list( + map( + lambda x: x[1], + sorted( + map( + lambda x: [ + int(x[0]), + list(map(lambda x: NUC_to_NUM[x.upper()], x[1].seq)), + ], + records.items(), + ), + key=lambda x: x[0], + ), + ) + ) + ) ref_seq = alignment[0, :] counts = np.zeros((5, alignment.shape[1])) for i in range(5): @@ -891,9 +1162,11 @@ def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf, simp alt_base = sorted_idx[-2] alt_count = counts[alt_base, i] af = alt_count / (alt_count + ref_count + 0.0001) - if ((alt_base != '-' and ref_base == "-" and af > ins_min_af) or - (alt_base != '-' and ref_base != "-" and af > snp_min_af) or - (alt_base == '-' and ref_base != "-" and af > del_min_af)): + if ( + (alt_base != "-" and ref_base == "-" and af > ins_min_af) + or (alt_base != "-" and ref_base != "-" and af > snp_min_af) + or (alt_base == "-" and ref_base != "-" and af > del_min_af) + ): alt_seq.append(alt_base) afs.append(af) i_afs.append(i) @@ -907,10 +1180,8 @@ def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf, simp # for ii in np.where(afs<=thr)[0]: alt_seq[i_afs[ii]] = ref_seq[i_afs[ii]] afs = np.array(afs) - ref_seq_ = "".join( - map(lambda x: NUM_to_NUC[x], filter(lambda x: x > 0, ref_seq))) - alt_seq_ = "".join( - map(lambda x: NUM_to_NUC[x], filter(lambda x: x > 0, alt_seq))) + ref_seq_ = "".join(map(lambda x: NUM_to_NUC[x], filter(lambda x: x > 0, ref_seq))) + alt_seq_ = "".join(map(lambda x: NUM_to_NUC[x], filter(lambda x: x > 0, alt_seq))) if not simplify: variants = [[0, ref_seq_, alt_seq_, afs]] else: @@ -941,12 +1212,19 @@ def find_var(out_fasta_file, snp_min_af, del_min_af, ins_min_af, scale_maf, simp done = True if done: if current_alt: - rr = "".join(map(lambda x: NUM_to_NUC[ - x], filter(lambda x: x > 0, current_ref))) - aa = "".join(map(lambda x: NUM_to_NUC[ - x], filter(lambda x: x > 0, current_alt))) - variants.append( - [current_bias, rr, aa, np.array(current_af)]) + rr = "".join( + map( + lambda x: NUM_to_NUC[x], + filter(lambda x: x > 0, current_ref), + ) + ) + aa = "".join( + map( + lambda x: NUM_to_NUC[x], + filter(lambda x: x > 0, current_alt), + ) + ) + variants.append([current_bias, rr, aa, np.array(current_af)]) done = False current_ref = [] current_alt = [] @@ -971,14 +1249,14 @@ def TrimREFALT(ref, alt, pos): logger = logging.getLogger(TrimREFALT.__name__) alte = len(alt) refe = len(ref) - while (alte > 1 and refe > 1 and alt[alte - 1] == ref[refe - 1]): + while alte > 1 and refe > 1 and alt[alte - 1] == ref[refe - 1]: alte -= 1 refe -= 1 alt = alt[0:alte] ref = ref[0:refe] s = 0 - while (s < (len(alt) - 1) and s < (len(ref) - 1) and alt[s] == ref[s]): + while s < (len(alt) - 1) and s < (len(ref) - 1) and alt[s] == ref[s]: s += 1 alt = alt[s:] @@ -988,30 +1266,55 @@ def TrimREFALT(ref, alt, pos): def run_realignment(input_record): - work, ref_fasta_file, target_region, pad, chunk_size, chunk_scale, \ - snp_min_af, del_min_af, ins_min_af, len_chr, input_bam, \ - match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, \ - max_realign_dp, \ - filter_duplicate, \ - msa_binary, get_var, do_split = input_record + ( + work, + ref_fasta_file, + target_region, + pad, + chunk_size, + chunk_scale, + snp_min_af, + del_min_af, + ins_min_af, + len_chr, + input_bam, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + max_realign_dp, + filter_duplicate, + msa_binary, + get_var, + do_split, + ) = input_record ref_fasta = pysam.Fastafile(ref_fasta_file) thread_logger = logging.getLogger( - "{} ({})".format(run_realignment.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + run_realignment.__name__, multiprocessing.current_process().name + ) + ) try: region = Region(target_region, pad, len_chr) not_realigned_region = None original_tempdir = tempfile.tempdir - bed_tempdir = os.path.join( - work, "bed_tmpdir_{}".format(region.__str__())) + bed_tempdir = os.path.join(work, "bed_tmpdir_{}".format(region.__str__())) if not os.path.exists(bed_tempdir): os.mkdir(bed_tempdir) tempfile.tempdir = bed_tempdir variants = [] all_entries = [] input_bam_splits, lens_splits, ds_splits = split_bam_to_chunks( - work, region, input_bam, chunk_size, chunk_scale, do_split or not get_var, filter_duplicate) + work, + region, + input_bam, + chunk_size, + chunk_scale, + do_split or not get_var, + filter_duplicate, + ) new_seqs = [] new_ref_seq = "" skipped = 0 @@ -1022,24 +1325,45 @@ def run_realignment(input_record): afss = [] for i, i_bam in enumerate(input_bam_splits): in_fasta_file, info_file = prepare_fasta( - work, region, i_bam, ref_fasta_file, True, i, ds_splits[i], filter_duplicate) + work, + region, + i_bam, + ref_fasta_file, + True, + i, + ds_splits[i], + filter_duplicate, + ) if do_realign(region, info_file, max_realign_dp): out_fasta_file_0 = run_msa( - in_fasta_file, match_score, mismatch_penalty, gap_open_penalty, - gap_ext_penalty, msa_binary) + in_fasta_file, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + msa_binary, + ) if get_var: var = find_var( - out_fasta_file_0, snp_min_af, del_min_af, ins_min_af, scale_maf, False) - assert(len(var) == 1) + out_fasta_file_0, + snp_min_af, + del_min_af, + ins_min_af, + scale_maf, + False, + ) + assert len(var) == 1 _, ref_seq_, alt_seq_, afs = var[0] afss.append(afs) new_ref_seq = ref_seq_ new_seqs.append(alt_seq_) new_cigars, excess_start, excess_end = extract_new_cigars( - region, info_file, out_fasta_file_0) + region, info_file, out_fasta_file_0 + ) if new_cigars: entries = get_entries( - region, info_file, new_cigars, excess_start, excess_end) + region, info_file, new_cigars, excess_start, excess_end + ) all_entries.extend(entries) else: skipped += 1 @@ -1049,52 +1373,58 @@ def run_realignment(input_record): new_seqs = [new_ref_seq] + new_seqs new_seqs = [new_ref_seq] + new_seqs consensus_fasta = os.path.join( - work, region.__str__() + "_consensus.fasta") + work, region.__str__() + "_consensus.fasta" + ) with open(consensus_fasta, "w") as output_handle: for i, seq in enumerate(new_seqs): record = SeqRecord( - Seq(seq, DNAAlphabet.letters), id=str(i), description="") + Seq(seq, DNAAlphabet.letters), id=str(i), description="" + ) SeqIO.write(record, output_handle, "fasta") consensus_fasta_aligned = run_msa( - consensus_fasta, match_score, mismatch_penalty, gap_open_penalty, - gap_ext_penalty, msa_binary) + consensus_fasta, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + msa_binary, + ) vars_ = find_var( - consensus_fasta_aligned, snp_min_af, del_min_af, ins_min_af, 1, True) + consensus_fasta_aligned, snp_min_af, del_min_af, ins_min_af, 1, True + ) for var in vars_: pos_, ref_seq, alt_seq, afs = var if ref_seq != alt_seq: - ref, alt, pos = ref_seq, alt_seq, int( - region.start) + 1 + pos_ + ref, alt, pos = ref_seq, alt_seq, int(region.start) + 1 + pos_ if pos > 1: num_add_before = min(40, pos - 1) before = ref_fasta.fetch( - region.chrom, pos - num_add_before, pos - 1).upper() + region.chrom, pos - num_add_before, pos - 1 + ).upper() pos -= num_add_before - 1 ref = before + ref alt = before + alt - ref, alt, pos = TrimREFALT( - ref, alt, pos) + ref, alt, pos = TrimREFALT(ref, alt, pos) a = int(np.ceil(np.max(afs) * len(afss))) - af = sum(sorted(map(lambda x: - np.max(x) if x.shape[0] > 0 else 0, - afss))[-a:]) / float(len(afss)) + af = sum( + sorted( + map(lambda x: np.max(x) if x.shape[0] > 0 else 0, afss) + )[-a:] + ) / float(len(afss)) dp = int(sum(lens_splits)) ao = int(af * dp) ro = dp - ao if ref == "" and pos > 1: pos -= 1 - r_ = ref_fasta.fetch( - region.chrom, pos - 1, pos).upper() + r_ = ref_fasta.fetch(region.chrom, pos - 1, pos).upper() ref = r_ + ref alt = r_ + alt if alt == "" and pos > 1: pos -= 1 - r_ = ref_fasta.fetch( - region.chrom, pos - 1, pos).upper() + r_ = ref_fasta.fetch(region.chrom, pos - 1, pos).upper() ref = r_ + ref alt = r_ + alt - variants.append( - [region.chrom, pos, ref, alt, dp, ro, ao]) + variants.append([region.chrom, pos, ref, alt, dp, ro, ao]) else: if skipped > 0: not_realigned_region = target_region @@ -1109,7 +1439,6 @@ def run_realignment(input_record): class fasta_seq: - def __init__(self, fasta_pysam): self.fasta_pysam = fasta_pysam self.chrom = "" @@ -1123,8 +1452,9 @@ def get_seq(self, start, end=[]): return self.fasta_pysam.fetch(self.chrom, start, end).upper() -def extend_regions_hp(region_bed_file, extended_region_bed_file, ref_fasta_file, - chrom_lengths, pad): +def extend_regions_hp( + region_bed_file, extended_region_bed_file, ref_fasta_file, chrom_lengths, pad +): # If boundaries of regions are in the middle of a homopolymer, this function extends the region # to fully include the homopolymer logger = logging.getLogger(extend_regions_hp.__name__) @@ -1135,10 +1465,8 @@ def extend_regions_hp(region_bed_file, extended_region_bed_file, ref_fasta_file, interval = line.strip().split("\t") chrom, start, end = interval[0:3] start, end = int(start), int(end) - s_base = ref_fasta.fetch( - chrom, start - pad, start - pad + 1).upper() - e_base = ref_fasta.fetch( - chrom, end + pad, end + pad + 1).upper() + s_base = ref_fasta.fetch(chrom, start - pad, start - pad + 1).upper() + e_base = ref_fasta.fetch(chrom, end + pad, end + pad + 1).upper() new_start = start i = start - pad - 1 while True: @@ -1161,22 +1489,27 @@ def extend_regions_hp(region_bed_file, extended_region_bed_file, ref_fasta_file, i += 1 if i >= chrom_lengths[chrom] - 3: break - if ref_fasta.fetch(chrom, new_end + pad, new_end + pad + 1 - ).upper() == ref_fasta.fetch(chrom, new_end - 1 + pad, - new_end + pad).upper(): + if ( + ref_fasta.fetch(chrom, new_end + pad, new_end + pad + 1).upper() + == ref_fasta.fetch(chrom, new_end - 1 + pad, new_end + pad).upper() + ): new_end += 1 - if ref_fasta.fetch(chrom, new_start - pad, new_start - pad + 1 - ).upper() == ref_fasta.fetch(chrom, new_start - pad + 1, - new_start - pad + 2).upper(): + if ( + ref_fasta.fetch(chrom, new_start - pad, new_start - pad + 1).upper() + == ref_fasta.fetch( + chrom, new_start - pad + 1, new_start - pad + 2 + ).upper() + ): new_start -= 1 - seq, new_seq = ref_fasta.fetch(chrom, start - pad, end + pad + 1).upper( - ), ref_fasta.fetch(chrom, new_start - pad, new_end + pad + 1).upper() + seq, new_seq = ( + ref_fasta.fetch(chrom, start - pad, end + pad + 1).upper(), + ref_fasta.fetch(chrom, new_start - pad, new_end + pad + 1).upper(), + ) intervals.append([chrom, new_start, new_end]) tmp_ = get_tmp_file() write_tsv_file(tmp_, intervals) - bedtools_sort(tmp_, output_fn=extended_region_bed_file, - run_logger=logger) + bedtools_sort(tmp_, output_fn=extended_region_bed_file, run_logger=logger) def check_rep(ref_seq, left_right, w): @@ -1184,16 +1517,17 @@ def check_rep(ref_seq, left_right, w): if len(ref_seq) < 2 * w: return False if left_right == "left": - return ref_seq[0:w] == ref_seq[w:2 * w] and len(set(ref_seq[0:2 * w])) > 1 + return ref_seq[0:w] == ref_seq[w : 2 * w] and len(set(ref_seq[0 : 2 * w])) > 1 elif left_right == "right": - return ref_seq[-w:] == ref_seq[-2 * w:-w] and len(set(ref_seq[-2 * w:])) > 1 + return ref_seq[-w:] == ref_seq[-2 * w : -w] and len(set(ref_seq[-2 * w :])) > 1 else: logger.error("Wrong left/right value: {}".format(left_right)) raise Exception -def extend_regions_repeat(region_bed_file, extended_region_bed_file, ref_fasta_file, - chrom_lengths, pad): +def extend_regions_repeat( + region_bed_file, extended_region_bed_file, ref_fasta_file, chrom_lengths, pad +): logger = logging.getLogger(extend_regions_repeat.__name__) with pysam.Fastafile(ref_fasta_file) as ref_fasta: intervals = [] @@ -1205,106 +1539,105 @@ def extend_regions_repeat(region_bed_file, extended_region_bed_file, ref_fasta_f w = 3 new_start = max(start - pad - w, 1) new_end = min(end + pad + w, chrom_lengths[chrom] - 2) - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s = 0 while check_rep(ref_seq, "left", 2): new_start -= 2 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += 2 if cnt_s == 0: while check_rep(ref_seq, "left", 3): new_start -= 3 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += 3 if cnt_s == 0: while check_rep(ref_seq, "left", 4): new_start -= 4 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += 4 if cnt_s == 0: new_start += w - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() if cnt_s == 0: while check_rep(ref_seq, "left", 2): new_start -= 2 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += 2 if cnt_s == 0: while check_rep(ref_seq, "left", 3): new_start -= 3 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += 3 if cnt_s == 0: while check_rep(ref_seq, "left", 4): new_start -= 4 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_s += 4 cnt_e = 0 while check_rep(ref_seq, "right", 2): new_end += 2 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += 2 if cnt_e == 0: while check_rep(ref_seq, "right", 3): new_end += 3 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += 3 if cnt_e == 0: while check_rep(ref_seq, "right", 4): new_end += 4 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += 4 if cnt_e == 0: new_end -= w - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() if cnt_e == 0: while check_rep(ref_seq, "right", 2): new_end += 2 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += 2 if cnt_e == 0: while check_rep(ref_seq, "right", 3): new_end += 3 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += 3 if cnt_e == 0: while check_rep(ref_seq, "right", 4): new_end += 4 - ref_seq = ref_fasta.fetch( - chrom, new_start, new_end + 1).upper() + ref_seq = ref_fasta.fetch(chrom, new_start, new_end + 1).upper() cnt_e += 4 intervals.append([chrom, new_start + pad, new_end - pad]) tmp_ = get_tmp_file() write_tsv_file(tmp_, intervals, add_fields=[".", ".", "."]) - bedtools_sort(tmp_, output_fn=extended_region_bed_file, - run_logger=logger) - - -def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_realigned_bed, - region_bed_file, - ref_fasta_file, num_threads, pad, - chunk_size, chunk_scale, snp_min_af, del_min_af, ins_min_af, - match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, - max_realign_dp, - do_split, - filter_duplicate, - msa_binary): + bedtools_sort(tmp_, output_fn=extended_region_bed_file, run_logger=logger) + + +def long_read_indelrealign( + work, + input_bam, + output_bam, + output_vcf, + output_not_realigned_bed, + region_bed_file, + ref_fasta_file, + num_threads, + pad, + chunk_size, + chunk_scale, + snp_min_af, + del_min_af, + ins_min_af, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + max_realign_dp, + do_split, + filter_duplicate, + msa_binary, +): logger = logging.getLogger(long_read_indelrealign.__name__) logger.info("-----------Resolve variants for INDELS (long-read)---------") @@ -1312,8 +1645,7 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r os.mkdir(work) if not output_bam and not output_vcf: - logger.error( - "At least one of --output_bam or --output_vcf should be provided.") + logger.error("At least one of --output_bam or --output_vcf should be provided.") raise Exception chrom_lengths = {} @@ -1329,7 +1661,8 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r extended_region_bed_file = os.path.join(work, "regions_extended.bed") extend_regions_repeat( - region_bed_file, extended_region_bed_file, ref_fasta_file, chrom_lengths, pad) + region_bed_file, extended_region_bed_file, ref_fasta_file, chrom_lengths, pad + ) region_bed_file = extended_region_bed_file region_bed_merged = region_bed_file @@ -1339,7 +1672,8 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r len_merged += 1 while True: region_bed_merged_tmp = bedtools_merge( - region_bed_merged, args=" -d {}".format(pad * 2), run_logger=logger) + region_bed_merged, args=" -d {}".format(pad * 2), run_logger=logger + ) len_tmp = 0 with open(region_bed_merged_tmp) as r_b: for line in skip_empty(r_b): @@ -1348,23 +1682,39 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r break region_bed_merged = region_bed_merged_tmp len_merged = len_tmp - shutil.copyfile(region_bed_merged, os.path.join( - work, "regions_merged.bed")) + shutil.copyfile(region_bed_merged, os.path.join(work, "regions_merged.bed")) target_regions = read_tsv_file(region_bed_merged, fields=range(3)) - target_regions = list( - map(lambda x: [x[0], int(x[1]), int(x[2])], target_regions)) + target_regions = list(map(lambda x: [x[0], int(x[1]), int(x[2])], target_regions)) get_var = True if output_vcf else False pool = multiprocessing.Pool(num_threads) map_args = [] for target_region in target_regions: - map_args.append((work, ref_fasta_file, target_region, pad, chunk_size, - chunk_scale, snp_min_af, del_min_af, ins_min_af, - chrom_lengths[target_region[0]], input_bam, - match_score, mismatch_penalty, gap_open_penalty, gap_ext_penalty, - max_realign_dp, filter_duplicate, - msa_binary, get_var, do_split)) + map_args.append( + ( + work, + ref_fasta_file, + target_region, + pad, + chunk_size, + chunk_scale, + snp_min_af, + del_min_af, + ins_min_af, + chrom_lengths[target_region[0]], + input_bam, + match_score, + mismatch_penalty, + gap_open_penalty, + gap_ext_penalty, + max_realign_dp, + filter_duplicate, + msa_binary, + get_var, + do_split, + ) + ) shuffle(map_args) try: @@ -1390,17 +1740,43 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r if get_var: with open(output_vcf, "w") as o_f: - o_f.write("#" + "\t".join(["CHROM", "POS", "ID", "REF", - "ALT", "QUAL", "FILTER", "INFO", "FORMAT"]) + "\n") - realign_variants = sorted(realign_variants, key=lambda x: - [chroms_order[x[0]], x[1]]) + o_f.write( + "#" + + "\t".join( + [ + "CHROM", + "POS", + "ID", + "REF", + "ALT", + "QUAL", + "FILTER", + "INFO", + "FORMAT", + ] + ) + + "\n" + ) + realign_variants = sorted( + realign_variants, key=lambda x: [chroms_order[x[0]], x[1]] + ) for variant in realign_variants: if variant: chrom, pos, ref, alt, dp, ro, ao = variant - line = "\t".join([chrom, str(pos), ".", ref, alt, "100", ".", - "DP={};RO={};AO={}".format(dp, ro, ao), - "GT:DP:RO:AO", "0/1:{}:{}:{}".format( - dp, ro, ao), ]) + line = "\t".join( + [ + chrom, + str(pos), + ".", + ref, + alt, + "100", + ".", + "DP={};RO={};AO={}".format(dp, ro, ao), + "GT:DP:RO:AO", + "0/1:{}:{}:{}".format(dp, ro, ao), + ] + ) o_f.write(line + "\n") with open(output_not_realigned_bed, "w") as o_f: for x in not_realigned_regions: @@ -1421,8 +1797,9 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r o_f.write("\t".join(map(str, x[0:4] + [".", "."] + x[4:])) + "\n") if output_bam: - parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, - realign_bed_file, num_threads) + parallel_correct_bam( + work, input_bam, output_bam, ref_fasta_file, realign_bed_file, num_threads + ) shutil.rmtree(bed_tempdir) tempfile.tempdir = original_tempdir @@ -1430,78 +1807,116 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r logger.info("Done") -if __name__ == '__main__': +if __name__ == "__main__": - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser(description='realign indels using MSA') - parser.add_argument('--input_bam', type=str, help='input bam') - parser.add_argument('--output_vcf', type=str, - help='output_vcf (needed for variant prediction)', default=None) - parser.add_argument('--output_not_realigned_bed', type=str, - help='output_not_realigned_bed', required=True) - parser.add_argument('--output_bam', type=str, - help='output_bam (needed for getting the realigned bam)', default=None) - parser.add_argument('--region_bed', type=str, - help='region_bed', required=True) - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--work', type=str, - help='work directory', required=True) - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) - parser.add_argument('--pad', type=int, - help='#base padding to the regions', default=1) - parser.add_argument('--chunk_size', type=int, - help='chuck split size for high depth', default=600) - parser.add_argument('--chunk_scale', type=float, - help='chuck scale size for high depth', default=1.5) - parser.add_argument('--snp_min_af', type=float, - help='SNP min allele freq', default=0.05) - parser.add_argument('--ins_min_af', type=float, - help='INS min allele freq', default=0.05) - parser.add_argument('--del_min_af', type=float, - help='DEL min allele freq', default=0.05) - parser.add_argument('--match_score', type=int, - help='match score', default=10) - parser.add_argument('--mismatch_penalty', type=int, - help='penalty for having a mismatch', default=8) - parser.add_argument('--gap_open_penalty', type=int, - help='penalty for opening a gap', default=8) - parser.add_argument('--gap_ext_penalty', type=int, - help='penalty for extending a gap', default=6) - parser.add_argument('--max_realign_dp', type=int, - help='max coverage for realign region', default=1000) - parser.add_argument('--do_split', - help='Split bam for high coverage regions (in variant-calling mode).', - action="store_true") - parser.add_argument('--filter_duplicate', - help='filter duplicate reads in analysis', - action="store_true") - parser.add_argument('--msa_binary', type=str, - help='MSA binary', default="../bin/msa") + parser = argparse.ArgumentParser(description="realign indels using MSA") + parser.add_argument("--input_bam", type=str, help="input bam") + parser.add_argument( + "--output_vcf", + type=str, + help="output_vcf (needed for variant prediction)", + default=None, + ) + parser.add_argument( + "--output_not_realigned_bed", + type=str, + help="output_not_realigned_bed", + required=True, + ) + parser.add_argument( + "--output_bam", + type=str, + help="output_bam (needed for getting the realigned bam)", + default=None, + ) + parser.add_argument("--region_bed", type=str, help="region_bed", required=True) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--work", type=str, help="work directory", required=True) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) + parser.add_argument( + "--pad", type=int, help="#base padding to the regions", default=1 + ) + parser.add_argument( + "--chunk_size", type=int, help="chuck split size for high depth", default=600 + ) + parser.add_argument( + "--chunk_scale", type=float, help="chuck scale size for high depth", default=1.5 + ) + parser.add_argument( + "--snp_min_af", type=float, help="SNP min allele freq", default=0.05 + ) + parser.add_argument( + "--ins_min_af", type=float, help="INS min allele freq", default=0.05 + ) + parser.add_argument( + "--del_min_af", type=float, help="DEL min allele freq", default=0.05 + ) + parser.add_argument("--match_score", type=int, help="match score", default=10) + parser.add_argument( + "--mismatch_penalty", type=int, help="penalty for having a mismatch", default=8 + ) + parser.add_argument( + "--gap_open_penalty", type=int, help="penalty for opening a gap", default=8 + ) + parser.add_argument( + "--gap_ext_penalty", type=int, help="penalty for extending a gap", default=6 + ) + parser.add_argument( + "--max_realign_dp", + type=int, + help="max coverage for realign region", + default=1000, + ) + parser.add_argument( + "--do_split", + help="Split bam for high coverage regions (in variant-calling mode).", + action="store_true", + ) + parser.add_argument( + "--filter_duplicate", + help="filter duplicate reads in analysis", + action="store_true", + ) + parser.add_argument( + "--msa_binary", type=str, help="MSA binary", default="../bin/msa" + ) args = parser.parse_args() logger.info(args) try: - processor = long_read_indelrealign(args.work, args.input_bam, args.output_bam, - args.output_vcf, args.output_not_realigned_bed, - args.region_bed, args.reference, - args.num_threads, args.pad, args.chunk_size, - args.chunk_scale, args.snp_min_af, args.del_min_af, - args.ins_min_af, args.match_score, - args.mismatch_penalty, args.gap_open_penalty, - args.gap_ext_penalty, - args.gap_ext_penalty, - args.max_realign_dp, - args.do_split, - args.filter_duplicate, - args.msa_binary) + processor = long_read_indelrealign( + args.work, + args.input_bam, + args.output_bam, + args.output_vcf, + args.output_not_realigned_bed, + args.region_bed, + args.reference, + args.num_threads, + args.pad, + args.chunk_size, + args.chunk_scale, + args.snp_min_af, + args.del_min_af, + args.ins_min_af, + args.match_score, + args.mismatch_penalty, + args.gap_open_penalty, + args.gap_ext_penalty, + args.gap_ext_penalty, + args.max_realign_dp, + args.do_split, + args.filter_duplicate, + args.msa_binary, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "long_read_indelrealign.py failure on arguments: {}".format(args)) + logger.error("long_read_indelrealign.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/merge_post_vcfs.py b/neusomatic/python/merge_post_vcfs.py index 111df49..ed245b7 100755 --- a/neusomatic/python/merge_post_vcfs.py +++ b/neusomatic/python/merge_post_vcfs.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # merge_post_vcfs.py # Merge resolved variants and other predicted variants and output the final NeuSomatic .vcf -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import traceback import fileinput @@ -14,8 +14,9 @@ from defaults import VCF_HEADER -def merge_post_vcfs(ref, resolved_vcf, no_resolve_vcf, out_vcf, - pass_threshold, lowqual_threshold): +def merge_post_vcfs( + ref, resolved_vcf, no_resolve_vcf, out_vcf, pass_threshold, lowqual_threshold +): logger = logging.getLogger(merge_post_vcfs.__name__) @@ -34,38 +35,48 @@ def merge_post_vcfs(ref, resolved_vcf, no_resolve_vcf, out_vcf, o_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") for record in sorted(good_records, key=lambda x: [chroms_order[x[0]], x[1]]): chrom, pos, ref, alt, gt, score = record - prob = np.round(1 - (10**(-float(score) / 10)), 4) + prob = np.round(1 - (10 ** (-float(score) / 10)), 4) filter_ = "REJECT" if prob >= pass_threshold: filter_ = "PASS" elif prob >= lowqual_threshold: filter_ = "LowQual" - o_f.write("\t".join([chrom, pos, ".", ref, alt, - "{:.4f}".format( - float(score)), filter_, "SCORE={:.4f}".format(prob), - "GT", "0/1"]) + "\n") + o_f.write( + "\t".join( + [ + chrom, + pos, + ".", + ref, + alt, + "{:.4f}".format(float(score)), + filter_, + "SCORE={:.4f}".format(prob), + "GT", + "0/1", + ] + ) + + "\n" + ) -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser(description='merge pred vcfs') - parser.add_argument( - '--ref', help='reference fasta filename', required=True) - parser.add_argument('--resolved_vcf', help='resolved_vcf', required=True) - parser.add_argument('--no_resolve_vcf', - help='no resolve vcf', required=True) - parser.add_argument('--out_vcf', help='output vcf', required=True) + parser = argparse.ArgumentParser(description="merge pred vcfs") + parser.add_argument("--ref", help="reference fasta filename", required=True) + parser.add_argument("--resolved_vcf", help="resolved_vcf", required=True) + parser.add_argument("--no_resolve_vcf", help="no resolve vcf", required=True) + parser.add_argument("--out_vcf", help="output vcf", required=True) args = parser.parse_args() logger.info(args) try: - merge_post_vcfs(args.ref, args.resolved_vcf, - args.no_resolve_vcf, args.out_vcf) + merge_post_vcfs(args.ref, args.resolved_vcf, args.no_resolve_vcf, args.out_vcf) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "merge_post_vcfs.py failure on arguments: {}".format(args)) + logger.error("merge_post_vcfs.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/merge_tsvs.py b/neusomatic/python/merge_tsvs.py index 96ca72c..cf7f4e5 100755 --- a/neusomatic/python/merge_tsvs.py +++ b/neusomatic/python/merge_tsvs.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # merge_tsvs.py # merge_tsvs generated by 'generate_dataset.py' to from larger size tsvs -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import shutil @@ -14,9 +14,15 @@ import numpy as np -def merge_tsvs(input_tsvs, out, - candidates_per_tsv, max_num_tsvs, overwrite_merged_tsvs, - keep_none_types, max_dp=1000000): +def merge_tsvs( + input_tsvs, + out, + candidates_per_tsv, + max_num_tsvs, + overwrite_merged_tsvs, + keep_none_types, + max_dp=1000000, +): logger = logging.getLogger(merge_tsvs.__name__) logger.info("----------------Merging Candidate tsvs-------------------") if not os.path.exists(out): @@ -28,19 +34,18 @@ def merge_tsvs(input_tsvs, out, else: i = 1 while os.path.exists(out_mreged_folder): - out_mreged_folder = os.path.join( - out, "merged_tsvs_{}".format(i)) + out_mreged_folder = os.path.join(out, "merged_tsvs_{}".format(i)) i += 1 os.mkdir(out_mreged_folder) n_var_file = 0 - var_file = os.path.join( - out_mreged_folder, "merged_var_{}.tsv".format(n_var_file)) + var_file = os.path.join(out_mreged_folder, "merged_var_{}.tsv".format(n_var_file)) var_f = open(var_file, "w") var_idx = [] n_none_file = 0 if not keep_none_types: none_file = os.path.join( - out_mreged_folder, "merged_none_{}.tsv".format(n_none_file)) + out_mreged_folder, "merged_none_{}.tsv".format(n_none_file) + ) none_f = open(none_file, "w") none_idx = [] merged_tsvs = [] @@ -54,8 +59,9 @@ def merge_tsvs(input_tsvs, out, for line in i_f: totla_L += 1 totla_L = max(0, totla_L) - candidates_per_tsv = max(candidates_per_tsv, np.ceil( - totla_L / float(max_num_tsvs)) + 1) + candidates_per_tsv = max( + candidates_per_tsv, np.ceil(totla_L / float(max_num_tsvs)) + 1 + ) for tsv in input_tsvs: logger.info("tsv:{}, merge_id: {}".format(tsv, len(merged_tsvs))) @@ -79,12 +85,12 @@ def merge_tsvs(input_tsvs, out, none_idx.append(none_f.tell()) pickle.dump(none_idx, open(none_file + ".idx", "wb")) none_f.close() - logger.info( - "Done with merge_id: {}".format(len(merged_tsvs))) + logger.info("Done with merge_id: {}".format(len(merged_tsvs))) merged_tsvs.append(none_file) n_none_file += 1 none_file = os.path.join( - out_mreged_folder, "merged_none_{}.tsv".format(n_none_file)) + out_mreged_folder, "merged_none_{}.tsv".format(n_none_file) + ) none_f = open(none_file, "w") none_idx = [] else: @@ -96,12 +102,12 @@ def merge_tsvs(input_tsvs, out, var_idx.append(var_f.tell()) pickle.dump(var_idx, open(var_file + ".idx", "wb")) var_f.close() - logger.info( - "Done with merge_id: {}".format(len(merged_tsvs))) + logger.info("Done with merge_id: {}".format(len(merged_tsvs))) merged_tsvs.append(var_file) n_var_file += 1 var_file = os.path.join( - out_mreged_folder, "merged_var_{}.tsv".format(n_var_file)) + out_mreged_folder, "merged_var_{}.tsv".format(n_var_file) + ) var_f = open(var_file, "w") var_idx = [] if not var_f.closed: @@ -120,39 +126,55 @@ def merge_tsvs(input_tsvs, out, logger.info("Merged input tsvs to: {}".format(merged_tsvs)) return merged_tsvs -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser(description='Resolve ambigues variants') - parser.add_argument('--input_tsvs', nargs="*", - help=' input candidate tsv files', required=True) - parser.add_argument('--out', type=str, - help='output directory', required=True) - parser.add_argument('--candidates_per_tsv', type=int, - help='Maximum number of candidates in each merged tsv file ', - default=10000000) - parser.add_argument('--max_num_tsvs', type=int, - help='Maximum number of merged tsv files \ - (higher priority than candidates_per_tsv)', default=10) - parser.add_argument('--overwrite_merged_tsvs', - help='if OUT/merged_tsvs/ folder exists overwrite the merged tsvs', - action="store_true") - parser.add_argument('--keep_none_types', action="store_true", - help='Do not split none somatic candidates to seperate files') + parser = argparse.ArgumentParser(description="Resolve ambigues variants") + parser.add_argument( + "--input_tsvs", nargs="*", help=" input candidate tsv files", required=True + ) + parser.add_argument("--out", type=str, help="output directory", required=True) + parser.add_argument( + "--candidates_per_tsv", + type=int, + help="Maximum number of candidates in each merged tsv file ", + default=10000000, + ) + parser.add_argument( + "--max_num_tsvs", + type=int, + help="Maximum number of merged tsv files \ + (higher priority than candidates_per_tsv)", + default=10, + ) + parser.add_argument( + "--overwrite_merged_tsvs", + help="if OUT/merged_tsvs/ folder exists overwrite the merged tsvs", + action="store_true", + ) + parser.add_argument( + "--keep_none_types", + action="store_true", + help="Do not split none somatic candidates to seperate files", + ) args = parser.parse_args() logger.info(args) try: - merged_tsvs = merge_tsvs(args.input_tsvs, args.out, - args.candidates_per_tsv, args.max_num_tsvs, - args.overwrite_merged_tsvs, - args.keep_none_types) + merged_tsvs = merge_tsvs( + args.input_tsvs, + args.out, + args.candidates_per_tsv, + args.max_num_tsvs, + args.overwrite_merged_tsvs, + args.keep_none_types, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "merge_tsvs.py failure on arguments: {}".format(args)) + logger.error("merge_tsvs.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/network.py b/neusomatic/python/network.py index da685b7..6a145a1 100755 --- a/neusomatic/python/network.py +++ b/neusomatic/python/network.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # network.py # Defines the architecture of NeuSomatic network -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import logging import torch.nn as nn @@ -9,40 +9,42 @@ import numpy as np -FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) class NSBlock(nn.Module): - def __init__(self, dim, ks_1=3, ks_2=3, dl_1=1, dl_2=1, mp_ks=3, mp_st=1): super(NSBlock, self).__init__() self.dim = dim self.conv_r1 = nn.Conv2d( - dim, dim, kernel_size=ks_1, dilation=dl_1, padding=(dl_1 * (ks_1 - 1)) // 2) + dim, dim, kernel_size=ks_1, dilation=dl_1, padding=(dl_1 * (ks_1 - 1)) // 2 + ) self.bn_r1 = nn.BatchNorm2d(dim) self.conv_r2 = nn.Conv2d( - dim, dim, kernel_size=ks_2, dilation=dl_2, padding=(dl_2 * (ks_2 - 1)) // 2) + dim, dim, kernel_size=ks_2, dilation=dl_2, padding=(dl_2 * (ks_2 - 1)) // 2 + ) self.bn_r2 = nn.BatchNorm2d(dim) - self.pool_r2 = nn.MaxPool2d((1, mp_ks), padding=( - 0, (mp_ks - 1) // 2), stride=(1, mp_st)) + self.pool_r2 = nn.MaxPool2d( + (1, mp_ks), padding=(0, (mp_ks - 1) // 2), stride=(1, mp_st) + ) def forward(self, x): - y1 = (F.relu(self.bn_r1(self.conv_r1(x)))) - y2 = (self.bn_r2(self.conv_r2(y1))) + y1 = F.relu(self.bn_r1(self.conv_r1(x))) + y2 = self.bn_r2(self.conv_r2(y1)) y3 = x + y2 z = self.pool_r2(y3) return z class NeuSomaticNet(nn.Module): - def __init__(self, num_channels): super(NeuSomaticNet, self).__init__() dim = 64 - self.conv1 = nn.Conv2d(num_channels, dim, kernel_size=( - 1, 3), padding=(0, 1), stride=1) + self.conv1 = nn.Conv2d( + num_channels, dim, kernel_size=(1, 3), padding=(0, 1), stride=1 + ) self.bn1 = nn.BatchNorm2d(dim) self.pool1 = nn.MaxPool2d((1, 3), padding=(0, 1), stride=(1, 1)) self.nsblocks = [ diff --git a/neusomatic/python/postprocess.py b/neusomatic/python/postprocess.py index f5ed9a7..29938d7 100755 --- a/neusomatic/python/postprocess.py +++ b/neusomatic/python/postprocess.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # postprocess.py # A wrapper that # 1- Extract variants that need postprocessing (call to 'extract_postprocess_targets.py') @@ -6,7 +6,7 @@ # 'resolve_variants.py' or 'long_read_indelrealign.py'/'resolve_scores.py') # 3- Merge resolved variants and other predicted variants and # output the final NeuSomatic .vcf (call to 'merge_post_vcfs.py') -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import os @@ -21,15 +21,29 @@ from extract_postprocess_targets import extract_postprocess_targets from merge_post_vcfs import merge_post_vcfs from resolve_variants import resolve_variants -from utils import concatenate_files, get_chromosomes_order, bedtools_window, bedtools_intersect, skip_empty +from utils import ( + concatenate_files, + get_chromosomes_order, + bedtools_window, + bedtools_intersect, + skip_empty, +) from long_read_indelrealign import long_read_indelrealign from resolve_scores import resolve_scores from _version import __version__ from defaults import VCF_HEADER -def add_vcf_info(work, reference, merged_vcf, candidates_vcf, ensemble_tsv, - output_vcf, pass_threshold, lowqual_threshold): +def add_vcf_info( + work, + reference, + merged_vcf, + candidates_vcf, + ensemble_tsv, + output_vcf, + pass_threshold, + lowqual_threshold, +): logger = logging.getLogger(add_vcf_info.__name__) ensemble_candids_vcf = None @@ -38,8 +52,7 @@ def add_vcf_info(work, reference, merged_vcf, candidates_vcf, ensemble_tsv, ensemble_candids_vcf = os.path.join(work, "ensemble_candids.vcf") with open(ensemble_tsv) as e_f, open(ensemble_candids_vcf, "w") as c_f: c_f.write("{}\n".format(VCF_HEADER)) - c_f.write( - "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") + c_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") ensemble_header_found = False for line in e_f: if "POS" in line: @@ -57,7 +70,13 @@ def add_vcf_info(work, reference, merged_vcf, candidates_vcf, ensemble_tsv, alt_rv_id = header.index("T_ALT_REV") use_ensemble_candids = True else: - dp_id, ref_fw_id, ref_rv_id, alt_fw_id, alt_rv_id = None, None, None, None, None + dp_id, ref_fw_id, ref_rv_id, alt_fw_id, alt_rv_id = ( + None, + None, + None, + None, + None, + ) continue assert ensemble_header_found fields = line.strip().split() @@ -75,21 +94,49 @@ def add_vcf_info(work, reference, merged_vcf, candidates_vcf, ensemble_tsv, ao = ao_fw + ao_rv af = np.round(ao / float(ao + ro + 0.0001), 4) c_f.write( - "\t".join(map(str, [chrom, pos, ".", ref, alt, ".", ".", ".", "GT:DP:RO:AO:AF", ":".join(map(str, ["0/1", dp, ro, ao, af]))])) + "\n") + "\t".join( + map( + str, + [ + chrom, + pos, + ".", + ref, + alt, + ".", + ".", + ".", + "GT:DP:RO:AO:AF", + ":".join(map(str, ["0/1", dp, ro, ao, af])), + ], + ) + ) + + "\n" + ) else: c_f.write( - "\t".join(map(str, [chrom, pos, ".", ref, alt, ".", ".", ".", ".", "."])) + "\n") - + "\t".join( + map( + str, + [chrom, pos, ".", ref, alt, ".", ".", ".", ".", "."], + ) + ) + + "\n" + ) in_candidates = bedtools_window( - merged_vcf, candidates_vcf, args=" -w 5", run_logger=logger) + merged_vcf, candidates_vcf, args=" -w 5", run_logger=logger + ) notin_candidates = bedtools_window( - merged_vcf, candidates_vcf, args=" -w 5 -v", run_logger=logger) + merged_vcf, candidates_vcf, args=" -w 5 -v", run_logger=logger + ) if ensemble_tsv and use_ensemble_candids: in_ensemble = bedtools_window( - merged_vcf, ensemble_candids_vcf, args=" -w 5", run_logger=logger) + merged_vcf, ensemble_candids_vcf, args=" -w 5", run_logger=logger + ) notin_any = bedtools_window( - notin_candidates, ensemble_candids_vcf, args=" -w 5 -v", run_logger=logger) + notin_candidates, ensemble_candids_vcf, args=" -w 5 -v", run_logger=logger + ) else: in_ensemble = None notin_any = notin_candidates @@ -114,14 +161,15 @@ def add_vcf_info(work, reference, merged_vcf, candidates_vcf, ensemble_tsv, af = float(info[4]) is_same = x[1] == x[11] and x[3] == x[13] and x[4] == x[14] is_same = 0 if is_same else 1 - is_same_type = np.sign( - len(x[3]) - len(x[13])) == np.sign(len(x[4]) - len(x[14])) + is_same_type = np.sign(len(x[3]) - len(x[13])) == np.sign( + len(x[4]) - len(x[14]) + ) is_same_type = 0 if is_same_type else 1 dist = abs(int(x[1]) - int(x[11])) - len_diff = abs( - (len(x[3]) - len(x[13])) - (len(x[4]) - len(x[14]))) + len_diff = abs((len(x[3]) - len(x[13])) - (len(x[4]) - len(x[14]))) tags_info[tag].append( - [is_same, is_same_type, dist, len_diff, s_e, dp, ro, ao, af]) + [is_same, is_same_type, dist, len_diff, s_e, dp, ro, ao, af] + ) fina_info_tag = {} for tag, hits in tags_info.items(): hits = sorted(hits, key=lambda x: x[0:5]) @@ -134,59 +182,112 @@ def add_vcf_info(work, reference, merged_vcf, candidates_vcf, ensemble_tsv, fina_info_tag[tag] = [0, 0, 0, 0] scores[tag] = [x[5], x[6], x[7], x[9]] - tags = sorted(fina_info_tag.keys(), key=lambda x: list(map(int, x.split("-")[0:2] - ))) + tags = sorted(fina_info_tag.keys(), key=lambda x: list(map(int, x.split("-")[0:2]))) with open(output_vcf, "w") as o_f: o_f.write("{}\n".format(VCF_HEADER)) o_f.write("##NeuSomatic Version={}\n".format(__version__)) o_f.write( - "##INFO=\n") + '##INFO=\n' + ) + o_f.write( + '##INFO=\n' + ) o_f.write( - "##INFO=\n") + '##INFO=\n' + ) o_f.write( - "##INFO=\n") + '##INFO=\n' + ) o_f.write( - "##INFO=\n") + '##INFO=\n' + ) o_f.write( - "##INFO=\n") - o_f.write("##FILTER=\n".format( - pass_threshold)) - o_f.write("##FILTER=\n".format( - lowqual_threshold)) - o_f.write("##FILTER=\n".format( - lowqual_threshold)) + '##FILTER=\n'.format( + pass_threshold + ) + ) o_f.write( - "##FORMAT=\n") + '##FILTER=\n'.format( + lowqual_threshold + ) + ) o_f.write( - "##FORMAT=\n") + '##FILTER=\n'.format( + lowqual_threshold + ) + ) + o_f.write('##FORMAT=\n') o_f.write( - "##FORMAT=\n") + '##FORMAT=\n' + ) o_f.write( - "##FORMAT=\n") + '##FORMAT=\n' + ) o_f.write( - "##FORMAT=\n") + '##FORMAT=\n' + ) + o_f.write( + '##FORMAT=\n' + ) o_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") for tag in tags: chrom_id, pos, ref, alt = tag.split("-") qual, filter_, score, gt = scores[tag] dp, ro, ao, af = fina_info_tag[tag] - info_field = "{};DP={};RO={};AO={};AF={}".format( - score, dp, ro, ao, af) + info_field = "{};DP={};RO={};AO={};AF={}".format(score, dp, ro, ao, af) gt_field = "{}:{}:{}:{}:{}".format(gt, dp, ro, ao, af) - o_f.write("\t".join(map(str, [chroms[int(chrom_id)], str( - pos), ".", ref, alt, qual, filter_, info_field, "GT:DP:RO:AO:AF", gt_field])) + "\n") - - -def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ensemble_tsv, - tumor_bam, min_len, - postprocess_max_dist, long_read, - lr_pad, lr_chunk_size, lr_chunk_scale, - lr_snp_min_af, lr_ins_min_af, lr_del_min_af, lr_match_score, lr_mismatch_penalty, - lr_gap_open_penalty, lr_gap_ext_penalty, lr_max_realign_dp, lr_do_split, - keep_duplicate, - pass_threshold, lowqual_threshold, - extend_repeats, - msa_binary, num_threads): + o_f.write( + "\t".join( + map( + str, + [ + chroms[int(chrom_id)], + str(pos), + ".", + ref, + alt, + qual, + filter_, + info_field, + "GT:DP:RO:AO:AF", + gt_field, + ], + ) + ) + + "\n" + ) + + +def postprocess( + work, + reference, + pred_vcf_file, + output_vcf, + candidates_vcf, + ensemble_tsv, + tumor_bam, + min_len, + postprocess_max_dist, + long_read, + lr_pad, + lr_chunk_size, + lr_chunk_scale, + lr_snp_min_af, + lr_ins_min_af, + lr_del_min_af, + lr_match_score, + lr_mismatch_penalty, + lr_gap_open_penalty, + lr_gap_ext_penalty, + lr_max_realign_dp, + lr_do_split, + keep_duplicate, + pass_threshold, + lowqual_threshold, + extend_repeats, + msa_binary, + num_threads, +): logger = logging.getLogger(postprocess.__name__) logger.info("----------------------Postprocessing-----------------------") @@ -194,7 +295,7 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense os.mkdir(work) filter_duplicate = not keep_duplicate - + original_tempdir = tempfile.tempdir bed_tempdir = os.path.join(work, "bed_tempdir_postprocess") if not os.path.exists(bed_tempdir): @@ -205,14 +306,30 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense ensembled_preds = os.path.join(work, "ensemble_preds.vcf") bedtools_window( - pred_vcf_file, candidates_vcf, args=" -w 5 -v", output_fn=ensembled_preds, run_logger=logger) + pred_vcf_file, + candidates_vcf, + args=" -w 5 -v", + output_fn=ensembled_preds, + run_logger=logger, + ) bedtools_window( - pred_vcf_file, candidates_vcf, args=" -w 5 -u", output_fn=candidates_preds, run_logger=logger) + pred_vcf_file, + candidates_vcf, + args=" -w 5 -u", + output_fn=candidates_preds, + run_logger=logger, + ) logger.info("Extract targets") postprocess_pad = 1 if not long_read else 10 extract_postprocess_targets( - reference, candidates_preds, min_len, postprocess_max_dist, extend_repeats, postprocess_pad) + reference, + candidates_preds, + min_len, + postprocess_max_dist, + extend_repeats, + postprocess_pad, + ) no_resolve = os.path.join(work, "candidates_preds.no_resolve.vcf") target_vcf = os.path.join(work, "candidates_preds.resolve_target.vcf") @@ -221,47 +338,85 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense logger.info("Resolve targets") if not long_read: - resolve_variants(tumor_bam, resolved_vcf, - reference, target_vcf, target_bed, filter_duplicate, - num_threads) + resolve_variants( + tumor_bam, + resolved_vcf, + reference, + target_vcf, + target_bed, + filter_duplicate, + num_threads, + ) all_no_resolve = concatenate_files( - [no_resolve, ensembled_preds], os.path.join(work, "no_resolve.vcf")) + [no_resolve, ensembled_preds], os.path.join(work, "no_resolve.vcf") + ) else: work_lr_indel_realign = os.path.join(work, "work_lr_indel_realign") if os.path.exists(work_lr_indel_realign): shutil.rmtree(work_lr_indel_realign) os.mkdir(work_lr_indel_realign) - ra_resolved_vcf = os.path.join( - work, "candidates_preds.ra_resolved.vcf") - not_resolved_bed = os.path.join( - work, "candidates_preds.not_ra_resolved.bed") - long_read_indelrealign(work_lr_indel_realign, tumor_bam, None, ra_resolved_vcf, - not_resolved_bed, target_bed, - reference, num_threads, lr_pad, - lr_chunk_size, lr_chunk_scale, lr_snp_min_af, - lr_del_min_af, lr_ins_min_af, - lr_match_score, lr_mismatch_penalty, lr_gap_open_penalty, - lr_gap_ext_penalty, lr_max_realign_dp, lr_do_split, - filter_duplicate, - msa_binary) + ra_resolved_vcf = os.path.join(work, "candidates_preds.ra_resolved.vcf") + not_resolved_bed = os.path.join(work, "candidates_preds.not_ra_resolved.bed") + long_read_indelrealign( + work_lr_indel_realign, + tumor_bam, + None, + ra_resolved_vcf, + not_resolved_bed, + target_bed, + reference, + num_threads, + lr_pad, + lr_chunk_size, + lr_chunk_scale, + lr_snp_min_af, + lr_del_min_af, + lr_ins_min_af, + lr_match_score, + lr_mismatch_penalty, + lr_gap_open_penalty, + lr_gap_ext_penalty, + lr_max_realign_dp, + lr_do_split, + filter_duplicate, + msa_binary, + ) resolve_scores(tumor_bam, ra_resolved_vcf, target_vcf, resolved_vcf) - not_resolved_vcf = os.path.join( - work, "candidates_preds.not_ra_resolved.vcf") - bedtools_intersect(target_vcf, not_resolved_bed, args=" -u ", - output_fn=not_resolved_vcf, run_logger=logger) + not_resolved_vcf = os.path.join(work, "candidates_preds.not_ra_resolved.vcf") + bedtools_intersect( + target_vcf, + not_resolved_bed, + args=" -u ", + output_fn=not_resolved_vcf, + run_logger=logger, + ) all_no_resolve = concatenate_files( - [no_resolve, ensembled_preds, not_resolved_vcf], os.path.join(work, "no_resolve.vcf")) + [no_resolve, ensembled_preds, not_resolved_vcf], + os.path.join(work, "no_resolve.vcf"), + ) logger.info("Merge vcfs") merged_vcf = os.path.join(work, "merged_preds.vcf") - merge_post_vcfs(reference, resolved_vcf, - all_no_resolve, merged_vcf, - pass_threshold, lowqual_threshold) - add_vcf_info(work, reference, merged_vcf, - candidates_vcf, ensemble_tsv, output_vcf, - pass_threshold, lowqual_threshold) + merge_post_vcfs( + reference, + resolved_vcf, + all_no_resolve, + merged_vcf, + pass_threshold, + lowqual_threshold, + ) + add_vcf_info( + work, + reference, + merged_vcf, + candidates_vcf, + ensemble_tsv, + output_vcf, + pass_threshold, + lowqual_threshold, + ) logger.info("Output NeuSomatic prediction at {}".format(output_vcf)) @@ -271,96 +426,180 @@ def postprocess(work, reference, pred_vcf_file, output_vcf, candidates_vcf, ense logger.info("Postprocessing is Done.") return output_vcf -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser( - description='Preprocess predictions for call mode') - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--tumor_bam', type=str, - help='tumor bam', required=True) - parser.add_argument('--pred_vcf', type=str, - help='predicted vcf', required=True) - parser.add_argument('--output_vcf', type=str, - help='output final vcf', required=True) - parser.add_argument('--candidates_vcf', type=str, - help='filtered candidate vcf', required=True) - parser.add_argument('--ensemble_tsv', type=str, - help='Ensemble annotation tsv file (only for short read)', default=None) - parser.add_argument('--min_len', type=int, - help='minimum INDEL len to resolve', default=4) - parser.add_argument('--postprocess_max_dist', type=int, - help='max distance to neighboring variant', default=5) + parser = argparse.ArgumentParser(description="Preprocess predictions for call mode") + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--tumor_bam", type=str, help="tumor bam", required=True) + parser.add_argument("--pred_vcf", type=str, help="predicted vcf", required=True) + parser.add_argument( + "--output_vcf", type=str, help="output final vcf", required=True + ) + parser.add_argument( + "--candidates_vcf", type=str, help="filtered candidate vcf", required=True + ) + parser.add_argument( + "--ensemble_tsv", + type=str, + help="Ensemble annotation tsv file (only for short read)", + default=None, + ) + parser.add_argument( + "--min_len", type=int, help="minimum INDEL len to resolve", default=4 + ) + parser.add_argument( + "--postprocess_max_dist", + type=int, + help="max distance to neighboring variant", + default=5, + ) + parser.add_argument( + "--long_read", + help="Enable long_read (high error-rate sequence) indel realignment", + action="store_true", + ) + parser.add_argument( + "--lr_pad", + type=int, + help="long_read indel realign: #base padding to the regions", + default=1, + ) + parser.add_argument( + "--lr_chunk_size", + type=int, + help="long_read indel realign: chuck split size for high depth", + default=600, + ) + parser.add_argument( + "--lr_chunk_scale", + type=float, + help="long_read indel realign: chuck scale size for high depth", + default=1.5, + ) + parser.add_argument( + "--lr_snp_min_af", + type=float, + help="long_read indel realign: SNP min allele freq", + default=0.05, + ) + parser.add_argument( + "--lr_ins_min_af", + type=float, + help="long_read indel realign: INS min allele freq", + default=0.05, + ) + parser.add_argument( + "--lr_del_min_af", + type=float, + help="long_read indel realign: DEL min allele freq", + default=0.05, + ) + parser.add_argument( + "--lr_match_score", + type=int, + help="long_read indel realign: match score", + default=10, + ) + parser.add_argument( + "--lr_mismatch_penalty", + type=int, + help="long_read indel realign: penalty for having a mismatch", + default=8, + ) + parser.add_argument( + "--lr_gap_open_penalty", + type=int, + help="long_read indel realign: penalty for opening a gap", + default=8, + ) + parser.add_argument( + "--lr_gap_ext_penalty", + type=int, + help="long_read indel realign: penalty for extending a gap", + default=6, + ) + parser.add_argument( + "--lr_max_realign_dp", + type=int, + help="long read max coverage for realign region", + default=1000, + ) + parser.add_argument( + "--lr_do_split", + help="long read split bam for high coverage regions (in variant-calling mode).", + action="store_true", + ) + parser.add_argument( + "--pass_threshold", + type=float, + help="SCORE for PASS (PASS for score => pass_threshold)", + default=0.7, + ) + parser.add_argument( + "--lowqual_threshold", + type=float, + help="SCORE for LowQual (PASS for lowqual_threshold <= score < pass_threshold)", + default=0.4, + ) + parser.add_argument( + "--keep_duplicate", + help="Dont filter duplicate reads in analysis", + action="store_true", + ) parser.add_argument( - '--long_read', help='Enable long_read (high error-rate sequence) indel realignment', action="store_true") + "--extend_repeats", + help="extend resolve regions to repeat boundaries", + action="store_true", + ) parser.add_argument( - '--lr_pad', type=int, help='long_read indel realign: #base padding to the regions', default=1) - parser.add_argument('--lr_chunk_size', type=int, - help='long_read indel realign: chuck split size for high depth', default=600) - parser.add_argument('--lr_chunk_scale', type=float, - help='long_read indel realign: chuck scale size for high depth', default=1.5) - parser.add_argument('--lr_snp_min_af', type=float, - help='long_read indel realign: SNP min allele freq', default=0.05) - parser.add_argument('--lr_ins_min_af', type=float, - help='long_read indel realign: INS min allele freq', default=0.05) - parser.add_argument('--lr_del_min_af', type=float, - help='long_read indel realign: DEL min allele freq', default=0.05) - parser.add_argument('--lr_match_score', type=int, - help='long_read indel realign: match score', default=10) - parser.add_argument('--lr_mismatch_penalty', type=int, - help='long_read indel realign: penalty for having a mismatch', default=8) - parser.add_argument('--lr_gap_open_penalty', type=int, - help='long_read indel realign: penalty for opening a gap', default=8) - parser.add_argument('--lr_gap_ext_penalty', type=int, - help='long_read indel realign: penalty for extending a gap', default=6) - parser.add_argument('--lr_max_realign_dp', type=int, - help='long read max coverage for realign region', default=1000) - parser.add_argument('--lr_do_split', - help='long read split bam for high coverage regions (in variant-calling mode).', - action="store_true") - parser.add_argument('--pass_threshold', type=float, - help='SCORE for PASS (PASS for score => pass_threshold)', default=0.7) - parser.add_argument('--lowqual_threshold', type=float, - help='SCORE for LowQual (PASS for lowqual_threshold <= score < pass_threshold)', - default=0.4) - parser.add_argument('--keep_duplicate', - help='Dont filter duplicate reads in analysis', - action="store_true") - parser.add_argument('--extend_repeats', - help='extend resolve regions to repeat boundaries', - action='store_true') - parser.add_argument('--msa_binary', type=str, - help='MSA binary', default="../bin/msa") - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) - parser.add_argument('--work', type=str, - help='work directory', required=True) + "--msa_binary", type=str, help="MSA binary", default="../bin/msa" + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) + parser.add_argument("--work", type=str, help="work directory", required=True) args = parser.parse_args() logger.info(args) try: - output_vcf = postprocess(args.work, args.reference, args.pred_vcf, args.output_vcf, - args.candidates_vcf, args.ensemble_tsv, - args.tumor_bam, args.min_len, - args.postprocess_max_dist, args.long_read, - args.lr_pad, args.lr_chunk_size, args.lr_chunk_scale, - args.lr_snp_min_af, args.lr_ins_min_af, args.lr_del_min_af, - args.lr_match_score, args.lr_mismatch_penalty, - args.lr_gap_open_penalty, - args.lr_gap_ext_penalty, args.lr_max_realign_dp, - args.lr_do_split, - args.keep_duplicate, - args.pass_threshold, args.lowqual_threshold, - args.extend_repeats, - args.msa_binary, args.num_threads) + output_vcf = postprocess( + args.work, + args.reference, + args.pred_vcf, + args.output_vcf, + args.candidates_vcf, + args.ensemble_tsv, + args.tumor_bam, + args.min_len, + args.postprocess_max_dist, + args.long_read, + args.lr_pad, + args.lr_chunk_size, + args.lr_chunk_scale, + args.lr_snp_min_af, + args.lr_ins_min_af, + args.lr_del_min_af, + args.lr_match_score, + args.lr_mismatch_penalty, + args.lr_gap_open_penalty, + args.lr_gap_ext_penalty, + args.lr_max_realign_dp, + args.lr_do_split, + args.keep_duplicate, + args.pass_threshold, + args.lowqual_threshold, + args.extend_repeats, + args.msa_binary, + args.num_threads, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "postprocess.py failure on arguments: {}".format(args)) + logger.error("postprocess.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/preprocess.py b/neusomatic/python/preprocess.py index bc5aa9a..b6f107c 100755 --- a/neusomatic/python/preprocess.py +++ b/neusomatic/python/preprocess.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # preprocess.py # A wrapper that # 1- scans the tumor and normal alignments to extract features and raw candidates (call to 'scan_alignments.py') # 2- filters candidates to met cut-offs (call to 'filter_candidates.py') # 3- generate datasets for candidates to be used by network (call to 'generate_dataset.py') -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import multiprocessing import argparse @@ -21,46 +21,115 @@ from generate_dataset import generate_dataset, extract_ensemble from scan_alignments import scan_alignments from extend_features import extend_features -from utils import concatenate_vcfs, run_bedtools_cmd, bedtools_sort, bedtools_merge, bedtools_intersect, bedtools_slop, get_tmp_file, skip_empty, vcf_2_bed +from utils import ( + concatenate_vcfs, + run_bedtools_cmd, + bedtools_sort, + bedtools_merge, + bedtools_intersect, + bedtools_slop, + get_tmp_file, + skip_empty, + vcf_2_bed, +) from defaults import MAT_DTYPES -def process_split_region(tn, work, region, reference, mode, alignment_bam, - scan_window_size, scan_maf, min_mapq, - filtered_candidates_vcf, min_dp, max_dp, - filter_duplicate, - good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, - ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r, - merge_d_for_scan, - report_all_alleles, - report_count_for_all_positions, - scan_alignments_binary, restart, num_splits, num_threads, calc_qual, regions=[]): +def process_split_region( + tn, + work, + region, + reference, + mode, + alignment_bam, + scan_window_size, + scan_maf, + min_mapq, + filtered_candidates_vcf, + min_dp, + max_dp, + filter_duplicate, + good_ao, + min_ao, + snp_min_af, + snp_min_bq, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + merge_d_for_scan, + report_all_alleles, + report_count_for_all_positions, + scan_alignments_binary, + restart, + num_splits, + num_threads, + calc_qual, + regions=[], +): logger = logging.getLogger(process_split_region.__name__) logger.info("Scan bam.") - scan_outputs = scan_alignments(work, merge_d_for_scan, scan_alignments_binary, alignment_bam, - region, reference, num_splits, num_threads, scan_window_size, - snp_min_ao, - snp_min_af, scan_maf, scan_maf, - min_mapq, snp_min_bq, max_dp, min_dp, - report_all_alleles, report_count_for_all_positions, - filter_duplicate, restart=restart, split_region_files=regions, - calc_qual=calc_qual) + scan_outputs = scan_alignments( + work, + merge_d_for_scan, + scan_alignments_binary, + alignment_bam, + region, + reference, + num_splits, + num_threads, + scan_window_size, + snp_min_ao, + snp_min_af, + scan_maf, + scan_maf, + min_mapq, + snp_min_bq, + max_dp, + min_dp, + report_all_alleles, + report_count_for_all_positions, + filter_duplicate, + restart=restart, + split_region_files=regions, + calc_qual=calc_qual, + ) if filtered_candidates_vcf: logger.info("Filter candidates.") if restart or not os.path.exists(filtered_candidates_vcf): pool = multiprocessing.Pool(num_threads) map_args = [] for i, (raw_vcf, count_bed, split_region_bed) in enumerate(scan_outputs): - filtered_vcf = os.path.join(os.path.dirname( - os.path.realpath(raw_vcf)), "filtered_candidates.vcf") - map_args.append((raw_vcf, filtered_vcf, reference, min_dp, max_dp, good_ao, - min_ao, snp_min_af, snp_min_bq, snp_min_ao, ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r)) + filtered_vcf = os.path.join( + os.path.dirname(os.path.realpath(raw_vcf)), + "filtered_candidates.vcf", + ) + map_args.append( + ( + raw_vcf, + filtered_vcf, + reference, + min_dp, + max_dp, + good_ao, + min_ao, + snp_min_af, + snp_min_bq, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + ) + ) try: filtered_candidates_vcfs = pool.map_async( - filter_candidates, map_args).get() + filter_candidates, map_args + ).get() pool.close() except Exception as inst: logger.error(inst) @@ -72,53 +141,101 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam, if o is None: raise Exception("filter_candidates failed!") - concatenate_vcfs(filtered_candidates_vcfs, - filtered_candidates_vcf, check_file_existence=True) + concatenate_vcfs( + filtered_candidates_vcfs, + filtered_candidates_vcf, + check_file_existence=True, + ) else: filtered_candidates_vcfs = [] for raw_vcf, _, _ in scan_outputs: - filtered_vcf = os.path.join(os.path.dirname( - os.path.realpath(raw_vcf)), "filtered_candidates.vcf") + filtered_vcf = os.path.join( + os.path.dirname(os.path.realpath(raw_vcf)), + "filtered_candidates.vcf", + ) filtered_candidates_vcfs.append(filtered_vcf) else: filtered_candidates_vcfs = None - return list(map(lambda x: x[1], scan_outputs)), list(map(lambda x: x[2], scan_outputs)), filtered_candidates_vcfs + return ( + list(map(lambda x: x[1], scan_outputs)), + list(map(lambda x: x[2], scan_outputs)), + filtered_candidates_vcfs, + ) -def generate_dataset_region(work, truth_vcf, mode, filtered_candidates_vcf, region, tumor_count_bed, normal_count_bed, reference, - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_cov, num_threads, ensemble_bed, - ensemble_custom_header, - no_seq_complexity, - no_feature_recomp_for_ensemble, - zero_vscore, - matrix_dtype, - strict_labeling, - tsv_batch_size): +def generate_dataset_region( + work, + truth_vcf, + mode, + filtered_candidates_vcf, + region, + tumor_count_bed, + normal_count_bed, + reference, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_cov, + num_threads, + ensemble_bed, + ensemble_custom_header, + no_seq_complexity, + no_feature_recomp_for_ensemble, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, +): logger = logging.getLogger(generate_dataset_region.__name__) - generate_dataset(work, truth_vcf, mode, filtered_candidates_vcf, region, tumor_count_bed, normal_count_bed, reference, - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_cov, num_threads, None, ensemble_bed, - ensemble_custom_header, - no_seq_complexity, - no_feature_recomp_for_ensemble, - zero_vscore, - matrix_dtype, - strict_labeling, - tsv_batch_size) + generate_dataset( + work, + truth_vcf, + mode, + filtered_candidates_vcf, + region, + tumor_count_bed, + normal_count_bed, + reference, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_cov, + num_threads, + None, + ensemble_bed, + ensemble_custom_header, + no_seq_complexity, + no_feature_recomp_for_ensemble, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, + ) return True def get_ensemble_region(record): reference, ensemble_bed, region, ensemble_bed_region_file, matrix_base_pad = record thread_logger = logging.getLogger( - "{} ({})".format(get_ensemble_region.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + get_ensemble_region.__name__, multiprocessing.current_process().name + ) + ) try: ensemble_bed_region_file_tmp = bedtools_slop( - region, reference + ".fai", args=" -b {}".format(matrix_base_pad + 3), - run_logger=thread_logger) + region, + reference + ".fai", + args=" -b {}".format(matrix_base_pad + 3), + run_logger=thread_logger, + ) bedtools_intersect( - ensemble_bed, ensemble_bed_region_file_tmp, args=" -u -header", - output_fn=ensemble_bed_region_file, run_logger=thread_logger) + ensemble_bed, + ensemble_bed_region_file_tmp, + args=" -u -header", + output_fn=ensemble_bed_region_file, + run_logger=thread_logger, + ) return ensemble_bed_region_file except Exception as ex: @@ -127,7 +244,9 @@ def get_ensemble_region(record): return None -def get_ensemble_beds(work, reference, ensemble_bed, split_regions, matrix_base_pad, num_threads): +def get_ensemble_beds( + work, reference, ensemble_bed, split_regions, matrix_base_pad, num_threads +): logger = logging.getLogger(get_ensemble_beds.__name__) work_ensemble = os.path.join(work, "ensemble_anns") @@ -137,10 +256,18 @@ def get_ensemble_beds(work, reference, ensemble_bed, split_regions, matrix_base_ ensemble_beds = [] for i, split_region_ in enumerate(split_regions): ensemble_bed_region_file = os.path.join( - work_ensemble, "ensemble_ann_{}.bed".format(i)) + work_ensemble, "ensemble_ann_{}.bed".format(i) + ) ensemble_beds.append(ensemble_bed_region_file) - map_args.append((reference, ensemble_bed, split_region_, - ensemble_bed_region_file, matrix_base_pad)) + map_args.append( + ( + reference, + ensemble_bed, + split_region_, + ensemble_bed_region_file, + matrix_base_pad, + ) + ) pool = multiprocessing.Pool(num_threads) try: outputs = pool.map_async(get_ensemble_region, map_args).get() @@ -157,15 +284,23 @@ def get_ensemble_beds(work, reference, ensemble_bed, split_regions, matrix_base_ def extract_candidate_split_regions( - work, filtered_candidates_vcfs, split_regions, ensemble_beds, - reference, matrix_base_pad, merge_d_for_short_read): + work, + filtered_candidates_vcfs, + split_regions, + ensemble_beds, + reference, + matrix_base_pad, + merge_d_for_short_read, +): logger = logging.getLogger(extract_candidate_split_regions.__name__) candidates_split_regions = [] - for i, (filtered_vcf, split_region_) in enumerate(zip(filtered_candidates_vcfs, - split_regions)): + for i, (filtered_vcf, split_region_) in enumerate( + zip(filtered_candidates_vcfs, split_regions) + ): candidates_region_file = os.path.join( - work, "candidates_region_{}.bed".format(i)) + work, "candidates_region_{}.bed".format(i) + ) is_empty = True with open(filtered_vcf) as f_: @@ -175,62 +310,98 @@ def extract_candidate_split_regions( logger.info([filtered_vcf, is_empty]) if not is_empty: candidates_bed = get_tmp_file() - vcf_2_bed(filtered_vcf, candidates_bed, - len_ref=True, keep_ref_alt=False) + vcf_2_bed(filtered_vcf, candidates_bed, len_ref=True, keep_ref_alt=False) candidates_bed = bedtools_sort(candidates_bed, run_logger=logger) candidates_bed = bedtools_slop( - candidates_bed, reference + ".fai", args=" -b {}".format(matrix_base_pad + 3), - run_logger=logger) + candidates_bed, + reference + ".fai", + args=" -b {}".format(matrix_base_pad + 3), + run_logger=logger, + ) candidates_bed = bedtools_merge( - candidates_bed, args=" -d {}".format(merge_d_for_short_read), run_logger=logger) + candidates_bed, + args=" -d {}".format(merge_d_for_short_read), + run_logger=logger, + ) else: candidates_bed = get_tmp_file() if ensemble_beds: - cmd = "cat {} {}".format( - candidates_bed, - ensemble_beds[i]) + cmd = "cat {} {}".format(candidates_bed, ensemble_beds[i]) candidates_bed = run_bedtools_cmd(cmd, run_logger=logger) - cmd = "cut -f 1,2,3 {}".format( - candidates_bed) + cmd = "cut -f 1,2,3 {}".format(candidates_bed) candidates_bed = run_bedtools_cmd(cmd, run_logger=logger) candidates_bed = bedtools_sort(candidates_bed, run_logger=logger) candidates_bed = bedtools_merge( - candidates_bed, args=" -d {}".format(merge_d_for_short_read), run_logger=logger) + candidates_bed, + args=" -d {}".format(merge_d_for_short_read), + run_logger=logger, + ) candidates_bed = bedtools_intersect( - candidates_bed, split_region_, run_logger=logger) - bedtools_sort(candidates_bed, - output_fn=candidates_region_file, run_logger=logger) + candidates_bed, split_region_, run_logger=logger + ) + bedtools_sort( + candidates_bed, output_fn=candidates_region_file, run_logger=logger + ) candidates_split_regions.append(candidates_region_file) return candidates_split_regions def generate_dataset_region_parallel(record): - work_dataset_split, truth_vcf, mode, filtered_vcf, \ - candidates_split_region, tumor_count, normal_count, reference, \ - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_dp, \ - ensemble_bed_i, \ - ensemble_custom_header, \ - no_seq_complexity, no_feature_recomp_for_ensemble, \ - zero_vscore, \ - matrix_dtype, \ - strict_labeling, \ - tsv_batch_size = record + ( + work_dataset_split, + truth_vcf, + mode, + filtered_vcf, + candidates_split_region, + tumor_count, + normal_count, + reference, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_dp, + ensemble_bed_i, + ensemble_custom_header, + no_seq_complexity, + no_feature_recomp_for_ensemble, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, + ) = record thread_logger = logging.getLogger( - "{} ({})".format(generate_dataset_region_parallel.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + generate_dataset_region_parallel.__name__, + multiprocessing.current_process().name, + ) + ) try: - ret = generate_dataset_region(work_dataset_split, truth_vcf, mode, filtered_vcf, - candidates_split_region, tumor_count, normal_count, reference, - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_dp, 1, - ensemble_bed_i, - ensemble_custom_header, - no_seq_complexity, no_feature_recomp_for_ensemble, - zero_vscore, - matrix_dtype, - strict_labeling, - tsv_batch_size) + ret = generate_dataset_region( + work_dataset_split, + truth_vcf, + mode, + filtered_vcf, + candidates_split_region, + tumor_count, + normal_count, + reference, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_dp, + 1, + ensemble_bed_i, + ensemble_custom_header, + no_seq_complexity, + no_feature_recomp_for_ensemble, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, + ) return ret except Exception as ex: @@ -239,28 +410,54 @@ def generate_dataset_region_parallel(record): return None -def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, - scan_window_size, scan_maf, min_mapq, - min_dp, max_dp, good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, - ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r, truth_vcf, tsv_batch_size, - matrix_width, matrix_base_pad, min_ev_frac_per_col, - ensemble_tsv, ensemble_custom_header, - long_read, restart, first_do_without_qual, - keep_duplicate, - add_extra_features, - no_seq_complexity, - no_feature_recomp_for_ensemble, - window_extend, - max_cluster_size, - merge_d_for_scan, - use_vscore, - num_splits, - matrix_dtype, - report_all_alleles, - strict_labeling, - num_threads, - scan_alignments_binary,): +def preprocess( + work, + mode, + reference, + region_bed, + tumor_bam, + normal_bam, + dbsnp, + scan_window_size, + scan_maf, + min_mapq, + min_dp, + max_dp, + good_ao, + min_ao, + snp_min_af, + snp_min_bq, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + truth_vcf, + tsv_batch_size, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + ensemble_tsv, + ensemble_custom_header, + long_read, + restart, + first_do_without_qual, + keep_duplicate, + add_extra_features, + no_seq_complexity, + no_feature_recomp_for_ensemble, + window_extend, + max_cluster_size, + merge_d_for_scan, + use_vscore, + num_splits, + matrix_dtype, + report_all_alleles, + strict_labeling, + num_threads, + scan_alignments_binary, +): logger = logging.getLogger(preprocess.__name__) logger.info("----------------------Preprocessing------------------------") @@ -283,26 +480,29 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, raise Exception("No normal BAM file {}".format(normal_bam)) if not os.path.exists(tumor_bam + ".bai"): logger.error("Aborting!") - raise Exception( - "No tumor .bai index file {}".format(tumor_bam + ".bai")) + raise Exception("No tumor .bai index file {}".format(tumor_bam + ".bai")) if not os.path.exists(normal_bam + ".bai"): logger.error("Aborting!") - raise Exception( - "No normal .bai index file {}".format(normal_bam + ".bai")) + raise Exception("No normal .bai index file {}".format(normal_bam + ".bai")) if no_feature_recomp_for_ensemble and ensemble_custom_header: logger.error("Aborting!") raise Exception( - "--ensemble_custom_header and --no_feature_recomp_for_ensemble are incompatible") + "--ensemble_custom_header and --no_feature_recomp_for_ensemble are incompatible" + ) if dbsnp: if dbsnp[-6:] != "vcf.gz": logger.error("Aborting!") raise Exception( - "The dbSNP file should be a tabix indexed file with .vcf.gz format") + "The dbSNP file should be a tabix indexed file with .vcf.gz format" + ) if not os.path.exists(dbsnp + ".tbi"): logger.error("Aborting!") raise Exception( - "The dbSNP file should be a tabix indexed file with .vcf.gz format. No {}.tbi file exists.".format(dbsnp)) + "The dbSNP file should be a tabix indexed file with .vcf.gz format. No {}.tbi file exists.".format( + dbsnp + ) + ) zero_vscore = False if (not ensemble_tsv and add_extra_features) and not use_vscore: @@ -313,11 +513,15 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, ensemble_bed = os.path.join(work, "ensemble.bed") logger.info("Extract ensemble info.") if restart or not os.path.exists(ensemble_bed): - extract_ensemble(ensemble_tsvs=[ensemble_tsv], ensemble_bed=ensemble_bed, - no_seq_complexity=no_seq_complexity, enforce_header=no_feature_recomp_for_ensemble, - custom_header=ensemble_custom_header, - zero_vscore=zero_vscore, - is_extend=False) + extract_ensemble( + ensemble_tsvs=[ensemble_tsv], + ensemble_bed=ensemble_bed, + no_seq_complexity=no_seq_complexity, + enforce_header=no_feature_recomp_for_ensemble, + custom_header=ensemble_custom_header, + zero_vscore=zero_vscore, + is_extend=False, + ) merge_d_for_short_read = 100 candidates_split_regions = [] @@ -328,60 +532,123 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, if restart or not os.path.exists(work_tumor_without_q): os.mkdir(work_tumor_without_q) filtered_candidates_vcf_without_q = os.path.join( - work_tumor_without_q, "filtered_candidates.vcf") - - tumor_outputs_without_q = process_split_region("tumor", work_tumor_without_q, region_bed, reference, mode, - tumor_bam, scan_window_size, scan_maf, min_mapq, - filtered_candidates_vcf_without_q, min_dp, max_dp, - filter_duplicate, - good_ao, min_ao, - snp_min_af, -10000, snp_min_ao, - ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r, - merge_d_for_scan, - report_all_alleles, - False, - scan_alignments_binary, restart, num_splits, num_threads, - calc_qual=False) - tumor_counts_without_q, split_regions, filtered_candidates_vcfs_without_q = tumor_outputs_without_q + work_tumor_without_q, "filtered_candidates.vcf" + ) + + tumor_outputs_without_q = process_split_region( + "tumor", + work_tumor_without_q, + region_bed, + reference, + mode, + tumor_bam, + scan_window_size, + scan_maf, + min_mapq, + filtered_candidates_vcf_without_q, + min_dp, + max_dp, + filter_duplicate, + good_ao, + min_ao, + snp_min_af, + -10000, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + merge_d_for_scan, + report_all_alleles, + False, + scan_alignments_binary, + restart, + num_splits, + num_threads, + calc_qual=False, + ) + ( + tumor_counts_without_q, + split_regions, + filtered_candidates_vcfs_without_q, + ) = tumor_outputs_without_q if ensemble_tsv: ensemble_beds = get_ensemble_beds( - work, reference, ensemble_bed, split_regions, matrix_base_pad, num_threads) + work, + reference, + ensemble_bed, + split_regions, + matrix_base_pad, + num_threads, + ) candidates_split_regions = extract_candidate_split_regions( - work_tumor_without_q, filtered_candidates_vcfs_without_q, split_regions, ensemble_beds, - reference, matrix_base_pad, merge_d_for_short_read) + work_tumor_without_q, + filtered_candidates_vcfs_without_q, + split_regions, + ensemble_beds, + reference, + matrix_base_pad, + merge_d_for_short_read, + ) work_tumor = os.path.join(work, "work_tumor") if restart or not os.path.exists(work_tumor): os.mkdir(work_tumor) - filtered_candidates_vcf = os.path.join( - work_tumor, "filtered_candidates.vcf") + filtered_candidates_vcf = os.path.join(work_tumor, "filtered_candidates.vcf") logger.info("Scan tumor bam (and extracting quality scores).") - tumor_outputs = process_split_region("tumor", work_tumor, region_bed, reference, mode, - tumor_bam, scan_window_size, scan_maf, min_mapq, - filtered_candidates_vcf, min_dp, max_dp, - filter_duplicate, - good_ao, min_ao, - snp_min_af, snp_min_bq, snp_min_ao, - ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r, - merge_d_for_scan, - report_all_alleles, - False, - scan_alignments_binary, restart, num_splits, num_threads, - calc_qual=True, - regions=candidates_split_regions) + tumor_outputs = process_split_region( + "tumor", + work_tumor, + region_bed, + reference, + mode, + tumor_bam, + scan_window_size, + scan_maf, + min_mapq, + filtered_candidates_vcf, + min_dp, + max_dp, + filter_duplicate, + good_ao, + min_ao, + snp_min_af, + snp_min_bq, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + merge_d_for_scan, + report_all_alleles, + False, + scan_alignments_binary, + restart, + num_splits, + num_threads, + calc_qual=True, + regions=candidates_split_regions, + ) tumor_counts, split_regions, filtered_candidates_vcfs = tumor_outputs if ensemble_tsv and not ensemble_beds: ensemble_beds = get_ensemble_beds( - work, reference, ensemble_bed, split_regions, matrix_base_pad, num_threads) + work, reference, ensemble_bed, split_regions, matrix_base_pad, num_threads + ) - if (not long_read): + if not long_read: candidates_split_regions = extract_candidate_split_regions( - work_tumor, filtered_candidates_vcfs, split_regions, ensemble_beds, - reference, matrix_base_pad, merge_d_for_short_read) + work_tumor, + filtered_candidates_vcfs, + split_regions, + ensemble_beds, + reference, + matrix_base_pad, + merge_d_for_short_read, + ) if not candidates_split_regions: candidates_split_regions = split_regions @@ -389,26 +656,57 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, if restart or not os.path.exists(work_normal): os.mkdir(work_normal) logger.info("Scan normal bam (and extracting quality scores).") - normal_counts, _, _ = process_split_region("normal", work_normal, region_bed, reference, mode, normal_bam, - scan_window_size, 0.2, min_mapq, - None, 1, max_dp, - filter_duplicate, - good_ao, min_ao, snp_min_af, snp_min_bq, snp_min_ao, - ins_min_af, del_min_af, del_merge_min_af, - ins_merge_min_af, merge_r, - merge_d_for_scan, - report_all_alleles, - True, - scan_alignments_binary, restart, num_splits, num_threads, - calc_qual=True, - regions=candidates_split_regions) + normal_counts, _, _ = process_split_region( + "normal", + work_normal, + region_bed, + reference, + mode, + normal_bam, + scan_window_size, + 0.2, + min_mapq, + None, + 1, + max_dp, + filter_duplicate, + good_ao, + min_ao, + snp_min_af, + snp_min_bq, + snp_min_ao, + ins_min_af, + del_min_af, + del_merge_min_af, + ins_merge_min_af, + merge_r, + merge_d_for_scan, + report_all_alleles, + True, + scan_alignments_binary, + restart, + num_splits, + num_threads, + calc_qual=True, + regions=candidates_split_regions, + ) work_dataset = os.path.join(work, "dataset") if restart or not os.path.exists(work_dataset): os.mkdir(work_dataset) logger.info("Generate dataset.") map_args_gen = [] - for i, (tumor_count, normal_count, filtered_vcf, candidates_split_region) in enumerate(zip(tumor_counts, normal_counts, filtered_candidates_vcfs, candidates_split_regions)): + for ( + i, + (tumor_count, normal_count, filtered_vcf, candidates_split_region), + ) in enumerate( + zip( + tumor_counts, + normal_counts, + filtered_candidates_vcfs, + candidates_split_regions, + ) + ): logger.info("Dataset for region {}".format(candidates_split_region)) work_dataset_split = os.path.join(work_dataset, "work.{}".format(i)) if restart or not os.path.exists("{}/done.txt".format(work_dataset_split)): @@ -416,63 +714,86 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, shutil.rmtree(work_dataset_split) os.mkdir(work_dataset_split) ensemble_bed_i = ensemble_beds[i] if ensemble_tsv else None - if add_extra_features or (ensemble_tsv and not no_feature_recomp_for_ensemble): + if add_extra_features or ( + ensemble_tsv and not no_feature_recomp_for_ensemble + ): work_tumor_i = os.path.dirname(filtered_vcf) if add_extra_features: extra_features_tsv = os.path.join( - work_tumor_i, "extra_features.tsv") + work_tumor_i, "extra_features.tsv" + ) ex_tsvs = [extra_features_tsv] if not os.path.exists(extra_features_tsv) or restart: - extend_features(filtered_vcf, - ensemble_beds[ - i] if (ensemble_tsv and no_feature_recomp_for_ensemble) else None, - None, - extra_features_tsv, - reference, tumor_bam, normal_bam, - min_mapq, snp_min_bq, - dbsnp, None, - no_seq_complexity, - window_extend, - max_cluster_size, - num_threads) + extend_features( + filtered_vcf, + ensemble_beds[i] + if (ensemble_tsv and no_feature_recomp_for_ensemble) + else None, + None, + extra_features_tsv, + reference, + tumor_bam, + normal_bam, + min_mapq, + snp_min_bq, + dbsnp, + None, + no_seq_complexity, + window_extend, + max_cluster_size, + num_threads, + ) else: ex_tsvs = [] extra_features_tsv = None if ensemble_tsv and not no_feature_recomp_for_ensemble: extra_features_others_tsv = os.path.join( - work_tumor_i, "extra_features_others.tsv") + work_tumor_i, "extra_features_others.tsv" + ) ex_tsvs.append(extra_features_others_tsv) if not os.path.exists(extra_features_others_tsv) or restart: - extend_features(ensemble_beds[i], - extra_features_tsv, - None, - extra_features_others_tsv, - reference, tumor_bam, normal_bam, - min_mapq, snp_min_bq, - dbsnp, None, - no_seq_complexity, - window_extend, - max_cluster_size, - num_threads) + extend_features( + ensemble_beds[i], + extra_features_tsv, + None, + extra_features_others_tsv, + reference, + tumor_bam, + normal_bam, + min_mapq, + snp_min_bq, + dbsnp, + None, + no_seq_complexity, + window_extend, + max_cluster_size, + num_threads, + ) extra_features_bed = os.path.join( - work_dataset_split, "extra_features.bed") + work_dataset_split, "extra_features.bed" + ) if not os.path.exists(extra_features_bed) or restart: - extract_ensemble(ensemble_tsvs=ex_tsvs, - ensemble_bed=extra_features_bed, - no_seq_complexity=no_seq_complexity, - enforce_header=True, - custom_header=ensemble_custom_header, - zero_vscore=zero_vscore, - is_extend=True) + extract_ensemble( + ensemble_tsvs=ex_tsvs, + ensemble_bed=extra_features_bed, + no_seq_complexity=no_seq_complexity, + enforce_header=True, + custom_header=ensemble_custom_header, + zero_vscore=zero_vscore, + is_extend=True, + ) if ensemble_tsv: merged_features_bed = os.path.join( - work_dataset_split, "merged_features.bed") + work_dataset_split, "merged_features.bed" + ) if not os.path.exists(merged_features_bed) or restart: exclude_ens_variants = [] if no_feature_recomp_for_ensemble: header_line = "" - with open(merged_features_bed, "w") as o_f, open(ensemble_beds[i]) as i_f_1, open(extra_features_bed) as i_f_2: + with open(merged_features_bed, "w") as o_f, open( + ensemble_beds[i] + ) as i_f_1, open(extra_features_bed) as i_f_2: for line in skip_empty(i_f_1, skip_header=False): if line.startswith("#"): if not header_line: @@ -481,11 +802,13 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, else: if header_line != line: logger.error( - "{}!={}".format(header_line, line)) + "{}!={}".format(header_line, line) + ) raise Exception continue chrom, pos, _, ref, alt = line.strip().split("\t")[ - 0:5] + 0:5 + ] var_id = "-".join([chrom, pos, ref, alt]) exclude_ens_variants.append(var_id) o_f.write(line) @@ -493,11 +816,13 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, if line.startswith("#"): if header_line != line: logger.error( - "{}!={}".format(header_line, line)) + "{}!={}".format(header_line, line) + ) raise Exception continue chrom, pos, _, ref, alt = line.strip().split("\t")[ - 0:5] + 0:5 + ] var_id = "-".join([chrom, pos, ref, alt]) if var_id in exclude_ens_variants: continue @@ -505,11 +830,34 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, else: if not ensemble_custom_header: header_line = "" - callers_features = ["if_MuTect", "if_VarScan2", "if_JointSNVMix2", "if_SomaticSniper", "if_VarDict", "MuSE_Tier", - "if_LoFreq", "if_Scalpel", "if_Strelka", "if_TNscope", "Strelka_Score", "Strelka_QSS", - "Strelka_TQSS", "SNVMix2_Score", "Sniper_Score", "VarDict_Score", - "M2_NLOD", "M2_TLOD", "M2_STR", "M2_ECNT", "MSI", "MSILEN", "SHIFT3"] - with open(merged_features_bed, "w") as o_f, open(ensemble_beds[i]) as i_f_1, open(extra_features_bed) as i_f_2: + callers_features = [ + "if_MuTect", + "if_VarScan2", + "if_JointSNVMix2", + "if_SomaticSniper", + "if_VarDict", + "MuSE_Tier", + "if_LoFreq", + "if_Scalpel", + "if_Strelka", + "if_TNscope", + "Strelka_Score", + "Strelka_QSS", + "Strelka_TQSS", + "SNVMix2_Score", + "Sniper_Score", + "VarDict_Score", + "M2_NLOD", + "M2_TLOD", + "M2_STR", + "M2_ECNT", + "MSI", + "MSILEN", + "SHIFT3", + ] + with open(merged_features_bed, "w") as o_f, open( + ensemble_beds[i] + ) as i_f_1, open(extra_features_bed) as i_f_2: ens_variants_info = {} header_1_found = False header_2_found = False @@ -520,59 +868,75 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, else: if header_line != line: logger.error( - "{}!={}".format(header_line, line)) + "{}!={}".format( + header_line, line + ) + ) raise Exception header_ = line.strip().split()[5:] - header_caller = list(filter( - lambda x: x[1] in callers_features, enumerate(header_))) + header_caller = list( + filter( + lambda x: x[1] in callers_features, + enumerate(header_), + ) + ) header_caller_ = list( - map(lambda x: x[1], header_caller)) + map(lambda x: x[1], header_caller) + ) header_i = list( - map(lambda x: x[0], header_caller)) + map(lambda x: x[0], header_caller) + ) header_1_found = True continue assert header_1_found fields = line.strip().split("\t") chrom, pos, _, ref, alt = fields[0:5] - var_id = "-".join([chrom, - pos, ref, alt]) - ens_variants_info[var_id] = np.array(fields[5:])[ - header_i] + var_id = "-".join([chrom, pos, ref, alt]) + ens_variants_info[var_id] = np.array( + fields[5:] + )[header_i] for line in skip_empty(i_f_2, skip_header=False): if line.startswith("#"): if header_line != line: logger.error( - "{}!={}".format(header_line, line)) + "{}!={}".format(header_line, line) + ) if not header_2_found: - header_2 = line.strip().split()[ - 5:] + header_2 = line.strip().split()[5:] order_header = [] for f in header_caller_: if f not in header_2: logger.info( - "Missing header field {}".format(f)) + "Missing header field {}".format( + f + ) + ) raise Exception order_header.append( - header_2.index(f)) + header_2.index(f) + ) o_f.write(line) header_2_found = True assert header_2_found fields = line.strip().split("\t") chrom, pos, _, ref, alt = fields[0:5] - var_id = "-".join([chrom, - pos, ref, alt]) + var_id = "-".join([chrom, pos, ref, alt]) if var_id in ens_variants_info: fields_ = np.array(fields[5:]) fields_[order_header] = ens_variants_info[ - var_id] + var_id + ] fields[5:] = fields_.tolist() o_f.write( - "\t".join(list(map(str, fields))) + "\n") + "\t".join(list(map(str, fields))) + "\n" + ) else: header_line_1 = "" header_line_2 = "" - with open(merged_features_bed, "w") as o_f, open(ensemble_beds[i]) as i_f_1, open(extra_features_bed) as i_f_2: + with open(merged_features_bed, "w") as o_f, open( + ensemble_beds[i] + ) as i_f_1, open(extra_features_bed) as i_f_2: ens_variants_info = {} ex_variants_info = {} header_1_found = False @@ -584,7 +948,10 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, else: if header_line_1 != line: logger.error( - "{}!={}".format(header_line_1, line)) + "{}!={}".format( + header_line_1, line + ) + ) raise Exception header_1 = line.strip().split()[5:] header_1_found = True @@ -592,10 +959,8 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, assert header_1_found fields = line.strip().split("\t") chrom, pos, _, ref, alt = fields[0:5] - var_id = "-".join([chrom, - pos, ref, alt]) - ens_variants_info[ - var_id] = np.array(fields[5:]) + var_id = "-".join([chrom, pos, ref, alt]) + ens_variants_info[var_id] = np.array(fields[5:]) for line in skip_empty(i_f_2, skip_header=False): if line.startswith("#"): if not header_line_2: @@ -603,7 +968,10 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, else: if header_line_2 != line: logger.error( - "{}!={}".format(header_line_2, line)) + "{}!={}".format( + header_line_2, line + ) + ) raise Exception header_2 = line.strip().split()[5:] header_2_found = True @@ -611,46 +979,81 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, assert header_2_found fields = line.strip().split("\t") chrom, pos, _, ref, alt = fields[0:5] - var_id = "-".join([chrom, - pos, ref, alt]) - ex_variants_info[ - var_id] = np.array(fields[5:]) - header_mixed = [ - "#CHROM", "POS", "ID", "REF", "ALT"] + header_1 + header_2 + var_id = "-".join([chrom, pos, ref, alt]) + ex_variants_info[var_id] = np.array(fields[5:]) + header_mixed = ( + ["#CHROM", "POS", "ID", "REF", "ALT"] + + header_1 + + header_2 + ) o_f.write( - "\t".join(list(map(str, header_mixed))) + "\n") - for var_id in set(ens_variants_info.keys()) | set(ex_variants_info.keys()): - features = [0.0] * \ - (len(header_1) + len(header_2)) + "\t".join(list(map(str, header_mixed))) + "\n" + ) + for var_id in set(ens_variants_info.keys()) | set( + ex_variants_info.keys() + ): + features = [0.0] * ( + len(header_1) + len(header_2) + ) if var_id in ens_variants_info: - features[0:len(header_1)] = ens_variants_info[ - var_id] + features[ + 0 : len(header_1) + ] = ens_variants_info[var_id] if var_id in ex_variants_info: - features[len(header_1):] = ex_variants_info[ - var_id] - chrom = "-".join(var_id.split("-") - [:-3]) + features[ + len(header_1) : + ] = ex_variants_info[var_id] + chrom = "-".join(var_id.split("-")[:-3]) pos, ref, alt = var_id.split("-")[-3:] o_f.write( - "\t".join(list(map(str, [chrom, pos, int(pos) + len(ref), ref, alt] + features))) + "\n") + "\t".join( + list( + map( + str, + [ + chrom, + pos, + int(pos) + len(ref), + ref, + alt, + ] + + features, + ) + ) + ) + + "\n" + ) ensemble_bed_i = merged_features_bed else: ensemble_bed_i = extra_features_bed - map_args_gen.append([work_dataset_split, truth_vcf, mode, filtered_vcf, - candidates_split_region, tumor_count, normal_count, reference, - matrix_width, matrix_base_pad, min_ev_frac_per_col, min_dp, - ensemble_bed_i, - ensemble_custom_header, - no_seq_complexity, no_feature_recomp_for_ensemble, - zero_vscore, - matrix_dtype, - strict_labeling, - tsv_batch_size]) + map_args_gen.append( + [ + work_dataset_split, + truth_vcf, + mode, + filtered_vcf, + candidates_split_region, + tumor_count, + normal_count, + reference, + matrix_width, + matrix_base_pad, + min_ev_frac_per_col, + min_dp, + ensemble_bed_i, + ensemble_custom_header, + no_seq_complexity, + no_feature_recomp_for_ensemble, + zero_vscore, + matrix_dtype, + strict_labeling, + tsv_batch_size, + ] + ) pool = multiprocessing.Pool(num_threads) try: - done_gen = pool.map_async( - generate_dataset_region_parallel, map_args_gen).get() + done_gen = pool.map_async(generate_dataset_region_parallel, map_args_gen).get() pool.close() except Exception as inst: logger.error(inst) @@ -662,157 +1065,258 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, if o is None: raise Exception("Generate dataset failed!") - shutil.rmtree(bed_tempdir) tempfile.tempdir = original_tempdir logger.info("Preprocessing is Done.") -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser( - description='Preprocess input alignments for train/call') - parser.add_argument('--mode', type=str, help='train/call mode', - choices=["train", "call"], required=True) - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--region_bed', type=str, - help='region bed', required=True) - parser.add_argument('--tumor_bam', type=str, - help='tumor bam', required=True) - parser.add_argument('--normal_bam', type=str, - help='normal bam', required=True) - parser.add_argument('--work', type=str, - help='work directory', required=True) - parser.add_argument('--dbsnp', type=str, - help='dbsnp vcf.gz', default=None) - parser.add_argument('--scan_window_size', type=int, - help='window size to scan the variants', default=2000) - parser.add_argument('--scan_maf', type=float, - help='minimum allele freq for scanning', default=0.01) - parser.add_argument('--min_mapq', type=int, - help='minimum mapping quality', default=1) - parser.add_argument('--min_dp', type=float, help='min depth', default=5) - parser.add_argument('--max_dp', type=float, - help='max depth', default=100000) - parser.add_argument('--good_ao', type=float, - help='good alternate count (ignores maf)', default=10) - parser.add_argument('--min_ao', type=float, - help='min alternate count', default=1) - parser.add_argument('--snp_min_af', type=float, - help='SNP min allele freq', default=0.05) - parser.add_argument('--snp_min_bq', type=float, - help='SNP min base quality', default=10) - parser.add_argument('--snp_min_ao', type=float, - help='SNP min alternate count for low AF candidates', default=3) - parser.add_argument('--ins_min_af', type=float, - help='INS min allele freq', default=0.05) - parser.add_argument('--del_min_af', type=float, - help='DEL min allele freq', default=0.05) - parser.add_argument('--del_merge_min_af', type=float, - help='min allele freq for merging DELs', default=0) - parser.add_argument('--ins_merge_min_af', type=float, - help='min allele freq for merging INSs', default=0) - parser.add_argument('--merge_r', type=float, - help='merge af ratio to the max af for merging adjacent variants', default=0.5) - parser.add_argument('--truth_vcf', type=str, - help='truth vcf (required for train mode)', default=None) - parser.add_argument('--tsv_batch_size', type=int, - help='output files batch size', default=50000) - parser.add_argument('--matrix_width', type=int, - help='target window width', default=32) - parser.add_argument('--matrix_base_pad', type=int, - help='number of bases to pad around the candidate variant', default=7) - parser.add_argument('--min_ev_frac_per_col', type=float, - help='minimum frac cov per column to keep columm', default=0.06) - parser.add_argument('--ensemble_tsv', type=str, - help='Ensemble annotation tsv file (only for short read)', default=None) - parser.add_argument('--ensemble_custom_header', - help='Allow ensemble tsv to have custom header fields. (Features should be\ - normalized between [0,1]', - action="store_true") - parser.add_argument('--long_read', - help='Enable long_read (high error-rate sequence) indel realignment', - action="store_true") - parser.add_argument('--restart', - help='Restart the process. (instead of continuing from where we left)', - action="store_true") - parser.add_argument('--first_do_without_qual', - help='Perform initial scan without calculating the quality stats', - action="store_true") - parser.add_argument('--keep_duplicate', - help='Don not filter duplicate reads when preparing pileup information', - action="store_true") - parser.add_argument('--add_extra_features', - help='add extra input features', - action="store_true") - parser.add_argument('--no_seq_complexity', - help='Dont compute linguistic sequence complexity features', - action="store_true") - parser.add_argument('--no_feature_recomp_for_ensemble', - help='Do not recompute features for ensemble_tsv', - action="store_true") - parser.add_argument('--window_extend', type=int, - help='window size for extending input features (should be in the order of readlength)', - default=1000) - parser.add_argument('--max_cluster_size', type=int, - help='max cluster size for extending input features (should be in the order of readlength)', - default=300) - parser.add_argument('--merge_d_for_scan', type=int, - help='-d used to merge regions before scan', - default=None) - parser.add_argument('--use_vscore', - help='don\'t set VarScan2_Score to zero', - action="store_true") - parser.add_argument('--num_splits', type=int, - help='number of region splits', default=None) - parser.add_argument('--matrix_dtype', type=str, - help='matrix_dtype to be used to store matrix', default="uint8", - choices=MAT_DTYPES) - parser.add_argument('--report_all_alleles', - help='report all alleles per position', - action="store_true") - parser.add_argument('--strict_labeling', - help='strict labeling in train mode', - action="store_true") - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) - parser.add_argument('--scan_alignments_binary', type=str, - help='binary for scanning alignment bam', default="../bin/scan_alignments") + description="Preprocess input alignments for train/call" + ) + parser.add_argument( + "--mode", + type=str, + help="train/call mode", + choices=["train", "call"], + required=True, + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--region_bed", type=str, help="region bed", required=True) + parser.add_argument("--tumor_bam", type=str, help="tumor bam", required=True) + parser.add_argument("--normal_bam", type=str, help="normal bam", required=True) + parser.add_argument("--work", type=str, help="work directory", required=True) + parser.add_argument("--dbsnp", type=str, help="dbsnp vcf.gz", default=None) + parser.add_argument( + "--scan_window_size", + type=int, + help="window size to scan the variants", + default=2000, + ) + parser.add_argument( + "--scan_maf", type=float, help="minimum allele freq for scanning", default=0.01 + ) + parser.add_argument( + "--min_mapq", type=int, help="minimum mapping quality", default=1 + ) + parser.add_argument("--min_dp", type=float, help="min depth", default=5) + parser.add_argument("--max_dp", type=float, help="max depth", default=100000) + parser.add_argument( + "--good_ao", type=float, help="good alternate count (ignores maf)", default=10 + ) + parser.add_argument("--min_ao", type=float, help="min alternate count", default=1) + parser.add_argument( + "--snp_min_af", type=float, help="SNP min allele freq", default=0.05 + ) + parser.add_argument( + "--snp_min_bq", type=float, help="SNP min base quality", default=10 + ) + parser.add_argument( + "--snp_min_ao", + type=float, + help="SNP min alternate count for low AF candidates", + default=3, + ) + parser.add_argument( + "--ins_min_af", type=float, help="INS min allele freq", default=0.05 + ) + parser.add_argument( + "--del_min_af", type=float, help="DEL min allele freq", default=0.05 + ) + parser.add_argument( + "--del_merge_min_af", + type=float, + help="min allele freq for merging DELs", + default=0, + ) + parser.add_argument( + "--ins_merge_min_af", + type=float, + help="min allele freq for merging INSs", + default=0, + ) + parser.add_argument( + "--merge_r", + type=float, + help="merge af ratio to the max af for merging adjacent variants", + default=0.5, + ) + parser.add_argument( + "--truth_vcf", + type=str, + help="truth vcf (required for train mode)", + default=None, + ) + parser.add_argument( + "--tsv_batch_size", type=int, help="output files batch size", default=50000 + ) + parser.add_argument( + "--matrix_width", type=int, help="target window width", default=32 + ) + parser.add_argument( + "--matrix_base_pad", + type=int, + help="number of bases to pad around the candidate variant", + default=7, + ) + parser.add_argument( + "--min_ev_frac_per_col", + type=float, + help="minimum frac cov per column to keep columm", + default=0.06, + ) + parser.add_argument( + "--ensemble_tsv", + type=str, + help="Ensemble annotation tsv file (only for short read)", + default=None, + ) + parser.add_argument( + "--ensemble_custom_header", + help="Allow ensemble tsv to have custom header fields. (Features should be\ + normalized between [0,1]", + action="store_true", + ) + parser.add_argument( + "--long_read", + help="Enable long_read (high error-rate sequence) indel realignment", + action="store_true", + ) + parser.add_argument( + "--restart", + help="Restart the process. (instead of continuing from where we left)", + action="store_true", + ) + parser.add_argument( + "--first_do_without_qual", + help="Perform initial scan without calculating the quality stats", + action="store_true", + ) + parser.add_argument( + "--keep_duplicate", + help="Don not filter duplicate reads when preparing pileup information", + action="store_true", + ) + parser.add_argument( + "--add_extra_features", help="add extra input features", action="store_true" + ) + parser.add_argument( + "--no_seq_complexity", + help="Dont compute linguistic sequence complexity features", + action="store_true", + ) + parser.add_argument( + "--no_feature_recomp_for_ensemble", + help="Do not recompute features for ensemble_tsv", + action="store_true", + ) + parser.add_argument( + "--window_extend", + type=int, + help="window size for extending input features (should be in the order of readlength)", + default=1000, + ) + parser.add_argument( + "--max_cluster_size", + type=int, + help="max cluster size for extending input features (should be in the order of readlength)", + default=300, + ) + parser.add_argument( + "--merge_d_for_scan", + type=int, + help="-d used to merge regions before scan", + default=None, + ) + parser.add_argument( + "--use_vscore", help="don't set VarScan2_Score to zero", action="store_true" + ) + parser.add_argument( + "--num_splits", type=int, help="number of region splits", default=None + ) + parser.add_argument( + "--matrix_dtype", + type=str, + help="matrix_dtype to be used to store matrix", + default="uint8", + choices=MAT_DTYPES, + ) + parser.add_argument( + "--report_all_alleles", + help="report all alleles per position", + action="store_true", + ) + parser.add_argument( + "--strict_labeling", help="strict labeling in train mode", action="store_true" + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) + parser.add_argument( + "--scan_alignments_binary", + type=str, + help="binary for scanning alignment bam", + default="../bin/scan_alignments", + ) args = parser.parse_args() logger.info(args) try: - preprocess(args.work, args.mode, args.reference, args.region_bed, args.tumor_bam, args.normal_bam, - args.dbsnp, - args.scan_window_size, args.scan_maf, args.min_mapq, - args.min_dp, args.max_dp, args.good_ao, args.min_ao, args.snp_min_af, args.snp_min_bq, args.snp_min_ao, - args.ins_min_af, args.del_min_af, args.del_merge_min_af, - args.ins_merge_min_af, args.merge_r, - args.truth_vcf, args.tsv_batch_size, args.matrix_width, args.matrix_base_pad, args.min_ev_frac_per_col, - args.ensemble_tsv, args.ensemble_custom_header, - args.long_read, args.restart, args.first_do_without_qual, - args.keep_duplicate, - args.add_extra_features, - args.no_seq_complexity, - args.no_feature_recomp_for_ensemble, - args.window_extend, - args.max_cluster_size, - args.merge_d_for_scan, - args.use_vscore, - args.num_splits, - args.matrix_dtype, - args.report_all_alleles, - args.strict_labeling, - args.num_threads, - args.scan_alignments_binary) + preprocess( + args.work, + args.mode, + args.reference, + args.region_bed, + args.tumor_bam, + args.normal_bam, + args.dbsnp, + args.scan_window_size, + args.scan_maf, + args.min_mapq, + args.min_dp, + args.max_dp, + args.good_ao, + args.min_ao, + args.snp_min_af, + args.snp_min_bq, + args.snp_min_ao, + args.ins_min_af, + args.del_min_af, + args.del_merge_min_af, + args.ins_merge_min_af, + args.merge_r, + args.truth_vcf, + args.tsv_batch_size, + args.matrix_width, + args.matrix_base_pad, + args.min_ev_frac_per_col, + args.ensemble_tsv, + args.ensemble_custom_header, + args.long_read, + args.restart, + args.first_do_without_qual, + args.keep_duplicate, + args.add_extra_features, + args.no_seq_complexity, + args.no_feature_recomp_for_ensemble, + args.window_extend, + args.max_cluster_size, + args.merge_d_for_scan, + args.use_vscore, + args.num_splits, + args.matrix_dtype, + args.report_all_alleles, + args.strict_labeling, + args.num_threads, + args.scan_alignments_binary, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "preprocess.py failure on arguments: {}".format(args)) + logger.error("preprocess.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/read_callers_vcf.py b/neusomatic/python/read_callers_vcf.py index 42c1428..fb693c3 100755 --- a/neusomatic/python/read_callers_vcf.py +++ b/neusomatic/python/read_callers_vcf.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # read_callers_vcf.py # read callers vcf files and generate ensemble tsv -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import traceback import logging @@ -18,13 +18,12 @@ # Normal/Tumor index in the Merged VCF file, or any other VCF file that # puts NORMAL first. idxN, idxT = 0, 1 -nan = float('nan') +nan = float("nan") def get_info_value(info_field, variable, ith_alt=None): logger = logging.getLogger(get_info_value.__name__) - key_item = re.search( - r'\b{}=([^;\s]+)([;\W]|$)'.format(variable), info_field) + key_item = re.search(r"\b{}=([^;\s]+)([;\W]|$)".format(variable), info_field) # The key has a value attached to it, e.g., VAR=1,2,3 if key_item: @@ -35,13 +34,13 @@ def get_info_value(info_field, variable, ith_alt=None): # Perhaps it's simply a flag without "=" else: - key_item = info_field.split(';') + key_item = info_field.split(";") return True if variable in key_item else False def get_sample_value(fields, samples, variable, idx=0): - var2value = dict(zip(fields.split(':'), samples[idx].split(':'))) + var2value = dict(zip(fields.split(":"), samples[idx].split(":"))) try: return var2value[variable] except KeyError: @@ -50,19 +49,20 @@ def get_sample_value(fields, samples, variable, idx=0): def get_mutect2_info(filters, info, ith_alt): - mutect_classification = 1 if (get_info_value(info, - 'SOMATIC') or 'PASS' in filters) else 0 + mutect_classification = ( + 1 if (get_info_value(info, "SOMATIC") or "PASS" in filters) else 0 + ) # MuTect2 has some useful information: - nlod = get_info_value(info, 'NLOD', ith_alt) + nlod = get_info_value(info, "NLOD", ith_alt) nlod = float(nlod) if nlod else nan - tlod = get_info_value(info, 'TLOD', ith_alt) + tlod = get_info_value(info, "TLOD", ith_alt) tlod = float(tlod) if tlod else nan - tandem = 1 if get_info_value(info, 'STR') else 0 + tandem = 1 if get_info_value(info, "STR") else 0 - ecnt = get_info_value(info, 'ECNT') + ecnt = get_info_value(info, "ECNT") if ecnt: try: ecnt = int(ecnt) @@ -74,19 +74,17 @@ def get_mutect2_info(filters, info, ith_alt): def get_varscan2_info(info): - varscan_classification = 1 if get_info_value(info, - 'SOMATIC') else 0 + varscan_classification = 1 if get_info_value(info, "SOMATIC") else 0 return varscan_classification def get_somaticsniper_info(fields, samples, idxT): - somaticsniper_classification = 1 if get_sample_value(fields, samples, - 'SS', idxT) == '2' else 0 + somaticsniper_classification = ( + 1 if get_sample_value(fields, samples, "SS", idxT) == "2" else 0 + ) if somaticsniper_classification == 1: - score_somaticsniper = get_sample_value(fields, samples, - 'SSC', idxT) - score_somaticsniper = int( - score_somaticsniper) if score_somaticsniper else nan + score_somaticsniper = get_sample_value(fields, samples, "SSC", idxT) + score_somaticsniper = int(score_somaticsniper) if score_somaticsniper else nan else: score_somaticsniper = nan @@ -95,23 +93,29 @@ def get_somaticsniper_info(fields, samples, idxT): def get_vardict_info(filters, info, fields, samples): - if (filters == 'PASS') and ('Somatic' in info): + if (filters == "PASS") and ("Somatic" in info): vardict_classification = 1 - elif 'Somatic' in info: - vardict_filters = filters.split(';') - - disqualifying_filters = \ - ('d7' in vardict_filters or 'd5' in vardict_filters) or \ - ('DIFF0.2' in vardict_filters) or \ - ('LongAT' in vardict_filters) or \ - ('MAF0.05' in vardict_filters) or \ - ('MSI6' in vardict_filters) or \ - ('NM4' in vardict_filters or 'NM4.25' in vardict_filters) or \ - ('pSTD' in vardict_filters) or \ - ('SN1.5' in vardict_filters) or \ - ( 'P0.05' in vardict_filters and float(get_info_value(info, 'SSF') ) >= 0.15 ) or \ - (('v3' in vardict_filters or 'v4' in vardict_filters) - and int(get_sample_value(fields, samples, 'VD', 1)) < 3) + elif "Somatic" in info: + vardict_filters = filters.split(";") + + disqualifying_filters = ( + ("d7" in vardict_filters or "d5" in vardict_filters) + or ("DIFF0.2" in vardict_filters) + or ("LongAT" in vardict_filters) + or ("MAF0.05" in vardict_filters) + or ("MSI6" in vardict_filters) + or ("NM4" in vardict_filters or "NM4.25" in vardict_filters) + or ("pSTD" in vardict_filters) + or ("SN1.5" in vardict_filters) + or ( + "P0.05" in vardict_filters + and float(get_info_value(info, "SSF")) >= 0.15 + ) + or ( + ("v3" in vardict_filters or "v4" in vardict_filters) + and int(get_sample_value(fields, samples, "VD", 1)) < 3 + ) + ) no_bad_filter = not disqualifying_filters filter_fail_times = len(vardict_filters) @@ -125,26 +129,26 @@ def get_vardict_info(filters, info, fields, samples): vardict_classification = 0 # Somatic Score: - score_vardict = get_info_value(info, 'SSF') + score_vardict = get_info_value(info, "SSF") if score_vardict: score_vardict = float(score_vardict) score_vardict = genome.p2phred(score_vardict, max_phred=100) - score_vardict = rescale(score_vardict, 'phred', None, 1001) + score_vardict = rescale(score_vardict, "phred", None, 1001) else: score_vardict = nan # MSI, MSILEN, and SHIFT3: - msi = get_info_value(info, 'MSI') + msi = get_info_value(info, "MSI") if msi: msi = float(msi) else: msi = nan - msilen = get_info_value(info, 'MSILEN') + msilen = get_info_value(info, "MSILEN") if msilen: msilen = float(msilen) else: msilen = nan - shift3 = get_info_value(info, 'SHIFT3') + shift3 = get_info_value(info, "SHIFT3") if shift3: shift3 = float(shift3) else: @@ -154,17 +158,17 @@ def get_vardict_info(filters, info, fields, samples): def get_muse_info(filters): - if filters == 'PASS': + if filters == "PASS": muse_classification = 1 - elif filters == 'Tier1': + elif filters == "Tier1": muse_classification = 0.9 - elif filters == 'Tier2': + elif filters == "Tier2": muse_classification = 0.8 - elif filters == 'Tier3': + elif filters == "Tier3": muse_classification = 0.7 - elif filters == 'Tier4': + elif filters == "Tier4": muse_classification = 0.6 - elif filters == 'Tier5': + elif filters == "Tier5": muse_classification = 0.5 else: muse_classification = 0 @@ -172,37 +176,38 @@ def get_muse_info(filters): def get_strelka2_info(filters, info): - strelka_classification = 1 if 'PASS' in filters else 0 - somatic_evs = get_info_value(info, 'SomaticEVS') - qss = get_info_value(info, 'QSS') - tqss = get_info_value(info, 'TQSS') + strelka_classification = 1 if "PASS" in filters else 0 + somatic_evs = get_info_value(info, "SomaticEVS") + qss = get_info_value(info, "QSS") + tqss = get_info_value(info, "TQSS") return strelka_classification, somatic_evs, qss, tqss def open_textfile(file_name): # See if the input file is a .gz file: - if file_name.lower().endswith('.gz'): - return gzip.open(file_name, 'rt') + if file_name.lower().endswith(".gz"): + return gzip.open(file_name, "rt") else: return open(file_name) -def read_callers_vcf(reference, - output_tsv, - mutect2_vcfs, - strelka2_vcfs, - varscan2_vcfs, - muse_vcfs, - vardict_vcfs, - somaticsniper_vcfs, - min_caller): +def read_callers_vcf( + reference, + output_tsv, + mutect2_vcfs, + strelka2_vcfs, + varscan2_vcfs, + muse_vcfs, + vardict_vcfs, + somaticsniper_vcfs, + min_caller, +): logger = logging.getLogger(read_callers_vcf.__name__) - logger.info( - "----------------------Read Callers VCF------------------------") + logger.info("----------------------Read Callers VCF------------------------") mutect2_info = {} if mutect2_vcfs: @@ -214,11 +219,21 @@ def read_callers_vcf(reference, for ith_alt, alt in enumerate(alts.split(",")): if ref != alt: if len(ref) == 1 or len(alt) == 1: - mutect_classification, nlod, tlod, tandem, ecnt = get_mutect2_info( - filters, info, ith_alt) + ( + mutect_classification, + nlod, + tlod, + tandem, + ecnt, + ) = get_mutect2_info(filters, info, ith_alt) var_id = "-".join([chrom, pos, ref, alt]) mutect2_info[var_id] = [ - mutect_classification, nlod, tlod, tandem, ecnt] + mutect_classification, + nlod, + tlod, + tandem, + ecnt, + ] i_f.close() strelka2_info = {} if strelka2_vcfs: @@ -228,12 +243,17 @@ def read_callers_vcf(reference, x = line.strip().split() chrom, pos, _, ref, alts, _, filters, info = x[0:8] strelka_classification, somatic_evs, qss, tqss = get_strelka2_info( - filters, info) + filters, info + ) for alt in alts.split(","): if ref != alt: var_id = "-".join([chrom, pos, ref, alt]) strelka2_info[var_id] = [ - strelka_classification, somatic_evs, qss, tqss] + strelka_classification, + somatic_evs, + qss, + tqss, + ] i_f.close() vardict_info = {} if vardict_vcfs: @@ -246,24 +266,44 @@ def read_callers_vcf(reference, # In the REF/ALT field, non-GCTA characters should be # changed to N to fit the VCF standard: - ref = re.sub(r'[^GCTA]', 'N', ref, flags=re.I) - alts = re.sub(r'[^GCTA]', 'N', alts, flags=re.I) - - vardict_classification, msi, msilen, shift3, score_vardict = get_vardict_info( - filters, info, fields, samples) + ref = re.sub(r"[^GCTA]", "N", ref, flags=re.I) + alts = re.sub(r"[^GCTA]", "N", alts, flags=re.I) + + ( + vardict_classification, + msi, + msilen, + shift3, + score_vardict, + ) = get_vardict_info(filters, info, fields, samples) for alt in alts.split(","): if ref != alt: - if 'TYPE=SNV' in info or 'TYPE=Deletion' in info or 'TYPE=Insertion' in info: + if ( + "TYPE=SNV" in info + or "TYPE=Deletion" in info + or "TYPE=Insertion" in info + ): var_id = "-".join([chrom, pos, ref, alt]) vardict_info[var_id] = [ - vardict_classification, msi, msilen, shift3, score_vardict] - elif 'TYPE=Complex' in info and (len(ref) == len(alt)): + vardict_classification, + msi, + msilen, + shift3, + score_vardict, + ] + elif "TYPE=Complex" in info and (len(ref) == len(alt)): for i, (ref_i, alt_i) in enumerate(zip(ref, alt)): if ref_i != alt_i: - var_id = "-".join([chrom, - str(int(pos) + i), ref_i, alt_i]) + var_id = "-".join( + [chrom, str(int(pos) + i), ref_i, alt_i] + ) vardict_info[var_id] = [ - vardict_classification, msi, msilen, shift3, score_vardict] + vardict_classification, + msi, + msilen, + shift3, + score_vardict, + ] i_f.close() varscan2_info = {} if varscan2_vcfs: @@ -276,20 +316,20 @@ def read_callers_vcf(reference, # Replace the wrong "G/A" with the correct "G,A" in ALT # column: - alts = alts.replace('/', ',') + alts = alts.replace("/", ",") # multiple sequences in the REF, as is the case in # VarScan2's indel output: - ref = re.sub(r'[^\w].*$', '', ref) + ref = re.sub(r"[^\w].*$", "", ref) # Get rid of non-compliant characters in the ALT column: - alts = re.sub(r'[^\w,.]', '', alts) + alts = re.sub(r"[^\w,.]", "", alts) # Eliminate dupliate entries in ALT: - alts = re.sub(r'(\w+),\1', r'\1', alts) + alts = re.sub(r"(\w+),\1", r"\1", alts) # VarScan2 output a line with REF allele as "M" - if re.search(r'[^GCTAU]', ref, re.I): + if re.search(r"[^GCTAU]", ref, re.I): continue for alt in alts.split(","): @@ -320,31 +360,39 @@ def read_callers_vcf(reference, x = line.strip().split() chrom, pos, _, ref, alts, _, filters, info, fields = x[0:9] samples = x[9:] - ref = re.sub(r'[^GCTA]', 'N', ref, flags=re.I) - somaticsniper_classification, score_somaticsniper = get_somaticsniper_info( - fields, samples, idxT) + ref = re.sub(r"[^GCTA]", "N", ref, flags=re.I) + ( + somaticsniper_classification, + score_somaticsniper, + ) = get_somaticsniper_info(fields, samples, idxT) for alt in alts.split(","): if ref != alt: var_id = "-".join([chrom, pos, ref, alt]) somaticsniper_info[var_id] = [ - somaticsniper_classification, score_somaticsniper] + somaticsniper_classification, + score_somaticsniper, + ] i_f.close() features = {} - for var_id in (set(mutect2_info.keys()) | set(strelka2_info.keys()) | set(vardict_info.keys()) | - set(varscan2_info.keys()) | set(somaticsniper_info.keys()) | set(muse_info.keys())): + for var_id in ( + set(mutect2_info.keys()) + | set(strelka2_info.keys()) + | set(vardict_info.keys()) + | set(varscan2_info.keys()) + | set(somaticsniper_info.keys()) + | set(muse_info.keys()) + ): num_callers = 0 if var_id in mutect2_info: - mutect_classification, nlod, tlod, tandem, ecnt = mutect2_info[ - var_id] + mutect_classification, nlod, tlod, tandem, ecnt = mutect2_info[var_id] num_callers += mutect_classification else: mutect_classification = 0 nlod = tlod = tandem = ecnt = nan if var_id in strelka2_info: - strelka_classification, somatic_evs, qss, tqss = strelka2_info[ - var_id] + strelka_classification, somatic_evs, qss, tqss = strelka2_info[var_id] num_callers += strelka_classification else: strelka_classification = 0 @@ -352,7 +400,8 @@ def read_callers_vcf(reference, if var_id in vardict_info: vardict_classification, msi, msilen, shift3, score_vardict = vardict_info[ - var_id] + var_id + ] num_callers += vardict_classification else: vardict_classification = 0 @@ -372,108 +421,188 @@ def read_callers_vcf(reference, if var_id in somaticsniper_info: somaticsniper_classification, score_somaticsniper = somaticsniper_info[ - var_id] + var_id + ] num_callers += somaticsniper_classification else: somaticsniper_classification = 0 score_somaticsniper = nan if num_callers >= min_caller: - features[var_id] = [mutect_classification, nlod, tlod, tandem, ecnt, - strelka_classification, somatic_evs, qss, tqss, - vardict_classification, msi, msilen, shift3, score_vardict, - varscan_classification, - muse_classification, - somaticsniper_classification, score_somaticsniper] + features[var_id] = [ + mutect_classification, + nlod, + tlod, + tandem, + ecnt, + strelka_classification, + somatic_evs, + qss, + tqss, + vardict_classification, + msi, + msilen, + shift3, + score_vardict, + varscan_classification, + muse_classification, + somaticsniper_classification, + score_somaticsniper, + ] chrom_order = get_chromosomes_order(reference) - ordered_vars = sorted(features.keys(), key=lambda x: [ - chrom_order["-".join(x.split("-")[:-3])], int(x.split("-")[1])]) + ordered_vars = sorted( + features.keys(), + key=lambda x: [chrom_order["-".join(x.split("-")[:-3])], int(x.split("-")[1])], + ) n_variants = len(ordered_vars) logger.info("Number of variants: {}".format(n_variants)) - header = ["CHROM", "POS", "ID", "REF", "ALT", "if_MuTect", "if_VarScan2", "if_SomaticSniper", "if_VarDict", "MuSE_Tier", - "if_Strelka", "Strelka_Score", "Strelka_QSS", - "Strelka_TQSS", "Sniper_Score", "VarDict_Score", - "M2_NLOD", "M2_TLOD", "M2_STR", "M2_ECNT", "MSI", "MSILEN", "SHIFT3"] + header = [ + "CHROM", + "POS", + "ID", + "REF", + "ALT", + "if_MuTect", + "if_VarScan2", + "if_SomaticSniper", + "if_VarDict", + "MuSE_Tier", + "if_Strelka", + "Strelka_Score", + "Strelka_QSS", + "Strelka_TQSS", + "Sniper_Score", + "VarDict_Score", + "M2_NLOD", + "M2_TLOD", + "M2_STR", + "M2_ECNT", + "MSI", + "MSILEN", + "SHIFT3", + ] with open(output_tsv, "w") as o_f: o_f.write("\t".join(header) + "\n") for var_id in ordered_vars: - mutect_classification, nlod, tlod, tandem, ecnt, \ - strelka_classification, somatic_evs, qss, tqss, \ - vardict_classification, msi, msilen, shift3, score_vardict, \ - varscan_classification, \ - muse_classification, \ - somaticsniper_classification, score_somaticsniper = features[ - var_id] - - f = [mutect_classification, varscan_classification, somaticsniper_classification, - vardict_classification, muse_classification, strelka_classification, - somatic_evs, qss, tqss, - score_somaticsniper, score_vardict, - nlod, tlod, tandem, ecnt, - msi, msilen, shift3] + ( + mutect_classification, + nlod, + tlod, + tandem, + ecnt, + strelka_classification, + somatic_evs, + qss, + tqss, + vardict_classification, + msi, + msilen, + shift3, + score_vardict, + varscan_classification, + muse_classification, + somaticsniper_classification, + score_somaticsniper, + ) = features[var_id] + + f = [ + mutect_classification, + varscan_classification, + somaticsniper_classification, + vardict_classification, + muse_classification, + strelka_classification, + somatic_evs, + qss, + tqss, + score_somaticsniper, + score_vardict, + nlod, + tlod, + tandem, + ecnt, + msi, + msilen, + shift3, + ] chrom = "-".join(var_id.split("-")[:-3]) pos, ref, alt = var_id.split("-")[-3:] ref, alt = ref.upper(), alt.upper() o_f.write( - "\t".join([chrom, pos, ".", ref, alt] + list(map(lambda x: str(x).replace("nan", "0"), f))) + "\n") + "\t".join( + [chrom, pos, ".", ref, alt] + + list(map(lambda x: str(x).replace("nan", "0"), f)) + ) + + "\n" + ) logger.info("Done Reading Callers' Features.") return output_tsv -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser( - description='extract extra features for standalone mode') - parser.add_argument('--reference', type=str, help='reference fasta filename', - required=True) - parser.add_argument('--output_tsv', type=str, help='output features tsv', - required=True) - parser.add_argument('--mutect2_vcfs', type=str, nargs="*", - help='MuTect2 VCFs', - default=None) - parser.add_argument('--strelka2_vcfs', type=str, nargs="*", - help='Strelka2 VCFs', - default=None) - parser.add_argument('--varscan2_vcfs', type=str, nargs="*", - help='VarScan2 VCFs', - default=None) - parser.add_argument('--muse_vcfs', type=str, nargs="*", - help='MuSE VCFs', - default=None) - parser.add_argument('--vardict_vcfs', type=str, nargs="*", - help='VarDict VCFs', - default=None) - parser.add_argument('--somaticsniper_vcfs', type=str, nargs="*", - help='SomaticSniper VCFs', - default=None) - parser.add_argument('--min_caller', type=float, - help='Number of minimum callers support a call', - default=0.5) + description="extract extra features for standalone mode" + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument( + "--output_tsv", type=str, help="output features tsv", required=True + ) + parser.add_argument( + "--mutect2_vcfs", type=str, nargs="*", help="MuTect2 VCFs", default=None + ) + parser.add_argument( + "--strelka2_vcfs", type=str, nargs="*", help="Strelka2 VCFs", default=None + ) + parser.add_argument( + "--varscan2_vcfs", type=str, nargs="*", help="VarScan2 VCFs", default=None + ) + parser.add_argument( + "--muse_vcfs", type=str, nargs="*", help="MuSE VCFs", default=None + ) + parser.add_argument( + "--vardict_vcfs", type=str, nargs="*", help="VarDict VCFs", default=None + ) + parser.add_argument( + "--somaticsniper_vcfs", + type=str, + nargs="*", + help="SomaticSniper VCFs", + default=None, + ) + parser.add_argument( + "--min_caller", + type=float, + help="Number of minimum callers support a call", + default=0.5, + ) args = parser.parse_args() logger.info(args) try: - output = read_callers_vcf(args.reference, - args.output_tsv, - args.mutect2_vcfs, - args.strelka2_vcfs, - args.varscan2_vcfs, - args.muse_vcfs, - args.vardict_vcfs, - args.somaticsniper_vcfs, - args.min_caller, - ) + output = read_callers_vcf( + args.reference, + args.output_tsv, + args.mutect2_vcfs, + args.strelka2_vcfs, + args.varscan2_vcfs, + args.muse_vcfs, + args.vardict_vcfs, + args.somaticsniper_vcfs, + args.min_caller, + ) if output is None: raise Exception("read_callers_vcf failed!") except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "read_callers_vcf.py failure on arguments: {}".format(args)) + logger.error("read_callers_vcf.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/read_info_extractor.py b/neusomatic/python/read_info_extractor.py index b9ae3e7..8260216 100644 --- a/neusomatic/python/read_info_extractor.py +++ b/neusomatic/python/read_info_extractor.py @@ -4,7 +4,7 @@ import logging import numpy as np -FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) @@ -18,16 +18,18 @@ cigar_seq_match = 7 cigar_seq_mismatch = 8 -nan = float('nan') -inf = float('inf') +nan = float("nan") +inf = float("inf") # Define functions: ### PYSAM ### -def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_position, win_size=3): - ''' +def position_of_aligned_read( + aligned_pairs, read_pos_for_ref_pos, target_position, win_size=3 +): + """ Return the base call of the target position, and if it's a start of insertion/deletion. This target position follows pysam convension, i.e., 0-based. In VCF files, deletions/insertions occur AFTER the position. @@ -38,7 +40,7 @@ def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_positio 2: Deletion after the target position 3: Insertion after the target position 0: The target position does not match to reference, and may be discarded for "reference/alternate" read count purposes, but can be kept for "inconsistent read" metrics. - ''' + """ flanking_deletion, flanking_insertion = nan, nan # get_read_pos_for_ref_pos(read_i, target_position) @@ -55,17 +57,23 @@ def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_positio indel_length = 0 # If the next alignment is the next sequenced base, then the # target is either a reference read of a SNP/SNV: - if aligned_pairs[idx_aligned_pair + 1][0] == seq_i + 1 and aligned_pairs[idx_aligned_pair + 1][1] == target_position + 1: + if ( + aligned_pairs[idx_aligned_pair + 1][0] == seq_i + 1 + and aligned_pairs[idx_aligned_pair + 1][1] == target_position + 1 + ): code = 1 # Reference read for mismatch # If the next reference position has no read position to it, it # is DELETED in this read: - elif aligned_pairs[idx_aligned_pair + 1][0] == None and aligned_pairs[idx_aligned_pair + 1][1] == target_position + 1: + elif ( + aligned_pairs[idx_aligned_pair + 1][0] == None + and aligned_pairs[idx_aligned_pair + 1][1] == target_position + 1 + ): code = 2 # Deletion - for align_j in aligned_pairs[idx_aligned_pair + 1::]: + for align_j in aligned_pairs[idx_aligned_pair + 1 : :]: if align_j[0] == None: indel_length -= 1 else: @@ -76,11 +84,14 @@ def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_positio # the inserted sequence is "too long" to align on a single # read. In this case, the inserted length derived here is but a # lower limit of the real inserted length. - elif aligned_pairs[idx_aligned_pair + 1][0] == seq_i + 1 and aligned_pairs[idx_aligned_pair + 1][1] == None: + elif ( + aligned_pairs[idx_aligned_pair + 1][0] == seq_i + 1 + and aligned_pairs[idx_aligned_pair + 1][1] == None + ): code = 3 # Insertion or soft-clipping - for align_j in aligned_pairs[idx_aligned_pair + 1::]: + for align_j in aligned_pairs[idx_aligned_pair + 1 : :]: if align_j[1] == None: indel_length += 1 else: @@ -88,8 +99,10 @@ def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_positio # If "i" is the final alignment, cannt exam for indel: else: - code = 1 # Assuming no indel - indel_length = nan # Would be zero if certain no indel, but uncertain here + code = 1 # Assuming no indel + indel_length = ( + nan # Would be zero if certain no indel, but uncertain here + ) # If the target position is deleted from the sequencing read (i.e., the # deletion in this read occurs before the target position): @@ -105,21 +118,23 @@ def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_positio left_side_start = idx_aligned_pair - 1 right_side_start = idx_aligned_pair + abs(indel_length) + 1 - #(i, None) = Insertion (or Soft-clips), i.e., means the i_th base in the query is not aligned to a reference - #(None, coordinate) = Deletion, i.e., there is no base in it that aligns to this coordinate. + # (i, None) = Insertion (or Soft-clips), i.e., means the i_th base in the query is not aligned to a reference + # (None, coordinate) = Deletion, i.e., there is no base in it that aligns to this coordinate. # If those two scenarios occur right after an aligned base, that # base position is counted as an indel. - for step_right_i in range(min(win_size, len(aligned_pairs) - right_side_start - 1)): + for step_right_i in range( + min(win_size, len(aligned_pairs) - right_side_start - 1) + ): j = right_side_start + step_right_i - if (aligned_pairs[j + 1][1] == None or aligned_pairs[j + 1][0] == None): + if aligned_pairs[j + 1][1] == None or aligned_pairs[j + 1][0] == None: right_indel_flanks = step_right_i + 1 break for step_left_i in range(min(win_size, left_side_start)): j = left_side_start - step_left_i - if (aligned_pairs[j][1] == None or aligned_pairs[j][0] == None): + if aligned_pairs[j][1] == None or aligned_pairs[j][0] == None: left_indel_flanks = step_left_i + 1 break flanking_indel = min(left_indel_flanks, right_indel_flanks) @@ -136,10 +151,10 @@ def position_of_aligned_read(aligned_pairs, read_pos_for_ref_pos, target_positio # Dedup test for BAM file def dedup_test(read_i, remove_dup_or_not=True): - ''' + """ Return False (i.e., remove the read) if the read is a duplicate and if the user specify that duplicates should be removed. Else return True (i.e, keep the read) - ''' + """ if read_i.is_duplicate and remove_dup_or_not: return False else: @@ -162,8 +177,8 @@ def mean(stuff): # Extract Indel DP4 info from pileup files: def pileup_indel_DP4(pileup_object, indel_pattern): if pileup_object.reads: - ref_for = pileup_object.reads.count('.') - ref_rev = pileup_object.reads.count(',') + ref_for = pileup_object.reads.count(".") + ref_rev = pileup_object.reads.count(",") alt_for = pileup_object.reads.count(indel_pattern.upper()) alt_rev = pileup_object.reads.count(indel_pattern.lower()) @@ -184,24 +199,36 @@ def pileup_DP4(pileup_object, ref_base, variant_call): # SNV if len(variant_call) == len(ref_base): - ref_for, ref_rev, alt_for, alt_rev = base_calls[0], base_calls[1], base_calls[ - 2].count(variant_call.upper()), base_calls[3].count(variant_call.lower()) + ref_for, ref_rev, alt_for, alt_rev = ( + base_calls[0], + base_calls[1], + base_calls[2].count(variant_call.upper()), + base_calls[3].count(variant_call.lower()), + ) # Insertion: elif len(variant_call) > len(ref_base): - inserted_sequence = variant_call[len(ref_base)::] + inserted_sequence = variant_call[len(ref_base) : :] - ref_for, ref_rev, alt_for, alt_rev = base_calls[0], base_calls[1], base_calls[ - 6].count(inserted_sequence.upper()), base_calls[7].count(inserted_sequence.lower()) + ref_for, ref_rev, alt_for, alt_rev = ( + base_calls[0], + base_calls[1], + base_calls[6].count(inserted_sequence.upper()), + base_calls[7].count(inserted_sequence.lower()), + ) # Deletion: elif len(variant_call) < len(ref_base): - deleted_sequence = ref_base[len(variant_call)::] + deleted_sequence = ref_base[len(variant_call) : :] - ref_for, ref_rev, alt_for, alt_rev = base_calls[0], base_calls[1], base_calls[ - 4].count(deleted_sequence.upper()), base_calls[5].count(deleted_sequence.lower()) + ref_for, ref_rev, alt_for, alt_rev = ( + base_calls[0], + base_calls[1], + base_calls[4].count(deleted_sequence.upper()), + base_calls[5].count(deleted_sequence.lower()), + ) else: ref_for = ref_rev = alt_for = alt_rev = 0 @@ -209,17 +236,17 @@ def pileup_DP4(pileup_object, ref_base, variant_call): return ref_for, ref_rev, alt_for, alt_rev -def rescale(x, original='fraction', rescale_to=None, max_phred=1001): +def rescale(x, original="fraction", rescale_to=None, max_phred=1001): if (rescale_to == None) or (original.lower() == rescale_to.lower()): - y = x if isinstance(x, int) else '%.2f' % x + y = x if isinstance(x, int) else "%.2f" % x - elif original.lower() == 'fraction' and rescale_to == 'phred': + elif original.lower() == "fraction" and rescale_to == "phred": y = genome.p2phred(x, max_phred=max_phred) - y = '%.2f' % y + y = "%.2f" % y - elif original.lower() == 'phred' and rescale_to == 'fraction': + elif original.lower() == "phred" and rescale_to == "fraction": y = genome.phred2p(x) - y = '%.2f' % y + y = "%.2f" % y return y diff --git a/neusomatic/python/resolve_scores.py b/neusomatic/python/resolve_scores.py index f54c57c..152720e 100755 --- a/neusomatic/python/resolve_scores.py +++ b/neusomatic/python/resolve_scores.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # resolve_score.py # resolve prediction scores for realigned variants -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import argparse import logging @@ -19,22 +19,21 @@ def resolve_scores(input_bam, ra_vcf, target_vcf, output_vcf): logger.info("-----Resolve Prediction Scores for Realigned Variants------") - tmp_ = bedtools_window( - ra_vcf, target_vcf, args=" -w 5 -v", run_logger=logger) + tmp_ = bedtools_window(ra_vcf, target_vcf, args=" -w 5 -v", run_logger=logger) final_intervals = read_tsv_file(tmp_) for x in final_intervals: x[5] = str(np.round(-10 * np.log10(0.25), 4)) - tmp_ = bedtools_window( - ra_vcf, target_vcf, args=" -w 5", run_logger=logger) + tmp_ = bedtools_window(ra_vcf, target_vcf, args=" -w 5", run_logger=logger) intervals_dict = {} with open(tmp_) as i_f: for line in skip_empty(i_f): interval = line.strip().split("\t") - id_ = "{}-{}-{}-{}".format(interval[0], - interval[1], interval[3], interval[4]) + id_ = "{}-{}-{}-{}".format( + interval[0], interval[1], interval[3], interval[4] + ) if id_ not in intervals_dict: intervals_dict[id_] = [] intervals_dict[id_].append(interval) @@ -45,31 +44,41 @@ def resolve_scores(input_bam, ra_vcf, target_vcf, output_vcf): interval = intervals[0][:10] interval[5] = score interval[7] = "SCORE={:.4f}".format( - np.round(1 - (10**(-float(score) / 10)), 4)) + np.round(1 - (10 ** (-float(score) / 10)), 4) + ) else: - len_ = (len(intervals[0][4]) - len(intervals[0][3])) + len_ = len(intervals[0][4]) - len(intervals[0][3]) pos_ = int(intervals[0][1]) - len_diff = list(map(lambda x: abs( - (len(x[14]) - len(x[13])) - len_), intervals)) + len_diff = list( + map(lambda x: abs((len(x[14]) - len(x[13])) - len_), intervals) + ) min_len_diff = min(len_diff) - intervals = list(filter(lambda x: abs( - (len(x[14]) - len(x[13])) - len_) == min_len_diff, intervals)) + intervals = list( + filter( + lambda x: abs((len(x[14]) - len(x[13])) - len_) == min_len_diff, + intervals, + ) + ) pos_diff = list(map(lambda x: abs(int(x[11]) - pos_), intervals)) min_pos_diff = min(pos_diff) - intervals = list(filter(lambda x: abs( - int(x[11]) - pos_) == min_pos_diff, intervals)) + intervals = list( + filter(lambda x: abs(int(x[11]) - pos_) == min_pos_diff, intervals) + ) score = "{:.4f}".format( - np.round(max(map(lambda x: float(x[15]), intervals)), 4)) + np.round(max(map(lambda x: float(x[15]), intervals)), 4) + ) interval = intervals[0][:10] interval[5] = score interval[7] = "SCORE={:.4f}".format( - np.round(1 - (10**(-float(score) / 10)), 4)) + np.round(1 - (10 ** (-float(score) / 10)), 4) + ) final_intervals.append(interval) chroms_order = get_chromosomes_order(bam=input_bam) - out_variants = sorted(final_intervals, key=lambda x: [ - chroms_order[x[0]], int(x[1])]) + out_variants = sorted( + final_intervals, key=lambda x: [chroms_order[x[0]], int(x[1])] + ) with open(output_vcf, "w") as o_f: o_f.write("{}\n".format(VCF_HEADER)) o_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") @@ -77,29 +86,23 @@ def resolve_scores(input_bam, ra_vcf, target_vcf, output_vcf): o_f.write("\t".join(var) + "\n") -if __name__ == '__main__': +if __name__ == "__main__": - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser(description='Resolve scores') - parser.add_argument('--input_bam', type=str, - help='input bam', required=True) - parser.add_argument('--ra_vcf', type=str, - help='realigned vcf', required=True) - parser.add_argument('--target_vcf', type=str, - help='target vcf', required=True) - parser.add_argument('--output_vcf', type=str, - help='output_vcf', required=True) + parser = argparse.ArgumentParser(description="Resolve scores") + parser.add_argument("--input_bam", type=str, help="input bam", required=True) + parser.add_argument("--ra_vcf", type=str, help="realigned vcf", required=True) + parser.add_argument("--target_vcf", type=str, help="target vcf", required=True) + parser.add_argument("--output_vcf", type=str, help="output_vcf", required=True) args = parser.parse_args() try: - resolve_scores(args.input_bam, args.ra_vcf, - args.target_vcf, args.output_vcf) + resolve_scores(args.input_bam, args.ra_vcf, args.target_vcf, args.output_vcf) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "resolve_scores.py failure on arguments: {}".format(args)) + logger.error("resolve_scores.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/resolve_variants.py b/neusomatic/python/resolve_variants.py index 4e7c009..ea60bad 100755 --- a/neusomatic/python/resolve_variants.py +++ b/neusomatic/python/resolve_variants.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # resolve_variants.py # Resolve variants (e.g. exact INDEL sequences) for target variants # identified by 'extract_postprocess_targets.py'. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import multiprocessing import argparse @@ -14,7 +14,14 @@ import pysam import numpy as np -from utils import get_chromosomes_order, read_tsv_file, bedtools_sort, bedtools_merge, get_tmp_file, skip_empty +from utils import ( + get_chromosomes_order, + read_tsv_file, + bedtools_sort, + bedtools_merge, + get_tmp_file, + skip_empty, +) from defaults import VCF_HEADER CIGAR_MATCH = 0 @@ -26,7 +33,7 @@ CIGAR_DIFF = 8 _CIGAR_OPS = "MIDNSHP=X" -_CIGAR_PATTERN = re.compile(r'([0-9]+)([MIDNSHPX=])') +_CIGAR_PATTERN = re.compile(r"([0-9]+)([MIDNSHPX=])") _CIGAR_OP_DICT = {op: index for index, op in enumerate(_CIGAR_OPS)} _CIGAR_REFERENCE_OPS = [CIGAR_MATCH, CIGAR_DEL, CIGAR_EQUAL, CIGAR_DIFF] _CIGAR_READ_ALN_OPS = [CIGAR_MATCH, CIGAR_INS, CIGAR_EQUAL, CIGAR_DIFF] @@ -41,10 +48,8 @@ def extract_del(record): dels = [] pos = record.pos cigartuples = record.cigartuples - first_sc = 1 if cigartuples[0][0] in [ - CIGAR_SOFTCLIP, CIGAR_HARDCLIP] else 0 - last_sc = 1 if cigartuples[-1][0] in [CIGAR_SOFTCLIP, - CIGAR_HARDCLIP] else 0 + first_sc = 1 if cigartuples[0][0] in [CIGAR_SOFTCLIP, CIGAR_HARDCLIP] else 0 + last_sc = 1 if cigartuples[-1][0] in [CIGAR_SOFTCLIP, CIGAR_HARDCLIP] else 0 for i, (C, L) in enumerate(cigartuples): if C in [CIGAR_SOFTCLIP, CIGAR_HARDCLIP, CIGAR_INS]: continue @@ -63,10 +68,8 @@ def extract_ins(record): pos = record.pos seq_pos = 0 cigartuples = record.cigartuples - first_sc = 1 if cigartuples[0][0] in [ - CIGAR_SOFTCLIP, CIGAR_HARDCLIP] else 0 - last_sc = 1 if cigartuples[-1][0] in [CIGAR_SOFTCLIP, - CIGAR_HARDCLIP] else 0 + first_sc = 1 if cigartuples[0][0] in [CIGAR_SOFTCLIP, CIGAR_HARDCLIP] else 0 + last_sc = 1 if cigartuples[-1][0] in [CIGAR_SOFTCLIP, CIGAR_HARDCLIP] else 0 for i, (C, L) in enumerate(cigartuples): if C == CIGAR_SOFTCLIP: seq_pos += L @@ -74,13 +77,26 @@ def extract_ins(record): elif C == CIGAR_HARDCLIP: continue if C == CIGAR_INS: - if not record.seq[seq_pos:seq_pos + L]: - logger.info([str(record).split("\t"), seq_pos, - L, len(record.seq), len(record.seq)]) + if not record.seq[seq_pos : seq_pos + L]: + logger.info( + [ + str(record).split("\t"), + seq_pos, + L, + len(record.seq), + len(record.seq), + ] + ) if i > first_sc and i < len(cigartuples) - 1 - last_sc: L_ = min(L, max_indel) - inss.append([record.reference_name, pos, pos + 1, - record.seq[seq_pos:seq_pos + L_]]) + inss.append( + [ + record.reference_name, + pos, + pos + 1, + record.seq[seq_pos : seq_pos + L_], + ] + ) seq_pos += L else: if C != CIGAR_DEL: @@ -115,7 +131,6 @@ def push_left_var(ref_fasta, chrom, pos, ref, alt): class Variant: - def __init__(self, chrom, pos, ref, alt, gt, score, cnt, vtype): self.chrom = chrom self.pos = int(pos) @@ -130,20 +145,38 @@ def __init__(self, chrom, pos, ref, alt, gt, score, cnt, vtype): def push_left(self, ref_fasta): _, self.pos, self.ref, self.alt = push_left_var( - ref_fasta, self.chrom, self.pos, self.ref, self.alt) + ref_fasta, self.chrom, self.pos, self.ref, self.alt + ) def var_str(self): - return "-".join(map(str, [self.chrom, self.pos, self.ref, self.alt, self.vtype])) + return "-".join( + map(str, [self.chrom, self.pos, self.ref, self.alt, self.vtype]) + ) def var_pos_vt_str(self): return "-".join(map(str, [self.chrom, self.pos, self.vtype])) def var_gt_str(self): - return "-".join(map(str, [self.chrom, self.pos, self.ref, self.alt, self.gt, self.vtype])) + return "-".join( + map(str, [self.chrom, self.pos, self.ref, self.alt, self.gt, self.vtype]) + ) def __str__(self): - return "-".join(map(str, [self.chrom, self.pos, self.ref, self.alt, self.gt, - self.score, self.cnt, self.vtype])) + return "-".join( + map( + str, + [ + self.chrom, + self.pos, + self.ref, + self.alt, + self.gt, + self.score, + self.cnt, + self.vtype, + ], + ) + ) def resolve_group(ref_fasta, variants, vars_count): @@ -153,8 +186,7 @@ def resolve_group(ref_fasta, variants, vars_count): for var_str in vars_count: pos, ref, alt, vtype = var_str.split("-")[-4:] pos = int(pos) - v = Variant(chrom, pos, ref, alt, "0/0", - 0, vars_count[var_str], vtype) + v = Variant(chrom, pos, ref, alt, "0/0", 0, vars_count[var_str], vtype) v.push_left(ref_fasta) s = v.var_str() if s not in vars_count_: @@ -181,8 +213,7 @@ def resolve_group(ref_fasta, variants, vars_count): pos = int(pos) if pos not in group_vars: group_vars[pos] = [] - v = Variant(chrom, pos, ref, alt, "0/0", - 0, vars_count[var_str], vtype) + v = Variant(chrom, pos, ref, alt, "0/0", 0, vars_count[var_str], vtype) group_vars[pos].append(v) for pos in group_vars: var_ = {} @@ -193,12 +224,17 @@ def resolve_group(ref_fasta, variants, vars_count): var_[var_id].append(v) group_vars[pos] = [] for var_id in var_: - group_vars[pos].append(sorted(var_[var_id], key=lambda x: x.score, reverse=True - )[0]) + group_vars[pos].append( + sorted(var_[var_id], key=lambda x: x.score, reverse=True)[0] + ) out_variants_ = [] - max_target = [v.cnt for pos in group_vars for v in group_vars[ - pos] if v.score > 0 or v.len >= 3] + max_target = [ + v.cnt + for pos in group_vars + for v in group_vars[pos] + if v.score > 0 or v.len >= 3 + ] if len(max_target) == 0: # logger.info( # "No non-zero COUNT with non-zero SCORE: {}".format(list(str(x) for x in group_vars[pos]))) @@ -209,13 +245,11 @@ def resolve_group(ref_fasta, variants, vars_count): max_count = max(max_target) for pos in group_vars.keys(): - if max(map(lambda x: x.cnt, group_vars[pos]) - ) < 0.2 * max_count: + if max(map(lambda x: x.cnt, group_vars[pos])) < 0.2 * max_count: continue mx = max(map(lambda x: x.cnt, group_vars[pos])) gts = [x.gt for x in group_vars[pos]] - gts = set([x.gt for x in group_vars[pos] - if x.gt != "0/0" or x.len >= 3]) + gts = set([x.gt for x in group_vars[pos] if x.gt != "0/0" or x.len >= 3]) if len(gts) == 0: continue if len(gts) > 1: @@ -230,41 +264,54 @@ def resolve_group(ref_fasta, variants, vars_count): if nz == 0: continue priority = {"0/1": 2, "0/0": 1} - sorted_gts = sorted(gts_count.keys(), key=lambda x: [ - gts_count[x], gts_score[x], - priority[x]], reverse=True) + sorted_gts = sorted( + gts_count.keys(), + key=lambda x: [gts_count[x], gts_score[x], priority[x]], + reverse=True, + ) gt = sorted_gts[0] else: gt = list(gts)[0] - all_vars = sorted(group_vars[pos], key=lambda x: [ - x.cnt, x.score, x.gt != "0/0"], reverse=True) - vtypes = set([x.vtype for x in group_vars[pos] - if (x.gt != "0/0" or x.len >= 3) and x.cnt >= 0.4 * mx]) + all_vars = sorted( + group_vars[pos], key=lambda x: [x.cnt, x.score, x.gt != "0/0"], reverse=True + ) + vtypes = set( + [ + x.vtype + for x in group_vars[pos] + if (x.gt != "0/0" or x.len >= 3) and x.cnt >= 0.4 * mx + ] + ) if not vtypes: - vtypes = set([x.vtype for x in group_vars[pos] - if (x.gt != "0/0" or x.len >= 3)]) - all_vars = list( - filter(lambda x: x.vtype in vtypes, all_vars)) + vtypes = set( + [x.vtype for x in group_vars[pos] if (x.gt != "0/0" or x.len >= 3)] + ) + all_vars = list(filter(lambda x: x.vtype in vtypes, all_vars)) if not all_vars: + logger.info("No vars: {}".format(list(str(x) for x in group_vars[pos]))) logger.info( - "No vars: {}".format(list(str(x) for x in group_vars[pos]))) - logger.info( - "No vars: {}".format([[list(str(x) for x in group_vars[pos_])]for pos_ in group_vars])) + "No vars: {}".format( + [[list(str(x) for x in group_vars[pos_])] for pos_ in group_vars] + ) + ) raise Exception score = max([v.score for v in all_vars]) if gt == "0/0": - nz_vars = [x for x in all_vars if x.gt != - "0/0" and x.vtype == all_vars[0].vtype] + nz_vars = [ + x for x in all_vars if x.gt != "0/0" and x.vtype == all_vars[0].vtype + ] if nz_vars: - nz_vars = sorted(nz_vars, key=lambda x: [ - x.score], reverse=True)[0] + nz_vars = sorted(nz_vars, key=lambda x: [x.score], reverse=True)[0] gt = nz_vars.gt v = all_vars[0] - out_variants_.append( - [v.chrom, v.pos, v.ref, v.alt, gt, score, v.cnt]) + out_variants_.append([v.chrom, v.pos, v.ref, v.alt, gt, score, v.cnt]) - if len(out_variants_) == 1 and out_variants_[0][4] == "0/0" and abs(len(out_variants_[0][2]) - len(out_variants_[0][3])) >= 3: + if ( + len(out_variants_) == 1 + and out_variants_[0][4] == "0/0" + and abs(len(out_variants_[0][2]) - len(out_variants_[0][3])) >= 3 + ): chrom_, pos_, ref_, alt_, gt_, score_, cnt_ = out_variants_[0] vtype = find_vtype(ref_, alt_) resolve_candids = [] @@ -273,10 +320,20 @@ def resolve_group(ref_fasta, variants, vars_count): if y.vtype == vtype and y.gt != "0/0": resolve_candids.append(y) if resolve_candids: - resolve_candids = sorted(resolve_candids, key=lambda x: [ - x.score], reverse=True)[0] - out_variants_ = [[chrom_, pos_, ref_, alt_, - resolve_candids.gt, resolve_candids.score, cnt_]] + resolve_candids = sorted( + resolve_candids, key=lambda x: [x.score], reverse=True + )[0] + out_variants_ = [ + [ + chrom_, + pos_, + ref_, + alt_, + resolve_candids.gt, + resolve_candids.score, + cnt_, + ] + ] if len(out_variants_) > 1 and "0/0" in [x[4] for x in out_variants_]: nz_vars = [x for x in out_variants_ if x[4] != "0/0"] @@ -296,10 +353,11 @@ def resolve_group(ref_fasta, variants, vars_count): for chrom_, pos_, ref_, alt_, gt_, score_, cnt_ in out_variants_: if gt_ not in vars_gt: vars_gt[gt_] = [] - vars_gt[gt_].append( - Variant(chrom_, pos_, ref_, alt_, gt_, score_, cnt_, "")) - vars_gt = {gt_: sorted(vars_gt[gt_], key=lambda x: [ - x.cnt, x.score], reverse=True) for gt_ in vars_gt} + vars_gt[gt_].append(Variant(chrom_, pos_, ref_, alt_, gt_, score_, cnt_, "")) + vars_gt = { + gt_: sorted(vars_gt[gt_], key=lambda x: [x.cnt, x.score], reverse=True) + for gt_ in vars_gt + } out_variants_ = [] for gt_ in vars_gt: v0 = vars_gt[gt_][0] @@ -307,14 +365,15 @@ def resolve_group(ref_fasta, variants, vars_count): for v in vars_gt[gt_][1:]: keep = True for g_v in good_vs: - if min(v.pos + len(v.ref), g_v.pos + len(g_v.ref)) > max(v.pos, g_v.pos): + if min(v.pos + len(v.ref), g_v.pos + len(g_v.ref)) > max( + v.pos, g_v.pos + ): keep = False break if keep: good_vs.append(v) for v in good_vs: - out_variants_.append( - [v.chrom, v.pos, v.ref, v.alt, v.gt, v.score]) + out_variants_.append([v.chrom, v.pos, v.ref, v.alt, v.gt, v.score]) out_variants_ = [x for x in out_variants_ if x[4] != "0/0"] return out_variants_ @@ -322,7 +381,10 @@ def resolve_group(ref_fasta, variants, vars_count): def find_resolved_variants(input_record): chrom, start, end, variants, input_bam, filter_duplicate, reference = input_record thread_logger = logging.getLogger( - "{} ({})".format(find_resolved_variants.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + find_resolved_variants.__name__, multiprocessing.current_process().name + ) + ) try: ref_fasta = pysam.FastaFile(reference) variants_ = [] @@ -362,28 +424,31 @@ def find_resolved_variants(input_record): if record.cigarstring and "I" in record.cigarstring: inss_.extend(extract_ins(record)) aligned_pairs = np.array( - record.get_aligned_pairs(matches_only=True)) + record.get_aligned_pairs(matches_only=True) + ) if len(aligned_pairs) == 0: continue - near_pos = np.where((start <= aligned_pairs[:, 1]) & ( - aligned_pairs[:, 1] <= end))[0] + near_pos = np.where( + (start <= aligned_pairs[:, 1]) & (aligned_pairs[:, 1] <= end) + )[0] if len(near_pos) != 0: for pos_i in near_pos: seq_pos, ref_pos = aligned_pairs[pos_i, :] if seq_pos is not None: ref_snp = ref_fasta.fetch( - chrom, ref_pos, ref_pos + 1).upper() + chrom, ref_pos, ref_pos + 1 + ).upper() alt_snp = record.seq[seq_pos] if alt_snp != ref_snp: - snps_.append( - [chrom, ref_pos + 1, ref_snp, alt_snp]) + snps_.append([chrom, ref_pos + 1, ref_snp, alt_snp]) dels.extend([x + [1.0 / (cov)] for x in dels_]) inss.extend([x + [1.0 / (cov)] for x in inss_]) snps.extend([x + [1.0 / (cov)] for x in snps_]) - dels = list(filter(lambda x: ( - start <= x[1] <= end) or start <= x[2] <= end, dels)) + dels = list( + filter(lambda x: (start <= x[1] <= end) or start <= x[2] <= end, dels) + ) if dels: del_strs = [] cnt_ = {} @@ -398,11 +463,24 @@ def find_resolved_variants(input_record): uniq_dels = list(set(del_strs)) for del_ in uniq_dels: st, en = map(int, del_.split("---")[1:3]) - del_str = "-".join(list(map(str, [chrom, int(st), ref_fasta.fetch(chrom, st - 1, en).upper(), - ref_fasta.fetch(chrom, st - 1, st).upper(), "DEL"]))) + del_str = "-".join( + list( + map( + str, + [ + chrom, + int(st), + ref_fasta.fetch(chrom, st - 1, en).upper(), + ref_fasta.fetch(chrom, st - 1, st).upper(), + "DEL", + ], + ) + ) + ) vars_count[del_str] = np.round(cnt_[del_], 4) - inss = list(filter(lambda x: ( - start <= x[1] <= end) or start <= x[2] <= end, inss)) + inss = list( + filter(lambda x: (start <= x[1] <= end) or start <= x[2] <= end, inss) + ) if inss: cnt_ = {} ins_strs = [] @@ -417,8 +495,20 @@ def find_resolved_variants(input_record): for ins_ in uniq_inss: st, en, bases = ins_.split("---")[1:4] st, en = map(int, [st, en]) - ins_str = "-".join(list(map(str, [chrom, int(st), ref_fasta.fetch(chrom, st - 1, st).upper(), - ref_fasta.fetch(chrom, st - 1, st).upper() + bases, "INS"]))) + ins_str = "-".join( + list( + map( + str, + [ + chrom, + int(st), + ref_fasta.fetch(chrom, st - 1, st).upper(), + ref_fasta.fetch(chrom, st - 1, st).upper() + bases, + "INS", + ], + ) + ) + ) vars_count[ins_str] = np.round(cnt_[ins_], 4) if snps: @@ -434,8 +524,7 @@ def find_resolved_variants(input_record): uniq_snps = list(set(snp_strs)) for snp_ in uniq_snps: st, ref_, alt_ = snp_.split("---")[1:4] - snp_str = "-".join(list(map(str, [chrom, st, ref_, - alt_, "SNP"]))) + snp_str = "-".join(list(map(str, [chrom, st, ref_, alt_, "SNP"]))) vars_count[snp_str] = np.round(cnt_[snp_], 4) out_variants_ = resolve_group(ref_fasta, variants, vars_count) @@ -448,8 +537,15 @@ def find_resolved_variants(input_record): raise Exception -def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, - target_bed_file, filter_duplicate, num_threads): +def resolve_variants( + input_bam, + resolved_vcf, + reference, + target_vcf_file, + target_bed_file, + filter_duplicate, + num_threads, +): logger = logging.getLogger(resolve_variants.__name__) logger.info("-------Resolve variants (e.g. exact INDEL sequences)-------") @@ -469,8 +565,17 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, tb = line.strip().split("\t") chrom, start, end, id_ = tb[0:4] id_ = int(id_) - map_args.append([chrom, start, end, variants[id_], - input_bam, filter_duplicate, reference]) + map_args.append( + [ + chrom, + start, + end, + variants[id_], + input_bam, + filter_duplicate, + reference, + ] + ) if num_threads > 1: try: @@ -481,8 +586,11 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, pool = multiprocessing.Pool(num_threads) batch_i_s = i batch_i_e = min(i + n_per_bacth, len(map_args)) - out_variants_list.extend(pool.map_async( - find_resolved_variants, map_args[batch_i_s:batch_i_e]).get()) + out_variants_list.extend( + pool.map_async( + find_resolved_variants, map_args[batch_i_s:batch_i_e] + ).get() + ) i = batch_i_e pool.close() except Exception as inst: @@ -496,8 +604,7 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, out_variants = [x for xs in out_variants_list for x in xs] chroms_order = get_chromosomes_order(bam=input_bam) - out_variants = sorted(out_variants, key=lambda x: [ - chroms_order[x[0]], int(x[1])]) + out_variants = sorted(out_variants, key=lambda x: [chroms_order[x[0]], int(x[1])]) with open(resolved_vcf, "w") as o_f: o_f.write("{}\n".format(VCF_HEADER)) o_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n") @@ -509,46 +616,66 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, continue done_id.add(id_) phred_score = float(phred_score) - prob = np.round(1 - (10**(-phred_score / 10)), 4) - o_f.write("\t".join([chrom, str(pos), ".", ref, - alt, "{:.4f}".format( - np.round(phred_score, 4)), - ".", "SCORE={:.4f}".format(prob), "GT", gt]) + "\n") - - -if __name__ == '__main__': - - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + prob = np.round(1 - (10 ** (-phred_score / 10)), 4) + o_f.write( + "\t".join( + [ + chrom, + str(pos), + ".", + ref, + alt, + "{:.4f}".format(np.round(phred_score, 4)), + ".", + "SCORE={:.4f}".format(prob), + "GT", + gt, + ] + ) + + "\n" + ) + + +if __name__ == "__main__": + + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser( - description='Resolve ambigues variants for high quality reads') - parser.add_argument('--input_bam', type=str, - help='input bam', required=True) - parser.add_argument('--resolved_vcf', type=str, - help='resolved_vcf', required=True) - parser.add_argument('--target_vcf', type=str, - help='resolve target vcf', required=True) - parser.add_argument('--target_bed', type=str, - help='resolve target bed', required=True) - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--filter_duplicate', - help='filter duplicate reads in analysis', - action="store_true") - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) + description="Resolve ambigues variants for high quality reads" + ) + parser.add_argument("--input_bam", type=str, help="input bam", required=True) + parser.add_argument("--resolved_vcf", type=str, help="resolved_vcf", required=True) + parser.add_argument( + "--target_vcf", type=str, help="resolve target vcf", required=True + ) + parser.add_argument( + "--target_bed", type=str, help="resolve target bed", required=True + ) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument( + "--filter_duplicate", + help="filter duplicate reads in analysis", + action="store_true", + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) args = parser.parse_args() try: - resolve_variants(args.input_bam, args.resolved_vcf, - args.reference, args.target_vcf, - args.target_bed, args.filter_duplicate, - args.num_threads) + resolve_variants( + args.input_bam, + args.resolved_vcf, + args.reference, + args.target_vcf, + args.target_bed, + args.filter_duplicate, + args.num_threads, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "resolve_variants.py failure on arguments: {}".format(args)) + logger.error("resolve_variants.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/scan_alignments.py b/neusomatic/python/scan_alignments.py index 3e7b56a..bfe0cd7 100755 --- a/neusomatic/python/scan_alignments.py +++ b/neusomatic/python/scan_alignments.py @@ -1,12 +1,12 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # scan_alignments.py # Scan the alignment .bam file, extract A/C/G/T/- counts on augmented alignment, # as well as different alignment feature matrices such as base quality, mapping # quality, strandness, clipping, alignment score, ... # It also outputs a raw .vcf files of potential candidates. This .vcf file should be processed # by 'filter_candidates.py' before it can be used. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import multiprocessing @@ -19,17 +19,39 @@ import pysam import numpy as np -from utils import concatenate_files, run_shell_command, bedtools_sort, bedtools_merge, get_tmp_file, skip_empty +from utils import ( + concatenate_files, + run_shell_command, + bedtools_sort, + bedtools_merge, + get_tmp_file, + skip_empty, +) from split_bed import split_region def run_scan_alignments(record): - work, reference, merge_d_for_scan, scan_alignments_binary, split_region_file, \ - input_bam, window_size, \ - snp_min_ao, \ - snp_min_af, ins_min_af, del_min_af, \ - min_mapq, snp_min_bq, max_dp, min_dp, \ - report_all_alleles, report_count_for_all_positions, filter_duplicate, calc_qual = record + ( + work, + reference, + merge_d_for_scan, + scan_alignments_binary, + split_region_file, + input_bam, + window_size, + snp_min_ao, + snp_min_af, + ins_min_af, + del_min_af, + min_mapq, + snp_min_bq, + max_dp, + min_dp, + report_all_alleles, + report_count_for_all_positions, + filter_duplicate, + calc_qual, + ) = record if filter_duplicate: filter_duplicate_str = "--filter_duplicate" @@ -44,7 +66,10 @@ def run_scan_alignments(record): else: report_count_for_all_positions_str = "" thread_logger = logging.getLogger( - "{} ({})".format(run_scan_alignments.__name__, multiprocessing.current_process().name)) + "{} ({})".format( + run_scan_alignments.__name__, multiprocessing.current_process().name + ) + ) try: if not os.path.exists(scan_alignments_binary): @@ -56,7 +81,11 @@ def run_scan_alignments(record): split_region_file_ = os.path.join(work, "merged_region.bed") tmp_ = bedtools_sort(split_region_file, run_logger=thread_logger) bedtools_merge( - tmp_, output_fn=split_region_file_, args=" -d {}".format(merge_d_for_scan), run_logger=thread_logger) + tmp_, + output_fn=split_region_file_, + args=" -d {}".format(merge_d_for_scan), + run_logger=thread_logger, + ) else: split_region_file_ = split_region_file @@ -67,17 +96,33 @@ def run_scan_alignments(record): --snp_min_af {} --ins_min_af {} --del_min_af {} \ --min_mapq {} --snp_min_bq {} --max_depth {} --min_depth {} \ {} {} {}".format( - scan_alignments_binary, reference, input_bam, split_region_file_, - work, work, window_size, + scan_alignments_binary, + reference, + input_bam, + split_region_file_, + work, + work, + window_size, snp_min_ao, - snp_min_af, ins_min_af, del_min_af, - min_mapq, snp_min_bq, max_dp * window_size / 100.0, min_dp, - report_all_alleles_str, report_count_for_all_positions_str, filter_duplicate_str) + snp_min_af, + ins_min_af, + del_min_af, + min_mapq, + snp_min_bq, + max_dp * window_size / 100.0, + min_dp, + report_all_alleles_str, + report_count_for_all_positions_str, + filter_duplicate_str, + ) if calc_qual: cmd += " --calculate_qual_stat" - run_shell_command(cmd, stdout=os.path.join(work, "scan.out"), - stderr=os.path.join(work, "scan.err"), - run_logger=thread_logger) + run_shell_command( + cmd, + stdout=os.path.join(work, "scan.out"), + stderr=os.path.join(work, "scan.err"), + run_logger=thread_logger, + ) else: with open(os.path.join(work, "candidates.vcf"), "w") as o_f: pass @@ -85,37 +130,69 @@ def run_scan_alignments(record): pass pysam.tabix_index(os.path.join(work, "count.bed"), preset="bed") - concatenate_files([split_region_file], - os.path.join(work, "region.bed")) - return os.path.join(work, "candidates.vcf"), os.path.join(work, "count.bed.gz"), os.path.join(work, "region.bed") + concatenate_files([split_region_file], os.path.join(work, "region.bed")) + return ( + os.path.join(work, "candidates.vcf"), + os.path.join(work, "count.bed.gz"), + os.path.join(work, "region.bed"), + ) except Exception as ex: thread_logger.error(traceback.format_exc()) thread_logger.error(ex) stderr_file = os.path.join(work, "scan.err") if os.path.exists(stderr_file) and os.path.getsize(stderr_file): - thread_logger.error( - "Please check error log at {}".format(stderr_file)) + thread_logger.error("Please check error log at {}".format(stderr_file)) return None - outputs = scan_alignments(args.work, args.merge_d_for_scan, args.scan_alignments_binary, args.input_bam, - args.regions_bed_file, args.reference, args.num_splits, - args.num_threads, args.window_size, - args.snp_min_ao, - args.snp_min_af, args.ins_min_af, args.del_min_af, - args.min_mapq, args.snp_min_bq, args.max_dp, args.min_dp, - args.report_all_alleles, args.report_count_for_all_positions, - args.filter_duplicate) + outputs = scan_alignments( + args.work, + args.merge_d_for_scan, + args.scan_alignments_binary, + args.input_bam, + args.regions_bed_file, + args.reference, + args.num_splits, + args.num_threads, + args.window_size, + args.snp_min_ao, + args.snp_min_af, + args.ins_min_af, + args.del_min_af, + args.min_mapq, + args.snp_min_bq, + args.max_dp, + args.min_dp, + args.report_all_alleles, + args.report_count_for_all_positions, + args.filter_duplicate, + ) -def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam, - regions_bed_file, reference, num_splits, - num_threads, window_size, - snp_min_ao, - snp_min_af, ins_min_af, del_min_af, - min_mapq, snp_min_bq, max_dp, min_dp, - report_all_alleles, - report_count_for_all_positions, filter_duplicate, restart=True, - split_region_files=[], calc_qual=True): +def scan_alignments( + work, + merge_d_for_scan, + scan_alignments_binary, + input_bam, + regions_bed_file, + reference, + num_splits, + num_threads, + window_size, + snp_min_ao, + snp_min_af, + ins_min_af, + del_min_af, + min_mapq, + snp_min_bq, + max_dp, + min_dp, + report_all_alleles, + report_count_for_all_positions, + filter_duplicate, + restart=True, + split_region_files=[], + calc_qual=True, +): logger = logging.getLogger(scan_alignments.__name__) @@ -130,15 +207,15 @@ def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam, chrom, st, en = line.strip().split()[0:3] o_f.write("\t".join([chrom, st, en]) + "\n") regions_bed = bedtools_sort(regions_bed, run_logger=logger) - regions_bed = bedtools_merge( - regions_bed, args=" -d 0", run_logger=logger) + regions_bed = bedtools_merge(regions_bed, args=" -d 0", run_logger=logger) else: regions_bed = get_tmp_file() with pysam.AlignmentFile(input_bam, "rb") as samfile: with open(regions_bed, "w") as tmpfile: for chrom, length in zip(samfile.references, samfile.lengths): - tmpfile.write("{}\t{}\t{}\t.\t.\t.\n".format( - chrom, 1, length - 1)) + tmpfile.write( + "{}\t{}\t{}\t.\t.\t.\n".format(chrom, 1, length - 1) + ) if not os.path.exists(work): os.mkdir(work) total_len = 0 @@ -155,10 +232,12 @@ def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam, chrom, st, en = line.strip().split("\t")[0:3] spilt_total_len += int(en) - int(st) if spilt_total_len >= split_len_ratio * total_len: - split_region_files = sorted(split_region_files, - key=lambda x: int( - os.path.basename(x).split(".bed")[0].split( - "_")[1])) + split_region_files = sorted( + split_region_files, + key=lambda x: int( + os.path.basename(x).split(".bed")[0].split("_")[1] + ), + ) if not split_region_files: regions_bed_file = os.path.join(work, "all_regions.bed") shutil.move(regions_bed, regions_bed_file) @@ -166,37 +245,71 @@ def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam, if num_splits is not None: num_split = num_splits else: - num_split = max(int(np.ceil((total_len // 10000000) // - num_threads) * num_threads), num_threads) - split_region_files = split_region(work, regions_bed_file, num_split, - min_region=window_size, max_region=1e20) + num_split = max( + int(np.ceil((total_len // 10000000) // num_threads) * num_threads), + num_threads, + ) + split_region_files = split_region( + work, + regions_bed_file, + num_split, + min_region=window_size, + max_region=1e20, + ) else: - logger.info("split_regions to be used (will ignore region_bed): {}".format( - " ".join(split_region_files))) + logger.info( + "split_regions to be used (will ignore region_bed): {}".format( + " ".join(split_region_files) + ) + ) map_args = [] all_outputs = [[]] * len(split_region_files) not_done = [] for i, split_region_file in enumerate(split_region_files): - if restart or not os.path.exists(os.path.join(work, "work.{}".format(i), "region.bed")) \ - or not os.path.exists(os.path.join(work, "work.{}".format(i), "candidates.vcf")) \ - or not os.path.exists(os.path.join(work, "work.{}".format(i), "count.bed.gz")): + if ( + restart + or not os.path.exists(os.path.join(work, "work.{}".format(i), "region.bed")) + or not os.path.exists( + os.path.join(work, "work.{}".format(i), "candidates.vcf") + ) + or not os.path.exists( + os.path.join(work, "work.{}".format(i), "count.bed.gz") + ) + ): work_ = os.path.join(work, "work.{}".format(i)) if os.path.exists(work_): shutil.rmtree(work_) - map_args.append((os.path.join(work, "work.{}".format(i)), - reference, merge_d_for_scan, scan_alignments_binary, split_region_file, - input_bam, window_size, - snp_min_ao, - snp_min_af, ins_min_af, del_min_af, - min_mapq, snp_min_bq, max_dp, min_dp, - report_all_alleles, report_count_for_all_positions, filter_duplicate, calc_qual)) + map_args.append( + ( + os.path.join(work, "work.{}".format(i)), + reference, + merge_d_for_scan, + scan_alignments_binary, + split_region_file, + input_bam, + window_size, + snp_min_ao, + snp_min_af, + ins_min_af, + del_min_af, + min_mapq, + snp_min_bq, + max_dp, + min_dp, + report_all_alleles, + report_count_for_all_positions, + filter_duplicate, + calc_qual, + ) + ) not_done.append(i) else: - all_outputs[i] = [os.path.join(work, "work.{}".format(i), "candidates.vcf"), - os.path.join(work, "work.{}".format( - i), "count.bed.gz"), - os.path.join(work, "work.{}".format(i), "region.bed")] + all_outputs[i] = [ + os.path.join(work, "work.{}".format(i), "candidates.vcf"), + os.path.join(work, "work.{}".format(i), "count.bed.gz"), + os.path.join(work, "work.{}".format(i), "region.bed"), + ] pool = multiprocessing.Pool(num_threads) try: @@ -216,71 +329,106 @@ def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam, all_outputs[i] = output return all_outputs -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + +if __name__ == "__main__": + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser( - description='simple call variants from bam') - parser.add_argument('--input_bam', type=str, - help='input bam', required=True) - parser.add_argument('--reference', type=str, - help='reference fasta filename', required=True) - parser.add_argument('--work', type=str, - help='work directory', required=True) - parser.add_argument('--regions_bed_file', type=str, - help='regions bed file', default="") - parser.add_argument('--scan_alignments_binary', type=str, - help='binary for scanning alignment bam', default="../bin/scan_alignments") - parser.add_argument('--window_size', type=int, help='window size to scan the variants', - default=2000) - parser.add_argument('--snp_min_ao', type=float, - help='SNP min alternate count for low AF candidates', default=3) - parser.add_argument('--snp_min_af', type=float, - help='SNP min allele freq', default=0.05) - parser.add_argument('--ins_min_af', type=float, - help='INS min allele freq', default=0.01) - parser.add_argument('--del_min_af', type=float, - help='DEL min allele freq', default=0.01) - parser.add_argument('--min_mapq', type=int, - help='minimum mapping quality', default=1) - parser.add_argument('--snp_min_bq', type=float, - help='SNP min base quality', default=10) - parser.add_argument('--max_dp', type=float, - help='max depth', default=100000) - parser.add_argument('--min_dp', type=float, help='min depth', default=1) - parser.add_argument('--filter_duplicate', - help='filter duplicate reads when preparing pileup information', - action="store_true") - parser.add_argument('--merge_d_for_scan', type=int, - help='-d used to merge regions before scan', - default=None) - parser.add_argument('--report_all_alleles', - help='report all alleles per position', - action="store_true") - parser.add_argument('--report_count_for_all_positions', - help='report_count_for_all_positions', - action="store_true") - parser.add_argument('--num_splits', type=int, - help='number of region splits', default=None) - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) + parser = argparse.ArgumentParser(description="simple call variants from bam") + parser.add_argument("--input_bam", type=str, help="input bam", required=True) + parser.add_argument( + "--reference", type=str, help="reference fasta filename", required=True + ) + parser.add_argument("--work", type=str, help="work directory", required=True) + parser.add_argument( + "--regions_bed_file", type=str, help="regions bed file", default="" + ) + parser.add_argument( + "--scan_alignments_binary", + type=str, + help="binary for scanning alignment bam", + default="../bin/scan_alignments", + ) + parser.add_argument( + "--window_size", type=int, help="window size to scan the variants", default=2000 + ) + parser.add_argument( + "--snp_min_ao", + type=float, + help="SNP min alternate count for low AF candidates", + default=3, + ) + parser.add_argument( + "--snp_min_af", type=float, help="SNP min allele freq", default=0.05 + ) + parser.add_argument( + "--ins_min_af", type=float, help="INS min allele freq", default=0.01 + ) + parser.add_argument( + "--del_min_af", type=float, help="DEL min allele freq", default=0.01 + ) + parser.add_argument( + "--min_mapq", type=int, help="minimum mapping quality", default=1 + ) + parser.add_argument( + "--snp_min_bq", type=float, help="SNP min base quality", default=10 + ) + parser.add_argument("--max_dp", type=float, help="max depth", default=100000) + parser.add_argument("--min_dp", type=float, help="min depth", default=1) + parser.add_argument( + "--filter_duplicate", + help="filter duplicate reads when preparing pileup information", + action="store_true", + ) + parser.add_argument( + "--merge_d_for_scan", + type=int, + help="-d used to merge regions before scan", + default=None, + ) + parser.add_argument( + "--report_all_alleles", + help="report all alleles per position", + action="store_true", + ) + parser.add_argument( + "--report_count_for_all_positions", + help="report_count_for_all_positions", + action="store_true", + ) + parser.add_argument( + "--num_splits", type=int, help="number of region splits", default=None + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) args = parser.parse_args() logger.info(args) try: - outputs = scan_alignments(args.work, args.merge_d_for_scan, args.scan_alignments_binary, args.input_bam, - args.regions_bed_file, args.reference, args.num_splits, - args.num_threads, args.window_size, - args.snp_min_ao, - args.snp_min_af, args.ins_min_af, args.del_min_af, - args.min_mapq, args.snp_min_bq, args.max_dp, args.min_dp, - args.report_all_alleles, args.report_count_for_all_positions, - args.filter_duplicate) + outputs = scan_alignments( + args.work, + args.merge_d_for_scan, + args.scan_alignments_binary, + args.input_bam, + args.regions_bed_file, + args.reference, + args.num_splits, + args.num_threads, + args.window_size, + args.snp_min_ao, + args.snp_min_af, + args.ins_min_af, + args.del_min_af, + args.min_mapq, + args.snp_min_bq, + args.max_dp, + args.min_dp, + args.report_all_alleles, + args.report_count_for_all_positions, + args.filter_duplicate, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "scan_alignments.py failure on arguments: {}".format(args)) + logger.error("scan_alignments.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/sequencing_features.py b/neusomatic/python/sequencing_features.py index 364f76b..9014416 100644 --- a/neusomatic/python/sequencing_features.py +++ b/neusomatic/python/sequencing_features.py @@ -11,11 +11,11 @@ import fisher import logging -FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -nan = float('nan') +nan = float("nan") def fisher_exact_test(mat, alternative="two-sided"): @@ -45,23 +45,26 @@ def get_read_pos_for_ref_pos(read, ref_pos_s): return output cigar_aligned = [cigar_aln_match, cigar_seq_match, cigar_seq_mismatch] cigar_s = 1 if cigartuples[0][0] == cigar_soft_clip else 0 - cigar_e = (len(cigartuples) - - 1) if cigartuples[-1][0] == cigar_soft_clip else len(cigartuples) + cigar_e = ( + (len(cigartuples) - 1) + if cigartuples[-1][0] == cigar_soft_clip + else len(cigartuples) + ) count = pos_q = cigartuples[0][1] if cigar_s == 1 else 0 cigar_index = cigar_s - for op, length in cigartuples[cigar_s: cigar_e]: + for op, length in cigartuples[cigar_s:cigar_e]: is_aligned = op == 0 or op >= 7 delta_r = length if (is_aligned or op == cigar_deletion) else 0 delta_q = length if (is_aligned or op == cigar_insertion) else 0 while current_i < len(ref_pos_s): diff = ref_pos_s[current_i] - pos_r if diff < delta_r: - output[ref_pos_s[current_i]] = [count + diff, (pos_q + diff) if delta_q else None, - read.seq[ - (pos_q + diff)] if delta_q else None, - read.query_qualities[ - (pos_q + diff)] if delta_q else None, - ] + output[ref_pos_s[current_i]] = [ + count + diff, + (pos_q + diff) if delta_q else None, + read.seq[(pos_q + diff)] if delta_q else None, + read.query_qualities[(pos_q + diff)] if delta_q else None, + ] current_i += 1 else: break @@ -78,7 +81,6 @@ def get_read_pos_for_ref_pos(read, ref_pos_s): class AugmentedAlignedRead: - def __init__(self, read, vars_pos): self.qname = read.qname self.vars_pos = vars_pos @@ -86,22 +88,34 @@ def __init__(self, read, vars_pos): self.pos_of_aligned_read = {} aligned_pairs = read.get_aligned_pairs() for pos in vars_pos: - code_i, ith_base, base_call_i, indel_length_i, flanking_indel_i = position_of_aligned_read( - aligned_pairs, self.read_pos_for_ref_pos[pos], pos) + ( + code_i, + ith_base, + base_call_i, + indel_length_i, + flanking_indel_i, + ) = position_of_aligned_read( + aligned_pairs, self.read_pos_for_ref_pos[pos], pos + ) self.pos_of_aligned_read[pos] = [ - code_i, ith_base, base_call_i, indel_length_i, flanking_indel_i] + code_i, + ith_base, + base_call_i, + indel_length_i, + flanking_indel_i, + ] self.mapping_quality = read.mapping_quality self.mean_query_qualities = mean(read.query_qualities) self.is_proper_pair = read.is_proper_pair self.is_reverse = read.is_reverse - self.NM = read.get_tag('NM') + self.NM = read.get_tag("NM") self.query_length = read.query_length - self.is_soft_clipped = read.cigar[0][ - 0] == cigar_soft_clip or read.cigar[-1][0] == cigar_soft_clip + self.is_soft_clipped = ( + read.cigar[0][0] == cigar_soft_clip or read.cigar[-1][0] == cigar_soft_clip + ) class ClusterReads: - def __init__(self, bam, variants): self.variants = variants self.chrom = variants[0][0] @@ -129,21 +143,23 @@ def __init__(self, bam, variants): self.reads.append(AugmentedAlignedRead(read_i, vars_pos)) i += 1 - def get_alignment_features(self, var_index, ref_base, first_alt, min_mq=1, min_bq=10): - ''' + def get_alignment_features( + self, var_index, ref_base, first_alt, min_mq=1, min_bq=10 + ): + """ bam is the opened file handle of bam file my_coordinate is a list or tuple of 0-based (contig, position) - ''' + """ my_coordinate = self.variants[var_index][0:2] reads = [self.reads[i] for i in self.var_reads[var_index]] bamfeatures = AlignmentFeatures( - reads, my_coordinate, ref_base, first_alt, min_mq, min_bq) + reads, my_coordinate, ref_base, first_alt, min_mq, min_bq + ) return bamfeatures class AlignmentFeatures: - def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=10): indel_length = len(first_alt) - len(ref_base) @@ -166,10 +182,14 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 for read_i in reads: dp += 1 - read_pos_for_ref_pos = read_i.read_pos_for_ref_pos[ - my_coordinate[1] - 1] - code_i, ith_base, base_call_i, indel_length_i, flanking_indel_i = read_i.pos_of_aligned_read[ - my_coordinate[1] - 1] + read_pos_for_ref_pos = read_i.read_pos_for_ref_pos[my_coordinate[1] - 1] + ( + code_i, + ith_base, + base_call_i, + indel_length_i, + flanking_indel_i, + ) = read_i.pos_of_aligned_read[my_coordinate[1] - 1] read_i_qual_ith_base = read_pos_for_ref_pos[3] if read_i.mapping_quality < min_mq and read_i.mean_query_qualities < min_bq: @@ -179,9 +199,11 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 MQ0 += 1 is_ref_call = code_i == 1 and base_call_i == ref_base[0] - is_alt_call = (indel_length == 0 and code_i == 1 and base_call_i == first_alt) or ( - indel_length < 0 and code_i == 2 and indel_length == indel_length_i) or ( - indel_length > 0 and code_i == 3) + is_alt_call = ( + (indel_length == 0 and code_i == 1 and base_call_i == first_alt) + or (indel_length < 0 and code_i == 2 and indel_length == indel_length_i) + or (indel_length > 0 and code_i == 3) + ) # inconsistent read or second alternate calls if not (is_ref_call or is_alt_call): @@ -202,8 +224,7 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 pass if read_i.mapping_quality >= min_mq and read_i_qual_ith_base >= min_bq: - concordance_counts[ - 0 if read_i.is_proper_pair else 1][index] += 1 + concordance_counts[0 if read_i.is_proper_pair else 1][index] += 1 orientation_counts[1 if read_i.is_reverse else 0][index] += 1 soft_clip_counts[1 if read_i.is_soft_clipped else 0][index] += 1 @@ -211,16 +232,15 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 # Distance from the end of the read: if ith_base is not None: pos_from_end[index].append( - min(ith_base, read_i.query_length - ith_base)) + min(ith_base, read_i.query_length - ith_base) + ) flanking_indel[index].append(flanking_indel_i) # unpack to get the ref and alt values ref_pos_from_end, alt_pos_from_end = pos_from_end - self.ref_concordant_reads, self.alt_concordant_reads = concordance_counts[ - 0] - self.ref_discordant_reads, self.alt_discordant_reads = concordance_counts[ - 1] + self.ref_concordant_reads, self.alt_concordant_reads = concordance_counts[0] + self.ref_discordant_reads, self.alt_discordant_reads = concordance_counts[1] self.ref_for, self.alt_for = orientation_counts[0] self.ref_rev, self.alt_rev = orientation_counts[1] self.ref_notSC_reads, self.alt_notSC_reads = soft_clip_counts[0] @@ -240,16 +260,14 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 ref_edit_distance, alt_edit_distance = edit_distance self.ref_NM = mean(ref_edit_distance) self.alt_NM = mean(alt_edit_distance) - self.z_ranksums_NM = stats.ranksums( - alt_edit_distance, ref_edit_distance)[0] + self.z_ranksums_NM = stats.ranksums(alt_edit_distance, ref_edit_distance)[0] self.NM_Diff = self.alt_NM - self.ref_NM - abs(indel_length) self.concordance_fet = fisher_exact_test(concordance_counts) self.strandbias_fet = fisher_exact_test(orientation_counts) self.clipping_fet = fisher_exact_test(soft_clip_counts) - self.z_ranksums_endpos = stats.ranksums( - alt_pos_from_end, ref_pos_from_end)[0] + self.z_ranksums_endpos = stats.ranksums(alt_pos_from_end, ref_pos_from_end)[0] ref_flanking_indel, alt_flanking_indel = flanking_indel self.ref_indel_1bp = ref_flanking_indel.count(1) @@ -260,7 +278,10 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 self.alt_indel_3bp = alt_flanking_indel.count(3) + self.alt_indel_2bp self.consistent_mates = self.inconsistent_mates = 0 - for one_count in map(lambda x: x.count(1), filter(lambda y: len(y) == 2, qname_collector.values())): + for one_count in map( + lambda x: x.count(1), + filter(lambda y: len(y) == 2, qname_collector.values()), + ): # Both are alternative calls: if one_count == 2: self.consistent_mates += 1 @@ -278,18 +299,22 @@ def __init__(self, reads, my_coordinate, ref_base, first_alt, min_mq=1, min_bq=1 def from_genome_reference(ref_fa, my_coordinate, ref_base, first_alt): - ''' + """ ref_fa is the opened reference fasta file handle my_coordinate is a list or tuple of 0-based (contig, position) - ''' + """ # Homopolymer eval (Make sure to modify for INDEL): # The min and max is to prevent the +/- 20 bases from exceeding the ends # of the reference sequence - lseq = ref_fa.fetch(my_coordinate[0], max( - 0, my_coordinate[1] - 20), my_coordinate[1]) - rseq = ref_fa.fetch(my_coordinate[0], my_coordinate[ - 1] + 1, min(ref_fa.get_reference_length(my_coordinate[0]) + 1, my_coordinate[1] + 21)) + lseq = ref_fa.fetch( + my_coordinate[0], max(0, my_coordinate[1] - 20), my_coordinate[1] + ) + rseq = ref_fa.fetch( + my_coordinate[0], + my_coordinate[1] + 1, + min(ref_fa.get_reference_length(my_coordinate[0]) + 1, my_coordinate[1] + 21), + ) # This is to get around buy in old version of pysam that reads the # reference sequence in bytes instead of strings @@ -361,12 +386,15 @@ def max_sub_vocabularies(seq_length, max_subseq_length): k = 1 while k <= max_subseq_length: - if 4**k < (seq_length - k + 1): - counts = counts + 4**k + if 4 ** k < (seq_length - k + 1): + counts = counts + 4 ** k else: - counts = counts + \ - (2 * seq_length - k - max_subseq_length + 2) * \ - (max_subseq_length - k + 1) / 2 + counts = ( + counts + + (2 * seq_length - k - max_subseq_length + 2) + * (max_subseq_length - k + 1) + / 2 + ) break k += 1 @@ -380,22 +408,24 @@ def subLC(sequence, max_substring_length=20): # Cut off substring at a fixed length sequence = sequence.upper() - if not 'N' in sequence: + if not "N" in sequence: number_of_subseqs = 0 seq_length = len(sequence) max_number_of_subseqs = max_sub_vocabularies( - seq_length, min(seq_length, max_substring_length)) + seq_length, min(seq_length, max_substring_length) + ) set_of_seq_n = set() for i in range(1, min(max_substring_length + 1, seq_length + 1)): - set_of_seq_n.update((sequence[n: n + i] - for n in range(len(sequence) - i + 1))) + set_of_seq_n.update( + (sequence[n : n + i] for n in range(len(sequence) - i + 1)) + ) number_of_subseqs = len(set_of_seq_n) lc = number_of_subseqs / max_number_of_subseqs else: - lc = float('nan') + lc = float("nan") return lc diff --git a/neusomatic/python/split_bed.py b/neusomatic/python/split_bed.py index e1be65f..61b5ef3 100755 --- a/neusomatic/python/split_bed.py +++ b/neusomatic/python/split_bed.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # split_bed.py # split bed file to multiple sub-regions -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import argparse import traceback @@ -12,15 +12,21 @@ from utils import write_tsv_file, bedtools_sort, bedtools_merge, skip_empty -def split_region(work, region_bed_file, num_splits, max_region=1000000, min_region=20, shuffle_intervals=False): +def split_region( + work, + region_bed_file, + num_splits, + max_region=1000000, + min_region=20, + shuffle_intervals=False, +): logger = logging.getLogger(split_region.__name__) logger.info("------------------------Split region-----------------------") regions_bed = bedtools_sort(region_bed_file, run_logger=logger) - regions_bed = bedtools_merge( - regions_bed, args=" -d 0", run_logger=logger) + regions_bed = bedtools_merge(regions_bed, args=" -d 0", run_logger=logger) intervals = [] with open(regions_bed) as r_f: @@ -29,8 +35,7 @@ def split_region(work, region_bed_file, num_splits, max_region=1000000, min_regi start, end = int(start), int(end) if end - start + 1 > max_region: for i in range(start, end + 1, max_region): - intervals.append( - [chrom, i, min(end, i + max_region - 1)]) + intervals.append([chrom, i, min(end, i + max_region - 1)]) else: intervals.append([chrom, start, end]) if shuffle_intervals: @@ -48,7 +53,7 @@ def split_region(work, region_bed_file, num_splits, max_region=1000000, min_regi start, end = int(start), int(end) s = start e = -1 - while(current_len < split_len): + while current_len < split_len: s = max(s, e + 1) e = min(s + split_len - current_len - 1, end) if (e - s + 1) < 2 * min_region: @@ -57,7 +62,7 @@ def split_region(work, region_bed_file, num_splits, max_region=1000000, min_regi e = end current_regions.append([chrom, s, e]) current_len += e - s + 1 - if (current_len >= split_len): + if current_len >= split_len: sofar_len += current_len split_lens.append(current_len) current_len = 0 @@ -74,8 +79,7 @@ def split_region(work, region_bed_file, num_splits, max_region=1000000, min_regi for i, split_region_ in enumerate(split_regions): split_region_file = os.path.join(work, "region_{}.bed".format(i)) logger.info(split_region_file) - write_tsv_file(split_region_file, split_region_, - add_fields=[".", ".", "."]) + write_tsv_file(split_region_file, split_region_, add_fields=[".", ".", "."]) logger.info("Split {}: {}".format(i, split_lens[i])) split_region_files.append(split_region_file) @@ -83,24 +87,28 @@ def split_region(work, region_bed_file, num_splits, max_region=1000000, min_regi logger.info("Total splitted length: {}".format(sum_len)) return split_region_files -if __name__ == '__main__': - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +if __name__ == "__main__": + + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser( - description='Split bedfile to multiple beds') - parser.add_argument('--region_bed', type=str, - help='region bed', required=True) - parser.add_argument('--num_splits', type=int, - help='number of splits', required=True) - parser.add_argument('--output', type=str, - help='output directory', required=True) - parser.add_argument('--max_region', type=int, - help='max region size in the bed (for shuffling purpose)', default=1000000) - parser.add_argument('--min_region', type=int, - help='min region size for spliting', default=20) + parser = argparse.ArgumentParser(description="Split bedfile to multiple beds") + parser.add_argument("--region_bed", type=str, help="region bed", required=True) + parser.add_argument( + "--num_splits", type=int, help="number of splits", required=True + ) + parser.add_argument("--output", type=str, help="output directory", required=True) + parser.add_argument( + "--max_region", + type=int, + help="max region size in the bed (for shuffling purpose)", + default=1000000, + ) + parser.add_argument( + "--min_region", type=int, help="min region size for spliting", default=20 + ) args = parser.parse_args() if not os.path.exists(args.output): @@ -111,10 +119,10 @@ def split_region(work, region_bed_file, num_splits, max_region=1000000, min_regi try: split_region_files = split_region( - work, args.region_bed, args.num_splits, args.max_region, args.min_region) + work, args.region_bed, args.num_splits, args.max_region, args.min_region + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "split_bed.py failure on arguments: {}".format(args)) + logger.error("split_bed.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/train.py b/neusomatic/python/train.py index d90184f..2b88025 100755 --- a/neusomatic/python/train.py +++ b/neusomatic/python/train.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # train.py # Train NeuSomatic network -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import traceback @@ -24,23 +24,34 @@ from network import NeuSomaticNet from dataloader import NeuSomaticDataset, matrix_transform from merge_tsvs import merge_tsvs -from defaults import TYPE_CLASS_DICT, VARTYPE_CLASSES, NUM_ENS_FEATURES, NUM_ST_FEATURES, MAT_DTYPES +from defaults import ( + TYPE_CLASS_DICT, + VARTYPE_CLASSES, + NUM_ENS_FEATURES, + NUM_ST_FEATURES, + MAT_DTYPES, +) import torch._utils + try: torch._utils._rebuild_tensor_v2 except AttributeError: - def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): - tensor = torch._utils._rebuild_tensor( - storage, storage_offset, size, stride) + + def _rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks + ): + tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad tensor._backward_hooks = backward_hooks return tensor + torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 -def make_weights_for_balanced_classes(count_class_t, count_class_l, nclasses_t, nclasses_l, - none_count=None): +def make_weights_for_balanced_classes( + count_class_t, count_class_l, nclasses_t, nclasses_l, none_count=None +): logger = logging.getLogger(make_weights_for_balanced_classes.__name__) w_t = [0] * nclasses_t @@ -52,23 +63,23 @@ def make_weights_for_balanced_classes(count_class_t, count_class_l, nclasses_t, count_class_t[TYPE_CLASS_DICT["NONE"]] = none_count count_class_l[0] = none_count - logger.info("count type classes: {}".format( - list(zip(VARTYPE_CLASSES, count_class_t)))) + logger.info( + "count type classes: {}".format(list(zip(VARTYPE_CLASSES, count_class_t))) + ) N = float(sum(count_class_t)) for i in range(nclasses_t): w_t[i] = (1 - (float(count_class_t[i]) / float(N))) / float(nclasses_t) w_t = np.array(w_t) - logger.info("weight type classes: {}".format( - list(zip(VARTYPE_CLASSES, w_t)))) + logger.info("weight type classes: {}".format(list(zip(VARTYPE_CLASSES, w_t)))) - logger.info("count length classes: {}".format(list( - zip(range(nclasses_l), count_class_l)))) + logger.info( + "count length classes: {}".format(list(zip(range(nclasses_l), count_class_l))) + ) N = float(sum(count_class_l)) for i in range(nclasses_l): w_l[i] = (1 - (float(count_class_l[i]) / float(N))) / float(nclasses_l) w_l = np.array(w_l) - logger.info("weight length classes: {}".format(list( - zip(range(nclasses_l), w_l)))) + logger.info("weight length classes: {}".format(list(zip(range(nclasses_l), w_l)))) return w_t, w_l @@ -76,13 +87,13 @@ def test(net, epoch, validation_loader, use_cuda): logger = logging.getLogger(test.__name__) net.eval() nclasses = len(VARTYPE_CLASSES) - class_correct = list(0. for i in range(nclasses)) - class_total = list(0. for i in range(nclasses)) - class_p_total = list(0. for i in range(nclasses)) + class_correct = list(0.0 for i in range(nclasses)) + class_total = list(0.0 for i in range(nclasses)) + class_p_total = list(0.0 for i in range(nclasses)) - len_class_correct = list(0. for i in range(4)) - len_class_total = list(0. for i in range(4)) - len_class_p_total = list(0. for i in range(4)) + len_class_correct = list(0.0 for i in range(4)) + len_class_total = list(0.0 for i in range(4)) + len_class_p_total = list(0.0 for i in range(4)) falses = [] for data in validation_loader: @@ -109,15 +120,24 @@ def test(net, epoch, validation_loader, use_cuda): if labels.size()[0] > 1: compare_labels = (predicted == labels).squeeze() else: - compare_labels = (predicted == labels) + compare_labels = predicted == labels false_preds = np.where(compare_labels.numpy() == 0)[0] if len(false_preds) > 0: for i in false_preds: - falses.append([paths[0][i], VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i], - list( - np.round(F.softmax(outputs1[i, :], 0).data.cpu().numpy(), 4)), - list( - np.round(F.softmax(outputs3[i, :], 0).data.cpu().numpy(), 4))]) + falses.append( + [ + paths[0][i], + VARTYPE_CLASSES[predicted[i]], + pos_pred[i], + len_pred[i], + list( + np.round(F.softmax(outputs1[i, :], 0).data.cpu().numpy(), 4) + ), + list( + np.round(F.softmax(outputs3[i, :], 0).data.cpu().numpy(), 4) + ), + ] + ) for i in range(len(labels)): label = labels[i] @@ -130,7 +150,7 @@ def test(net, epoch, validation_loader, use_cuda): if var_len_s.size()[0] > 1: compare_len = (len_pred == var_len_s).squeeze() else: - compare_len = (len_pred == var_len_s) + compare_len = len_pred == var_len_s for i in range(len(var_len_s)): len_ = var_len_s[i] @@ -144,30 +164,40 @@ def test(net, epoch, validation_loader, use_cuda): SN = 100 * class_correct[i] / (class_total[i] + 0.0001) PR = 100 * class_correct[i] / (class_p_total[i] + 0.0001) F1 = 2 * PR * SN / (PR + SN + 0.0001) - logger.info('Epoch {}: Type Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}'.format( + logger.info( + "Epoch {}: Type Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}".format( + epoch, VARTYPE_CLASSES[i], class_total[i], SN, PR, F1 + ) + ) + logger.info( + "Epoch {}: Type Accuracy of the network on the {} test candidates: {:.4f} %".format( epoch, - VARTYPE_CLASSES[i], class_total[i], - SN, PR, F1)) - logger.info('Epoch {}: Type Accuracy of the network on the {} test candidates: {:.4f} %'.format( - epoch, sum(class_total), ( - 100 * sum(class_correct) / float(sum(class_total))))) + sum(class_total), + (100 * sum(class_correct) / float(sum(class_total))), + ) + ) for i in range(4): SN = 100 * len_class_correct[i] / (len_class_total[i] + 0.0001) PR = 100 * len_class_correct[i] / (len_class_p_total[i] + 0.0001) F1 = 2 * PR * SN / (PR + SN + 0.0001) - logger.info('Epoch {}: Length Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}'.format( - epoch, i, len_class_total[i], - SN, PR, F1)) - logger.info('Epoch {}: Length Accuracy of the network on the {} test candidates: {:.4f} %'.format( - epoch, sum(len_class_total), ( - 100 * sum(len_class_correct) / float(sum(len_class_total))))) + logger.info( + "Epoch {}: Length Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}".format( + epoch, i, len_class_total[i], SN, PR, F1 + ) + ) + logger.info( + "Epoch {}: Length Accuracy of the network on the {} test candidates: {:.4f} %".format( + epoch, + sum(len_class_total), + (100 * sum(len_class_correct) / float(sum(len_class_total))), + ) + ) net.train() class SubsetNoneSampler(torch.utils.data.sampler.Sampler): - def __init__(self, none_indices, var_indices, none_count): self.none_indices = none_indices self.var_indices = var_indices @@ -177,38 +207,60 @@ def __init__(self, none_indices, var_indices, none_count): def __iter__(self): logger = logging.getLogger(SubsetNoneSampler.__iter__.__name__) if self.current_none_id > (len(self.none_indices) - self.none_count): - this_round_nones = self.none_indices[self.current_none_id:] - self.none_indices = list(map(lambda i: self.none_indices[i], - torch.randperm(len(self.none_indices)).tolist())) + this_round_nones = self.none_indices[self.current_none_id :] + self.none_indices = list( + map( + lambda i: self.none_indices[i], + torch.randperm(len(self.none_indices)).tolist(), + ) + ) self.current_none_id = self.none_count - len(this_round_nones) - this_round_nones += self.none_indices[0:self.current_none_id] + this_round_nones += self.none_indices[0 : self.current_none_id] else: this_round_nones = self.none_indices[ - self.current_none_id:self.current_none_id + self.none_count] + self.current_none_id : self.current_none_id + self.none_count + ] self.current_none_id += self.none_count current_indices = this_round_nones + self.var_indices - ret = iter(map(lambda i: current_indices[i], - torch.randperm(len(current_indices)))) + ret = iter( + map(lambda i: current_indices[i], torch.randperm(len(current_indices))) + ) return ret def __len__(self): return len(self.var_indices) + self.none_count -def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint, - num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs, - lr_drop_ratio, momentum, boost_none, none_count_scale, - max_load_candidates, coverage_thr, save_freq, - merged_candidates_per_tsv, merged_max_num_tsvs, overwrite_merged_tsvs, - train_split_len, - normalize_channels, - no_seq_complexity, - zero_ann_cols, - force_zero_ann_cols, - ensemble_custom_header, - matrix_dtype, - use_cuda): +def train_neusomatic( + candidates_tsv, + validation_candidates_tsv, + out_dir, + checkpoint, + num_threads, + batch_size, + max_epochs, + learning_rate, + lr_drop_epochs, + lr_drop_ratio, + momentum, + boost_none, + none_count_scale, + max_load_candidates, + coverage_thr, + save_freq, + merged_candidates_per_tsv, + merged_max_num_tsvs, + overwrite_merged_tsvs, + train_split_len, + normalize_channels, + no_seq_complexity, + zero_ann_cols, + force_zero_ann_cols, + ensemble_custom_header, + matrix_dtype, + use_cuda, +): logger = logging.getLogger(train_neusomatic.__name__) logger.info("----------------Train NeuSomatic Network-------------------") @@ -224,37 +276,46 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) if checkpoint: - logger.info( - "Load pretrained model from checkpoint {}".format(checkpoint)) + logger.info("Load pretrained model from checkpoint {}".format(checkpoint)) pretrained_dict = torch.load( - checkpoint, map_location=lambda storage, loc: storage) + checkpoint, map_location=lambda storage, loc: storage + ) pretrained_state_dict = pretrained_dict["state_dict"] tag = pretrained_dict["tag"] sofar_epochs = pretrained_dict["epoch"] - logger.info( - "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs)) + logger.info("sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs)) coverage_thr = pretrained_dict["coverage_thr"] logger.info( - "Override coverage_thr from pretrained checkpoint: {}".format(coverage_thr)) + "Override coverage_thr from pretrained checkpoint: {}".format(coverage_thr) + ) if "normalize_channels" in pretrained_dict: normalize_channels = pretrained_dict["normalize_channels"] else: normalize_channels = False logger.info( - "Override normalize_channels from pretrained checkpoint: {}".format(normalize_channels)) + "Override normalize_channels from pretrained checkpoint: {}".format( + normalize_channels + ) + ) if "no_seq_complexity" in pretrained_dict: no_seq_complexity = pretrained_dict["no_seq_complexity"] else: no_seq_complexity = True logger.info( - "Override no_seq_complexity from pretrained checkpoint: {}".format(no_seq_complexity)) + "Override no_seq_complexity from pretrained checkpoint: {}".format( + no_seq_complexity + ) + ) if "zero_ann_cols" in pretrained_dict: zero_ann_cols = pretrained_dict["zero_ann_cols"] else: zero_ann_cols = [] if not force_zero_ann_cols: logger.info( - "Override zero_ann_cols from pretrained checkpoint: {}".format(zero_ann_cols)) + "Override zero_ann_cols from pretrained checkpoint: {}".format( + zero_ann_cols + ) + ) if "ensemble_custom_header" in pretrained_dict: ensemble_custom_header = pretrained_dict["ensemble_custom_header"] else: @@ -273,7 +334,10 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo if force_zero_ann_cols: zero_ann_cols = force_zero_ann_cols logger.info( - "Override zero_ann_cols from force_zero_ann_cols: {}".format(force_zero_ann_cols)) + "Override zero_ann_cols from force_zero_ann_cols: {}".format( + force_zero_ann_cols + ) + ) if not ensemble_custom_header: expected_ens_fields = NUM_ENS_FEATURES @@ -298,10 +362,12 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo break else: raise Exception( - "Wrong number of fields in {}: {}".format(tsv, len(x))) + "Wrong number of fields in {}: {}".format(tsv, len(x)) + ) - num_channels = expected_ens_fields + \ - NUM_ST_FEATURES if ensemble else NUM_ST_FEATURES + num_channels = ( + expected_ens_fields + NUM_ST_FEATURES if ensemble else NUM_ST_FEATURES + ) else: num_channels = 0 for tsv in candidates_tsv: @@ -330,15 +396,28 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo # 1. filter out unnecessary keys # pretrained_state_dict = { # k: v for k, v in pretrained_state_dict.items() if k in model_dict} - if "module." in list(pretrained_state_dict.keys())[0] and "module." not in list(model_dict.keys())[0]: - pretrained_state_dict = {k.split("module.")[1]: v for k, v in pretrained_state_dict.items( - ) if k.split("module.")[1] in model_dict} - elif "module." not in list(pretrained_state_dict.keys())[0] and "module." in list(model_dict.keys())[0]: + if ( + "module." in list(pretrained_state_dict.keys())[0] + and "module." not in list(model_dict.keys())[0] + ): pretrained_state_dict = { - ("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict} + k.split("module.")[1]: v + for k, v in pretrained_state_dict.items() + if k.split("module.")[1] in model_dict + } + elif ( + "module." not in list(pretrained_state_dict.keys())[0] + and "module." in list(model_dict.keys())[0] + ): + pretrained_state_dict = { + ("module." + k): v + for k, v in pretrained_state_dict.items() + if ("module." + k) in model_dict + } else: - pretrained_state_dict = {k: v for k, - v in pretrained_state_dict.items() if k in model_dict} + pretrained_state_dict = { + k: v for k, v in pretrained_state_dict.items() if k in model_dict + } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_state_dict) # 3. load the new state dict @@ -347,19 +426,23 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo shuffle(candidates_tsv) if len(candidates_tsv) > merged_max_num_tsvs: - candidates_tsv = merge_tsvs(input_tsvs=candidates_tsv, out=out_dir, - candidates_per_tsv=merged_candidates_per_tsv, - max_num_tsvs=merged_max_num_tsvs, - overwrite_merged_tsvs=overwrite_merged_tsvs, - keep_none_types=True) + candidates_tsv = merge_tsvs( + input_tsvs=candidates_tsv, + out=out_dir, + candidates_per_tsv=merged_candidates_per_tsv, + max_num_tsvs=merged_max_num_tsvs, + overwrite_merged_tsvs=overwrite_merged_tsvs, + keep_none_types=True, + ) Ls = [] for tsv in candidates_tsv: idx = pickle.load(open(tsv + ".idx", "rb")) Ls.append(len(idx) - 1) - Ls, candidates_tsv = list(zip( - *sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True))) + Ls, candidates_tsv = list( + zip(*sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True)) + ) train_split_tsvs = [] current_L = 0 @@ -367,9 +450,12 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)): current_L += L current_split_tsvs.append(tsv) - if current_L >= train_split_len or (i == len(candidates_tsv) - 1 and current_L > 0): - logger.info("tsvs in split {}: {}".format( - len(train_split_tsvs), current_split_tsvs)) + if current_L >= train_split_len or ( + i == len(candidates_tsv) - 1 and current_L > 0 + ): + logger.info( + "tsvs in split {}: {}".format(len(train_split_tsvs), current_split_tsvs) + ) train_split_tsvs.append(current_split_tsvs) current_L = 0 current_split_tsvs = [] @@ -382,31 +468,45 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo var_indices_ = [] samplers = [] for split_i, tsvs in enumerate(train_split_tsvs): - train_set = NeuSomaticDataset(roots=tsvs, - max_load_candidates=int( - max_load_candidates * len(tsvs) / float(len(candidates_tsv))), - transform=data_transform, is_test=False, - num_threads=num_threads, coverage_thr=coverage_thr, - normalize_channels=normalize_channels, - zero_ann_cols=zero_ann_cols, - matrix_dtype=matrix_dtype) + train_set = NeuSomaticDataset( + roots=tsvs, + max_load_candidates=int( + max_load_candidates * len(tsvs) / float(len(candidates_tsv)) + ), + transform=data_transform, + is_test=False, + num_threads=num_threads, + coverage_thr=coverage_thr, + normalize_channels=normalize_channels, + zero_ann_cols=zero_ann_cols, + matrix_dtype=matrix_dtype, + ) train_sets.append(train_set) none_indices = train_set.get_none_indices() var_indices = train_set.get_var_indices() if none_indices: - none_indices = list(map(lambda i: none_indices[i], - torch.randperm(len(none_indices)).tolist())) + none_indices = list( + map( + lambda i: none_indices[i], + torch.randperm(len(none_indices)).tolist(), + ) + ) logger.info( - "Non-somatic candidates in split {}: {}".format(split_i, len(none_indices))) + "Non-somatic candidates in split {}: {}".format(split_i, len(none_indices)) + ) if var_indices: - var_indices = list(map(lambda i: var_indices[i], - torch.randperm(len(var_indices)).tolist())) - logger.info("Somatic candidates in split {}: {}".format( - split_i, len(var_indices))) - none_count = max(min(len(none_indices), len( - var_indices) * none_count_scale), 1) + var_indices = list( + map(lambda i: var_indices[i], torch.randperm(len(var_indices)).tolist()) + ) + logger.info( + "Somatic candidates in split {}: {}".format(split_i, len(var_indices)) + ) + none_count = max(min(len(none_indices), len(var_indices) * none_count_scale), 1) logger.info( - "Non-somatic considered in each epoch of split {}: {}".format(split_i, none_count)) + "Non-somatic considered in each epoch of split {}: {}".format( + split_i, none_count + ) + ) sampler = SubsetNoneSampler(none_indices, var_indices, none_count) samplers.append(sampler) @@ -414,20 +514,29 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo var_counts.append(len(var_indices)) var_indices_.append(var_indices) none_indices_.append(none_indices) - logger.info("# Total Train cadidates: {}".format( - sum(map(lambda x: len(x), train_sets)))) + logger.info( + "# Total Train cadidates: {}".format(sum(map(lambda x: len(x), train_sets))) + ) if validation_candidates_tsv: - validation_set = NeuSomaticDataset(roots=validation_candidates_tsv, - max_load_candidates=max_load_candidates, - transform=data_transform, is_test=True, - num_threads=num_threads, coverage_thr=coverage_thr, - normalize_channels=normalize_channels, - zero_ann_cols=zero_ann_cols, - matrix_dtype=matrix_dtype) - validation_loader = torch.utils.data.DataLoader(validation_set, - batch_size=batch_size, shuffle=True, - num_workers=num_threads, pin_memory=True) + validation_set = NeuSomaticDataset( + roots=validation_candidates_tsv, + max_load_candidates=max_load_candidates, + transform=data_transform, + is_test=True, + num_threads=num_threads, + coverage_thr=coverage_thr, + normalize_channels=normalize_channels, + zero_ann_cols=zero_ann_cols, + matrix_dtype=matrix_dtype, + ) + validation_loader = torch.utils.data.DataLoader( + validation_set, + batch_size=batch_size, + shuffle=True, + num_workers=num_threads, + pin_memory=True, + ) logger.info("#Validation candidates: {}".format(len(validation_set))) count_class_t = [0] * 4 @@ -438,13 +547,15 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo count_class_l[i] += train_set.count_class_l[i] weights_type, weights_length = make_weights_for_balanced_classes( - count_class_t, count_class_l, 4, 4, sum(none_counts)) + count_class_t, count_class_l, 4, 4, sum(none_counts) + ) weights_type[2] *= boost_none weights_length[0] *= boost_none - logger.info("weights_type:{}, weights_length:{}".format( - weights_type, weights_length)) + logger.info( + "weights_type:{}, weights_length:{}".format(weights_type, weights_length) + ) loss_s = [] gradients = torch.FloatTensor(weights_type) @@ -455,31 +566,37 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo criterion_crossentropy = nn.CrossEntropyLoss(gradients) criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2) criterion_smoothl1 = nn.SmoothL1Loss() - optimizer = optim.SGD( - net.parameters(), lr=learning_rate, momentum=momentum) + optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) net.train() len_train_set = sum(none_counts) + sum(var_counts) logger.info("Number of candidates per epoch: {}".format(len_train_set)) print_freq = max(1, int(len_train_set / float(batch_size) / 4.0)) curr_epoch = prev_epochs - torch.save({"state_dict": net.state_dict(), - "tag": tag, - "epoch": curr_epoch, - "coverage_thr": coverage_thr, - "normalize_channels": normalize_channels, - "no_seq_complexity": no_seq_complexity, - "zero_ann_cols": zero_ann_cols, - "ensemble_custom_header": ensemble_custom_header, - "matrix_dtype": matrix_dtype, - }, '{}/models/checkpoint_{}_epoch{}_.pth'.format(out_dir, tag, curr_epoch)) + torch.save( + { + "state_dict": net.state_dict(), + "tag": tag, + "epoch": curr_epoch, + "coverage_thr": coverage_thr, + "normalize_channels": normalize_channels, + "no_seq_complexity": no_seq_complexity, + "zero_ann_cols": zero_ann_cols, + "ensemble_custom_header": ensemble_custom_header, + "matrix_dtype": matrix_dtype, + }, + "{}/models/checkpoint_{}_epoch{}_.pth".format(out_dir, tag, curr_epoch), + ) if len(train_sets) == 1: train_sets[0].open_candidate_tsvs() - train_loader = torch.utils.data.DataLoader(train_sets[0], - batch_size=batch_size, - num_workers=num_threads, pin_memory=True, - sampler=samplers[0]) + train_loader = torch.utils.data.DataLoader( + train_sets[0], + batch_size=batch_size, + num_workers=num_threads, + pin_memory=True, + sampler=samplers[0], + ) # loop over the dataset multiple times n_epoch = 0 for epoch in range(max_epochs - prev_epochs): @@ -489,20 +606,31 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo for split_i, train_set in enumerate(train_sets): if len(train_sets) > 1: train_set.open_candidate_tsvs() - train_loader = torch.utils.data.DataLoader(train_set, - batch_size=batch_size, - num_workers=num_threads, pin_memory=True, - sampler=samplers[split_i]) + train_loader = torch.utils.data.DataLoader( + train_set, + batch_size=batch_size, + num_workers=num_threads, + pin_memory=True, + sampler=samplers[split_i], + ) for i, data in enumerate(train_loader, 0): i_ += 1 # get the inputs (inputs, labels, var_pos_s, var_len_s, _), _ = data # wrap them in Variable - inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable( - labels), Variable(var_pos_s), Variable(var_len_s) + inputs, labels, var_pos_s, var_len_s = ( + Variable(inputs), + Variable(labels), + Variable(var_pos_s), + Variable(var_len_s), + ) if use_cuda: - inputs, labels, var_pos_s, var_len_s = inputs.cuda( - ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda() + inputs, labels, var_pos_s, var_len_s = ( + inputs.cuda(), + labels.cuda(), + var_pos_s.cuda(), + var_len_s.cuda(), + ) # zero the parameter gradients optimizer.zero_grad() @@ -510,12 +638,15 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo outputs, _ = net(inputs) [outputs_classification, outputs_pos, outputs_len] = outputs var_len_labels = Variable( - torch.LongTensor(var_len_s.cpu().data.numpy())) + torch.LongTensor(var_len_s.cpu().data.numpy()) + ) if use_cuda: var_len_labels = var_len_labels.cuda() - loss = criterion_crossentropy(outputs_classification, labels) + 1 * criterion_smoothl1( - outputs_pos.squeeze(1), var_pos_s[:, 1] - ) + 1 * criterion_crossentropy2(outputs_len, var_len_labels) + loss = ( + criterion_crossentropy(outputs_classification, labels) + + 1 * criterion_smoothl1(outputs_pos.squeeze(1), var_pos_s[:, 1]) + + 1 * criterion_crossentropy2(outputs_len, var_len_labels) + ) loss.backward() optimizer.step() @@ -523,48 +654,59 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo running_loss += loss.data if i_ % print_freq == print_freq - 1: - logger.info('epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format( - n_epoch + prev_epochs, len(loss_s), - learning_rate, running_loss / print_freq)) + logger.info( + "epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}".format( + n_epoch + prev_epochs, + len(loss_s), + learning_rate, + running_loss / print_freq, + ) + ) running_loss = 0.0 if len(train_sets) > 1: train_set.close_candidate_tsvs() curr_epoch = n_epoch + prev_epochs if curr_epoch % save_freq == 0: - torch.save({"state_dict": net.state_dict(), - "tag": tag, - "epoch": curr_epoch, - "coverage_thr": coverage_thr, - "normalize_channels": normalize_channels, - "no_seq_complexity": no_seq_complexity, - "zero_ann_cols": zero_ann_cols, - "ensemble_custom_header": ensemble_custom_header, - "matrix_dtype": matrix_dtype, - }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) + torch.save( + { + "state_dict": net.state_dict(), + "tag": tag, + "epoch": curr_epoch, + "coverage_thr": coverage_thr, + "normalize_channels": normalize_channels, + "no_seq_complexity": no_seq_complexity, + "zero_ann_cols": zero_ann_cols, + "ensemble_custom_header": ensemble_custom_header, + "matrix_dtype": matrix_dtype, + }, + "{}/models/checkpoint_{}_epoch{}.pth".format(out_dir, tag, curr_epoch), + ) if validation_candidates_tsv: test(net, curr_epoch, validation_loader, use_cuda) if curr_epoch % lr_drop_epochs == 0: learning_rate *= lr_drop_ratio - optimizer = optim.SGD( - net.parameters(), lr=learning_rate, momentum=momentum) - logger.info('Finished Training') + optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) + logger.info("Finished Training") if len(train_sets) == 1: train_sets[0].close_candidate_tsvs() curr_epoch = n_epoch + prev_epochs - torch.save({"state_dict": net.state_dict(), - "tag": tag, - "epoch": curr_epoch, - "coverage_thr": coverage_thr, - "normalize_channels": normalize_channels, - "no_seq_complexity": no_seq_complexity, - "zero_ann_cols": zero_ann_cols, - "ensemble_custom_header": ensemble_custom_header, - "matrix_dtype": matrix_dtype, - }, '{}/models/checkpoint_{}_epoch{}.pth'.format( - out_dir, tag, curr_epoch)) + torch.save( + { + "state_dict": net.state_dict(), + "tag": tag, + "epoch": curr_epoch, + "coverage_thr": coverage_thr, + "normalize_channels": normalize_channels, + "no_seq_complexity": no_seq_complexity, + "zero_ann_cols": zero_ann_cols, + "ensemble_custom_header": ensemble_custom_header, + "matrix_dtype": matrix_dtype, + }, + "{}/models/checkpoint_{}_epoch{}.pth".format(out_dir, tag, curr_epoch), + ) if validation_candidates_tsv: test(net, curr_epoch, validation_loader, use_cuda) logger.info("Total Epochs: {}".format(curr_epoch)) @@ -572,90 +714,151 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo logger.info("Training is Done.") - return '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch) + return "{}/models/checkpoint_{}_epoch{}.pth".format(out_dir, tag, curr_epoch) + -if __name__ == '__main__': +if __name__ == "__main__": - FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) - parser = argparse.ArgumentParser( - description='simple call variants from bam') - parser.add_argument('--candidates_tsv', nargs="*", - help=' train candidate tsv files', required=True) - parser.add_argument('--out', type=str, - help='output directory', required=True) - parser.add_argument('--checkpoint', type=str, - help='pretrained network model checkpoint path', default=None) - parser.add_argument('--validation_candidates_tsv', nargs="*", - help=' validation candidate tsv files', default=[]) - parser.add_argument('--num_threads', type=int, - help='number of threads', default=1) - parser.add_argument('--ensemble', - help='Enable training for ensemble mode', - action="store_true") - parser.add_argument('--batch_size', type=int, - help='batch size', default=1000) - parser.add_argument('--max_epochs', type=int, - help='maximum number of training epochs', default=1000) - parser.add_argument('--lr', type=float, help='learning_rate', default=0.01) - parser.add_argument('--lr_drop_epochs', type=int, - help='number of epochs to drop learning rate', default=400) - parser.add_argument('--lr_drop_ratio', type=float, - help='learning rate drop scale', default=0.1) - parser.add_argument('--momentum', type=float, - help='SGD momentum', default=0.9) - parser.add_argument('--boost_none', type=float, - help='the amount to boost none-somatic classification weight', default=100) - parser.add_argument('--none_count_scale', type=float, - help='ratio of the none/somatic canidates to use in each training epoch \ - (the none candidate will be subsampled in each epoch based on this ratio', - default=2) - parser.add_argument('--max_load_candidates', type=int, - help='maximum candidates to load in memory', default=1000000) - parser.add_argument('--save_freq', type=int, - help='the frequency of saving checkpoints in terms of # epochs', default=50) - parser.add_argument('--merged_candidates_per_tsv', type=int, - help='Maximum number of candidates in each merged tsv file ', default=10000000) - parser.add_argument('--merged_max_num_tsvs', type=int, - help='Maximum number of merged tsv files \ - (higher priority than merged_candidates_per_tsv)', default=10) - parser.add_argument('--overwrite_merged_tsvs', - help='if OUT/merged_tsvs/ folder exists overwrite the merged tsvs', - action="store_true") - parser.add_argument('--train_split_len', type=int, - help='Maximum number of candidates used in each split of training (>=merged_candidates_per_tsv)', - default=10000000) - parser.add_argument('--coverage_thr', type=int, - help='maximum coverage threshold to be used for network input \ + parser = argparse.ArgumentParser(description="simple call variants from bam") + parser.add_argument( + "--candidates_tsv", nargs="*", help=" train candidate tsv files", required=True + ) + parser.add_argument("--out", type=str, help="output directory", required=True) + parser.add_argument( + "--checkpoint", + type=str, + help="pretrained network model checkpoint path", + default=None, + ) + parser.add_argument( + "--validation_candidates_tsv", + nargs="*", + help=" validation candidate tsv files", + default=[], + ) + parser.add_argument("--num_threads", type=int, help="number of threads", default=1) + parser.add_argument( + "--ensemble", help="Enable training for ensemble mode", action="store_true" + ) + parser.add_argument("--batch_size", type=int, help="batch size", default=1000) + parser.add_argument( + "--max_epochs", type=int, help="maximum number of training epochs", default=1000 + ) + parser.add_argument("--lr", type=float, help="learning_rate", default=0.01) + parser.add_argument( + "--lr_drop_epochs", + type=int, + help="number of epochs to drop learning rate", + default=400, + ) + parser.add_argument( + "--lr_drop_ratio", type=float, help="learning rate drop scale", default=0.1 + ) + parser.add_argument("--momentum", type=float, help="SGD momentum", default=0.9) + parser.add_argument( + "--boost_none", + type=float, + help="the amount to boost none-somatic classification weight", + default=100, + ) + parser.add_argument( + "--none_count_scale", + type=float, + help="ratio of the none/somatic canidates to use in each training epoch \ + (the none candidate will be subsampled in each epoch based on this ratio", + default=2, + ) + parser.add_argument( + "--max_load_candidates", + type=int, + help="maximum candidates to load in memory", + default=1000000, + ) + parser.add_argument( + "--save_freq", + type=int, + help="the frequency of saving checkpoints in terms of # epochs", + default=50, + ) + parser.add_argument( + "--merged_candidates_per_tsv", + type=int, + help="Maximum number of candidates in each merged tsv file ", + default=10000000, + ) + parser.add_argument( + "--merged_max_num_tsvs", + type=int, + help="Maximum number of merged tsv files \ + (higher priority than merged_candidates_per_tsv)", + default=10, + ) + parser.add_argument( + "--overwrite_merged_tsvs", + help="if OUT/merged_tsvs/ folder exists overwrite the merged tsvs", + action="store_true", + ) + parser.add_argument( + "--train_split_len", + type=int, + help="Maximum number of candidates used in each split of training (>=merged_candidates_per_tsv)", + default=10000000, + ) + parser.add_argument( + "--coverage_thr", + type=int, + help="maximum coverage threshold to be used for network input \ normalization. \ Will be overridden if pretrained model is provided\ For ~50x WGS, coverage_thr=100 should work. \ - For higher coverage WES, coverage_thr=300 should work.', default=100) - parser.add_argument('--normalize_channels', - help='normalize BQ, MQ, and other bam-info channels by frequency of observed alleles. \ - Will be overridden if pretrained model is provided', - action="store_true") - parser.add_argument('--no_seq_complexity', - help='Dont compute linguistic sequence complexity features', - action="store_true") - parser.add_argument('--zero_ann_cols', nargs="*", type=int, - help='columns to be set to zero in the annotations \ - idx starts from 5th column in candidate.tsv file', - default=[]) - parser.add_argument('--force_zero_ann_cols', nargs="*", type=int, - help='force columns to be set to zero in the annotations. Higher priority than \ + For higher coverage WES, coverage_thr=300 should work.", + default=100, + ) + parser.add_argument( + "--normalize_channels", + help="normalize BQ, MQ, and other bam-info channels by frequency of observed alleles. \ + Will be overridden if pretrained model is provided", + action="store_true", + ) + parser.add_argument( + "--no_seq_complexity", + help="Dont compute linguistic sequence complexity features", + action="store_true", + ) + parser.add_argument( + "--zero_ann_cols", + nargs="*", + type=int, + help="columns to be set to zero in the annotations \ + idx starts from 5th column in candidate.tsv file", + default=[], + ) + parser.add_argument( + "--force_zero_ann_cols", + nargs="*", + type=int, + help="force columns to be set to zero in the annotations. Higher priority than \ --zero_ann_cols and pretrained setting \ - idx starts from 5th column in candidate.tsv file', - default=[]) - parser.add_argument('--ensemble_custom_header', - help='Allow ensemble tsv to have custom header fields. (Features should be\ - normalized between [0,1]', - action="store_true") - parser.add_argument('--matrix_dtype', type=str, - help='matrix_dtype to be used to store matrix', default="uint8", - choices=MAT_DTYPES) + idx starts from 5th column in candidate.tsv file", + default=[], + ) + parser.add_argument( + "--ensemble_custom_header", + help="Allow ensemble tsv to have custom header fields. (Features should be\ + normalized between [0,1]", + action="store_true", + ) + parser.add_argument( + "--matrix_dtype", + type=str, + help="matrix_dtype to be used to store matrix", + default="uint8", + choices=MAT_DTYPES, + ) args = parser.parse_args() logger.info(args) @@ -664,24 +867,37 @@ def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpo logger.info("use_cuda: {}".format(use_cuda)) try: - checkpoint = train_neusomatic(args.candidates_tsv, args.validation_candidates_tsv, - args.out, args.checkpoint, args.num_threads, args.batch_size, - args.max_epochs, - args.lr, args.lr_drop_epochs, args.lr_drop_ratio, args.momentum, - args.boost_none, args.none_count_scale, - args.max_load_candidates, args.coverage_thr, args.save_freq, - args.merged_candidates_per_tsv, args.merged_max_num_tsvs, - args.overwrite_merged_tsvs, args.train_split_len, - args.normalize_channels, - args.no_seq_complexity, - args.zero_ann_cols, - args.force_zero_ann_cols, - args.ensemble_custom_header, - args.matrix_dtype, - use_cuda) + checkpoint = train_neusomatic( + args.candidates_tsv, + args.validation_candidates_tsv, + args.out, + args.checkpoint, + args.num_threads, + args.batch_size, + args.max_epochs, + args.lr, + args.lr_drop_epochs, + args.lr_drop_ratio, + args.momentum, + args.boost_none, + args.none_count_scale, + args.max_load_candidates, + args.coverage_thr, + args.save_freq, + args.merged_candidates_per_tsv, + args.merged_max_num_tsvs, + args.overwrite_merged_tsvs, + args.train_split_len, + args.normalize_channels, + args.no_seq_complexity, + args.zero_ann_cols, + args.force_zero_ann_cols, + args.ensemble_custom_header, + args.matrix_dtype, + use_cuda, + ) except Exception as e: logger.error(traceback.format_exc()) logger.error("Aborting!") - logger.error( - "train.py failure on arguments: {}".format(args)) + logger.error("train.py failure on arguments: {}".format(args)) raise e diff --git a/neusomatic/python/utils.py b/neusomatic/python/utils.py index 67ba79c..44edb60 100755 --- a/neusomatic/python/utils.py +++ b/neusomatic/python/utils.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # utils.py # Utility functions -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import os import shutil import shlex @@ -14,12 +14,14 @@ import numpy as np -FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' +FORMAT = "%(levelname)s %(asctime)-15s %(name)-20s %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -def run_shell_command(command, stdout=None, stderr=None, run_logger=None, no_print=False): +def run_shell_command( + command, stdout=None, stderr=None, run_logger=None, no_print=False +): stdout_fd = open(stdout, "w") if stdout else None stderr_fd = open(stderr, "w") if stderr else None my_logger = logger @@ -30,7 +32,8 @@ def run_shell_command(command, stdout=None, stderr=None, run_logger=None, no_pri if not no_print: my_logger.info("Running command: {}".format(fixed_command)) returncode = subprocess.check_call( - fixed_command, stdout=stdout_fd, stderr=stderr_fd) + fixed_command, stdout=stdout_fd, stderr=stderr_fd + ) if stdout_fd: stdout_fd.close() if stderr_fd: @@ -51,8 +54,9 @@ def concatenate_files(infiles, outfile, check_file_existence=True): def concatenate_vcfs(infiles, outfile, check_file_existence=True, header_string="#"): with open(outfile, "w") as out_fd: # Only keep files which exist - files_to_process = filter(lambda f: f and ( - not check_file_existence or os.path.isfile(f)), infiles) + files_to_process = filter( + lambda f: f and (not check_file_existence or os.path.isfile(f)), infiles + ) for index, infile in enumerate(files_to_process): with open(infile) as in_fd: @@ -69,13 +73,15 @@ def get_chromosomes_order(reference=None, bam=None): chroms_order = {} if reference: with pysam.FastaFile(reference) as fd: - chroms_order = {chrom: chrom_index for chrom_index, - chrom in enumerate(fd.references)} + chroms_order = { + chrom: chrom_index for chrom_index, chrom in enumerate(fd.references) + } if bam: with pysam.AlignmentFile(bam, "rb") as fd: - chroms_order = {chrom: chrom_index for chrom_index, - chrom in enumerate(fd.references)} + chroms_order = { + chrom: chrom_index for chrom_index, chrom in enumerate(fd.references) + } return chroms_order @@ -87,14 +93,20 @@ def safe_read_info_dict(d, field, t=str, default_val=""): def run_bedtools_cmd(command, output_fn=None, run_logger=None): if output_fn is None: tmpfn = tempfile.NamedTemporaryFile( - prefix="tmpbed_", suffix=".bed", delete=False) + prefix="tmpbed_", suffix=".bed", delete=False + ) output_fn = tmpfn.name stderr_file = output_fn + ".stderr" if run_logger is None: run_logger = logger try: - returncode = run_shell_command(command, stdout=output_fn, stderr=stderr_file, - run_logger=run_logger, no_print=True) + returncode = run_shell_command( + command, + stdout=output_fn, + stderr=stderr_file, + run_logger=run_logger, + no_print=True, + ) os.remove(stderr_file) return output_fn except Exception as ex: @@ -107,7 +119,7 @@ def run_bedtools_cmd(command, output_fn=None, run_logger=None): def prob2phred(p, max_phred=100): - '''Convert prob to Phred-scale quality score.''' + """Convert prob to Phred-scale quality score.""" assert 0 <= p <= 1 if p == 1: Q = max_phred @@ -120,7 +132,7 @@ def prob2phred(p, max_phred=100): return Q -def write_tsv_file(tsv_file, records, sep='\t', add_fields=[]): +def write_tsv_file(tsv_file, records, sep="\t", add_fields=[]): with open(tsv_file, "w") as f_o: for x in records: f_o.write(sep.join(map(str, x + add_fields)) + "\n") @@ -135,7 +147,7 @@ def skip_empty(fh, skip_header=True): yield line -def read_tsv_file(tsv_file, sep='\t', fields=None): +def read_tsv_file(tsv_file, sep="\t", fields=None): records = [] with open(tsv_file) as i_f: for line in skip_empty(i_f): @@ -153,10 +165,22 @@ def vcf_2_bed(vcf_file, bed_file, add_fields=[], len_ref=False, keep_ref_alt=Tru len_ = 1 if not len_ref else len(x[3]) if keep_ref_alt: f_o.write( - "\t".join(map(str, [x[0], int(x[1]), int(x[1]) + len_, x[3], x[4]] + add_fields)) + "\n") + "\t".join( + map( + str, + [x[0], int(x[1]), int(x[1]) + len_, x[3], x[4]] + + add_fields, + ) + ) + + "\n" + ) else: f_o.write( - "\t".join(map(str, [x[0], int(x[1]), int(x[1]) + len_] + add_fields)) + "\n") + "\t".join( + map(str, [x[0], int(x[1]), int(x[1]) + len_] + add_fields) + ) + + "\n" + ) def bedtools_sort(bed_file, args="", output_fn=None, run_logger=None): @@ -186,9 +210,10 @@ def bedtools_window(a_bed_file, b_bed_file, args="", output_fn=None, run_logger= return output_fn -def bedtools_intersect(a_bed_file, b_bed_file, args="", output_fn=None, run_logger=None): - cmd = "bedtools intersect -a {} -b {} {}".format( - a_bed_file, b_bed_file, args) +def bedtools_intersect( + a_bed_file, b_bed_file, args="", output_fn=None, run_logger=None +): + cmd = "bedtools intersect -a {} -b {} {}".format(a_bed_file, b_bed_file, args) if output_fn is None: output_fn = run_bedtools_cmd(cmd, run_logger=run_logger) else: @@ -206,7 +231,6 @@ def bedtools_slop(bed_file, genome, args="", output_fn=None, run_logger=None): def get_tmp_file(prefix="tmpbed_", suffix=".bed", delete=False): - myfile = tempfile.NamedTemporaryFile( - prefix=prefix, suffix=suffix, delete=delete) + myfile = tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix, delete=delete) myfile = myfile.name return myfile