Skip to content

Commit

Permalink
File Stream - Support multi-writers (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
Or-Geva authored Mar 18, 2024
1 parent b0ac460 commit 63a831b
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 12 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/ulikunitz/xz v0.5.11 // indirect
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
golang.org/x/sync v0.6.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMx
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
19 changes: 12 additions & 7 deletions http/filestream/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ const (
)

// The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file
type FileWriterFunc func(fileName string) (writer io.WriteCloser, err error)
type FileWriterFunc func(fileName string) (writers []io.WriteCloser, err error)

func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileWriterFunc) error {
func ReadFilesFromStream(multipartReader *multipart.Reader, fileWritersFunc FileWriterFunc) error {
for {
// Read the next file streamed from client
fileReader, err := multipartReader.NextPart()
Expand All @@ -27,8 +27,7 @@ func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileW
}
return fmt.Errorf("failed to read file: %w", err)
}
err = readFile(fileReader, fileWriterFunc)
if err != nil {
if err = readFile(fileReader, fileWritersFunc); err != nil {
return err
}

Expand All @@ -42,11 +41,17 @@ func readFile(fileReader *multipart.Part, fileWriterFunc FileWriterFunc) (err er
if err != nil {
return err
}
defer ioutils.Close(fileWriter, &err)
if _, err = io.Copy(fileWriter, fileReader); err != nil {
var writers []io.Writer
for _, writer := range fileWriter {
defer ioutils.Close(writer, &err)
// Create a multi writer that will write the file to all the provided writers
// We read multipart once and write to multiple writers, so we can't use the same multipart writer multiple times
writers = append(writers, writer)
}
if _, err = io.Copy(ioutils.AsyncMultiWriter(10, writers...), fileReader); err != nil {
return fmt.Errorf("failed writing '%s' file: %w", fileName, err)
}
return err
return nil
}

type FileInfo struct {
Expand Down
11 changes: 8 additions & 3 deletions http/filestream/filestream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package filestream

import (
"bytes"
"github.com/stretchr/testify/assert"
"io"
"mime/multipart"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

var targetDir string
Expand Down Expand Up @@ -45,6 +46,10 @@ func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) {
assert.Equal(t, file2Content, content)
}

func simpleFileWriter(fileName string) (fileWriter io.WriteCloser, err error) {
return os.Create(filepath.Join(targetDir, fileName))
func simpleFileWriter(fileName string) (fileWriter []io.WriteCloser, err error) {
writer, err := os.Create(filepath.Join(targetDir, fileName))
if err != nil {
return nil, err
}
return []io.WriteCloser{writer}, nil
}
45 changes: 45 additions & 0 deletions io/multiwriter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io

import (
"errors"
"io"

"golang.org/x/sync/errgroup"
)

var ErrShortWrite = errors.New("The number of bytes written is less than the length of the input")

type asyncMultiWriter struct {
writers []io.Writer
limit int
}

// AsyncMultiWriter creates a writer that duplicates its writes to all the
// provided writers asynchronous
func AsyncMultiWriter(limit int, writers ...io.Writer) io.Writer {
w := make([]io.Writer, len(writers))
copy(w, writers)
return &asyncMultiWriter{writers: w, limit: limit}
}

// Writes data asynchronously to each writer and waits for all of them to complete.
// In case of an error, the writing will not complete.
func (amw *asyncMultiWriter) Write(p []byte) (int, error) {
eg := errgroup.Group{}
eg.SetLimit(amw.limit)
for _, w := range amw.writers {
currentWriter := w
eg.Go(func() error {
n, err := currentWriter.Write(p)
if err != nil {
return err
}
if n != len(p) {
return ErrShortWrite
}
return nil
})
}

return len(p), eg.Wait()
}
45 changes: 45 additions & 0 deletions io/multiwriter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io

import (
"bytes"
"errors"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAsyncMultiWriter(t *testing.T) {
for _, limit := range []int{1, 2} {
var buf1, buf2 bytes.Buffer
multiWriter := AsyncMultiWriter(limit, &buf1, &buf2)

data := []byte("test data")
n, err := multiWriter.Write(data)
assert.NoError(t, err)
assert.Equal(t, len(data), n)

// Check if data is correctly written to both writers
assert.Equal(t, string(data), buf1.String())
assert.Equal(t, string(data), buf2.String())
}
}

func TestAsyncMultiWriter_Error(t *testing.T) {
expectedErr := errors.New("write error")

// Mock writer that always returns an error
mockWriter := &mockWriter{writeErr: expectedErr}
multiWriter := AsyncMultiWriter(2, mockWriter)

_, err := multiWriter.Write([]byte("test data"))
assert.Equal(t, expectedErr, err)
}

// Mock writer to simulate Write errors
type mockWriter struct {
writeErr error
}

func (m *mockWriter) Write(p []byte) (int, error) {
return 0, m.writeErr
}
4 changes: 2 additions & 2 deletions unarchive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ func (u *Unarchiver) byExtension(filename string) (interface{}, error) {

// Make sure the archive is free from Zip Slip and Zip symlinks attacks
func inspectArchive(archive interface{}, localArchivePath, destinationDir string) error {
// If the destination directory ends with a slash, delete it.
// This is necessary to handle a situation where the entry path might be at the root of the destination directory,
// If the destination directory ends with a slash, delete it.
// This is necessary to handle a situation where the entry path might be at the root of the destination directory,
// but in such a case "<destination-dir>/" is not a prefix of "<destination-dir>".
destinationDir = strings.TrimSuffix(destinationDir, string(os.PathSeparator))
walker, ok := archive.(archiver.Walker)
Expand Down

0 comments on commit 63a831b

Please sign in to comment.