Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kms: refactor functions to accept a context parameter #4477

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/cephfs/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (cs *ControllerServer) createBackingVolume(
&volOptions.SubVolume, volOptions.ClusterID, cs.ClusterName, cs.SetMetadata)

if sID != nil {
err = parentVolOpt.CopyEncryptionConfig(volOptions, sID.SnapshotID, vID.VolumeID)
err = parentVolOpt.CopyEncryptionConfig(ctx, volOptions, sID.SnapshotID, vID.VolumeID)
if err != nil {
return status.Error(codes.Internal, err.Error())
}
Expand All @@ -86,7 +86,7 @@ func (cs *ControllerServer) createBackingVolume(
}

if parentVolOpt != nil {
err = parentVolOpt.CopyEncryptionConfig(volOptions, pvID.VolumeID, vID.VolumeID)
err = parentVolOpt.CopyEncryptionConfig(ctx, volOptions, pvID.VolumeID, vID.VolumeID)
if err != nil {
return status.Error(codes.Internal, err.Error())
}
Expand Down Expand Up @@ -596,7 +596,7 @@ func (cs *ControllerServer) cleanUpBackingVolume(
// GetSecret enabled KMS the DEKs are stored by
// fscrypt on the volume that is going to be deleted anyway.
log.DebugLog(ctx, "going to remove DEK for integrated store %q (fscrypt)", volOptions.Encryption.GetID())
if err := volOptions.Encryption.RemoveDEK(volID.VolumeID); err != nil {
if err := volOptions.Encryption.RemoveDEK(ctx, volID.VolumeID); err != nil {
log.WarningLog(ctx, "failed to clean the passphrase for volume %q (file encryption): %s",
volOptions.VolID, err)
}
Expand Down Expand Up @@ -907,7 +907,7 @@ func (cs *ControllerServer) CreateSnapshot(
// Use same encryption KMS than source volume and copy the passphrase. The passphrase becomes
// available under the snapshot id for CreateVolume to use this snap as a backing volume
snapVolOptions := store.VolumeOptions{}
err = parentVolOptions.CopyEncryptionConfig(&snapVolOptions, sourceVolID, sID.SnapshotID)
err = parentVolOptions.CopyEncryptionConfig(ctx, &snapVolOptions, sourceVolID, sID.SnapshotID)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
Expand Down
10 changes: 5 additions & 5 deletions internal/cephfs/store/volumeoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ func IsEncrypted(ctx context.Context, volOptions map[string]string) (bool, error

// CopyEncryptionConfig copies passphrases and initializes a fresh
// Encryption struct if necessary from (vo, vID) to (cp, cpVID).
func (vo *VolumeOptions) CopyEncryptionConfig(cp *VolumeOptions, vID, cpVID string) error {
func (vo *VolumeOptions) CopyEncryptionConfig(ctx context.Context, cp *VolumeOptions, vID, cpVID string) error {
var err error

if !vo.IsEncrypted() {
Expand All @@ -916,21 +916,21 @@ func (vo *VolumeOptions) CopyEncryptionConfig(cp *VolumeOptions, vID, cpVID stri
if cp.Encryption == nil {
cp.Encryption, err = util.NewVolumeEncryption(vo.Encryption.GetID(), vo.Encryption.KMS)
if errors.Is(err, util.ErrDEKStoreNeeded) {
_, err := vo.Encryption.KMS.GetSecret("")
_, err := vo.Encryption.KMS.GetSecret(ctx, "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
return err
}
}
}

if vo.Encryption.KMS.RequiresDEKStore() == kmsapi.DEKStoreIntegrated {
passphrase, err := vo.Encryption.GetCryptoPassphrase(vID)
passphrase, err := vo.Encryption.GetCryptoPassphrase(ctx, vID)
if err != nil {
return fmt.Errorf("failed to fetch passphrase for %q (%+v): %w",
vID, vo, err)
}

err = cp.Encryption.StoreCryptoPassphrase(cpVID, passphrase)
err = cp.Encryption.StoreCryptoPassphrase(ctx, cpVID, passphrase)
if err != nil {
return fmt.Errorf("failed to store passphrase for %q (%+v): %w",
cpVID, cp, err)
Expand Down Expand Up @@ -962,7 +962,7 @@ func (vo *VolumeOptions) ConfigureEncryption(
// store. Since not all "metadata" KMS support
// GetSecret, test for support here. Postpone any
// other error handling
_, err := vo.Encryption.KMS.GetSecret("")
_, err := vo.Encryption.KMS.GetSecret(ctx, "")
if errors.Is(err, kmsapi.ErrGetSecretUnsupported) {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions internal/kms/aws_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (kms *awsMetadataKMS) getService() (*awsKMS.KMS, error) {
}

// EncryptDEK uses the Amazon KMS and the configured CMK to encrypt the DEK.
func (kms *awsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, error) {
func (kms *awsMetadataKMS) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
svc, err := kms.getService()
if err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err)
Expand All @@ -205,7 +205,7 @@ func (kms *awsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, error)
}

// DecryptDEK uses the Amazon KMS and the configured CMK to decrypt the DEK.
func (kms *awsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string, error) {
func (kms *awsMetadataKMS) DecryptDEK(ctx context.Context, volumeID, encryptedDEK string) (string, error) {
svc, err := kms.getService()
if err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err)
Expand All @@ -227,6 +227,6 @@ func (kms *awsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string, er
return string(result.Plaintext), nil
}

func (kms *awsMetadataKMS) GetSecret(volumeID string) (string, error) {
func (kms *awsMetadataKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported
}
4 changes: 2 additions & 2 deletions internal/kms/aws_sts_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (as *awsSTSMetadataKMS) getServiceWithSTS() (*awsKMS.KMS, error) {
}

// EncryptDEK uses the Amazon KMS and the configured CMK to encrypt the DEK.
func (as *awsSTSMetadataKMS) EncryptDEK(_, plainDEK string) (string, error) {
func (as *awsSTSMetadataKMS) EncryptDEK(ctx context.Context, _, plainDEK string) (string, error) {
svc, err := as.getServiceWithSTS()
if err != nil {
return "", fmt.Errorf("failed to get KMS service: %w", err)
Expand All @@ -213,7 +213,7 @@ func (as *awsSTSMetadataKMS) EncryptDEK(_, plainDEK string) (string, error) {
}

// DecryptDEK uses the Amazon KMS and the configured CMK to decrypt the DEK.
func (as *awsSTSMetadataKMS) DecryptDEK(_, encryptedDEK string) (string, error) {
func (as *awsSTSMetadataKMS) DecryptDEK(ctx context.Context, _, encryptedDEK string) (string, error) {
svc, err := as.getServiceWithSTS()
if err != nil {
return "", fmt.Errorf("failed to get KMS service: %w", err)
Expand Down
10 changes: 5 additions & 5 deletions internal/kms/keyprotect.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,14 @@ func (kms *keyProtectKMS) getService() error {
}

// EncryptDEK uses the KeyProtect KMS and the configured CRK to encrypt the DEK.
func (kms *keyProtectKMS) EncryptDEK(volumeID, plainDEK string) (string, error) {
func (kms *keyProtectKMS) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
if err := kms.getService(); err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err)
}

dekByteSlice := []byte(plainDEK)
aadVolID := []string{volumeID}
result, err := kms.client.Wrap(context.TODO(), kms.customerRootKey, dekByteSlice, &aadVolID)
result, err := kms.client.Wrap(ctx, kms.customerRootKey, dekByteSlice, &aadVolID)
if err != nil {
return "", fmt.Errorf("failed to wrap the DEK: %w", err)
}
Expand All @@ -223,7 +223,7 @@ func (kms *keyProtectKMS) EncryptDEK(volumeID, plainDEK string) (string, error)
}

// DecryptDEK uses the Key protect KMS and the configured CRK to decrypt the DEK.
func (kms *keyProtectKMS) DecryptDEK(volumeID, encryptedDEK string) (string, error) {
func (kms *keyProtectKMS) DecryptDEK(ctx context.Context, volumeID, encryptedDEK string) (string, error) {
if err := kms.getService(); err != nil {
return "", fmt.Errorf("could not get KMS service: %w", err)
}
Expand All @@ -235,14 +235,14 @@ func (kms *keyProtectKMS) DecryptDEK(volumeID, encryptedDEK string) (string, err
}

aadVolID := []string{volumeID}
result, err := kms.client.Unwrap(context.TODO(), kms.customerRootKey, ciphertextBlob, &aadVolID)
result, err := kms.client.Unwrap(ctx, kms.customerRootKey, ciphertextBlob, &aadVolID)
if err != nil {
return "", fmt.Errorf("failed to unwrap the DEK: %w", err)
}

return string(result), nil
}

func (kms *keyProtectKMS) GetSecret(volumeID string) (string, error) {
func (kms *keyProtectKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported
}
6 changes: 3 additions & 3 deletions internal/kms/kmip.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func initKMIPKMS(args ProviderInitArgs) (EncryptionKMS, error) {
}

// EncryptDEK uses the KMIP encrypt operation to encrypt the DEK.
func (kms *kmipKMS) EncryptDEK(_, plainDEK string) (string, error) {
func (kms *kmipKMS) EncryptDEK(ctx context.Context, _, plainDEK string) (string, error) {
conn, err := kms.connect()
if err != nil {
return "", err
Expand Down Expand Up @@ -236,7 +236,7 @@ func (kms *kmipKMS) EncryptDEK(_, plainDEK string) (string, error) {
}

// DecryptDEK uses the KMIP decrypt operation to decrypt the DEK.
func (kms *kmipKMS) DecryptDEK(_, encryptedDEK string) (string, error) {
func (kms *kmipKMS) DecryptDEK(ctx context.Context, _, encryptedDEK string) (string, error) {
conn, err := kms.connect()
if err != nil {
return "", err
Expand Down Expand Up @@ -500,7 +500,7 @@ func (kms *kmipKMS) verifyResponse(
return &batchItem, nil
}

func (kms *kmipKMS) GetSecret(volumeID string) (string, error) {
func (kms *kmipKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported
}

Expand Down
18 changes: 9 additions & 9 deletions internal/kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,18 @@ type EncryptionKMS interface {
// EncryptDEK provides a way for a KMS to encrypt a DEK. In case the
// encryption is done transparently inside the KMS service, the
// function can return an unencrypted value.
EncryptDEK(volumeID, plainDEK string) (string, error)
EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error)

// DecryptDEK provides a way for a KMS to decrypt a DEK. In case the
// encryption is done transparently inside the KMS service, the
// function does not need to do anything except return the encyptedDEK
// as it was received.
DecryptDEK(volumeID, encyptedDEK string) (string, error)
DecryptDEK(ctx context.Context, volumeID, encyptedDEK string) (string, error)

// GetSecret allows external key management systems to
// retrieve keys used in EncryptDEK / DecryptDEK to use them
// directly. Example: fscrypt uses this to unlock raw protectors
GetSecret(volumeID string) (string, error)
GetSecret(ctx context.Context, volumeID string) (string, error)
}

// DEKStoreType describes what DEKStore needs to be configured when using a
Expand All @@ -364,11 +364,11 @@ const (
// the KMS can not store passphrases for volumes.
type DEKStore interface {
// StoreDEK saves the DEK in the configured store.
StoreDEK(volumeID string, dek string) error
StoreDEK(ctx context.Context, volumeID string, dek string) error
// FetchDEK reads the DEK from the configured store and returns it.
FetchDEK(volumeID string) (string, error)
FetchDEK(ctx context.Context, volumeID string) (string, error)
// RemoveDEK deletes the DEK from the configured store.
RemoveDEK(volumeID string) error
RemoveDEK(ctx context.Context, volumeID string) error
}

// integratedDEK is a DEKStore that can not be configured. Either the KMS does
Expand All @@ -380,15 +380,15 @@ func (i integratedDEK) RequiresDEKStore() DEKStoreType {
return DEKStoreIntegrated
}

func (i integratedDEK) EncryptDEK(volumeID, plainDEK string) (string, error) {
func (i integratedDEK) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
return plainDEK, nil
}

func (i integratedDEK) DecryptDEK(volumeID, encyptedDEK string) (string, error) {
func (i integratedDEK) DecryptDEK(ctx context.Context, volumeID, encyptedDEK string) (string, error) {
return encyptedDEK, nil
}

func (i integratedDEK) GetSecret(volumeID string) (string, error) {
func (i integratedDEK) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretIntegrated
}

Expand Down
18 changes: 9 additions & 9 deletions internal/kms/secretskms.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,19 @@ func (kms secretsKMS) Destroy() {
}

// FetchDEK returns passphrase from Kubernetes secrets.
func (kms secretsKMS) FetchDEK(key string) (string, error) {
func (kms secretsKMS) FetchDEK(ctx context.Context, key string) (string, error) {
return kms.passphrase, nil
}

// StoreDEK does nothing, as there is no passphrase per key (volume), so
// no need to store is anywhere.
func (kms secretsKMS) StoreDEK(key, value string) error {
func (kms secretsKMS) StoreDEK(ctx context.Context, key, value string) error {
return nil
}

// RemoveDEK is doing nothing as no new passphrases are saved with
// secretsKMS.
func (kms secretsKMS) RemoveDEK(key string) error {
func (kms secretsKMS) RemoveDEK(ctx context.Context, key string) error {
return nil
}

Expand Down Expand Up @@ -206,9 +206,9 @@ type encryptedMetedataDEK struct {
// the secretsKMS and the volumeID.
// The resulting encryptedDEK contains a JSON with the encrypted DEK and the
// nonce that was used for encrypting.
func (kms secretsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, error) {
func (kms secretsMetadataKMS) EncryptDEK(ctx context.Context, volumeID, plainDEK string) (string, error) {
// use the passphrase from the secretKMS
passphrase, err := kms.secretsKMS.FetchDEK(volumeID)
passphrase, err := kms.secretsKMS.FetchDEK(ctx, volumeID)
if err != nil {
return "", fmt.Errorf("failed to get passphrase: %w", err)
}
Expand Down Expand Up @@ -236,9 +236,9 @@ func (kms secretsMetadataKMS) EncryptDEK(volumeID, plainDEK string) (string, err

// DecryptDEK takes the JSON formatted `encryptedMetadataDEK` contents, and it
// fetches secretKMS passphrase to decrypt the DEK.
func (kms secretsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string, error) {
func (kms secretsMetadataKMS) DecryptDEK(ctx context.Context, volumeID, encryptedDEK string) (string, error) {
// use the passphrase from the secretKMS
passphrase, err := kms.secretsKMS.FetchDEK(volumeID)
passphrase, err := kms.secretsKMS.FetchDEK(ctx, volumeID)
if err != nil {
return "", fmt.Errorf("failed to get passphrase: %w", err)
}
Expand All @@ -263,9 +263,9 @@ func (kms secretsMetadataKMS) DecryptDEK(volumeID, encryptedDEK string) (string,
return string(dek), nil
}

func (kms secretsMetadataKMS) GetSecret(volumeID string) (string, error) {
func (kms secretsMetadataKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
// use the passphrase from the secretKMS
return kms.secretsKMS.FetchDEK(volumeID)
return kms.secretsKMS.FetchDEK(ctx, volumeID)
}

// generateCipher returns a AEAD cipher based on a passphrase and salt
Expand Down
9 changes: 6 additions & 3 deletions internal/kms/secretskms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package kms

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -103,19 +104,21 @@ func TestWorkflowSecretsMetadataKMS(t *testing.T) {
// plainDEK is the (LUKS) passphrase for the volume
plainDEK := "usually created with generateNewEncryptionPassphrase()"

encryptedDEK, err := kms.EncryptDEK(volumeID, plainDEK)
ctx := context.TODO()

encryptedDEK, err := kms.EncryptDEK(ctx, volumeID, plainDEK)
assert.NoError(t, err)
assert.NotEqual(t, "", encryptedDEK)
assert.NotEqual(t, plainDEK, encryptedDEK)

// with an incorrect volumeID, decrypting should fail
decryptedDEK, err := kms.DecryptDEK("incorrect-volumeID", encryptedDEK)
decryptedDEK, err := kms.DecryptDEK(ctx, "incorrect-volumeID", encryptedDEK)
assert.Error(t, err)
assert.Equal(t, "", decryptedDEK)
assert.NotEqual(t, plainDEK, decryptedDEK)

// with the right volumeID, decrypting should return the plainDEK
decryptedDEK, err = kms.DecryptDEK(volumeID, encryptedDEK)
decryptedDEK, err = kms.DecryptDEK(ctx, volumeID, encryptedDEK)
assert.NoError(t, err)
assert.NotEqual(t, "", decryptedDEK)
assert.Equal(t, plainDEK, decryptedDEK)
Expand Down
7 changes: 4 additions & 3 deletions internal/kms/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package kms

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -395,7 +396,7 @@ func initVaultKMS(args ProviderInitArgs) (EncryptionKMS, error) {

// FetchDEK returns passphrase from Vault. The passphrase is stored in a
// data.data.passphrase structure.
func (kms *vaultKMS) FetchDEK(key string) (string, error) {
func (kms *vaultKMS) FetchDEK(ctx context.Context, key string) (string, error) {
// Since the second return variable loss.Version is not used, there it is ignored.
s, _, err := kms.secrets.GetSecret(filepath.Join(kms.vaultPassphrasePath, key), kms.keyContext)
if err != nil {
Expand All @@ -415,7 +416,7 @@ func (kms *vaultKMS) FetchDEK(key string) (string, error) {
}

// StoreDEK saves new passphrase in Vault.
func (kms *vaultKMS) StoreDEK(key, value string) error {
func (kms *vaultKMS) StoreDEK(ctx context.Context, key, value string) error {
data := map[string]interface{}{
"data": map[string]string{
"passphrase": value,
Expand All @@ -433,7 +434,7 @@ func (kms *vaultKMS) StoreDEK(key, value string) error {
}

// RemoveDEK deletes passphrase from Vault.
func (kms *vaultKMS) RemoveDEK(key string) error {
func (kms *vaultKMS) RemoveDEK(ctx context.Context, key string) error {
pathKey := filepath.Join(kms.vaultPassphrasePath, key)
err := kms.secrets.DeleteSecret(pathKey, kms.getDeleteKeyContext())
if err != nil {
Expand Down
Loading