Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(util): file downloader with verify sha256 hash #1422

Merged
merged 11 commits into from
Jul 17, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ todo
# VIM artifacts
.swp
.*.sw*

util/downloader/testdata/
b00f marked this conversation as resolved.
Show resolved Hide resolved
284 changes: 284 additions & 0 deletions util/downloader/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
package downloader

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
)

var (
ErrHeaderRequest = errors.New("request header error")
ErrSHA256Mismatch = errors.New("sha256 mismatch")
ErrCreateDir = errors.New("create dir error")
ErrInvalidFilePath = errors.New("file path is a directory, not a file")
ErrGetFileInfo = errors.New("get file info error")
ErrCopyExistsFileData = errors.New("error copying existing file data")
ErrDoRequest = errors.New("error doing request")
ErrFileWriting = errors.New("error writing file")
ErrNewRequest = errors.New("error creating request")
ErrOpenFileExists = errors.New("error opening existing file")
)

type Downloader struct {
client *http.Client
url string
filePath string
sha256Sum string
fileType string
fileName string
statsCh chan Stats
errCh chan error
}

type Stats struct {
Downloaded int64
TotalSize int64
Percent float64
Completed bool
}

func New(url, filePath, sha256Sum string, opts ...Option) *Downloader {
opt := defaultOptions()

for _, o := range opts {
o(opt)
}

return &Downloader{
client: opt.client,
url: url,
filePath: filePath,
sha256Sum: sha256Sum,
statsCh: make(chan Stats),
errCh: make(chan error, 1),
}
}

func (d *Downloader) Start(ctx context.Context) {
go d.download(ctx)
}

func (d *Downloader) Stats() <-chan Stats {
return d.statsCh
}

func (d *Downloader) FileType() string {
return d.fileType
}

func (d *Downloader) FileName() string {
return d.fileName
}

func (d *Downloader) Errors() <-chan error {
return d.errCh
}

func (d *Downloader) download(ctx context.Context) {
stats, err := d.getHeader(ctx)
if err != nil {
d.handleError(err)

return
}

d.fileName = filepath.Base(d.filePath)
if err := d.createDir(); err != nil {
d.handleError(err)

return
}

out, err := d.openFile()
if err != nil {
d.handleError(err)

return
}
defer func() {
_ = out.Close()
}()

if err := d.validateExistingFile(out, &stats); err != nil {
d.handleError(err)

return
}

if err := d.downloadFile(ctx, out, &stats); err != nil {
d.handleError(err)
}
}

func (d *Downloader) getHeader(ctx context.Context) (Stats, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodHead, d.url, http.NoBody)
if err != nil {
return Stats{}, ErrHeaderRequest
}

resp, err := d.client.Do(req)
if err != nil {
return Stats{}, ErrHeaderRequest
}

defer func() {
_ = resp.Body.Close()
}()

d.fileType = resp.Header.Get("Content-Type")

return Stats{
TotalSize: resp.ContentLength,
}, nil
}

func (d *Downloader) createDir() error {
dir := filepath.Dir(d.filePath)
if err := os.MkdirAll(dir, 0o750); err != nil {
return ErrCreateDir
}

return nil
}

func (d *Downloader) openFile() (*os.File, error) {
fileInfo, err := os.Stat(d.filePath)
if err == nil && fileInfo.IsDir() {
return nil, ErrInvalidFilePath
}

return os.OpenFile(d.filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
}

func (*Downloader) validateExistingFile(out *os.File, stats *Stats) error {
fileInfo, err := out.Stat()
if err != nil {
return ErrGetFileInfo
}
stats.Downloaded = fileInfo.Size()

return nil
}

func (d *Downloader) downloadFile(ctx context.Context, out *os.File, stats *Stats) error {
req, err := d.createRequest(ctx, stats.Downloaded)
if err != nil {
return err
}

resp, err := d.client.Do(req)
if err != nil {
return ErrDoRequest
}

defer func() {
_ = resp.Body.Close()
}()

buffer := make([]byte, 32*1024)
hasher := sha256.New()

if err := d.updateHasherWithExistingData(stats.Downloaded, hasher); err != nil {
return err
}

return d.writeToFile(ctx, resp, out, buffer, hasher, stats)
}

func (d *Downloader) createRequest(ctx context.Context, downloaded int64) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, http.NoBody)
if err != nil {
return nil, ErrNewRequest
}
if downloaded > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloaded))
}

