diff --git a/mem/file.go b/mem/file.go index 699f1fb0..07b2e12a 100644 --- a/mem/file.go +++ b/mem/file.go @@ -225,11 +225,11 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { return 0, ErrFileClosed } switch whence { - case 0: + case io.SeekStart: atomic.StoreInt64(&f.at, offset) - case 1: - atomic.AddInt64(&f.at, int64(offset)) - case 2: + case io.SeekCurrent: + atomic.AddInt64(&f.at, offset) + case io.SeekEnd: atomic.StoreInt64(&f.at, int64(len(f.fileData.data))+offset) } return f.at, nil @@ -260,7 +260,7 @@ func (f *File) Write(b []byte) (n int, err error) { } setModTime(f.fileData, time.Now()) - atomic.StoreInt64(&f.at, int64(len(f.fileData.data))) + atomic.AddInt64(&f.at, int64(n)) return } diff --git a/mem/file_test.go b/mem/file_test.go index bb54db65..22af9707 100644 --- a/mem/file_test.go +++ b/mem/file_test.go @@ -1,6 +1,7 @@ package mem import ( + "bytes" "io" "testing" "time" @@ -205,3 +206,42 @@ func TestFileReadAtSeekOffset(t *testing.T) { t.Fatal(err) } } + +func TestFileWriteAndSeek(t *testing.T) { + fd := CreateFile("foo") + f := NewFileHandle(fd) + + assert := func(expected bool, v ...interface{}) { + if !expected { + t.Helper() + t.Fatal(v...) + } + } + + data4 := []byte{0, 1, 2, 3} + data20 := bytes.Repeat(data4, 5) + var off int64 + + for i := 0; i < 100; i++ { + // write 20 bytes + n, err := f.Write(data20) + assert(err == nil, err) + off += int64(n) + assert(n == len(data20), n) + assert(off == int64((i+1)*len(data20)), off) + + // rewind to start and write 4 bytes there + cur, err := f.Seek(-off, io.SeekCurrent) + assert(err == nil, err) + assert(cur == 0, cur) + + n, err = f.Write(data4) + assert(err == nil, err) + assert(n == len(data4), n) + + // back at the end + cur, err = f.Seek(off-int64(n), io.SeekCurrent) + assert(err == nil, err) + assert(cur == off, cur, off) + } +}