diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index bc096b42b7..2c4e422d5d 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -44,8 +44,9 @@ services: # policy is enabled by default in mode 'all' # policy: # enabled: true - # list_request_limit_default: 1000 - # list_request_limit_max: 2500 + # list_request_limit_default: 1000 + # list_request_limit_max: 2500 + # cache_refresh_interval_seconds: 15 server: tls: enabled: false diff --git a/opentdf-example.yaml b/opentdf-example.yaml index d15d4a4477..e2891cdf5d 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -37,6 +37,7 @@ services: # enabled: true # list_request_limit_default: 1000 # list_request_limit_max: 2500 + # cache_refresh_interval_seconds: 15 server: auth: enabled: true diff --git a/service/pkg/config/config.go b/service/pkg/config/config.go index 50f49b62a3..ec26545099 100644 --- a/service/pkg/config/config.go +++ b/service/pkg/config/config.go @@ -14,6 +14,9 @@ import ( // ChangeHook is a function invoked when the configuration changes. type ChangeHook func(configServices ServicesMap) error +// ServicesStartedHook is a function invoked when all service registrations are complete. +type ServicesStartedHook func(context.Context) error + // Config structure holding all services. type ServicesMap map[string]ServiceConfig @@ -47,6 +50,8 @@ type Config struct { // Trace is for configuring open telemetry based tracing. Trace tracing.Config `mapstructure:"trace"` + // onServicesStartedHooks is a list of functions to call when all service registrations are complete. + onServicesStartedHooks []ServicesStartedHook // onConfigChangeHooks is a list of functions to call when the configuration changes. onConfigChangeHooks []ChangeHook // loaders is a list of configuration loaders. @@ -110,6 +115,11 @@ func (c *Config) AddOnConfigChangeHook(hook ChangeHook) { c.onConfigChangeHooks = append(c.onConfigChangeHooks, hook) } +// AddOnServicesStartedHook adds a hook to the list of hooks to call when all service registrations are complete. +func (c *Config) AddOnServicesStartedHook(hook ServicesStartedHook) { + c.onServicesStartedHooks = append(c.onServicesStartedHooks, hook) +} + // Watch starts watching the configuration for changes in all config loaders. func (c *Config) Watch(ctx context.Context) error { if len(c.loaders) == 0 { @@ -123,6 +133,19 @@ func (c *Config) Watch(ctx context.Context) error { return nil } +// RunServicesStartedHooks triggers the service hooks that run once all services are live. +func (c *Config) RunServicesStartedHooks(ctx context.Context) error { + if len(c.onServicesStartedHooks) == 0 { + return nil + } + for _, hook := range c.onServicesStartedHooks { + if err := hook(ctx); err != nil { + return err + } + } + return nil +} + // Close invokes close method on all config loaders. func (c *Config) Close(_ context.Context) error { if len(c.loaders) == 0 { diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 505e5f4d27..23f42ea3f9 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -194,6 +194,10 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF return func() {}, fmt.Errorf("failed to register config update hook: %w", err) } + if err := svc.RegisterOnServicesStartedHook(ctx, cfg.AddOnServicesStartedHook); err != nil { + return func() {}, fmt.Errorf("failed to register on complete service registration hook: %w", err) + } + // Register Connect RPC Services if err := svc.RegisterConnectRPCServiceHandler(ctx, otdf.ConnectRPC); err != nil { logger.Info("service did not register a connect-rpc handler", slog.String("namespace", ns)) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index fd54ca9d0e..54a4a82184 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -300,6 +300,11 @@ func Start(f ...StartOptions) error { } defer cfg.Close(ctx) + // Run the services started hooks + if err := cfg.RunServicesStartedHooks(ctx); err != nil { + return fmt.Errorf("failed to run service registration complete hooks: %w", err) + } + // Start the server logger.Info("starting opentdf") if err := otdf.Start(); err != nil { diff --git a/service/pkg/serviceregistry/serviceregistry.go b/service/pkg/serviceregistry/serviceregistry.go index 0efb58e096..8806625747 100644 --- a/service/pkg/serviceregistry/serviceregistry.go +++ b/service/pkg/serviceregistry/serviceregistry.go @@ -59,6 +59,8 @@ type ( RegisterFunc[S any] func(RegistrationParams) (impl S, HandlerServer HandlerServer) // Allow services to implement handling for config changes as direced by caller OnConfigUpdateHook func(context.Context, config.ServiceConfig) error + // Allow services to implement a callback to run when all services are registered + OnServicesStartedHook func(context.Context) error ) // DBRegister is a struct that holds the information needed to register a service with a database @@ -81,6 +83,7 @@ type IService interface { IsStarted() bool Shutdown() error RegisterConfigUpdateHook(ctx context.Context, hookAppender func(config.ChangeHook)) error + RegisterOnServicesStartedHook(ctx context.Context, hookAppender func(config.ServicesStartedHook)) error RegisterConnectRPCServiceHandler(context.Context, *server.ConnectRPC) error RegisterGRPCGatewayHandler(context.Context, *runtime.ServeMux, *grpc.ClientConn) error RegisterHTTPHandlers(context.Context, *runtime.ServeMux) error @@ -110,6 +113,8 @@ type ServiceOptions[S any] struct { ServiceDesc *grpc.ServiceDesc // OnConfigUpdate is a hook to handle in-service actions when config changes OnConfigUpdate OnConfigUpdateHook + // OnServicesStarted is a hook to handle in-service actions that should run when all services are registered + OnServicesStarted OnServicesStartedHook // RegisterFunc is the function that will be called to register the service RegisterFunc RegisterFunc[S] // HTTPHandlerFunc is the function that will be called to register extra http handlers @@ -192,6 +197,22 @@ func (s Service[S]) RegisterConfigUpdateHook(ctx context.Context, hookAppender f return nil } +// RegisterOnServicesStartedHook appends a registered service's onServicesStartedHook to any watching services. +func (s Service[S]) RegisterOnServicesStartedHook(_ context.Context, hookAppender func(config.ServicesStartedHook)) error { + // If no hook is registered, exit + if s.OnServicesStarted != nil { + var onChange config.ServicesStartedHook = func(ctx context.Context) error { + slog.Debug("OnServicesStarted hook called", + slog.String("namespace", s.GetNamespace()), + slog.String("service", s.GetServiceDesc().ServiceName), + ) + return s.OnServicesStarted(ctx) + } + hookAppender(onChange) + } + return nil +} + func (s Service[S]) RegisterConnectRPCServiceHandler(_ context.Context, connectRPC *server.ConnectRPC) error { if s.ConnectRPCFunc == nil { return errors.New("service did not register a handler") diff --git a/service/policy/attributes/attributes.go b/service/policy/attributes/attributes.go index c2ec32a430..cc19654ff5 100644 --- a/service/policy/attributes/attributes.go +++ b/service/policy/attributes/attributes.go @@ -6,6 +6,7 @@ import ( "log/slog" "connectrpc.com/connect" + "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" @@ -24,6 +25,14 @@ type AttributesService struct { //nolint:revive // AttributesService is a valid logger *logger.Logger config *policyconfig.Config trace.Tracer + cache *policyconfig.EntitlementPolicyCache // Cache for attributes and subject mappings +} + +func OnServicesStarted(svc *AttributesService) serviceregistry.OnServicesStartedHook { + return func(ctx context.Context) error { + svc.cache = policyconfig.GetSharedEntitlementPolicyCache(ctx, svc.dbClient, svc.logger, svc.config) + return nil + } } func OnConfigUpdate(as *AttributesService) serviceregistry.OnConfigUpdateHook { @@ -44,16 +53,18 @@ func OnConfigUpdate(as *AttributesService) serviceregistry.OnConfigUpdateHook { func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[attributesconnect.AttributesServiceHandler] { as := new(AttributesService) onUpdateConfigHook := OnConfigUpdate(as) + onStartHook := OnServicesStarted(as) return &serviceregistry.Service[attributesconnect.AttributesServiceHandler]{ Close: as.Close, ServiceOptions: serviceregistry.ServiceOptions[attributesconnect.AttributesServiceHandler]{ - Namespace: ns, - DB: dbRegister, - ServiceDesc: &attributes.AttributesService_ServiceDesc, - ConnectRPCFunc: attributesconnect.NewAttributesServiceHandler, - GRPCGatewayFunc: attributes.RegisterAttributesServiceHandler, - OnConfigUpdate: onUpdateConfigHook, + Namespace: ns, + DB: dbRegister, + ServiceDesc: &attributes.AttributesService_ServiceDesc, + ConnectRPCFunc: attributesconnect.NewAttributesServiceHandler, + GRPCGatewayFunc: attributes.RegisterAttributesServiceHandler, + OnConfigUpdate: onUpdateConfigHook, + OnServicesStarted: onStartHook, RegisterFunc: func(srp serviceregistry.RegistrationParams) (attributesconnect.AttributesServiceHandler, serviceregistry.HandlerServer) { logger := srp.Logger cfg, err := policyconfig.GetSharedPolicyConfig(srp.Config) @@ -65,15 +76,19 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer as.logger = logger as.dbClient = policydb.NewClient(srp.DBClient, logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)) as.config = cfg + return as, nil }, }, } } -// Close gracefully shuts down the service, closing the database client. +// Close gracefully shuts down the attributes service's cache and database client. func (s *AttributesService) Close() { s.logger.Info("gracefully shutting down attributes service") + if s.cache != nil { + s.cache.Stop() + } s.dbClient.Close() } @@ -120,6 +135,50 @@ func (s *AttributesService) ListAttributes(ctx context.Context, state := req.Msg.GetState().String() s.logger.Debug("listing attribute definitions", slog.String("state", state)) + // If active state and caching enabled, return from cache instead of DB + isActiveState := req.Msg.GetState() == common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE + if s.cache.IsEnabled() && isActiveState { + s.logger.Debug("returning cached attributes") + + limit := req.Msg.GetPagination().GetLimit() + if limit <= 0 { + limit = int32(s.config.ListRequestLimitDefault) + } + offset := req.Msg.GetPagination().GetOffset() + + // Validate limit against the max configured value + maxLimit := int32(s.config.ListRequestLimitMax) + if maxLimit > 0 && limit > maxLimit { + return nil, db.StatusifyError(db.ErrListLimitTooLarge, db.ErrTextListLimitTooLarge) + } + + // Get all cached attributes + cachedAttrs, total, err := s.cache.ListCachedAttributes(ctx, limit, offset) + if err != nil { + s.logger.Error("failed to retrieve cached attributes", slog.Any("error", err)) + return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) + } + + // Calculate next offset using the same logic as the DB implementation + var nextOffset int32 + next := offset + limit + if next < total { + nextOffset = next + } // else nextOffset remains 0 + + rsp := &attributes.ListAttributesResponse{ + Attributes: cachedAttrs, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + } + return connect.NewResponse(rsp), nil + } + + s.logger.Debug("querying database for attributes") + rsp, err := s.dbClient.ListAttributes(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) @@ -365,7 +424,6 @@ func (s *AttributesService) DeactivateAttributeValue(ctx context.Context, req *c s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams) rsp.Value = updated - return connect.NewResponse(rsp), nil } diff --git a/service/policy/config/cache.go b/service/policy/config/cache.go new file mode 100644 index 0000000000..093d22e2bc --- /dev/null +++ b/service/policy/config/cache.go @@ -0,0 +1,269 @@ +package config + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/dgraph-io/ristretto" + "github.com/eko/gocache/lib/v4/cache" + ristretto_store "github.com/eko/gocache/store/ristretto/v4" + "github.com/opentdf/platform/protocol/go/policy" + "github.com/opentdf/platform/service/logger" + policydb "github.com/opentdf/platform/service/policy/db" +) + +// Shared service-level instance of EntitlementPolicyCache (attributes and subject mappings) +var ( + entitlementPolicyCacheInstance *EntitlementPolicyCache + entitlementPolicyCacheOnce sync.Once +) + +const ( + attributesCacheKey = "attributes" + subjectMappingsCacheKey = "subject_mappings" + + numCounters = 1000 // Number of counters for the ristretto cache + maxCost = 100000000 // Maximum cost for the ristretto cache (100MB) + bufferItems = 64 // Buffer items for the ristretto cache +) + +// EntitlementPolicyCache caches attributes and subject mappings with periodic refresh +type EntitlementPolicyCache struct { + dbClient policydb.PolicyDBClient + logger *logger.Logger + attributesCache *cache.Cache[[]*policy.Attribute] + subjectMappingCache *cache.Cache[[]*policy.SubjectMapping] + configuredRefreshInterval time.Duration + stopRefresh chan struct{} + refreshCompleted chan struct{} +} + +func (c *EntitlementPolicyCache) IsEnabled() bool { + return c != nil +} + +// Start initiates the cache and begins periodic refresh +func (c *EntitlementPolicyCache) Start(ctx context.Context) error { + // Reset channels in case Start is called multiple times + // Only reset if stopRefresh is closed or nil + select { + case <-c.stopRefresh: + // Channel was closed, recreate it + c.stopRefresh = make(chan struct{}) + c.refreshCompleted = make(chan struct{}) + default: + // Channel is still open, do nothing + } + + // Initial refresh + if err := c.Refresh(ctx); err != nil { + return fmt.Errorf("failed initial cache refresh: %w", err) + } + + // Begin periodic refresh if an interval is set + if c.configuredRefreshInterval > 0 { + c.logger.DebugContext(ctx, "Starting periodic cache refresh", + "interval_seconds", c.configuredRefreshInterval.Seconds()) + go c.periodicRefresh(ctx) + } else { + c.logger.DebugContext(ctx, "Periodic cache refresh is disabled (interval <= 0)") + } + + return nil +} + +// Timeout for the stop operation +var stopTimeout = 5 * time.Second + +// Stop stops the periodic refresh goroutine if it's running +func (c *EntitlementPolicyCache) Stop() { + // Only attempt to stop the refresh goroutine if an interval was set + if c.configuredRefreshInterval > 0 { + // Check if stopRefresh is already closed + select { + case <-c.stopRefresh: + // Channel is already closed, nothing to do + c.logger.DebugContext(context.Background(), "Stop called on already stopped cache") + return + default: + // Channel is still open, proceed with closing + // Signal the goroutine to stop + close(c.stopRefresh) + // Wait with a timeout for the refresh goroutine to complete + select { + case <-c.refreshCompleted: + // Goroutine completed successfully + case <-time.After(stopTimeout): + // Timeout as a safety mechanism in case the goroutine is stuck + c.logger.WarnContext(context.Background(), "Timed out waiting for refresh goroutine to complete") + } + } + } +} + +// Refresh manually refreshes the cache +func (c *EntitlementPolicyCache) Refresh(ctx context.Context) error { + attributes, err := c.dbClient.ListAllAttributes(ctx) + if err != nil { + return fmt.Errorf("failed to fetch attributes: %w", err) + } + err = c.attributesCache.Set(ctx, attributesCacheKey, attributes) + if err != nil { + return fmt.Errorf("failed to cache attributes: %w", err) + } + + subjectMappings, err := c.dbClient.ListAllSubjectMappings(ctx) + if err != nil { + return fmt.Errorf("failed to fetch subject mappings: %w", err) + } + err = c.subjectMappingCache.Set(ctx, subjectMappingsCacheKey, subjectMappings) + if err != nil { + return fmt.Errorf("failed to cache subject mappings: %w", err) + } + + c.logger.DebugContext(ctx, + "EntitlementPolicyCache refreshed", + "attributes_count", len(attributes), + "subject_mappings_count", len(subjectMappings), + ) + + return nil +} + +// ListCachedAttributes returns the cached attributes and overall total, where +// a limit of 0 and offset 0 returns all attributes +func (c *EntitlementPolicyCache) ListCachedAttributes(ctx context.Context, limit, offset int32) ([]*policy.Attribute, int32, error) { + attributes, err := c.attributesCache.Get(ctx, attributesCacheKey) + if err != nil { + return nil, 0, fmt.Errorf("failed to retrieve attributes from cache: %w", err) + } + + total := int32(len(attributes)) + // TODO: we may want to copy this so callers cannot modify the cached data + // If offset is beyond the length, return empty slice + if offset >= total { + return nil, 0, nil + } + // If limit is 0, return any attributes beyond the offset + if limit == 0 { + return attributes[offset:], total, nil + } + // Ensure we don't exceed the slice bounds + limited := min(offset+limit, total) + + return attributes[offset:limited], total, nil +} + +// ListCachedSubjectMappings returns the cached subject mappings and overall total, where +// a limit of 0 returns all subject mappings +func (c *EntitlementPolicyCache) ListCachedSubjectMappings(ctx context.Context, limit, offset int32) ([]*policy.SubjectMapping, int32, error) { + subjectMappings, err := c.subjectMappingCache.Get(ctx, subjectMappingsCacheKey) + if err != nil { + return nil, 0, fmt.Errorf("failed to retrieve subject mappings from cache: %w", err) + } + total := int32(len(subjectMappings)) + // TODO: we may want to copy this so callers cannot modify the cached data + // If offset is beyond the length, return empty slice + if offset >= total { + return nil, 0, nil + } + // If limit is 0, return any subject mappings beyond the offset + if limit == 0 { + return subjectMappings[offset:], total, nil + } + // Ensure we don't exceed the slice bounds + limited := min(offset+limit, total) + + return subjectMappings[offset:limited], total, nil +} + +// periodicRefresh refreshes the cache at the specified interval +func (c *EntitlementPolicyCache) periodicRefresh(ctx context.Context) { + //nolint:mnd // Half the refresh interval for the context timeout + waitTimeout := c.configuredRefreshInterval / 2 + + ticker := time.NewTicker(c.configuredRefreshInterval) + defer func() { + ticker.Stop() + // Always signal completion, regardless of how we exit + close(c.refreshCompleted) + }() + + for { + select { + case <-ticker.C: + // Create a child context that can be canceled if refresh takes too long + refreshCtx, cancel := context.WithTimeout(ctx, waitTimeout) + err := c.Refresh(refreshCtx) + cancel() // Always cancel the context to prevent leaks + if err != nil { + c.logger.ErrorContext(ctx, "Failed to refresh cache", "error", err) + } + case <-c.stopRefresh: + return + case <-ctx.Done(): + c.logger.DebugContext(ctx, "Context canceled, stopping periodic refresh") + return + } + } +} + +func GetSharedEntitlementPolicyCache( + ctx context.Context, + dbClient policydb.PolicyDBClient, + l *logger.Logger, + cfg *Config, +) *EntitlementPolicyCache { + if cfg.CacheRefreshIntervalSeconds == 0 { + l.DebugContext(ctx, "Entitlement policy cache is disabled, returning nil") + return nil + } + + var initErr error + entitlementPolicyCacheOnce.Do(func() { + l.DebugContext(ctx, "Initializing shared entitlement policy cache") + instance := &EntitlementPolicyCache{ + logger: l, + dbClient: dbClient, + configuredRefreshInterval: time.Duration(cfg.CacheRefreshIntervalSeconds) * time.Second, + stopRefresh: make(chan struct{}), + refreshCompleted: make(chan struct{}), + } + + ristrettoCache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: numCounters, + MaxCost: maxCost, + BufferItems: bufferItems, + }) + if err != nil { + panic(err) + } + ristrettoStore := ristretto_store.NewRistretto(ristrettoCache) + + attributesCache := cache.New[[]*policy.Attribute](ristrettoStore) + instance.attributesCache = attributesCache + + subjectMappingCache := cache.New[[]*policy.SubjectMapping](ristrettoStore) + instance.subjectMappingCache = subjectMappingCache + + // Try to start the cache + if err := instance.Start(ctx); err != nil { + l.ErrorContext(ctx, "Failed to start entitlement policy cache", "error", err) + initErr = err + return + } + + // Only set the instance if Start() succeeds + entitlementPolicyCacheInstance = instance + l.DebugContext(ctx, "Shared entitlement policy cache initialized") + }) + + // Log if we're returning nil due to an initialization error + if initErr != nil && entitlementPolicyCacheInstance == nil { + l.WarnContext(ctx, "Returning nil entitlement policy cache due to previous initialization error") + } + + return entitlementPolicyCacheInstance +} diff --git a/service/policy/config/config.go b/service/policy/config/config.go index 6f9423edc7..2164320e32 100644 --- a/service/policy/config/config.go +++ b/service/policy/config/config.go @@ -15,6 +15,10 @@ type Config struct { ListRequestLimitDefault int `mapstructure:"list_request_limit_default" default:"1000"` // Maximum pagination list limit allowed by policy services ListRequestLimitMax int `mapstructure:"list_request_limit_max" default:"2500"` + + // Interval in seconds to refresh the in-memory policy entitlement cache (attributes and subject mappings) + // Default: cache is disabled with refresh interval set to 0. + CacheRefreshIntervalSeconds int `mapstructure:"cache_refresh_interval_seconds" default:"0"` // Cache disabled by default } func (c Config) Validate() error { @@ -24,6 +28,7 @@ func (c Config) Validate() error { return nil } +// GetSharedPolicyConfig retrieves the shared policy configuration, applying defaults and validating it. func GetSharedPolicyConfig(cfg config.ServiceConfig) (*Config, error) { policyCfg := new(Config) diff --git a/service/policy/db/attributes.go b/service/policy/db/attributes.go index d778df7207..5e72ccad57 100644 --- a/service/policy/db/attributes.go +++ b/service/policy/db/attributes.go @@ -198,7 +198,7 @@ func (c PolicyDBClient) ListAllAttributes(ctx context.Context) ([]*policy.Attrib for { listed, err := c.ListAttributes(ctx, &attributes.ListAttributesRequest{ - State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, Pagination: &policy.PageRequest{ Limit: c.listCfg.limitMax, Offset: nextOffset, diff --git a/service/policy/db/subject_mappings.go b/service/policy/db/subject_mappings.go index 235208e744..83296abc56 100644 --- a/service/policy/db/subject_mappings.go +++ b/service/policy/db/subject_mappings.go @@ -440,6 +440,34 @@ func (c PolicyDBClient) ListSubjectMappings(ctx context.Context, r *subjectmappi }, nil } +// Loads all subject mappings into memory by making iterative db roundtrip requests of defaultObjectListAllLimit size +func (c PolicyDBClient) ListAllSubjectMappings(ctx context.Context) ([]*policy.SubjectMapping, error) { + var nextOffset int32 + smList := make([]*policy.SubjectMapping, 0) + + for { + listed, err := c.ListSubjectMappings(ctx, &subjectmapping.ListSubjectMappingsRequest{ + Pagination: &policy.PageRequest{ + Limit: c.listCfg.limitMax, + Offset: nextOffset, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to list all attributes: %w", err) + } + + nextOffset = listed.GetPagination().GetNextOffset() + smList = append(smList, listed.GetSubjectMappings()...) + + // offset becomes zero when list is exhausted + if nextOffset <= 0 { + break + } + } + + return smList, nil +} + // Mutates provided fields and returns the updated subject mapping func (c PolicyDBClient) UpdateSubjectMapping(ctx context.Context, r *subjectmapping.UpdateSubjectMappingRequest) (*policy.SubjectMapping, error) { id := r.GetId() diff --git a/service/policy/subjectmapping/subject_mapping.go b/service/policy/subjectmapping/subject_mapping.go index d7b1784695..8a985b195c 100644 --- a/service/policy/subjectmapping/subject_mapping.go +++ b/service/policy/subjectmapping/subject_mapping.go @@ -22,6 +22,15 @@ type SubjectMappingService struct { //nolint:revive // SubjectMappingService is dbClient policydb.PolicyDBClient logger *logger.Logger config *policyconfig.Config + // Cache for attributes and subject mappings + cache *policyconfig.EntitlementPolicyCache +} + +func OnServicesStarted(svc *SubjectMappingService) serviceregistry.OnServicesStartedHook { + return func(ctx context.Context) error { + svc.cache = policyconfig.GetSharedEntitlementPolicyCache(ctx, svc.dbClient, svc.logger, svc.config) + return nil + } } func OnConfigUpdate(smSvc *SubjectMappingService) serviceregistry.OnConfigUpdateHook { @@ -42,16 +51,18 @@ func OnConfigUpdate(smSvc *SubjectMappingService) serviceregistry.OnConfigUpdate func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[subjectmappingconnect.SubjectMappingServiceHandler] { smSvc := new(SubjectMappingService) onUpdateConfigHook := OnConfigUpdate(smSvc) + onStartHook := OnServicesStarted(smSvc) return &serviceregistry.Service[subjectmappingconnect.SubjectMappingServiceHandler]{ Close: smSvc.Close, ServiceOptions: serviceregistry.ServiceOptions[subjectmappingconnect.SubjectMappingServiceHandler]{ - Namespace: ns, - DB: dbRegister, - ServiceDesc: &sm.SubjectMappingService_ServiceDesc, - ConnectRPCFunc: subjectmappingconnect.NewSubjectMappingServiceHandler, - GRPCGatewayFunc: sm.RegisterSubjectMappingServiceHandler, - OnConfigUpdate: onUpdateConfigHook, + Namespace: ns, + DB: dbRegister, + ServiceDesc: &sm.SubjectMappingService_ServiceDesc, + ConnectRPCFunc: subjectmappingconnect.NewSubjectMappingServiceHandler, + GRPCGatewayFunc: sm.RegisterSubjectMappingServiceHandler, + OnConfigUpdate: onUpdateConfigHook, + OnServicesStarted: onStartHook, RegisterFunc: func(srp serviceregistry.RegistrationParams) (subjectmappingconnect.SubjectMappingServiceHandler, serviceregistry.HandlerServer) { logger := srp.Logger cfg, err := policyconfig.GetSharedPolicyConfig(srp.Config) @@ -63,6 +74,7 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer smSvc.logger = logger smSvc.dbClient = policydb.NewClient(srp.DBClient, logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)) smSvc.config = cfg + return smSvc, nil }, }, @@ -72,7 +84,12 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer // Close gracefully shuts down the service, closing the database client. func (s *SubjectMappingService) Close() { s.logger.Info("gracefully shutting down subject mapping service") + s.dbClient.Close() + + if s.cache != nil { + s.cache.Stop() + } } /* --------------------------------------------------- @@ -109,6 +126,7 @@ func (s SubjectMappingService) CreateSubjectMapping(ctx context.Context, if err != nil { return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("subjectMapping", req.Msg.String())) } + return connect.NewResponse(rsp), nil } @@ -117,6 +135,47 @@ func (s SubjectMappingService) ListSubjectMappings(ctx context.Context, ) (*connect.Response[sm.ListSubjectMappingsResponse], error) { s.logger.Debug("listing subject mappings") + // If caching enabled, return from cache instead of DB + if s.cache.IsEnabled() { + s.logger.Debug("returning cached subject mappings") + + limit := req.Msg.GetPagination().GetLimit() + if limit <= 0 { + limit = int32(s.config.ListRequestLimitDefault) + } + offset := req.Msg.GetPagination().GetOffset() + + // Validate limit against the max configured value + maxLimit := int32(s.config.ListRequestLimitMax) + if maxLimit > 0 && limit > maxLimit { + return nil, db.StatusifyError(db.ErrListLimitTooLarge, db.ErrTextListLimitTooLarge) + } + + // Get all cached subject mappings + cachedSubjectMappings, total, err := s.cache.ListCachedSubjectMappings(ctx, limit, offset) + if err != nil { + s.logger.Error("failed to retrieve cached subject mappings", slog.Any("error", err)) + return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) + } + + // Calculate next offset using the same logic as the DB implementation + var nextOffset int32 + next := offset + limit + if next < total { + nextOffset = next + } // else nextOffset remains 0 + + rsp := &sm.ListSubjectMappingsResponse{ + SubjectMappings: cachedSubjectMappings, + Pagination: &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + }, + } + return connect.NewResponse(rsp), nil + } + rsp, err := s.dbClient.ListSubjectMappings(ctx, req.Msg) if err != nil { return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed) @@ -180,6 +239,7 @@ func (s SubjectMappingService) UpdateSubjectMapping(ctx context.Context, if err != nil { return nil, err } + return connect.NewResponse(rsp), nil } diff --git a/service/policy/unsafe/unsafe.go b/service/policy/unsafe/unsafe.go index 2b598edcd8..1df64b8b63 100644 --- a/service/policy/unsafe/unsafe.go +++ b/service/policy/unsafe/unsafe.go @@ -61,6 +61,7 @@ func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *servicer unsafeSvc.logger = logger unsafeSvc.dbClient = policydb.NewClient(srp.DBClient, logger, int32(cfg.ListRequestLimitMax), int32(cfg.ListRequestLimitDefault)) unsafeSvc.config = cfg + return unsafeSvc, nil }, },