return req, nil
}

func (d *Downloader) updateHasherWithExistingData(downloaded int64, hasher io.Writer) error {
if downloaded > 0 {
existingFile, err := os.Open(d.filePath)
if err != nil {
return ErrOpenFileExists
}
defer func() {
_ = existingFile.Close()
}()

if _, err := io.CopyN(hasher, existingFile, downloaded); err != nil {
return ErrCopyExistsFileData
}
}

return nil
}

func (d *Downloader) writeToFile(ctx context.Context, resp *http.Response, out *os.File, buffer []byte,
hasher hash.Hash, stats *Stats,
) error {
for {
select {
case <-ctx.Done():
d.stop()

return ctx.Err()
default:
n, err := resp.Body.Read(buffer)
if n > 0 {
if _, err := out.Write(buffer[:n]); err != nil {
return ErrFileWriting
}

if _, err := hasher.Write(buffer[:n]); err != nil {
return ErrFileWriting
}

stats.Downloaded += int64(n)
stats.Percent = float64(stats.Downloaded) / float64(stats.TotalSize) * 100
d.statsCh <- *stats
}
if err != nil {
if err == io.EOF {
return d.finalizeDownload(hasher, stats)
}

return fmt.Errorf("error reading response body: %w", err)
}
}
}
}

func (d *Downloader) finalizeDownload(hasher hash.Hash, stats *Stats) error {
stats.Completed = true
sum := hex.EncodeToString(hasher.Sum(nil))
if sum != d.sha256Sum {
return ErrSHA256Mismatch
}
d.statsCh <- *stats

d.stop()

return nil
}

func (d *Downloader) stop() {
close(d.statsCh)
close(d.errCh)
}

func (d *Downloader) handleError(err error) {
select {
case d.errCh <- err:
default:
d.stop()
}
}
89 changes: 89 additions & 0 deletions util/downloader/downloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package downloader

import (
"context"
"crypto/sha256"
"encoding/hex"
"log"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestDownloader(t *testing.T) {
fileContent := []byte("This is a test file content")
fileURL := "/testfile"
expectedSHA256 := sha256.Sum256(fileContent)
expectedSHA256Hex := hex.EncodeToString(expectedSHA256[:])

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == fileURL {
_, err := w.Write(fileContent)
assert.NoError(t, err)
} else {
http.NotFound(w, r)
}
}))
defer server.Close()

filePath := "./testdata/example_testfile.txt"
Ja7ad marked this conversation as resolved.
Show resolved Hide resolved

defer func() {
assert.NoError(t, os.RemoveAll("./testdata"))
}()

dl := New(server.URL+fileURL, filePath, expectedSHA256Hex, WithCustomClient(server.Client()))

assrt := assert.New(t)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

go func() {
dl.Start(ctx)
}()

done := make(chan bool)

go func() {
for stat := range dl.Stats() {
log.Printf("Downloaded: %d / %d (%.2f%%)\n", stat.Downloaded, stat.TotalSize, stat.Percent)
assrt.True(stat.Downloaded <= stat.TotalSize, "Downloaded size should not exceed total size")
assrt.True(stat.Percent <= 100, "Download percentage should not exceed 100")

if stat.Completed {
log.Println("Download completed successfully")
assrt.Equal(float64(100), stat.Percent, "Download should be 100% complete")
done <- true

return
}
}
}()

go func() {
for err := range dl.Errors() {
assrt.Fail("Download encountered an error", err)
done <- true

return
}
}()

select {
case <-done:
case <-time.After(2 * time.Minute):
cancel()
assrt.Fail("Download test timed out")
}

t.Log(dl.FileName())
t.Log(dl.FileType())

downloadedContent, err := os.ReadFile(filePath)
assrt.NoError(err, "Failed to read the downloaded file")
assrt.Equal(fileContent, downloadedContent, "Downloaded file content does not match expected content")
}
Loading
Loading