Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

File Stream - Support multi-writers #58

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions http/filestream/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ const (
)

// The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file
type FileWriterFunc func(fileName string) (writer io.WriteCloser, err error)
type FileWriterFunc func(fileName string) (writer []io.WriteCloser, err error)

Or-Geva marked this conversation as resolved.
Show resolved Hide resolved
func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileWriterFunc) error {
func ReadFilesFromStream(multipartReader *multipart.Reader, fileWritersFunc FileWriterFunc) error {
for {
// Read the next file streamed from client
fileReader, err := multipartReader.NextPart()
Expand All @@ -27,7 +27,7 @@ func ReadFilesFromStream(multipartReader *multipart.Reader, fileWriterFunc FileW
}
return fmt.Errorf("failed to read file: %w", err)
}
err = readFile(fileReader, fileWriterFunc)
err = readFile(fileReader, fileWritersFunc)
if err != nil {
return err
}
Or-Geva marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -42,11 +42,17 @@ func readFile(fileReader *multipart.Part, fileWriterFunc FileWriterFunc) (err er
if err != nil {
return err
}
defer ioutils.Close(fileWriter, &err)
if _, err = io.Copy(fileWriter, fileReader); err != nil {
var writers []io.Writer
for _, writer := range fileWriter {
defer ioutils.Close(writer, &err)
// Create a multi writer that will write the file to all the provided writers
// We read multipart once and write to multiple writers, so we can't use the same multipart writer multiple times
writers = append(writers, writer)
omerzi marked this conversation as resolved.
Show resolved Hide resolved
}
if _, err = io.Copy(ioutils.AsyncMultiWriter(writers...), fileReader); err != nil {
return fmt.Errorf("failed writing '%s' file: %w", fileName, err)
}
return err
return nil
}

type FileInfo struct {
Expand Down
11 changes: 8 additions & 3 deletions http/filestream/filestream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package filestream

import (
"bytes"
"github.com/stretchr/testify/assert"
"io"
"mime/multipart"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

var targetDir string
Expand Down Expand Up @@ -45,6 +46,10 @@ func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) {
assert.Equal(t, file2Content, content)
}

func simpleFileWriter(fileName string) (fileWriter io.WriteCloser, err error) {
return os.Create(filepath.Join(targetDir, fileName))
func simpleFileWriter(fileName string) (fileWriter []io.WriteCloser, err error) {
writer, err := os.Create(filepath.Join(targetDir, fileName))
if err != nil {
return nil, err
}
return []io.WriteCloser{writer}, nil
}
56 changes: 56 additions & 0 deletions io/multiwriter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io

import (
"errors"
"io"
"sync"
)

var ErrShortWrite = errors.New("short write")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is short write?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to
"The number of bytes written is less than the length of the input"


type asyncMultiWriter struct {
writers []io.Writer
}

// AsyncMultiWriter creates a writer that duplicates its writes to all the
// provided writers asynchronous
func AsyncMultiWriter(writers ...io.Writer) io.Writer {
w := make([]io.Writer, len(writers))
copy(w, writers)
return &asyncMultiWriter{w}
}

// Writes data asynchronously to each writer and waits for all of them to complete.
// In case of an error, the writing will not complete.
func (t *asyncMultiWriter) Write(p []byte) (int, error) {
Or-Geva marked this conversation as resolved.
Show resolved Hide resolved
var wg sync.WaitGroup
wg.Add(len(t.writers))
errChannel := make(chan error)
Or-Geva marked this conversation as resolved.
Show resolved Hide resolved
finished := make(chan bool, 1)
Or-Geva marked this conversation as resolved.
Show resolved Hide resolved
for _, w := range t.writers {
go writeData(p, w, &wg, errChannel)
}
go func() {
wg.Wait()
close(finished)
}()
// This select will block until one of the two channels returns a value.
select {
case <-finished:
case err := <-errChannel:
if err != nil {
return 0, err
}
}
return len(p), nil
}
func writeData(p []byte, w io.Writer, wg *sync.WaitGroup, errChan chan error) {
n, err := w.Write(p)
if err != nil {
errChan <- err
}
if n != len(p) {
errChan <- ErrShortWrite
}
wg.Done()
}
Loading