diff --git a/internal/csi-common/controllerserver-default.go b/internal/csi-common/controllerserver-default.go index fcc261d78a6..2ce290928b9 100644 --- a/internal/csi-common/controllerserver-default.go +++ b/internal/csi-common/controllerserver-default.go @@ -47,3 +47,19 @@ func (cs *DefaultControllerServer) ControllerGetCapabilities( Capabilities: cs.Driver.capabilities, }, nil } + +// GroupControllerGetCapabilities implements the default +// GroupControllerGetCapabilities GRPC callout. +func (cs *DefaultControllerServer) GroupControllerGetCapabilities( + ctx context.Context, + req *csi.GroupControllerGetCapabilitiesRequest, +) (*csi.GroupControllerGetCapabilitiesResponse, error) { + log.TraceLog(ctx, "Using default GroupControllerGetCapabilities") + if cs.Driver == nil { + return nil, status.Error(codes.Unimplemented, "Group controller server is not enabled") + } + + return &csi.GroupControllerGetCapabilitiesResponse{ + Capabilities: cs.Driver.groupCapabilities, + }, nil +} diff --git a/internal/csi-common/driver.go b/internal/csi-common/driver.go index 31c89070e08..4062845b4cb 100644 --- a/internal/csi-common/driver.go +++ b/internal/csi-common/driver.go @@ -31,9 +31,10 @@ type CSIDriver struct { nodeID string version string // topology constraints that this nodeserver will advertise - topology map[string]string - capabilities []*csi.ControllerServiceCapability - vc []*csi.VolumeCapability_AccessMode + topology map[string]string + capabilities []*csi.ControllerServiceCapability + groupCapabilities []*csi.GroupControllerServiceCapability + vc []*csi.VolumeCapability_AccessMode } // NewCSIDriver Creates a NewCSIDriver object. Assumes vendor @@ -116,3 +117,34 @@ func (d *CSIDriver) AddVolumeCapabilityAccessModes( func (d *CSIDriver) GetVolumeCapabilityAccessModes() []*csi.VolumeCapability_AccessMode { return d.vc } + +// AddControllerServiceCapabilities stores the group controller capabilities +// in driver object. +func (d *CSIDriver) AddGroupControllerServiceCapabilities(cl []csi.GroupControllerServiceCapability_RPC_Type) { + csc := make([]*csi.GroupControllerServiceCapability, 0, len(cl)) + + for _, c := range cl { + log.DefaultLog("Enabling group controller service capability: %v", c.String()) + csc = append(csc, NewGroupControllerServiceCapability(c)) + } + + d.groupCapabilities = csc +} + +// ValidateGroupControllerServiceRequest validates the group controller +// plugin capabilities. +// +//nolint:interfacer // c can be of type fmt.Stringer, but that does not make the API clearer +func (d *CSIDriver) ValidateGroupControllerServiceRequest(c csi.GroupControllerServiceCapability_RPC_Type) error { + if c == csi.GroupControllerServiceCapability_RPC_UNKNOWN { + return nil + } + + for _, capability := range d.groupCapabilities { + if c == capability.GetRpc().GetType() { + return nil + } + } + + return status.Error(codes.InvalidArgument, c.String()) +} diff --git a/internal/csi-common/server.go b/internal/csi-common/server.go index 13157f222fa..727ef57f80f 100644 --- a/internal/csi-common/server.go +++ b/internal/csi-common/server.go @@ -45,6 +45,7 @@ type Servers struct { IS csi.IdentityServer CS csi.ControllerServer NS csi.NodeServer + GS csi.GroupControllerServer } // NewNonBlockingGRPCServer return non-blocking GRPC. @@ -109,6 +110,9 @@ func (s *nonBlockingGRPCServer) serve(endpoint string, srv Servers) { if srv.NS != nil { csi.RegisterNodeServer(server, srv.NS) } + if srv.GS != nil { + csi.RegisterGroupControllerServer(server, srv.GS) + } log.DefaultLog("Listening for connections on address: %#v", listener.Addr()) err = server.Serve(listener) diff --git a/internal/csi-common/utils.go b/internal/csi-common/utils.go index 080f9df9347..daf6170eeb6 100644 --- a/internal/csi-common/utils.go +++ b/internal/csi-common/utils.go @@ -95,6 +95,18 @@ func NewControllerServiceCapability(ctrlCap csi.ControllerServiceCapability_RPC_ } } +// NewGroupControllerServiceCapability returns group controller capabilities. +func NewGroupControllerServiceCapability(ctrlCap csi.GroupControllerServiceCapability_RPC_Type, +) *csi.GroupControllerServiceCapability { + return &csi.GroupControllerServiceCapability{ + Type: &csi.GroupControllerServiceCapability_Rpc{ + Rpc: &csi.GroupControllerServiceCapability_RPC{ + Type: ctrlCap, + }, + }, + } +} + // NewMiddlewareServerOption creates a new grpc.ServerOption that configures a // common format for log messages and other gRPC related handlers. func NewMiddlewareServerOption() grpc.ServerOption { @@ -133,6 +145,13 @@ func getReqID(req interface{}) string { case *csi.NodeExpandVolumeRequest: reqID = r.VolumeId + + case *csi.CreateVolumeGroupSnapshotRequest: + reqID = r.Name + case *csi.DeleteVolumeGroupSnapshotRequest: + reqID = r.GroupSnapshotId + case *csi.GetVolumeGroupSnapshotRequest: + reqID = r.GroupSnapshotId } return reqID diff --git a/internal/csi-common/utils_test.go b/internal/csi-common/utils_test.go index e5687986aba..ddb16648a34 100644 --- a/internal/csi-common/utils_test.go +++ b/internal/csi-common/utils_test.go @@ -65,6 +65,16 @@ func TestGetReqID(t *testing.T) { &csi.NodeExpandVolumeRequest{ VolumeId: fakeID, }, + + &csi.CreateVolumeGroupSnapshotRequest{ + Name: fakeID, + }, + &csi.DeleteVolumeGroupSnapshotRequest{ + GroupSnapshotId: fakeID, + }, + &csi.GetVolumeGroupSnapshotRequest{ + GroupSnapshotId: fakeID, + }, } for _, r := range req { if got := getReqID(r); got != fakeID {