Skip to content

Commit

Permalink
fix: add context to luks commands
Browse files Browse the repository at this point in the history
Make sure commands can be aborted on timeout.

Signed-off-by: Andrey Smirnov <[email protected]>
  • Loading branch information
smira committed Aug 30, 2024
1 parent 2021ab8 commit d39fa20
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 59 deletions.
52 changes: 26 additions & 26 deletions encryption/luks/luks.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ func New(cipher Cipher, options ...Option) *LUKS {
}

// Open runs luksOpen on a device and returns mapped device path.
func (l *LUKS) Open(deviceName, mappedName string, key *encryption.Key) (string, error) {
func (l *LUKS) Open(ctx context.Context, deviceName, mappedName string, key *encryption.Key) (string, error) {
args := slices.Concat(
[]string{"luksOpen", deviceName, mappedName, "--key-file=-"},
keyslotArgs(key),
l.perfArgs(),
)

_, err := l.runCommand(args, key.Value)
_, err := l.runCommand(ctx, args, key.Value)
if err != nil {
return "", err
}
Expand All @@ -149,7 +149,7 @@ func (l *LUKS) Open(deviceName, mappedName string, key *encryption.Key) (string,
}

// Encrypt implements encryption.Provider.
func (l *LUKS) Encrypt(deviceName string, key *encryption.Key) error {
func (l *LUKS) Encrypt(ctx context.Context, deviceName string, key *encryption.Key) error {
cipher, err := l.cipher.String()
if err != nil {
return err
Expand All @@ -166,29 +166,29 @@ func (l *LUKS) Encrypt(deviceName string, key *encryption.Key) error {
args = append(args, fmt.Sprintf("--sector-size=%d", l.blockSize))
}

_, err = l.runCommand(args, key.Value)
_, err = l.runCommand(ctx, args, key.Value)

return err
}

// Resize implements encryption.Provider.
func (l *LUKS) Resize(devname string, key *encryption.Key) error {
func (l *LUKS) Resize(ctx context.Context, devname string, key *encryption.Key) error {
args := []string{"resize", devname, "--key-file=-"}

_, err := l.runCommand(args, key.Value)
_, err := l.runCommand(ctx, args, key.Value)

return err
}

// Close implements encryption.Provider.
func (l *LUKS) Close(devname string) error {
_, err := l.runCommand([]string{"luksClose", devname}, nil)
func (l *LUKS) Close(ctx context.Context, devname string) error {
_, err := l.runCommand(ctx, []string{"luksClose", devname}, nil)

return err
}

// AddKey adds a new key at the LUKS encryption slot.
func (l *LUKS) AddKey(devname string, key, newKey *encryption.Key) error {
func (l *LUKS) AddKey(ctx context.Context, devname string, key, newKey *encryption.Key) error {
var buffer bytes.Buffer

keyfileLen, _ := buffer.Write(key.Value)
Expand All @@ -206,13 +206,13 @@ func (l *LUKS) AddKey(devname string, key, newKey *encryption.Key) error {
keyslotArgs(newKey),
)

_, err := l.runCommand(args, buffer.Bytes())
_, err := l.runCommand(ctx, args, buffer.Bytes())

return err
}

// SetKey sets new key value at the LUKS encryption slot.
func (l *LUKS) SetKey(devname string, oldKey, newKey *encryption.Key) error {
func (l *LUKS) SetKey(ctx context.Context, devname string, oldKey, newKey *encryption.Key) error {
if oldKey.Slot != newKey.Slot {
return fmt.Errorf("old and new key slots must match")
}
Expand All @@ -234,19 +234,19 @@ func (l *LUKS) SetKey(devname string, oldKey, newKey *encryption.Key) error {
l.perfArgs(),
)

_, err := l.runCommand(args, buffer.Bytes())
_, err := l.runCommand(ctx, args, buffer.Bytes())

return err
}

// CheckKey checks if the key is valid.
func (l *LUKS) CheckKey(devname string, key *encryption.Key) (bool, error) {
func (l *LUKS) CheckKey(ctx context.Context, devname string, key *encryption.Key) (bool, error) {
args := slices.Concat(
[]string{"luksOpen", "--test-passphrase", devname, "--key-file=-"},
keyslotArgs(key),
)

_, err := l.runCommand(args, key.Value)
_, err := l.runCommand(ctx, args, key.Value)
if err != nil {
if err == encryption.ErrEncryptionKeyRejected { //nolint:errorlint
return false, nil
Expand All @@ -259,13 +259,13 @@ func (l *LUKS) CheckKey(devname string, key *encryption.Key) (bool, error) {
}

// RemoveKey removes a key at the specified LUKS encryption slot.
func (l *LUKS) RemoveKey(devname string, slot int, key *encryption.Key) error {
_, err := l.runCommand([]string{"luksKillSlot", devname, strconv.Itoa(slot), "--key-file=-"}, key.Value)
func (l *LUKS) RemoveKey(ctx context.Context, devname string, slot int, key *encryption.Key) error {
_, err := l.runCommand(ctx, []string{"luksKillSlot", devname, strconv.Itoa(slot), "--key-file=-"}, key.Value)
if err != nil {
return err
}

if err = l.RemoveToken(devname, slot); err != nil && !errors.Is(err, encryption.ErrTokenNotFound) {
if err = l.RemoveToken(ctx, devname, slot); err != nil && !errors.Is(err, encryption.ErrTokenNotFound) {
return err
}

Expand Down Expand Up @@ -306,22 +306,22 @@ func (l *LUKS) ReadKeyslots(deviceName string) (*encryption.Keyslots, error) {

// SetToken adds arbitrary token to the key slot.
// Token id == slot id: only one token per key slot is supported.
func (l *LUKS) SetToken(devname string, slot int, token token.Token) error {
func (l *LUKS) SetToken(ctx context.Context, devname string, slot int, token token.Token) error {
data, err := token.Bytes()
if err != nil {
return err
}

id := strconv.Itoa(slot)

_, err = l.runCommand([]string{"token", "import", "-q", devname, "--token-id", id, "--json-file=-", "--token-replace"}, data)
_, err = l.runCommand(ctx, []string{"token", "import", "-q", devname, "--token-id", id, "--json-file=-", "--token-replace"}, data)

return err
}

// ReadToken reads arbitrary token from the luks metadata.
func (l *LUKS) ReadToken(devname string, slot int, token token.Token) error {
stdout, err := l.runCommand([]string{"token", "export", "-q", devname, "--token-id", strconv.Itoa(slot), "--json-file=-"}, nil)
func (l *LUKS) ReadToken(ctx context.Context, devname string, slot int, token token.Token) error {
stdout, err := l.runCommand(ctx, []string{"token", "export", "-q", devname, "--token-id", strconv.Itoa(slot), "--json-file=-"}, nil)
if err != nil {
return err
}
Expand All @@ -330,19 +330,19 @@ func (l *LUKS) ReadToken(devname string, slot int, token token.Token) error {
}

// RemoveToken removes token from the luks metadata.
func (l *LUKS) RemoveToken(devname string, slot int) error {
_, err := l.runCommand([]string{"token", "remove", "--token-id", strconv.Itoa(slot), devname}, nil)
func (l *LUKS) RemoveToken(ctx context.Context, devname string, slot int) error {
_, err := l.runCommand(ctx, []string{"token", "remove", "--token-id", strconv.Itoa(slot), devname}, nil)

return err
}

var notFoundMatcher = regexp.MustCompile("(is not in use|Failed to get token)")

// runCommand executes cryptsetup with arguments.
func (l *LUKS) runCommand(args []string, stdin []byte) (string, error) {
func (l *LUKS) runCommand(ctx context.Context, args []string, stdin []byte) (string, error) {
stdout, err := cmd.RunContext(cmd.WithStdin(
context.Background(),
bytes.NewBuffer(stdin)), "cryptsetup", args...)
ctx,
bytes.NewReader(stdin)), "cryptsetup", args...)
if err != nil {
var exitError *cmd.ExitError

Expand Down
50 changes: 27 additions & 23 deletions encryption/luks/luks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package luks_test

import (
"context"
"errors"
randv2 "math/rand/v2"
"os"
Expand All @@ -31,6 +32,9 @@ const (
)

func testEncrypt(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
t.Cleanup(cancel)

tmpDir := t.TempDir()

rawImage := filepath.Join(tmpDir, "image.raw")
Expand Down Expand Up @@ -97,21 +101,21 @@ func testEncrypt(t *testing.T) {

t.Logf("unencrypted partition path %s", path)

require.NoError(t, provider.Encrypt(path, key))
require.NoError(t, provider.Encrypt(ctx, path, key))

encryptedPath, err := provider.Open(path, mappedName, key)
encryptedPath, err := provider.Open(ctx, path, mappedName, key)
require.NoError(t, err)

require.NoError(t, provider.Resize(encryptedPath, key))
require.NoError(t, provider.Resize(ctx, encryptedPath, key))

require.NoError(t, provider.AddKey(path, key, keyExtra))
require.NoError(t, provider.SetKey(path, keyExtra, keyExtra))
require.NoError(t, provider.AddKey(ctx, path, key, keyExtra))
require.NoError(t, provider.SetKey(ctx, path, keyExtra, keyExtra))

valid, err := provider.CheckKey(path, keyExtra)
valid, err := provider.CheckKey(ctx, path, keyExtra)
require.NoError(t, err)
require.True(t, valid)

valid, err = provider.CheckKey(path, encryption.NewKey(1, []byte("nope")))
valid, err = provider.CheckKey(ctx, path, encryption.NewKey(1, []byte("nope")))
require.NoError(t, err)
require.False(t, valid)

Expand All @@ -131,36 +135,36 @@ func testEncrypt(t *testing.T) {
Type: "sealedkey",
}

err = provider.SetToken(path, 0, token)
err = provider.SetToken(ctx, path, 0, token)
require.NoError(t, err)

err = provider.ReadToken(path, 0, token)
err = provider.ReadToken(ctx, path, 0, token)
require.NoError(t, err)

require.Equal(t, token.UserData.SealedKey, "aaaa")

require.NoError(t, provider.RemoveToken(path, 0))
require.Error(t, provider.ReadToken(path, 0, token))
require.NoError(t, provider.RemoveToken(ctx, path, 0))
require.Error(t, provider.ReadToken(ctx, path, 0, token))

// create and replace token
err = provider.SetToken(path, 0, token)
err = provider.SetToken(ctx, path, 0, token)
require.NoError(t, err)

token.UserData.SealedKey = "bbbb"

err = provider.SetToken(path, 0, token)
err = provider.SetToken(ctx, path, 0, token)
require.NoError(t, err)

require.NoError(t, unix.Mount(encryptedPath, mountPath, "vfat", 0, ""))
require.NoError(t, unix.Unmount(mountPath, 0))

require.NoError(t, provider.Close(encryptedPath))
require.Error(t, provider.Close(encryptedPath))
require.NoError(t, provider.Close(ctx, encryptedPath))
require.Error(t, provider.Close(ctx, encryptedPath))

// second key slot
encryptedPath, err = provider.Open(path, mappedName, keyExtra)
encryptedPath, err = provider.Open(ctx, path, mappedName, keyExtra)
require.NoError(t, err)
require.NoError(t, provider.Close(encryptedPath))
require.NoError(t, provider.Close(ctx, encryptedPath))

// check keyslots list
keyslots, err := provider.ReadKeyslots(path)
Expand All @@ -172,23 +176,23 @@ func testEncrypt(t *testing.T) {
require.True(t, ok)

// remove key slot
err = provider.RemoveKey(path, 1, key)
err = provider.RemoveKey(ctx, path, 1, key)
require.NoError(t, err)
_, err = provider.Open(path, mappedName, keyExtra)
_, err = provider.Open(ctx, path, mappedName, keyExtra)
require.Equal(t, err, encryption.ErrEncryptionKeyRejected)

valid, err = provider.CheckKey(path, key)
valid, err = provider.CheckKey(ctx, path, key)
require.NoError(t, err)
require.True(t, valid)

// unhappy cases
_, err = provider.Open(path, mappedName, encryption.NewKey(0, []byte("エクスプロシオン")))
_, err = provider.Open(ctx, path, mappedName, encryption.NewKey(0, []byte("エクスプロシオン")))
require.Equal(t, err, encryption.ErrEncryptionKeyRejected)

_, err = provider.Open("/dev/nosuchdevice", mappedName, encryption.NewKey(0, []byte("エクスプロシオン")))
_, err = provider.Open(ctx, "/dev/nosuchdevice", mappedName, encryption.NewKey(0, []byte("エクスプロシオン")))
require.Error(t, err)

_, err = provider.Open(loDev.Path(), mappedName, key)
_, err = provider.Open(ctx, loDev.Path(), mappedName, key)
require.Error(t, err)
}

Expand Down
21 changes: 11 additions & 10 deletions encryption/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package encryption

import (
"context"
"fmt"

"github.com/siderolabs/go-blockdevice/v2/encryption/token"
Expand All @@ -20,21 +21,21 @@ const (
// Provider represents encryption utility methods.
type Provider interface {
TokenProvider
Encrypt(devname string, key *Key) error
Open(devname, mappedName string, key *Key) (string, error)
Close(devname string) error
AddKey(devname string, key, newKey *Key) error
SetKey(devname string, key, newKey *Key) error
CheckKey(devname string, key *Key) (bool, error)
RemoveKey(devname string, slot int, key *Key) error
Encrypt(ctx context.Context, devname string, key *Key) error
Open(ctx context.Context, devname, mappedName string, key *Key) (string, error)
Close(ctx context.Context, devname string) error
AddKey(ctx context.Context, devname string, key, newKey *Key) error
SetKey(ctx context.Context, devname string, key, newKey *Key) error
CheckKey(ctx context.Context, devname string, key *Key) (bool, error)
RemoveKey(ctx context.Context, devname string, slot int, key *Key) error
ReadKeyslots(deviceName string) (*Keyslots, error)
}

// TokenProvider represents token management methods.
type TokenProvider interface {
SetToken(devname string, slot int, token token.Token) error
ReadToken(devname string, slot int, token token.Token) error
RemoveToken(devname string, slot int) error
SetToken(ctx context.Context, devname string, slot int, token token.Token) error
ReadToken(ctx context.Context, devname string, slot int, token token.Token) error
RemoveToken(ctx context.Context, devname string, slot int) error
}

var (
Expand Down

0 comments on commit d39fa20

Please sign in to comment.