Skip to content

Commit

Permalink
tsize and metrics in packethook; fixes msoulier#103
Browse files Browse the repository at this point in the history
now shows remaining amount to transfer in progress messages
if the client was started with --tsize

also fixes partftpy_client so it sends tsize on upload
  • Loading branch information
9001 committed Jun 16, 2024
1 parent 20cc6a9 commit f2e10d3
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 22 deletions.
16 changes: 10 additions & 6 deletions partftpy/TftpClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def download(
"""This method initiates a tftp download from the configured remote
host, requesting the filename passed. It writes the file to output,
which can be a file-like object or a path to a local file. If a
packethook is provided, it must be a function that takes a single
parameter, which will be a copy of each DAT packet received in the
form of a TftpPacketDAT object. The timeout parameter may be used to
packethook is provided, it must be a function that takes two
parameters, the first being a copy of each packet received in the
form of a TftpPacket object, and the second being the TftpContext
for this transfer, which can be inspected for more accurate statistics,
progress estimates and such. The timeout parameter may be used to
override the default SOCK_TIMEOUT setting, which is the amount of time
that the client will wait for a receive packet to arrive.
The retires parameter may be used to override the default DEF_TIMEOUT_RETRIES
Expand Down Expand Up @@ -108,9 +110,11 @@ def upload(
"""This method initiates a tftp upload to the configured remote host,
uploading the filename passed. It reads the file from input, which
can be a file-like object or a path to a local file. If a packethook
is provided, it must be a function that takes a single parameter,
which will be a copy of each DAT packet sent in the form of a
TftpPacketDAT object. The timeout parameter may be used to override
is provided, it must be a function that takes two parameters,
the first being a copy of each packet received in the form of
a TftpPacket object, and the second being the TftpContext for
this transfer, which can be inspected for more accurate statistics,
progress estimates, etc. The timeout parameter may be used to override
the default SOCK_TIMEOUT setting, which is the amount of time that
the client will wait for a DAT packet to be ACKd by the server.
The retires parameter may be used to override the default DEF_TIMEOUT_RETRIES
Expand Down
18 changes: 14 additions & 4 deletions partftpy/TftpContexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from .TftpShared import *
from .TftpStates import *

if TYPE_CHECKING:
from typing import Optional

log = logging.getLogger("partftpy.TftpContext")

###############################################################################
Expand All @@ -35,10 +38,13 @@ class TftpMetrics(object):
"""A class representing metrics of the transfer."""

def __init__(self):
# Bytes transferred
# Set by context if available
self.tsize = 0
# Transfer counters
self.bytes = 0
# Bytes re-sent
self.packets = 0
self.resent_bytes = 0
self.resent_packets = 0
# Duplicate packets received
self.dups = {}
self.dupcount = 0
Expand Down Expand Up @@ -134,7 +140,7 @@ def __init__(
# FIXME: does this belong in metrics?
self.last_update = 0
# The last packet we sent, if applicable, to make resending easy.
self.last_pkt = None
self.last_pkt = None # type: Optional[TftpPacket]
# Count the number of retry attempts.
self.retry_count = 0
# Flag to signal timeout error when waiting for ACK of the current block
Expand Down Expand Up @@ -254,7 +260,7 @@ def cycle(self):
# kinds of packets. This way, the client is privy to things like
# negotiated options.
if self.packethook:
self.packethook(recvpkt)
self.packethook(recvpkt, self)

# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
Expand Down Expand Up @@ -367,6 +373,10 @@ def start(self):
log.info(" filename -> %s", self.file_to_transfer)
log.info(" options -> %s", self.options)

tsize = self.options.get("tsize")
if tsize:
self.metrics.tsize = tsize

self.metrics.start_time = time.time()
log.debug("Set metrics.start_time to %s", self.metrics.start_time)

Expand Down
18 changes: 15 additions & 3 deletions partftpy/TftpStates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from .TftpPacketTypes import *
from .TftpShared import *

if TYPE_CHECKING:
from .TftpContexts import TftpContext


log = logging.getLogger("partftpy.TftpStates")

###############################################################################
Expand All @@ -35,7 +39,7 @@ def __init__(self, context):
"""Constructor for setting up common instance variables. The involved
file object is required, since in tftp there's always a file
involved."""
self.context = context
self.context = context # type: TftpContext

def handle(self, pkt, raddress, rport):
"""An abstract method for handling a packet. It is expected to return
Expand All @@ -50,6 +54,9 @@ def handleOACK(self, pkt):
log.info("Successful negotiation of options")
# Set options to OACK options
self.context.options = pkt.options
tsize = pkt.options.get("tsize")
if tsize:
self.context.metrics.tsize = tsize
for k, v in self.context.options.items():
log.info(" %s = %s", k, v)
else:
Expand Down Expand Up @@ -112,6 +119,7 @@ def sendDAT(self):
dat.data = buffer
dat.blocknumber = blocknumber
self.context.metrics.bytes += len(dat.data)
self.context.metrics.packets += 1
# Testing hook
if NETWORK_UNRELIABILITY > 0 and random.randrange(NETWORK_UNRELIABILITY) == 0:
log.warning("Skipping DAT packet %d for testing", dat.blocknumber)
Expand All @@ -122,7 +130,7 @@ def sendDAT(self):
)
self.context.metrics.last_dat_time = time.time()
if self.context.packethook:
self.context.packethook(dat)
self.context.packethook(dat, self.context)
self.context.last_pkt = dat
return finished

Expand Down Expand Up @@ -175,6 +183,7 @@ def resendLast(self):
assert self.context.last_pkt is not None
log.warning("Resending packet %s on sessions %s", self.context.last_pkt, self)
self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
self.context.metrics.resent_packets += 1
self.context.metrics.add_dup(self.context.last_pkt)
sendto_port = self.context.tidport
if not sendto_port:
Expand All @@ -186,9 +195,10 @@ def resendLast(self):
self.context.last_pkt.encode().buffer, (self.context.host, sendto_port)
)
if self.context.packethook:
self.context.packethook(self.context.last_pkt)
self.context.packethook(self.context.last_pkt, self.context)

def handleDat(self, pkt):
# type: (TftpPacket) -> TftpState
"""This method handles a DAT packet during a client download, or a
server upload."""
log.debug("Handling DAT packet - block %d", pkt.blocknumber)
Expand All @@ -202,6 +212,7 @@ def handleDat(self, pkt):
log.debug("Writing %d bytes to output file", len(pkt.data))
self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data)
self.context.metrics.packets += 1
# Check for end-of-file, any less than full data packet.
if len(pkt.data) < self.context.options["blksize"]:
log.info("End of file detected")
Expand Down Expand Up @@ -354,6 +365,7 @@ def handle(self, pkt, raddress, rport):
tsize = str(self.context.fileobj.tell())
self.context.fileobj.seek(0, 0)
self.context.options["tsize"] = tsize
self.context.metrics.tsize = tsize

if sendoack:
# Note, next_block is 0 here since that's the proper
Expand Down
42 changes: 35 additions & 7 deletions partftpy/bin/partftpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import os
import socket
import sys
import threading
import time
from optparse import OptionParser

import partftpy.TftpPacketTypes
from partftpy.TftpClient import TftpClient
from partftpy.TftpShared import TftpException
from partftpy.TftpContexts import TftpContext

log = logging.getLogger("partftpy")
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -112,18 +115,40 @@ def main():

class Progress(object):
def __init__(self, out):
self.progress = 0
self.pkts = 0
self.out = out
self.metrics = None
self.thr = threading.Thread(target=self._print_progress)
self.thr.daemon = True
self.thr.start()

def progresshook(self, pkt):
def progresshook(self, pkt, ctx):
# type: (bytes, TftpContext) -> None
if isinstance(pkt, partftpy.TftpPacketTypes.TftpPacketDAT):
self.pkts += 1
self.progress += len(pkt.data)
self.out("Transferred %d bytes, %d pkts" % (self.progress, self.pkts))
self.metrics = ctx.metrics
elif isinstance(pkt, partftpy.TftpPacketTypes.TftpPacketOACK):
self.out("Received OACK, options are: %s" % pkt.options)

def _print_progress(self):
while True:
time.sleep(0.5)
if not self.metrics:
continue
metrics = self.metrics
self.metrics = None

pkts = metrics.packets
nbytes = metrics.bytes
left = metrics.tsize - nbytes
if left < 0:
self.out("Transferred %d pkts, %d bytes", pkts, nbytes)
else:
self.out(
"Transferred %d pkts, %d bytes, %d bytes left",
pkts,
nbytes,
left,
)

if options.debug:
log.setLevel(logging.DEBUG)
# increase the verbosity of the formatter
Expand All @@ -139,8 +164,11 @@ def progresshook(self, pkt):
tftp_options = {}
if options.blksize:
tftp_options["blksize"] = int(options.blksize)
if options.tsize:
if options.tsize and options.download:
tftp_options["tsize"] = 0
if options.tsize and options.upload and options.input != "-":
fn = options.input or options.upload
tftp_options["tsize"] = os.path.getsize(fn)

fam = socket.AF_INET6 if ":" in options.host else socket.AF_INET

Expand Down
4 changes: 2 additions & 2 deletions t/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def testServerDownloadWithStopNow(self, output="/tmp/out"):
stopped_early = False
time.sleep(1)

def delay_hook(pkt):
def delay_hook(pkt, ctx):
time.sleep(0.005) # 5ms

client.download("640KBFILE", output, delay_hook)
Expand Down Expand Up @@ -539,7 +539,7 @@ def testServerDownloadWithStopNotNow(self, output="/tmp/out"):
# parent - let the server start
time.sleep(1)

def delay_hook(pkt):
def delay_hook(pkt, ctx):
time.sleep(0.005) # 5ms

client.download("640KBFILE", output, delay_hook)
Expand Down

0 comments on commit f2e10d3

Please sign in to comment.