Skip to content

Commit

Permalink
read/unmarshal: avoid exceptions, use Error instead
Browse files Browse the repository at this point in the history
  • Loading branch information
robur-team committed Jan 3, 2024
1 parent 58c89b1 commit cb76332
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 94 deletions.
211 changes: 121 additions & 90 deletions lib/tar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block]
type error = [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block | `Unmarshal of string ]

let pp_error ppf = function
| `Eof -> Format.fprintf ppf "end of file"
| `Checksum_mismatch -> Format.fprintf ppf "checksum mismatch"
| `Corrupt_pax_header -> Format.fprintf ppf "corrupt PAX header"
| `Zero_block -> Format.fprintf ppf "zero block"
| `Unmarshal e -> Format.fprintf ppf "unmarshal %s" e

let ( let* ) = Result.bind

(** Process and create tar file headers *)
module Header = struct
Expand All @@ -32,38 +35,41 @@ module Header = struct
String.(trim (map (function '\000' -> ' ' | x -> x) s))

(** Unmarshal an integer field (stored as 0-padded octal) *)
let unmarshal_int (x: string) : int =
let unmarshal_int x =
let tmp = "0o0" ^ (trim_numerical x) in
try
int_of_string tmp
Ok (int_of_string tmp)
with Failure msg ->
failwith (Printf.sprintf "%s: failed to parse integer %S" msg tmp)
Error (`Unmarshal (Printf.sprintf "%s: failed to parse integer %S" msg tmp))

(** Unmarshal an int64 field (stored as 0-padded octal) *)
let unmarshal_int64 (x: string) : int64 =
let unmarshal_int64 x =
let tmp = "0o0" ^ (trim_numerical x) in
Int64.of_string tmp
try
Ok (Int64.of_string tmp)
with Failure msg ->
Error (`Unmarshal (Printf.sprintf "%s: failed to parse int64 %S" msg tmp))

(** Unmarshal a string *)
let unmarshal_string (x: string) : string =
let unmarshal_string x =
try
let first_0 = String.index x '\000' in
String.sub x 0 first_0
with
Not_found -> x (* TODO should error *)
Ok (String.sub x 0 first_0)
with Not_found ->
Ok x

(** Marshal an integer field of size 'n' *)
let marshal_int (x: int) (n: int) =
let marshal_int x n =
let octal = Printf.sprintf "%0*o" (n - 1) x in
octal ^ "\000" (* space or NULL allowed *)

(** Marshal an int64 field of size 'n' *)
let marshal_int64 (x: int64) (n: int) =
let marshal_int64 x n =
let octal = Printf.sprintf "%0*Lo" (n - 1) x in
octal ^ "\000" (* space or NULL allowed *)

(** Marshal an string field of size 'n' *)
let marshal_string (x: string) (n: int) =
let marshal_string x n =
if String.length x < n then
let bytes = Bytes.make n '\000' in
Bytes.blit_string x 0 bytes 0 (String.length x);
Expand All @@ -74,11 +80,14 @@ module Header = struct
(** Unmarshal a pax Extended Header File time
It can contain a <period> ( '.' ) for sub-second granularity, that we ignore.
https://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_05 *)
let unmarshal_pax_time (x:string) : int64 =
match String.split_on_char '.' x with
| [seconds] -> Int64.of_string seconds
| [seconds; _subseconds] -> Int64.of_string seconds
| _ -> raise (Failure "Wrong pax Extended Header File Times format")
let unmarshal_pax_time x =
try
match String.split_on_char '.' x with
| [seconds] -> Ok (Int64.of_string seconds)
| [seconds; _subseconds] -> Ok (Int64.of_string seconds)
| _ -> raise (Failure "Wrong pax Extended Header File time format (at most one . allowed)")
with Failure msg ->
Error (`Unmarshal (Printf.sprintf "Failed to parse pax time %S (%s)" x msg))

let hdr_file_name_off = 0
let sizeof_hdr_file_name = 100
Expand Down Expand Up @@ -387,7 +396,19 @@ module Header = struct
user_id; uname }
| None -> extended

let unmarshal ~(global: t option) (c: Cstruct.t) : t =
let decode_int x =
try
Ok (int_of_string x)
with Failure msg ->
Error (`Unmarshal (Printf.sprintf "%s: failed to parse integer %S" msg x))

let decode_int64 x =
try
Ok (Int64.of_string x)
with Failure msg ->
Error (`Unmarshal (Printf.sprintf "%s: failed to parse integer %S" msg x))

let unmarshal ~(global: t option) c =
(* "%d %s=%s\n", <length>, <keyword>, <value> with constraints that
- the <keyword> cannot contain an equals sign
- the <length> is the number of octets of the record, including \n
Expand All @@ -398,48 +419,59 @@ module Header = struct
then None
else if Cstruct.get_char buffer i = char
then Some i
else loop (i + 1) in
loop 0 in
else loop (i + 1)
in
loop 0
in
let rec loop remaining =
if Cstruct.length remaining = 0
then []
then Ok []
else begin
(* Find the space, then decode the length *)
match find remaining ' ' with
| None -> failwith "Failed to decode pax extended header record"
| None -> Error (`Unmarshal "Failed to decode pax extended header record")
| Some i ->
let length = int_of_string @@ Cstruct.to_string @@ Cstruct.sub remaining 0 i in
let record = Cstruct.sub remaining 0 length in
let remaining = Cstruct.shift remaining length in
begin match find record '=' with
| None -> failwith "Failed to decode pax extended header record"
| None -> Error (`Unmarshal "Failed to decode pax extended header record")
| Some j ->
let keyword = Cstruct.to_string @@ Cstruct.sub record (i + 1) (j - i - 1) in
let v = Cstruct.to_string @@ Cstruct.sub record (j + 1) (Cstruct.length record - j - 2) in
(keyword, v) :: (loop remaining)
let* rem = loop remaining in
Ok ((keyword, v) :: rem)
end
end in
let pairs = loop c in
end
in
let* pairs = loop c in
let option name f =
if List.mem_assoc name pairs
then Some (f (List.assoc name pairs))
else None in
then
let* v = f (List.assoc name pairs) in
Ok (Some v)
else
Ok None
in
(* integers are stored as decimal, not octal here *)
let access_time = option "atime" unmarshal_pax_time in
let charset = option "charset" unmarshal_string in
let comment = option "comment" unmarshal_string in
let group_id = option "gid" int_of_string in
let gname = option "group_name" unmarshal_string in
let header_charset = option "hdrcharset" unmarshal_string in
let link_path = option "linkpath" unmarshal_string in
let mod_time = option "mtime" unmarshal_pax_time in
let path = option "path" unmarshal_string in
let file_size = option "size" Int64.of_string in
let user_id = option "uid" int_of_string in
let uname = option "uname" unmarshal_string in
{ access_time; charset; comment; group_id; gname;
header_charset; link_path; mod_time; path; file_size;
user_id; uname } |> merge global
let* access_time = option "atime" unmarshal_pax_time in
let* charset = option "charset" unmarshal_string in
let* comment = option "comment" unmarshal_string in
let* group_id = option "gid" decode_int in
let* gname = option "group_name" unmarshal_string in
let* header_charset = option "hdrcharset" unmarshal_string in
let* link_path = option "linkpath" unmarshal_string in
let* mod_time = option "mtime" unmarshal_pax_time in
let* path = option "path" unmarshal_string in
let* file_size = option "size" decode_int64 in
let* user_id = option "uid" decode_int in
let* uname = option "uname" unmarshal_string in
let g =
{ access_time; charset; comment; group_id; gname;
header_charset; link_path; mod_time; path; file_size;
user_id; uname }
in
Ok (merge global g)

end

Expand Down Expand Up @@ -490,12 +522,6 @@ module Header = struct
(** A blank header block (two of these in series mark the end of the tar) *)
let zero_block = Cstruct.create length

(** [allzeroes buf] is true if [buf] contains only zero bytes *)
let allzeroes buf =
let rec loop i =
(i >= Cstruct.length buf) || (Cstruct.get_uint8 buf i = 0 && (loop (i + 1))) in
loop 0

(** Pretty-print the header record *)
let to_detailed_string (x: t) =
let table = [ "file_name", x.file_name;
Expand Down Expand Up @@ -530,47 +556,48 @@ module Header = struct
(** Unmarshal a header block, returning None if it's all zeroes *)
let unmarshal ?(extended = Extended.make ()) (c: Cstruct.t)
: (t, [>`Zero_block | `Checksum_mismatch]) result =
if allzeroes c then Error `Zero_block
if Cstruct.length c <> length then Error (`Unmarshal "buffer is not of block size")
else if Cstruct.equal zero_block c then Error `Zero_block
else
let chksum = get_hdr_chksum c in
let* chksum = get_hdr_chksum c in
if checksum c <> chksum then Error `Checksum_mismatch
else let ustar =
let magic = get_hdr_magic c in
else let* ustar =
let* magic = get_hdr_magic c in
(* GNU tar and Posix differ in interpretation of the character following ustar. For Posix, it should be '\0' but GNU tar uses ' ' *)
String.length magic >= 5 && (String.sub magic 0 5 = "ustar") in
let prefix = if ustar then get_hdr_prefix c else "" in
let file_name = match extended.Extended.path with
| Some path -> path
Ok (String.length magic >= 5 && (String.sub magic 0 5 = "ustar")) in
let* prefix = if ustar then get_hdr_prefix c else Ok "" in
let* file_name = match extended.Extended.path with
| Some path -> Ok path
| None ->
let file_name = get_hdr_file_name c in
if file_name = "" then prefix
else if prefix = "" then file_name
else Filename.concat prefix file_name in
let file_mode = get_hdr_file_mode c in
let user_id = match extended.Extended.user_id with
let* file_name = get_hdr_file_name c in
if file_name = "" then Ok prefix
else if prefix = "" then Ok file_name
else Ok (Filename.concat prefix file_name) in
let* file_mode = get_hdr_file_mode c in
let* user_id = match extended.Extended.user_id with
| None -> get_hdr_user_id c
| Some x -> x in
let group_id = match extended.Extended.group_id with
| Some x -> Ok x in
let* group_id = match extended.Extended.group_id with
| None -> get_hdr_group_id c
| Some x -> x in
let file_size = match extended.Extended.file_size with
| Some x -> Ok x in
let* file_size = match extended.Extended.file_size with
| None -> get_hdr_file_size c
| Some x -> x in
let mod_time = match extended.Extended.mod_time with
| Some x -> Ok x in
let* mod_time = match extended.Extended.mod_time with
| None -> get_hdr_mod_time c
| Some x -> x in
| Some x -> Ok x in
let link_indicator = Link.of_char (get_hdr_link_indicator c) in
let uname = match extended.Extended.uname with
| None -> if ustar then get_hdr_uname c else ""
| Some x -> x in
let gname = match extended.Extended.gname with
| None -> if ustar then get_hdr_gname c else ""
| Some x -> x in
let devmajor = if ustar then get_hdr_devmajor c else 0 in
let devminor = if ustar then get_hdr_devminor c else 0 in

let link_name = match extended.Extended.link_path with
| Some link_path -> link_path
let* uname = match extended.Extended.uname with
| None -> if ustar then get_hdr_uname c else Ok ""
| Some x -> Ok x in
let* gname = match extended.Extended.gname with
| None -> if ustar then get_hdr_gname c else Ok ""
| Some x -> Ok x in
let* devmajor = if ustar then get_hdr_devmajor c else Ok 0 in
let* devminor = if ustar then get_hdr_devminor c else Ok 0 in

let* link_name = match extended.Extended.link_path with
| Some link_path -> Ok link_path
| None -> get_hdr_link_name c in
Ok (make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator
~link_name ~uname ~gname ~devmajor ~devminor file_name file_size)
Expand Down Expand Up @@ -667,7 +694,7 @@ module type HEADERREADER = sig
type in_channel
type 'a io
val read : global:Header.Extended.t option -> in_channel ->
(Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result io
(Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ]) result io
end

module type HEADERWRITER = sig
Expand All @@ -684,7 +711,12 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
open Reader

type in_channel = Reader.in_channel
type 'a io = 'a Async.t
type 'a io = 'a t

let ( let* ) x f =
match x with
| Ok x -> f x
| Error y -> return (Error y)

let fix_link_indicator x =
(* For backward compatibility we treat normal files ending in slash as
Expand All @@ -700,7 +732,7 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
else
x

let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result t =
let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ]) result t =
(* We might need to read 2 headers at once if we encounter a Pax header *)
let buffer = Cstruct.create Header.length in
let real_header_buf = Cstruct.create Header.length in
Expand All @@ -722,20 +754,19 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
>>= fun () ->
(* unmarshal merges the previous global (if any) with the
discovered global (if any) and returns the new global. *)
let global = Header.Extended.unmarshal ~global extra_header_buf in
let* global = Header.Extended.unmarshal ~global extra_header_buf in
get_hdr ~next_longname ~next_longlink (Some global) ()
| Ok x when x.Header.link_indicator = Header.Link.PerFileExtendedHeader ->
let extra_header_buf = Cstruct.create (Int64.to_int x.Header.file_size) in
really_read ifd extra_header_buf
>>= fun () ->
skip ifd (Header.compute_zero_padding_length x)
>>= fun () ->
let extended = Header.Extended.unmarshal ~global extra_header_buf in
let* extended = Header.Extended.unmarshal ~global extra_header_buf in
really_read ifd real_header_buf
>>= fun () ->
begin match Header.unmarshal ~extended real_header_buf with
| Error _ ->
(* FIXME: Corrupt pax headers *)
return (Error `Corrupt_pax_header)
| Ok x ->
let x = fix_link_indicator x in
Expand Down Expand Up @@ -769,9 +800,9 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
>>= function
| Ok x -> return (Ok (x, global))
| Error `Zero_block -> return (Error `Eof)
| Error `Checksum_mismatch as e -> return e
| Error (`Checksum_mismatch | `Unmarshal _) as e -> return e
end
| Error `Checksum_mismatch as e ->
| Error (`Checksum_mismatch | `Unmarshal _) as e ->
return e
in

Expand Down
8 changes: 4 additions & 4 deletions lib/tar.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
{e %%VERSION%% - {{:%%PKG_HOMEPAGE%% }homepage}} *)

(** The type of errors that may occur. *)
type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block]
type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block | `Unmarshal of string]

(** [pp_error ppf e] pretty prints the error [e] on the formatter [ppf]. *)
val pp_error : Format.formatter -> [< error] -> unit
Expand Down Expand Up @@ -82,7 +82,7 @@ module Header : sig
(** Unmarshal a pax Extended Header block. This header block may
be preceded by [global] blocks which will override some
fields. *)
val unmarshal : global:t option -> Cstruct.t -> t
val unmarshal : global:t option -> Cstruct.t -> (t, [> error ]) result
end

(** Represents a standard archive (note checksum not stored). *)
Expand Down Expand Up @@ -123,7 +123,7 @@ module Header : sig
(** Unmarshal a header block, returning [None] if it's all zeroes.
This header block may be preceded by an [?extended] block which
will override some fields. *)
val unmarshal : ?extended:Extended.t -> Cstruct.t -> (t, [`Zero_block | `Checksum_mismatch]) result
val unmarshal : ?extended:Extended.t -> Cstruct.t -> (t, [`Zero_block | `Checksum_mismatch | `Unmarshal of string]) result

(** Marshal a header block, computing and inserting the checksum. *)
val marshal : ?level:compatibility -> Cstruct.t -> t -> unit
Expand Down Expand Up @@ -168,7 +168,7 @@ module type HEADERREADER = sig
@param global Holds the current global pax extended header, if
any. Needs to be given to the next call to [read]. *)
val read : global:Header.Extended.t option -> in_channel ->
(Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result io
(Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ]) result io
end

module type HEADERWRITER = sig
Expand Down

0 comments on commit cb76332

Please sign in to comment.