Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Niraj Yadav <[email protected]>
  • Loading branch information
black-dragon74 committed Jun 27, 2024
1 parent 448c870 commit bce2d76
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 62 deletions.
6 changes: 3 additions & 3 deletions internal/csi-addons/rbd/encryptionkeyrotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package rbd
import (
"context"

"github.com/ceph/ceph-csi/internal/rbd"
"github.com/ceph/ceph-csi/internal/util"
ekr "github.com/csi-addons/spec/lib/go/encryptionkeyrotation"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/ceph/ceph-csi/internal/rbd"
"github.com/ceph/ceph-csi/internal/util"
)

type EncryptionKeyRotationServer struct {
Expand All @@ -43,7 +44,6 @@ func (ekrs *EncryptionKeyRotationServer) RotateEncryptionKey(
ctx context.Context,
req *ekr.EncryptionKeyRotateRequest,
) (*ekr.EncryptionKeyRotateResponse, error) {

// Get the volume ID from the request
volID := req.GetVolumeId()
if volID == "" {
Expand Down
3 changes: 2 additions & 1 deletion internal/kms/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import (
"strconv"
"strings"

"github.com/ceph/ceph-csi/internal/util/file"
"github.com/hashicorp/vault/api"
loss "github.com/libopenstorage/secrets"
"github.com/libopenstorage/secrets/vault"

"github.com/ceph/ceph-csi/internal/util/file"
)

const (
Expand Down
19 changes: 0 additions & 19 deletions internal/kms/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ package kms

import (
"errors"
"os"
"testing"

"github.com/ceph/ceph-csi/internal/util/file"
loss "github.com/libopenstorage/secrets"
"github.com/stretchr/testify/require"
)
Expand All @@ -45,23 +43,6 @@ func TestDetectAuthMountPath(t *testing.T) {
}
}

func TestCreateTempFile(t *testing.T) {
t.Parallel()
data := "Hello World!"
tmpfile, err := file.CreateTempFile("my-file", data)
if err != nil {
t.Errorf("createTempFile() failed: %s", err)
}
if tmpfile.Name() == "" {
t.Errorf("createTempFile() returned an empty filename")
}

err = os.Remove(tmpfile.Name())
if err != nil {
t.Errorf("failed to remove tmpfile (%s): %s", tmpfile, err)
}
}

func TestSetConfigString(t *testing.T) {
t.Parallel()
const defaultValue = "default-value"
Expand Down
12 changes: 6 additions & 6 deletions internal/kms/vault_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ func (vtc *vaultTenantConnection) initCertificates(config map[string]interface{}
return fmt.Errorf("failed to get CA certificate from secret %s: %w", vaultCAFromSecret, cErr)
}
}
cer, err := file.CreateTempFile("vault-ca-cert", cert)
if err != nil {
return fmt.Errorf("failed to create temporary file for Vault CA: %w", err)
cer, ferr := file.CreateTempFile("vault-ca-cert", cert)
if ferr != nil {
return fmt.Errorf("failed to create temporary file for Vault CA: %w", ferr)
}
vaultConfig[api.EnvVaultCACert] = cer.Name()
}
Expand All @@ -405,9 +405,9 @@ func (vtc *vaultTenantConnection) initCertificates(config map[string]interface{}
return fmt.Errorf("failed to get client certificate from secret %s: %w", vaultCAFromSecret, cErr)
}
}
cer, err := file.CreateTempFile("vault-ca-cert", cert)
if err != nil {
return fmt.Errorf("failed to create temporary file for Vault client certificate: %w", err)
cer, ferr := file.CreateTempFile("vault-ca-cert", cert)
if ferr != nil {
return fmt.Errorf("failed to create temporary file for Vault client certificate: %w", ferr)
}
vaultConfig[api.EnvVaultClientCert] = cer.Name()
}
Expand Down
23 changes: 12 additions & 11 deletions internal/rbd/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,69 +438,70 @@ func (ri *rbdImage) RemoveDEK(ctx context.Context, volumeID string) error {
return nil
}

// GetEncryptionPassphraseSize returns the value of `encryptionPassphraseSize`
// GetEncryptionPassphraseSize returns the value of `encryptionPassphraseSize`.
func GetEncryptionPassphraseSize() int {
return encryptionPassphraseSize
}

// RotateKey processes the key rotation for the RBD Volume
// RotateKey processes the key rotation for the RBD Volume.
func (rv *rbdVolume) RotateEncryptionKey(ctx context.Context) error {
if !rv.isBlockEncrypted() {
return fmt.Errorf("key rotation not supported for the encryption type")
return errors.New("key rotation not supported for the encryption type")
}

// Verify that the underlying device has been setup for encryption
currState, err := rv.checkRbdImageEncrypted(ctx)
if err != nil {
return fmt.Errorf("error: %v while checking encrpytion state", err)
return fmt.Errorf("error while checking encryption state: %w", err)
}

if currState != rbdImageEncrypted {
return fmt.Errorf("key rotation not supported for unencrypted device")
return errors.New("key rotation not supported for unencrypted device")
}

// Get the device path for the underlying image
useNbd := rv.Mounter == rbdNbdMounter && hasNBD
devicePath, found := waitForPath(ctx, rv.Pool, rv.RadosNamespace, rv.RbdImageName, 1, useNbd)
if !found {
return fmt.Errorf("unable to get the device path for the image")
return errors.New("unable to get the device path for the image")
}

// Step 1: Get the current passphrase
oldPassphrase, err := rv.blockEncryption.GetCryptoPassphrase(ctx, rv.VolID)
if err != nil {
return fmt.Errorf("error in fetching the current passphrase: %v", err)
return fmt.Errorf("error in fetching the current passphrase: %w", err)
}

// Step 2: Add current key to slot 1
err = util.LuksAddKey(devicePath, oldPassphrase, oldPassphrase, "1")
if err != nil {
return fmt.Errorf("error in adding curr key to slot 1: %v", err)
return fmt.Errorf("error in adding curr key to slot 1: %w", err)
}

// Step 3: Generate new key and add it to slot 0
newPassphrase, err := rv.blockEncryption.GetNewCryptoPassphrase(
GetEncryptionPassphraseSize())
if err != nil {
return fmt.Errorf("error in generating a new passphrase: %v", err)
return fmt.Errorf("error in generating a new passphrase: %w", err)
}

err = util.LuksAddKey(devicePath, oldPassphrase, newPassphrase, "0")
if err != nil {
return fmt.Errorf("error in adding the new key to slot 0: %v", err)
return fmt.Errorf("error in adding the new key to slot 0: %w", err)
}

// Step 4: Add the new key to KMS
err = rv.blockEncryption.StoreCryptoPassphrase(ctx, rv.VolID, newPassphrase)
if err != nil {
return fmt.Errorf("failed to update the new key into the KMS: %v", err)
return fmt.Errorf("failed to update the new key into the KMS: %w", err)
}

// Step 5: Remove the old key from slot 1
// We use the newPassphrase to authenticate LUKS here
err = util.LuksRemoveKey(devicePath, newPassphrase, "1")
if err != nil {
// FIXME: Discuss if we should return an error here
//nolint:nilerr // Intentional
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/util/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (ve *VolumeEncryption) GetCryptoPassphrase(ctx context.Context, volumeID st
return ve.KMS.DecryptDEK(ctx, volumeID, passphrase)
}

// GetNewCryptoPassphrase returns a random passphrase of given length
// GetNewCryptoPassphrase returns a random passphrase of given length.
func (ve *VolumeEncryption) GetNewCryptoPassphrase(length int) (string, error) {
return generateNewEncryptionPassphrase(length)
}
Expand Down
40 changes: 19 additions & 21 deletions internal/util/cryptsetup.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func LuksStatus(mapperFile string) (string, string, error) {
return execCryptsetupCommand(nil, "status", mapperFile)
}

// LuksAddKey adds a new key to the specified slot
// LuksAddKey adds a new key to the specified slot.
func LuksAddKey(devicePath, passphrase, newPassphrase, slot string) error {
passFile, err := file.CreateTempFile("luks-", passphrase)
if err != nil {
Expand All @@ -86,8 +86,8 @@ func LuksAddKey(devicePath, passphrase, newPassphrase, slot string) error {
_, stderr, err := execCryptsetupCommand(
nil,
"--verbose",
fmt.Sprintf("--key-file=%s", passFile.Name()),
fmt.Sprintf("--key-slot=%s", slot),
"--key-file="+passFile.Name(),
"--key-slot="+slot,
"luksAddKey",
devicePath,
newPassFile.Name(),
Expand All @@ -106,9 +106,9 @@ func LuksAddKey(devicePath, passphrase, newPassphrase, slot string) error {
if strings.Contains(stderr, fmt.Sprintf("Key slot %s is full", slot)) {
// The given slot already has a key
// Check if it is the one that we want to update with
exists, err := LuksVerifyKey(devicePath, newPassphrase, slot)
if err != nil {
return err
exists, fErr := LuksVerifyKey(devicePath, newPassphrase, slot)
if fErr != nil {
return fErr
}

// Verification passed, return early
Expand All @@ -119,15 +119,15 @@ func LuksAddKey(devicePath, passphrase, newPassphrase, slot string) error {
// Else, we remove the key from the given slot and add the new one
// Note: we use existing passphrase here as we are not yet sure if
// the newPassphrase is present in the headers
err = LuksRemoveKey(devicePath, passphrase, slot)
if err != nil {
return err
fErr = LuksRemoveKey(devicePath, passphrase, slot)
if fErr != nil {
return fErr
}

// Now the slot is free, add the new key to it
err = LuksAddKey(devicePath, passphrase, newPassphrase, slot)
if err != nil {
return err
fErr = LuksAddKey(devicePath, passphrase, newPassphrase, slot)
if fErr != nil {
return fErr
}

// No errors, we good.
Expand All @@ -138,7 +138,7 @@ func LuksAddKey(devicePath, passphrase, newPassphrase, slot string) error {
return err
}

// LuksRemoveKey removes the key by killing the specified slot
// LuksRemoveKey removes the key by killing the specified slot.
func LuksRemoveKey(devicePath, passphrase, slot string) error {
keyFile, err := file.CreateTempFile("luks-", passphrase)
if err != nil {
Expand All @@ -149,23 +149,22 @@ func LuksRemoveKey(devicePath, passphrase, slot string) error {
_, stderr, err := execCryptsetupCommand(
nil,
"--verbose",
fmt.Sprintf("--key-file=%s", keyFile.Name()),
"--key-file="+keyFile.Name(),
"luksKillSlot",
devicePath,
slot,
)

if err != nil {
// If a slot is not active, don't treat that as an error
if !strings.Contains(stderr, fmt.Sprintf("Keyslot %s is not active.", slot)) {
return fmt.Errorf("failed to kill slot %s for device %s: %v", slot, devicePath, err)
return fmt.Errorf("failed to kill slot %s for device %s: %w", slot, devicePath, err)
}
}

return nil
}

// LuksVerifyKey verifies that a key exists in a given slot
// LuksVerifyKey verifies that a key exists in a given slot.
func LuksVerifyKey(devicePath, passphrase, slot string) (bool, error) {
// Create a temp file that we will use to open the device
keyFile, err := file.CreateTempFile("luks-", passphrase)
Expand All @@ -178,13 +177,12 @@ func LuksVerifyKey(devicePath, passphrase, slot string) (bool, error) {
_, stderr, err := execCryptsetupCommand(
nil,
"--verbose",
fmt.Sprintf("--key-file=%s", keyFile.Name()),
fmt.Sprintf("--key-slot=%s", slot),
"--key-file="+keyFile.Name(),
"--key-slot="+slot,
"luksChangeKey",
devicePath,
keyFile.Name(),
)

if err != nil {
// If the passphrase doesn't match the key in given slot
if strings.Contains(stderr, "No key available with this passphrase.") {
Expand All @@ -193,7 +191,7 @@ func LuksVerifyKey(devicePath, passphrase, slot string) (bool, error) {
}

// Otherwise it was something else, return the wrapped error
return false, fmt.Errorf("failed to verify key in slot %s for device %s: %v", slot, devicePath, err)
return false, fmt.Errorf("failed to verify key in slot %s for device %s: %w", slot, devicePath, err)
}

return true, nil
Expand Down
6 changes: 6 additions & 0 deletions internal/util/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
)

func TestCreateTempFile_WithValidContent(t *testing.T) {
t.Parallel()

content := "Valid Content"

file, err := CreateTempFile("test-", content)
Expand All @@ -45,6 +47,8 @@ func TestCreateTempFile_WithValidContent(t *testing.T) {
}

func TestCreateTempFile_WithEmptyContent(t *testing.T) {
t.Parallel()

content := ""

file, err := CreateTempFile("test-", content)
Expand All @@ -68,6 +72,8 @@ func TestCreateTempFile_WithEmptyContent(t *testing.T) {
}

func TestCreateTempFile_WithLargeContent(t *testing.T) {
t.Parallel()

content := string(make([]byte, 1<<20))

file, err := CreateTempFile("test-", content)
Expand Down

0 comments on commit bce2d76

Please sign in to comment.