Skip to content

Commit

Permalink
new generic protojson marshaller, keep existing types on deprecated f…
Browse files Browse the repository at this point in the history
…astmarshaller
  • Loading branch information
nklaassen committed May 10, 2024
1 parent 98c15c4 commit aa7e4aa
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 180 deletions.
4 changes: 2 additions & 2 deletions lib/services/access_monitoring_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ func ValidateAccessMonitoringRule(accessMonitoringRule *accessmonitoringrulesv1.

// MarshalAccessMonitoringRule marshals AccessMonitoringRule resource to JSON.
func MarshalAccessMonitoringRule(accessMonitoringRule *accessmonitoringrulesv1.AccessMonitoringRule, opts ...MarshalOption) ([]byte, error) {
return MarshalProtoResource(accessMonitoringRule, opts...)
return FastMarshalProtoResourceDeprecated(accessMonitoringRule, opts...)
}

// UnmarshalAccessMonitoringRule unmarshals the AccessMonitoringRule resource.
func UnmarshalAccessMonitoringRule(data []byte, opts ...MarshalOption) (*accessmonitoringrulesv1.AccessMonitoringRule, error) {
return UnmarshalProtoResource[*accessmonitoringrulesv1.AccessMonitoringRule](data, opts...)
return FastUnmarshalProtoResourceDeprecated[*accessmonitoringrulesv1.AccessMonitoringRule](data, opts...)
}
38 changes: 2 additions & 36 deletions lib/services/crown_jewel.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ package services
import (
"context"

"github.com/gravitational/trace"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"

crownjewelv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/crownjewel/v1"
)

Expand All @@ -49,39 +44,10 @@ type CrownJewels interface {

// MarshalCrownJewel marshals the CrownJewel object into a JSON byte array.
func MarshalCrownJewel(object *crownjewelv1.CrownJewel, opts ...MarshalOption) ([]byte, error) {
cfg, err := CollectOptions(opts)
if err != nil {
return nil, trace.Wrap(err)
}
if !cfg.PreserveResourceID {
object = proto.Clone(object).(*crownjewelv1.CrownJewel)
object.Metadata.Revision = ""
}
data, err := protojson.Marshal(object)
if err != nil {
return nil, trace.Wrap(err)
}
return data, nil
return MarshalProtoResource(object, opts...)
}

// UnmarshalCrownJewel unmarshals the CrownJewel object from a JSON byte array.
func UnmarshalCrownJewel(data []byte, opts ...MarshalOption) (*crownjewelv1.CrownJewel, error) {
if len(data) == 0 {
return nil, trace.BadParameter("missing crown jewel data")
}
cfg, err := CollectOptions(opts)
if err != nil {
return nil, trace.Wrap(err)
}
var obj crownjewelv1.CrownJewel
if err := protojson.Unmarshal(data, &obj); err != nil {
return nil, trace.BadParameter(err.Error())
}
if cfg.Revision != "" {
obj.Metadata.Revision = cfg.Revision
}
if !cfg.Expires.IsZero() {
obj.Metadata.Expires = timestamppb.New(cfg.Expires)
}
return &obj, nil
return UnmarshalProtoResource[*crownjewelv1.CrownJewel](data, opts...)
}
4 changes: 2 additions & 2 deletions lib/services/kubewaitingcontainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ func MarshalKubeWaitingContainer(in *kubewaitingcontainerpb.KubernetesWaitingCon
return nil, trace.Wrap(err)
}

return MarshalProtoResource(in, opts...)
return FastMarshalProtoResourceDeprecated(in, opts...)
}

// UnmarshalKubeWaitingContainer unmarshals a KubernetesWaitingContainer resource from JSON.
func UnmarshalKubeWaitingContainer(data []byte, opts ...MarshalOption) (*kubewaitingcontainerpb.KubernetesWaitingContainer, error) {
out, err := UnmarshalProtoResource[*kubewaitingcontainerpb.KubernetesWaitingContainer](data, opts...)
out, err := FastUnmarshalProtoResourceDeprecated[*kubewaitingcontainerpb.KubernetesWaitingContainer](data, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
6 changes: 4 additions & 2 deletions lib/services/local/databaseobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ func NewDatabaseObjectService(backend backend.Backend) (services.DatabaseObjects
service, err := generic.NewServiceWrapper(backend,
types.KindDatabaseObject,
databaseObjectPrefix,
services.MarshalProtoResource[*dbobjectv1.DatabaseObject],
services.UnmarshalProtoResource[*dbobjectv1.DatabaseObject],
//nolint:staticcheck // SA1019. Using this marshaler for json compatibility.
services.FastMarshalProtoResourceDeprecated[*dbobjectv1.DatabaseObject],
//nolint:staticcheck // SA1019. Using this unmarshaler for json compatibility.
services.FastUnmarshalProtoResourceDeprecated[*dbobjectv1.DatabaseObject],
)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
6 changes: 4 additions & 2 deletions lib/services/local/databaseobject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,11 @@ func TestMarshalDatabaseObjectRoundTrip(t *testing.T) {
obj, err := databaseobject.NewDatabaseObject("dummy-table", spec)
require.NoError(t, err)

out, err := services.MarshalProtoResource(obj)
//nolint:staticcheck // SA1019. Using this marshaler for json compatibility.
out, err := services.FastMarshalProtoResourceDeprecated(obj)
require.NoError(t, err)
newObj, err := services.UnmarshalProtoResource[*dbobjectv1.DatabaseObject](out)
//nolint:staticcheck // SA1019. Using this unmarshaler for json compatibility.
newObj, err := services.FastUnmarshalProtoResourceDeprecated[*dbobjectv1.DatabaseObject](out)
require.NoError(t, err)
require.True(t, proto.Equal(obj, newObj), "messages are not equal")
}
6 changes: 4 additions & 2 deletions lib/services/local/databaseobjectimportrule.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ func NewDatabaseObjectImportRuleService(backend backend.Backend) (services.Datab
service, err := generic.NewServiceWrapper(backend,
types.KindDatabaseObjectImportRule,
databaseObjectImportRulePrefix,
services.MarshalProtoResource[*databaseobjectimportrulev1.DatabaseObjectImportRule],
services.UnmarshalProtoResource[*databaseobjectimportrulev1.DatabaseObjectImportRule],
//nolint:staticcheck // SA1019. Using this marshaler for json compatibility.
services.FastMarshalProtoResourceDeprecated[*databaseobjectimportrulev1.DatabaseObjectImportRule],
//nolint:staticcheck // SA1019. Using this unmarshaler for json compatibility.
services.FastUnmarshalProtoResourceDeprecated[*databaseobjectimportrulev1.DatabaseObjectImportRule],
)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
6 changes: 4 additions & 2 deletions lib/services/local/databaseobjectimportrule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,11 @@ func TestMarshalDatabaseObjectImportRuleRoundTrip(t *testing.T) {
Spec: spec,
}

out, err := services.MarshalProtoResource(obj)
//nolint:staticcheck // SA1019. Using this marshaler for json compatibility.
out, err := services.FastMarshalProtoResourceDeprecated(obj)
require.NoError(t, err)
newObj, err := services.UnmarshalProtoResource[*databaseobjectimportrulev1.DatabaseObjectImportRule](out)
//nolint:staticcheck // SA1019. Using this unmarshaler for json compatibility.
newObj, err := services.FastUnmarshalProtoResourceDeprecated[*databaseobjectimportrulev1.DatabaseObjectImportRule](out)
require.NoError(t, err)
require.True(t, proto.Equal(obj, newObj), "messages are not equal")
}
8 changes: 2 additions & 6 deletions lib/services/local/vnet_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ package local

import (
"context"
"log/slog"
"net"

"github.com/gravitational/trace"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/gen/proto/go/teleport/vnet/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/backend"
Expand All @@ -37,8 +35,7 @@ const (
)

type VnetConfigService struct {
slog *slog.Logger
svc *generic.ServiceWrapper[*vnet.VnetConfig]
svc *generic.ServiceWrapper[*vnet.VnetConfig]
}

func NewVnetConfigService(backend backend.Backend) (*VnetConfigService, error) {
Expand All @@ -54,8 +51,7 @@ func NewVnetConfigService(backend backend.Backend) (*VnetConfigService, error) {
}

return &VnetConfigService{
svc: svc,
slog: slog.With(teleport.ComponentKey, "VnetConfig.local"),
svc: svc,
}, nil
}

Expand Down
56 changes: 8 additions & 48 deletions lib/services/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ import (
"context"

"github.com/gravitational/trace"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"

notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1"
)
Expand Down Expand Up @@ -78,12 +75,12 @@ func MarshalNotification(notification *notificationsv1.Notification, opts ...Mar
return nil, trace.Wrap(err)
}

return MarshalProtoResource(notification, opts...)
return FastMarshalProtoResourceDeprecated(notification, opts...)
}

// UnmarshalNotification unmarshals a Notification resource from JSON.
func UnmarshalNotification(data []byte, opts ...MarshalOption) (*notificationsv1.Notification, error) {
return UnmarshalProtoResource[*notificationsv1.Notification](data, opts...)
return FastUnmarshalProtoResourceDeprecated[*notificationsv1.Notification](data, opts...)
}

// ValidateGlobalNotification verifies that the necessary fields are configured for a global notification object.
Expand Down Expand Up @@ -113,49 +110,12 @@ func MarshalGlobalNotification(globalNotification *notificationsv1.GlobalNotific
return nil, trace.Wrap(err)
}

cfg, err := CollectOptions(opts)
if err != nil {
return nil, trace.Wrap(err)
}
if !cfg.PreserveResourceID {
globalNotification = proto.Clone(globalNotification).(*notificationsv1.GlobalNotification)
//nolint:staticcheck // SA1019. Deprecated, but still needed.
globalNotification.Metadata.Id = 0
globalNotification.Metadata.Revision = ""
}
// We marshal with raw protojson here because utils.FastMarshal doesn't work with oneof.
data, err := protojson.Marshal(globalNotification)
if err != nil {
return nil, trace.Wrap(err)
}
return data, nil
return MarshalProtoResource(globalNotification, opts...)
}

// UnmarshalGlobalNotification unmarshals a GlobalNotification resource from JSON.
func UnmarshalGlobalNotification(data []byte, opts ...MarshalOption) (*notificationsv1.GlobalNotification, error) {
if len(data) == 0 {
return nil, trace.BadParameter("missing notification data")
}
cfg, err := CollectOptions(opts)
if err != nil {
return nil, trace.Wrap(err)
}
var obj notificationsv1.GlobalNotification
// We unmarshal with raw protojson here because utils.FastUnmarshal doesn't work with oneof.
if err = protojson.Unmarshal(data, &obj); err != nil {
return nil, trace.Wrap(err)
}
if cfg.ID != 0 {
//nolint:staticcheck // SA1019. Id is deprecated, but still needed.
obj.Metadata.Id = cfg.ID
}
if cfg.Revision != "" {
obj.Metadata.Revision = cfg.Revision
}
if !cfg.Expires.IsZero() {
obj.Metadata.Expires = timestamppb.New(cfg.Expires)
}
return &obj, nil
return UnmarshalProtoResource[*notificationsv1.GlobalNotification](data, opts...)
}

// ValidateUserNotificationState verifies that the necessary fields are configured for user notification state object.
Expand All @@ -177,12 +137,12 @@ func MarshalUserNotificationState(notificationState *notificationsv1.UserNotific
return nil, trace.Wrap(err)
}

return MarshalProtoResource(notificationState, opts...)
return FastMarshalProtoResourceDeprecated(notificationState, opts...)
}

// UnmarshalUserNotificationState unmarshals a UserNotificationState resource from JSON.
func UnmarshalUserNotificationState(data []byte, opts ...MarshalOption) (*notificationsv1.UserNotificationState, error) {
return UnmarshalProtoResource[*notificationsv1.UserNotificationState](data, opts...)
return FastUnmarshalProtoResourceDeprecated[*notificationsv1.UserNotificationState](data, opts...)
}

// ValidateUserLastSeenNotification verifies that the necessary fields are configured for a user's last seen notification timestamp object.
Expand All @@ -200,10 +160,10 @@ func MarshalUserLastSeenNotification(userLastSeenNotification *notificationsv1.U
return nil, trace.Wrap(err)
}

return MarshalProtoResource(userLastSeenNotification, opts...)
return FastMarshalProtoResourceDeprecated(userLastSeenNotification, opts...)
}

// UnmarshalUserLastSeenNotification unmarshals a UserLastSeenNotification resource from JSON.
func UnmarshalUserLastSeenNotification(data []byte, opts ...MarshalOption) (*notificationsv1.UserLastSeenNotification, error) {
return UnmarshalProtoResource[*notificationsv1.UserLastSeenNotification](data, opts...)
return FastUnmarshalProtoResourceDeprecated[*notificationsv1.UserLastSeenNotification](data, opts...)
}
Loading

0 comments on commit aa7e4aa

Please sign in to comment.