Skip to content

Commit

Permalink
adding cli, other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilianmordig committed Oct 31, 2023
1 parent 51be8d9 commit eec3176
Show file tree
Hide file tree
Showing 18 changed files with 525 additions and 72 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

This repository provides a simulator for an ONT device controlled by the ReadUntil API either directly or via gRPC, and can be accelerated (e.g. factor 10 with 512 channels). It takes full-length reads as input, plays them back with suitable gaps in between, and responds to ReadUntil actions.
The code is well-tested with `pytest` and an example usecase combining the simulator with ReadFish and NanoSim is provided.
Access the documentation [here](https://ratschlab.github.io/ont_project/).
Access the documentation [here](https://ratschlab.github.io/sim_read_until/).

See below for a [quick start](#quick-start).

Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,16 @@ readfish = [
]

[project.scripts]
plot_sim_seqsum = "simreaduntil.seqsum_tools.seqsum_plotting:main"
plot_seqsum = "simreaduntil.seqsum_tools.seqsum_plotting:main"
simfasta_to_seqsum = "simreaduntil.simulator.simfasta_to_seqsum:main"
simulator_with_readfish = "simreaduntil.usecase_helpers.simulator_with_readfish:main"
normalize_fasta = "simreaduntil.usecase_helpers.utils:normalize_fasta_cmd"
# helpers for the usecase
usecase_generate_random_reads = "simreaduntil.usecase_helpers.cli_usecase.generate_random_reads:main"
simulator_server_cli = "simreaduntil.usecase_helpers.cli_usecase.simulator_server_cli:main"
usecase_simulator_client_cli = "simreaduntil.usecase_helpers.cli_usecase.simulator_client_cli:main"
sim_plots_cli = "simreaduntil.usecase_helpers.cli_usecase.sim_plots_cli:main"


[project.urls]
"Homepage" = "https://github.com/ratschlab/sim_read_until"
Expand Down
6 changes: 3 additions & 3 deletions src/simreaduntil/seqsum_tools/seqsum_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def main():
"""
CLI entrypoint to create plots from a sequencing summary file
"""
add_comprehensive_stream_handler_to_logger(logging.DEBUG)
add_comprehensive_stream_handler_to_logger(None, logging.DEBUG)

if is_test_mode():
args = argparse.Namespace()
Expand All @@ -1009,10 +1009,10 @@ def main():
args.paf_file = None
else:
parser = argparse.ArgumentParser(description="Plotting script for sequencing summary file from the simulator")
parser.add_argument("ref_genome_path", type=Path, help="Path to the reference genome")
parser.add_argument("seqsummary_filename", type=Path, help="Path to the sequencing summary file")
parser.add_argument("--nrows", type=int, default=None, help="Number of rows to read from the sequencing summary file")
parser.add_argument("--save_dir", type=Path, default=None, help="Directory to save plots to")
parser.add_argument("--ref_genome_path", type=Path, help="Path to the reference genome", default=None)
parser.add_argument("--save_dir", type=Path, default=None, help="Directory to save plots to; display if None")
parser.add_argument("--cov_thresholds", type=str, default="1,2,3,4,5,6", help="Comma-separated list of target coverages to plot; set to '' if non-NanoSim reads")
parser.add_argument("--targets", type=str, default=None, help="if provided, a comma-separated list, e.g. 'chr1,chr2'. Creates two groups on the plot, opposing targets to the rest")
parser.add_argument("--cov_every", type=int, default=1, help="Compute coverage every x reads")
Expand Down
3 changes: 3 additions & 0 deletions src/simreaduntil/simulator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ mkdir protos_generated
# replace imports so they work with the project structure ("ont_simulator" must be on the python path)
python -m grpc_tools.protoc -Iprotos/ --python_out=protos_generated/ --pyi_out=protos_generated/ --grpc_python_out=protos_generated/ protos/ont_device.proto && \
sed -i -E "s%import (.*)_pb2 as%import simreaduntil.simulator.protos_generated.\1_pb2 as%g" protos_generated/ont_device_pb2_grpc.py
# todo: check
cd src && python -m grpc_tools.protoc -Isimreaduntil/simulator/protos/ --python_out=simreaduntil/simulator/protos_generated/ --pyi_out=simreaduntil/simulator/protos_generated/ --grpc_python_out=simreaduntil/simulator/protos_generated/ simreaduntil/simulator/protos/ont_device.proto
```


Expand Down
4 changes: 2 additions & 2 deletions src/simreaduntil/simulator/protos/ont_device.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ message ReadActionsRequest {
message StopReceivingAction {}

message UnblockAction {
double unblock_duration = 1;
double unblock_duration = 1; // set to <0 to take default, otherwise sets to 0!
}

uint32 channel = 1;
Expand Down Expand Up @@ -86,7 +86,7 @@ message ActionResultResponse {
}

message StartRequest {
double acceleration_factor = 1;
double acceleration_factor = 1; // <= 0 interpreted as 1.0
string update_method = 2;
uint32 log_interval = 3;
bool stop_if_no_reads = 4;
Expand Down
10 changes: 6 additions & 4 deletions src/simreaduntil/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _check_channels_available(self, channel_subset: List[int]=None):
if not set(channel_subset).issubset(self._onebased_available_channels()):
raise InexistentChannelsException(message=f"Tried to use {channel_subset}, but only channels {self._onebased_available_channels()} are available (channels are 1-based!)")

def start(self):
def start(self, **kwargs):
"""
Start the sequencing.
"""
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_basecalled_read_chunks(self, batch_size=None, channel_subset=None) -> Li
"""
raise NotImplementedError()

def get_action_results(self) -> List[Tuple[Any, float, int, str, Any]]:
def get_action_results(self, **kwargs) -> List[Tuple[Any, float, int, str, Any]]:
"""
Get new results of actions that were performed with unblock and stop_receiving (mux scans etc not included)
Expand Down Expand Up @@ -409,7 +409,7 @@ def start(self, *args, **kwargs):

return True

def _forward_sim_loop(self, acceleration_factor=1, update_method="realtime", log_interval: int=10, stop_if_no_reads=True, **kwargs):
def _forward_sim_loop(self, acceleration_factor=1.0, update_method="realtime", log_interval: int=10, stop_if_no_reads=True, **kwargs):
"""
Helper method launched by .start() to forward the simulation.
Expand Down Expand Up @@ -444,7 +444,7 @@ def _log():
logger.debug(f"Simulation has been running for real {cur_ns_time() - t_real_start:.2f} seconds with acceleration factor {acceleration_factor:.2f} (t_sim={t_sim:.2f}, i={i})")

combined_channel_statuses = self.get_channel_stats(combined=True)
logger.debug("\nCombined channel status: " + str(combined_channel_statuses))
logger.info("\nCombined channel status: " + str(combined_channel_statuses))

if self._channel_status_filename is not None:
print(combined_channel_statuses.get_table_line(), file=self._channel_status_fh)
Expand Down Expand Up @@ -618,13 +618,15 @@ def unblock_read(self, read_channel, read_id, unblock_duration=None) -> Optional
self._check_channels_available([read_channel])
action_res = self._channels[read_channel-1].unblock(unblock_duration=unblock_duration, read_id=read_id)
self._action_results.append((read_id, self._channels[read_channel-1].t, read_channel, ActionType.Unblock, action_res))
logger.info(f"Unblocking read {read_id} on channel {read_channel}, result: {action_res.to_str()}")
return action_res

def stop_receiving_read(self, read_channel, read_id) -> Optional[StoppedReceivingResponse]:
"""Stop receiving from read"""
self._check_channels_available([read_channel])
action_res = self._channels[read_channel-1].stop_receiving(read_id=read_id)
self._action_results.append((read_id, self._channels[read_channel-1].t, read_channel, ActionType.StopReceiving, action_res))
logger.info(f"Stopping receiving from read {read_id} on channel {read_channel}, result: {action_res.to_str()}")
return action_res

def run_mux_scan(self, t_duration: float) -> int:
Expand Down
40 changes: 31 additions & 9 deletions src/simreaduntil/simulator/simulator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def unique_id(self) -> str:
"""
return self._stub.GetServerInfo(ont_device_pb2.EmptyRequest()).unique_id

def start(self, acceleration_factor: float = 1.0):
def start(self, acceleration_factor: float = 1.0, update_method: str ="realtime", log_interval: int=10, stop_if_no_reads: bool = True):
"""
Start the sequencing.
"""
self._check_connected()
return self._stub.StartSim(ont_device_pb2.StartRequest(acceleration_factor=acceleration_factor)).value
return self._stub.StartSim(ont_device_pb2.StartRequest(acceleration_factor=acceleration_factor, update_method=update_method, log_interval=log_interval, stop_if_no_reads=stop_if_no_reads)).value

def stop(self):
"""
Expand All @@ -110,7 +110,7 @@ def run_mux_scan(self, t_duration: float):
Run mux scan
"""
self._check_connected()
return self._stub.RunMuxScan(ont_device_pb2.RunMuxScanRequest()).nb_reads_rejected
return self._stub.RunMuxScan(ont_device_pb2.RunMuxScanRequest(t_duration=t_duration)).nb_reads_rejected

@property
def is_running(self):
Expand All @@ -129,6 +129,7 @@ def get_basecalled_read_chunks(self, batch_size=None, channel_subset=None):
channels = None
else:
channels = ont_device_pb2.BasecalledChunksRequest.Channels(value=channel_subset)
# batch_size = None does not set it, so proto3 assigns it a value of 0
for chunk in self._stub.GetBasecalledChunks(ont_device_pb2.BasecalledChunksRequest(batch_size=batch_size, channels=channels)):
yield (chunk.channel, chunk.read_id, chunk.seq, chunk.quality_seq, chunk.estimated_ref_len_so_far)

Expand All @@ -145,19 +146,40 @@ def unblock_read(self, read_channel, read_id, unblock_duration=None):
"""
Unblock read_id on channel; returns whether the action was performed (not performed if the read was already over)
"""
self._check_connected()
return self._stub.PerformActions(ont_device_pb2.ReadActionsRequest(actions=[
ont_device_pb2.ReadActionsRequest.Action(channel=read_channel, read_id=read_id, unblock=ont_device_pb2.ReadActionsRequest.Action.UnblockAction(unblock_duration=unblock_duration))
])).succeeded[0]
# self._check_connected()
# return self._stub.PerformActions(ont_device_pb2.ReadActionsRequest(actions=[
# ont_device_pb2.ReadActionsRequest.Action(channel=read_channel, read_id=read_id, unblock=ont_device_pb2.ReadActionsRequest.Action.UnblockAction(unblock_duration=unblock_duration if unblock_duration is not None else -1))
# ])).succeeded[0]
return self.unblock_read_batch([(read_channel, read_id)], unblock_duration=unblock_duration)[0]

def stop_receiving_read(self, read_channel, read_id):
"""
Stop receiving read_id on channel; returns whether the action was performed (not performed if the read was already over)
"""
# self._check_connected()
# return self._stub.PerformActions(ont_device_pb2.ReadActionsRequest(actions=[
# ont_device_pb2.ReadActionsRequest.Action(channel=read_channel, read_id=read_id, stop_further_data=ont_device_pb2.ReadActionsRequest.Action.StopReceivingAction()),
# ])).succeeded[0]
return self.stop_receiving_read_batch([(read_channel, read_id)])[0]

# batch methods
def unblock_read_batch(self, channel_and_ids, unblock_duration=None):
"""
Unblock a batch of reads on channel; returns whether the actions were performed (not performed if the read was already over)
"""
self._check_connected()
return self._stub.PerformActions(ont_device_pb2.ReadActionsRequest(actions=[
ont_device_pb2.ReadActionsRequest.Action(channel=read_channel, read_id=read_id, unblock=ont_device_pb2.ReadActionsRequest.Action.UnblockAction(unblock_duration=unblock_duration if unblock_duration is not None else -1))
for (read_channel, read_id) in channel_and_ids])).succeeded

def stop_receiving_read_batch(self, channel_and_ids):
"""
Stop receiving a batch of reads on channel; returns whether the actions were performed (not performed if the read was already over)
"""
self._check_connected()
return self._stub.PerformActions(ont_device_pb2.ReadActionsRequest(actions=[
ont_device_pb2.ReadActionsRequest.Action(channel=read_channel, read_id=read_id, stop_further_data=ont_device_pb2.ReadActionsRequest.Action.StopReceivingAction()),
])).succeeded[0]
ont_device_pb2.ReadActionsRequest.Action(channel=read_channel, read_id=read_id, stop_further_data=ont_device_pb2.ReadActionsRequest.Action.StopReceivingAction())
for (read_channel, read_id) in channel_and_ids])).succeeded

@property
def mk_run_dir(self):
Expand Down
10 changes: 7 additions & 3 deletions src/simreaduntil/simulator/simulator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def PerformActions(self, request, context):
for action_desc in request.actions:
channel, read_id = action_desc.channel, action_desc.read_id
if action_desc.WhichOneof("action") == "unblock":
res.append(self.device.unblock_read(channel, read_id=read_id, unblock_duration=action_desc.unblock.unblock_duration))
unblock_duration = None if action_desc.unblock.unblock_duration < 0 else action_desc.unblock.unblock_duration
res.append(self.device.unblock_read(channel, read_id=read_id, unblock_duration=unblock_duration))
else:
res.append(self.device.stop_receiving_read(channel, read_id=read_id)) #todo2: current conversion from enum 0,1,2 to bool is not ideal
return ont_device_pb2.ActionResultImmediateResponse(succeeded=res)
Expand All @@ -112,7 +113,8 @@ def GetActionResults(self, request, context):
def GetBasecalledChunks(self, request, context):
# channel_subset=None on request side means that field was not set
channel_subset = request.channels.value if request.HasField("channels") else None
for (channel, read_id, seq, quality_seq, estimated_ref_len_so_far) in self.device.get_basecalled_read_chunks(batch_size=request.batch_size, channel_subset=channel_subset):
batch_size = request.batch_size if request.batch_size > 0 else None
for (channel, read_id, seq, quality_seq, estimated_ref_len_so_far) in self.device.get_basecalled_read_chunks(batch_size=batch_size, channel_subset=channel_subset):
yield ont_device_pb2.BasecalledReadChunkResponse(channel=channel, read_id=read_id, seq=seq, quality_seq=quality_seq, estimated_ref_len_so_far=estimated_ref_len_so_far)

@print_nongen_exceptions
Expand All @@ -122,7 +124,8 @@ def StartSim(self, request, context):
Returns: whether it succeeded (i.e. if simulation was not running)
"""
return ont_device_pb2.BoolResponse(value=self.device.start(request.acceleration_factor))
acceleration_factor = request.acceleration_factor if request.acceleration_factor <= 0 else 1.0
return ont_device_pb2.BoolResponse(value=self.device.start(acceleration_factor=acceleration_factor, update_method=request.update_method, log_interval=request.log_interval, stop_if_no_reads=request.stop_if_no_reads))

@print_nongen_exceptions
# stop simulation, returns whether it succeeded (i.e. if simulation was running)
Expand All @@ -131,6 +134,7 @@ def StopSim(self, request, context):

@print_nongen_exceptions
def RunMuxScan(self, request, context):
assert request.HasField("t_duration"), "t_duration must be set"
return ont_device_pb2.MuxScanStartedInfo(value=self.device.run_mux_scan(t_duration=request.t_duration))

@print_nongen_exceptions
Expand Down
3 changes: 3 additions & 0 deletions src/simreaduntil/usecase_helpers/cli_usecase/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Scripts to support the CLI usecase
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import argparse
import itertools
import logging
import os
import sys
from Bio.Seq import Seq
from Bio import SeqIO
from pathlib import Path
from tqdm import tqdm
from simreaduntil.shared_utils.debugging_helpers import is_test_mode
from simreaduntil.shared_utils.logging_utils import add_comprehensive_stream_handler_to_logger, print_logging_levels, setup_logger_simple
from simreaduntil.shared_utils.utils import print_args
from simreaduntil.simulator.readswriter import SingleFileReadsWriter
from simreaduntil.usecase_helpers.utils import random_nanosim_reads_gen

logger = setup_logger_simple(__name__)
"""module logger"""


def parse_args(args=None):
parser = argparse.ArgumentParser(description="Generate dummy reads, only for testing purposes")
parser.add_argument("reads_file", type=Path, help="Path to write reads to")
parser.add_argument("--num_reads", type=int, help="Number of reads to generate", default=10_000)
parser.add_argument("--length_range", type=str, help="Length range of reads", default="5_000,10_000")
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing reads file")

args = parser.parse_args(args)
print_args(args, logger=logger)

return args

def main():
log_level = logging.DEBUG
logging.getLogger(__name__).setLevel(log_level)
add_comprehensive_stream_handler_to_logger(None, level=log_level)
# print_logging_levels()

if is_test_mode():
os.chdir("server_client_cli_example")
args = parse_args(args=["test.fasta", "--overwrite"]) #todo
else:
args = parse_args()

reads_file = args.reads_file
overwrite = args.overwrite
num_reads = args.num_reads
assert num_reads > 0, f"num_reads {num_reads} must be > 0"
length_range = args.length_range
length_range = tuple(map(int, length_range.split(",")))
assert len(length_range) == 2, f"length_range {length_range} must have length 2"
assert length_range[0] < length_range[1], f"length_range {length_range} must be increasing"

assert reads_file.suffix == ".fasta", f"reads_file '{reads_file}' must end with '.fasta'"
if reads_file.exists():
if overwrite:
# logger.warning(f"Overwriting existing reads file '{reads_file}'")
reads_file.unlink()
else:
raise FileExistsError(f"reads_file '{reads_file}' already exists, use --overwrite to overwrite")

reads_gen = random_nanosim_reads_gen(length_range=length_range)
with open(reads_file, "w") as fh:
writer = SingleFileReadsWriter(fh)
for (read_id, seq) in tqdm(itertools.islice(reads_gen, num_reads), total=num_reads, desc="Writing read: "):
writer.write_read(SeqIO.SeqRecord(id=str(read_id), seq=Seq(seq)))
logger.info(f"Done writing reads to file '{reads_file}'")

if __name__ == "__main__":
main()
Loading

0 comments on commit eec3176

Please sign in to comment.