diff --git a/internal/data/fullreader.go b/internal/data/fullreader.go deleted file mode 100644 index 8e3a79f..0000000 --- a/internal/data/fullreader.go +++ /dev/null @@ -1,5 +0,0 @@ -package data - -type FullReader struct{ - -} \ No newline at end of file diff --git a/squashfs/base.go b/squashfs/base.go index bd5f432..10351a1 100644 --- a/squashfs/base.go +++ b/squashfs/base.go @@ -4,9 +4,10 @@ import ( "errors" "io" - "github.com/CalebQ42/squashfs/internal/data" + "github.com/CalebQ42/squashfs/internal/decompress" "github.com/CalebQ42/squashfs/internal/metadata" "github.com/CalebQ42/squashfs/internal/toreader" + "github.com/CalebQ42/squashfs/squashfs/data" "github.com/CalebQ42/squashfs/squashfs/directory" "github.com/CalebQ42/squashfs/squashfs/inode" ) @@ -50,9 +51,9 @@ func (b *Base) IsRegular() bool { return b.Inode.Type == inode.Fil || b.Inode.Type == inode.EFil } -func (b *Base) GetRegFileReaders(r *Reader) (io.ReadCloser, error) { +func (b *Base) GetRegFileReaders(r *Reader) (*data.Reader, *data.FullReader, error) { if !b.IsRegular() { - return nil, errors.New("not a regular file") + return nil, nil, errors.New("not a regular file") } var blockStart uint64 var fragIndex uint32 @@ -69,19 +70,26 @@ func (b *Base) GetRegFileReaders(r *Reader) (io.ReadCloser, error) { fragOffset = b.Inode.Data.(inode.EFile).FragOffset sizes = b.Inode.Data.(inode.EFile).BlockSizes } - var frag *data.Reader - if fragIndex != 0xFFFFFFFF { + frag := func(rdr io.ReaderAt, d decompress.Decompressor) (*data.Reader, error) { ent, err := r.fragEntry(fragIndex) if err != nil { return nil, err } - frag = data.NewReader(toreader.NewReader(r.r, int64(ent.start)), r.d, []uint32{ent.size}) + frag := data.NewReader(toreader.NewReader(r.r, int64(ent.start)), r.d, []uint32{ent.size}) frag.Read(make([]byte, fragOffset)) + return frag, nil } - out := data.NewReader(toreader.NewReader(r.r, int64(blockStart)), r.d, sizes) - if frag != nil { - out.AddFrag(out) + outRdr := data.NewReader(toreader.NewReader(r.r, int64(blockStart)), r.d, sizes) + if fragIndex != 0xffffffff { + f, err := frag(r.r, r.d) + if err != nil { + return nil, nil, err + } + outRdr.AddFrag(f) + } + outFull := data.NewFullReader(r.r, int64(blockStart), r.d, sizes) + if fragIndex != 0xffffffff { + outFull.AddFrag(frag) } - //TODO: implement and add full reader - return out, nil + return outRdr, outFull, nil } diff --git a/squashfs/data/fullreader.go b/squashfs/data/fullreader.go new file mode 100644 index 0000000..8c1d48b --- /dev/null +++ b/squashfs/data/fullreader.go @@ -0,0 +1,153 @@ +package data + +import ( + "encoding/binary" + "errors" + "io" + "sync" + + "github.com/CalebQ42/squashfs/internal/decompress" + "github.com/CalebQ42/squashfs/internal/toreader" +) + +type FragReaderConstructor func(io.ReaderAt, decompress.Decompressor) (*Reader, error) + +type FullReader struct { + r io.ReaderAt + d decompress.Decompressor + frag FragReaderConstructor + retPool *sync.Pool + sizes []uint32 + initialOffset int64 + goroutineLimit uint16 +} + +func NewFullReader(r io.ReaderAt, initialOffset int64, d decompress.Decompressor, sizes []uint32) *FullReader { + return &FullReader{ + r: r, + d: d, + sizes: sizes, + initialOffset: initialOffset, + goroutineLimit: 10, + retPool: &sync.Pool{ + New: func() any { + return &retValue{} + }, + }, + } +} + +func (r *FullReader) AddFrag(frag FragReaderConstructor) { + r.frag = frag +} + +func (r *FullReader) SetGoroutineLimit(limit uint16) { + r.goroutineLimit = limit +} + +type retValue struct { + err error + data []byte + index uint64 +} + +func (r *FullReader) process(index uint64, fileOffset uint64, retChan chan *retValue) { + ret := r.retPool.Get().(*retValue) + ret.index = index + realSize := r.sizes[index] &^ (1 << 24) + ret.data = make([]byte, realSize) + ret.err = binary.Read(toreader.NewReader(r.r, int64(r.initialOffset)+int64(fileOffset)), binary.LittleEndian, &ret.data) + if r.sizes[index] == realSize { + ret.data, ret.err = r.d.Decompress(ret.data) + } + retChan <- ret +} + +func (r *FullReader) WriteTo(w io.Writer) (int64, error) { + var curIndex uint64 + var curOffset uint64 + var toProcess uint16 + var wrote int64 + cache := make(map[uint64]*retValue) + var errCache []error + retChan := make(chan *retValue, r.goroutineLimit) + for i := uint64(0); i < uint64(len(r.sizes))/uint64(r.goroutineLimit); i++ { + toProcess = uint16(len(r.sizes)) - (uint16(i) * r.goroutineLimit) + if toProcess > r.goroutineLimit { + toProcess = r.goroutineLimit + } + // Start all the goroutines + for j := uint16(0); j < toProcess; j++ { + go r.process((i*uint64(r.goroutineLimit))+uint64(j), curOffset, retChan) + curOffset += uint64(r.sizes[(i*uint64(r.goroutineLimit))+uint64(j)]) &^ (1 << 24) + } + // Then consume the results on retChan + for j := uint16(0); j < toProcess; j++ { + res := <-retChan + // If there's an error, we don't care about the results. + if res.err != nil { + errCache = append(errCache, res.err) + if len(cache) > 0 { + clear(cache) + } + continue + } + // If there has been an error previously, we don't care about the results. + // We still want to wait for all the goroutines to prevent resources being wasted. + if len(errCache) > 0 { + continue + } + // If we don't need the data yet, we cache it and move on + if res.index != curIndex { + cache[res.index] = res + continue + } + // If we do need the data, we write it + wr, err := w.Write(res.data) + wrote += int64(wr) + if err != nil { + errCache = append(errCache, err) + if len(cache) > 0 { + clear(cache) + } + continue + } + r.retPool.Put(res) + curIndex++ + // Now we recursively try to clear the cache + for len(cache) > 0 { + res, ok := cache[curIndex] + if !ok { + break + } + wr, err := w.Write(res.data) + wrote += int64(wr) + if err != nil { + errCache = append(errCache, err) + if len(cache) > 0 { + clear(cache) + } + break + } + delete(cache, curIndex) + r.retPool.Put(res) + curIndex++ + } + } + if len(errCache) > 0 { + return wrote, errors.Join(errCache...) + } + } + if r.frag != nil { + rdr, err := r.frag(r.r, r.d) + if err != nil { + return wrote, err + } + wr, err := io.Copy(w, rdr) + wrote += wr + if err != nil { + return wrote, err + } + } + return wrote, nil +} diff --git a/internal/data/reader.go b/squashfs/data/reader.go similarity index 100% rename from internal/data/reader.go rename to squashfs/data/reader.go