From 3d5666d4c6b6f90c7f0801d11f31a0846170d6da Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Wed, 26 Feb 2025 10:47:05 -0600 Subject: [PATCH] Added GET enterprise API endpoint. (#26555) For #26218 - Added `GET /api/_version_/fleet/android_enterprise` andpoint and tests - Set up some testing infrastructure for Android service tests -- see new README.md # Checklist for submitter - [x] Added/updated automated tests - [x] Manual QA for all new/changed functionality --- Makefile | 2 +- server/authz/policy.rego | 2 +- server/mdm/android/arch_test.go | 1 + server/mdm/android/datastore.go | 2 +- server/mdm/android/mock/android.go | 4 + server/mdm/android/mock/datastore.go | 113 +++++++++++++ server/mdm/android/mock/datastore_setup.go | 28 ++++ server/mdm/android/mock/proxy.go | 77 +++++++++ server/mdm/android/mock/proxy_setup.go | 23 +++ server/mdm/android/mysql/enterprises.go | 2 +- server/mdm/android/mysql/enterprises_test.go | 6 +- .../mysql/{mysql_test.go => testing_utils.go} | 12 +- server/mdm/android/proxy.go | 22 +++ server/mdm/android/service.go | 20 +++ .../mdm/android/service/enterprises_test.go | 135 +++++++++++++++ server/mdm/android/service/handler.go | 15 +- server/mdm/android/service/proxy/proxy.go | 16 +- server/mdm/android/service/pubsub.go | 2 +- server/mdm/android/service/service.go | 85 ++++++---- server/mdm/android/tests/README.md | 5 + .../tests/enterprise/enterprise_test.go | 52 ++++++ server/mdm/android/tests/http.go | 49 ++++++ server/mdm/android/tests/testing_utils.go | 157 ++++++++++++++++++ server/service/endpoint_utils_test.go | 5 +- server/service/handler.go | 15 +- server/service/middleware/log/log.go | 13 ++ server/service/testing_client.go | 46 +---- server/test/httptest/README.md | 2 + server/test/httptest/http.go | 55 ++++++ 29 files changed, 854 insertions(+), 112 deletions(-) create mode 100644 server/mdm/android/mock/android.go create mode 100644 server/mdm/android/mock/datastore.go create mode 100644 server/mdm/android/mock/datastore_setup.go create mode 100644 server/mdm/android/mock/proxy.go create mode 100644 server/mdm/android/mock/proxy_setup.go rename server/mdm/android/mysql/{mysql_test.go => testing_utils.go} (71%) create mode 100644 server/mdm/android/proxy.go create mode 100644 server/mdm/android/service/enterprises_test.go create mode 100644 server/mdm/android/tests/README.md create mode 100644 server/mdm/android/tests/enterprise/enterprise_test.go create mode 100644 server/mdm/android/tests/http.go create mode 100644 server/mdm/android/tests/testing_utils.go create mode 100644 server/test/httptest/README.md create mode 100644 server/test/httptest/http.go diff --git a/Makefile b/Makefile index cd54e4c20a98..839f2e8fd8d5 100644 --- a/Makefile +++ b/Makefile @@ -239,7 +239,7 @@ generate-dev: .prefix NODE_ENV=development yarn run webpack --progress --watch generate-mock: .prefix - go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult github.com/fleetdm/fleet/v4/server/service/mock + go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult github.com/fleetdm/fleet/v4/server/service/mock github.com/fleetdm/fleet/v4/server/mdm/android/mock generate-doc: .prefix go generate github.com/fleetdm/fleet/v4/server/fleet diff --git a/server/authz/policy.rego b/server/authz/policy.rego index 49821732255f..937c73d2799d 100644 --- a/server/authz/policy.rego +++ b/server/authz/policy.rego @@ -1027,5 +1027,5 @@ allow { allow { object.type == "android_enterprise" subject.global_role == admin - action == write + action == [read, write][_] } diff --git a/server/mdm/android/arch_test.go b/server/mdm/android/arch_test.go index 3f07acf5b431..fd753ad88186 100644 --- a/server/mdm/android/arch_test.go +++ b/server/mdm/android/arch_test.go @@ -22,6 +22,7 @@ func TestAllAndroidPackageDependencies(t *testing.T) { "github.com/fleetdm/fleet/v4/server/service/middleware/auth", "github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck", "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils", + "github.com/fleetdm/fleet/v4/server/service/middleware/log", "github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit", ). ShouldNotDependOn( diff --git a/server/mdm/android/datastore.go b/server/mdm/android/datastore.go index 8ff78066871c..e5495f1916fe 100644 --- a/server/mdm/android/datastore.go +++ b/server/mdm/android/datastore.go @@ -11,7 +11,7 @@ type Datastore interface { GetEnterpriseByID(ctx context.Context, ID uint) (*EnterpriseDetails, error) GetEnterprise(ctx context.Context) (*Enterprise, error) UpdateEnterprise(ctx context.Context, enterprise *EnterpriseDetails) error - DeleteEnterprises(ctx context.Context) error + DeleteAllEnterprises(ctx context.Context) error DeleteOtherEnterprises(ctx context.Context, ID uint) error CreateDeviceTx(ctx context.Context, tx sqlx.ExtContext, device *Device) (*Device, error) diff --git a/server/mdm/android/mock/android.go b/server/mdm/android/mock/android.go new file mode 100644 index 000000000000..7f7da7086137 --- /dev/null +++ b/server/mdm/android/mock/android.go @@ -0,0 +1,4 @@ +package mock + +//go:generate go run ../../../mock/mockimpl/impl.go -o proxy.go "p *Proxy" "android.Proxy" +//go:generate go run ../../../mock/mockimpl/impl.go -o datastore.go "ds *Datastore" "android.Datastore" diff --git a/server/mdm/android/mock/datastore.go b/server/mdm/android/mock/datastore.go new file mode 100644 index 000000000000..ac0c9f23d47a --- /dev/null +++ b/server/mdm/android/mock/datastore.go @@ -0,0 +1,113 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "context" + "sync" + + "github.com/fleetdm/fleet/v4/server/mdm/android" + "github.com/jmoiron/sqlx" +) + +var _ android.Datastore = (*Datastore)(nil) + +type CreateEnterpriseFunc func(ctx context.Context) (uint, error) + +type GetEnterpriseByIDFunc func(ctx context.Context, ID uint) (*android.EnterpriseDetails, error) + +type GetEnterpriseFunc func(ctx context.Context) (*android.Enterprise, error) + +type UpdateEnterpriseFunc func(ctx context.Context, enterprise *android.EnterpriseDetails) error + +type DeleteAllEnterprisesFunc func(ctx context.Context) error + +type DeleteOtherEnterprisesFunc func(ctx context.Context, ID uint) error + +type CreateDeviceTxFunc func(ctx context.Context, tx sqlx.ExtContext, device *android.Device) (*android.Device, error) + +type UpdateDeviceTxFunc func(ctx context.Context, tx sqlx.ExtContext, device *android.Device) error + +type Datastore struct { + CreateEnterpriseFunc CreateEnterpriseFunc + CreateEnterpriseFuncInvoked bool + + GetEnterpriseByIDFunc GetEnterpriseByIDFunc + GetEnterpriseByIDFuncInvoked bool + + GetEnterpriseFunc GetEnterpriseFunc + GetEnterpriseFuncInvoked bool + + UpdateEnterpriseFunc UpdateEnterpriseFunc + UpdateEnterpriseFuncInvoked bool + + DeleteAllEnterprisesFunc DeleteAllEnterprisesFunc + DeleteAllEnterprisesFuncInvoked bool + + DeleteOtherEnterprisesFunc DeleteOtherEnterprisesFunc + DeleteOtherEnterprisesFuncInvoked bool + + CreateDeviceTxFunc CreateDeviceTxFunc + CreateDeviceTxFuncInvoked bool + + UpdateDeviceTxFunc UpdateDeviceTxFunc + UpdateDeviceTxFuncInvoked bool + + mu sync.Mutex +} + +func (ds *Datastore) CreateEnterprise(ctx context.Context) (uint, error) { + ds.mu.Lock() + ds.CreateEnterpriseFuncInvoked = true + ds.mu.Unlock() + return ds.CreateEnterpriseFunc(ctx) +} + +func (ds *Datastore) GetEnterpriseByID(ctx context.Context, ID uint) (*android.EnterpriseDetails, error) { + ds.mu.Lock() + ds.GetEnterpriseByIDFuncInvoked = true + ds.mu.Unlock() + return ds.GetEnterpriseByIDFunc(ctx, ID) +} + +func (ds *Datastore) GetEnterprise(ctx context.Context) (*android.Enterprise, error) { + ds.mu.Lock() + ds.GetEnterpriseFuncInvoked = true + ds.mu.Unlock() + return ds.GetEnterpriseFunc(ctx) +} + +func (ds *Datastore) UpdateEnterprise(ctx context.Context, enterprise *android.EnterpriseDetails) error { + ds.mu.Lock() + ds.UpdateEnterpriseFuncInvoked = true + ds.mu.Unlock() + return ds.UpdateEnterpriseFunc(ctx, enterprise) +} + +func (ds *Datastore) DeleteAllEnterprises(ctx context.Context) error { + ds.mu.Lock() + ds.DeleteAllEnterprisesFuncInvoked = true + ds.mu.Unlock() + return ds.DeleteAllEnterprisesFunc(ctx) +} + +func (ds *Datastore) DeleteOtherEnterprises(ctx context.Context, ID uint) error { + ds.mu.Lock() + ds.DeleteOtherEnterprisesFuncInvoked = true + ds.mu.Unlock() + return ds.DeleteOtherEnterprisesFunc(ctx, ID) +} + +func (ds *Datastore) CreateDeviceTx(ctx context.Context, tx sqlx.ExtContext, device *android.Device) (*android.Device, error) { + ds.mu.Lock() + ds.CreateDeviceTxFuncInvoked = true + ds.mu.Unlock() + return ds.CreateDeviceTxFunc(ctx, tx, device) +} + +func (ds *Datastore) UpdateDeviceTx(ctx context.Context, tx sqlx.ExtContext, device *android.Device) error { + ds.mu.Lock() + ds.UpdateDeviceTxFuncInvoked = true + ds.mu.Unlock() + return ds.UpdateDeviceTxFunc(ctx, tx, device) +} diff --git a/server/mdm/android/mock/datastore_setup.go b/server/mdm/android/mock/datastore_setup.go new file mode 100644 index 000000000000..18d62845975a --- /dev/null +++ b/server/mdm/android/mock/datastore_setup.go @@ -0,0 +1,28 @@ +package mock + +import ( + "context" + + "github.com/fleetdm/fleet/v4/server/mdm/android" +) + +func (s *Datastore) InitCommonMocks() { + s.CreateEnterpriseFunc = func(ctx context.Context) (uint, error) { + return 1, nil + } + s.UpdateEnterpriseFunc = func(ctx context.Context, enterprise *android.EnterpriseDetails) error { + return nil + } + s.GetEnterpriseFunc = func(ctx context.Context) (*android.Enterprise, error) { + return &android.Enterprise{}, nil + } + s.GetEnterpriseByIDFunc = func(ctx context.Context, ID uint) (*android.EnterpriseDetails, error) { + return &android.EnterpriseDetails{}, nil + } + s.DeleteAllEnterprisesFunc = func(ctx context.Context) error { + return nil + } + s.DeleteOtherEnterprisesFunc = func(ctx context.Context, ID uint) error { + return nil + } +} diff --git a/server/mdm/android/mock/proxy.go b/server/mdm/android/mock/proxy.go new file mode 100644 index 000000000000..15a90a12c988 --- /dev/null +++ b/server/mdm/android/mock/proxy.go @@ -0,0 +1,77 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "context" + "sync" + + "github.com/fleetdm/fleet/v4/server/mdm/android" + "google.golang.org/api/androidmanagement/v1" +) + +var _ android.Proxy = (*Proxy)(nil) + +type SignupURLsCreateFunc func(callbackURL string) (*android.SignupDetails, error) + +type EnterprisesCreateFunc func(ctx context.Context, req android.ProxyEnterprisesCreateRequest) (string, string, error) + +type EnterprisesPoliciesPatchFunc func(enterpriseID string, policyName string, policy *androidmanagement.Policy) error + +type EnterprisesEnrollmentTokensCreateFunc func(enterpriseName string, token *androidmanagement.EnrollmentToken) (*androidmanagement.EnrollmentToken, error) + +type EnterpriseDeleteFunc func(enterpriseID string) error + +type Proxy struct { + SignupURLsCreateFunc SignupURLsCreateFunc + SignupURLsCreateFuncInvoked bool + + EnterprisesCreateFunc EnterprisesCreateFunc + EnterprisesCreateFuncInvoked bool + + EnterprisesPoliciesPatchFunc EnterprisesPoliciesPatchFunc + EnterprisesPoliciesPatchFuncInvoked bool + + EnterprisesEnrollmentTokensCreateFunc EnterprisesEnrollmentTokensCreateFunc + EnterprisesEnrollmentTokensCreateFuncInvoked bool + + EnterpriseDeleteFunc EnterpriseDeleteFunc + EnterpriseDeleteFuncInvoked bool + + mu sync.Mutex +} + +func (p *Proxy) SignupURLsCreate(callbackURL string) (*android.SignupDetails, error) { + p.mu.Lock() + p.SignupURLsCreateFuncInvoked = true + p.mu.Unlock() + return p.SignupURLsCreateFunc(callbackURL) +} + +func (p *Proxy) EnterprisesCreate(ctx context.Context, req android.ProxyEnterprisesCreateRequest) (string, string, error) { + p.mu.Lock() + p.EnterprisesCreateFuncInvoked = true + p.mu.Unlock() + return p.EnterprisesCreateFunc(ctx, req) +} + +func (p *Proxy) EnterprisesPoliciesPatch(enterpriseID string, policyName string, policy *androidmanagement.Policy) error { + p.mu.Lock() + p.EnterprisesPoliciesPatchFuncInvoked = true + p.mu.Unlock() + return p.EnterprisesPoliciesPatchFunc(enterpriseID, policyName, policy) +} + +func (p *Proxy) EnterprisesEnrollmentTokensCreate(enterpriseName string, token *androidmanagement.EnrollmentToken) (*androidmanagement.EnrollmentToken, error) { + p.mu.Lock() + p.EnterprisesEnrollmentTokensCreateFuncInvoked = true + p.mu.Unlock() + return p.EnterprisesEnrollmentTokensCreateFunc(enterpriseName, token) +} + +func (p *Proxy) EnterpriseDelete(enterpriseID string) error { + p.mu.Lock() + p.EnterpriseDeleteFuncInvoked = true + p.mu.Unlock() + return p.EnterpriseDeleteFunc(enterpriseID) +} diff --git a/server/mdm/android/mock/proxy_setup.go b/server/mdm/android/mock/proxy_setup.go new file mode 100644 index 000000000000..99caaef35d77 --- /dev/null +++ b/server/mdm/android/mock/proxy_setup.go @@ -0,0 +1,23 @@ +package mock + +import ( + "context" + + "github.com/fleetdm/fleet/v4/server/mdm/android" + "google.golang.org/api/androidmanagement/v1" +) + +func (p *Proxy) InitCommonMocks() { + p.EnterpriseDeleteFunc = func(enterpriseID string) error { + return nil + } + p.SignupURLsCreateFunc = func(callbackURL string) (*android.SignupDetails, error) { + return &android.SignupDetails{}, nil + } + p.EnterprisesCreateFunc = func(ctx context.Context, req android.ProxyEnterprisesCreateRequest) (string, string, error) { + return "enterpriseName", "projects/project/topics/topic", nil + } + p.EnterprisesPoliciesPatchFunc = func(enterpriseID string, policyName string, policy *androidmanagement.Policy) error { + return nil + } +} diff --git a/server/mdm/android/mysql/enterprises.go b/server/mdm/android/mysql/enterprises.go index 6527352c509e..aedcd122cbc8 100644 --- a/server/mdm/android/mysql/enterprises.go +++ b/server/mdm/android/mysql/enterprises.go @@ -77,7 +77,7 @@ func (ds *Datastore) DeleteOtherEnterprises(ctx context.Context, id uint) error return nil } -func (ds *Datastore) DeleteEnterprises(ctx context.Context) error { +func (ds *Datastore) DeleteAllEnterprises(ctx context.Context) error { stmt := `DELETE FROM android_enterprises` _, err := ds.Writer(ctx).ExecContext(ctx, stmt) if err != nil { diff --git a/server/mdm/android/mysql/enterprises_test.go b/server/mdm/android/mysql/enterprises_test.go index 58128fde1aef..c645f04f1512 100644 --- a/server/mdm/android/mysql/enterprises_test.go +++ b/server/mdm/android/mysql/enterprises_test.go @@ -20,7 +20,7 @@ func TestEnterprise(t *testing.T) { }{ {"CreateGetEnterprise", testCreateGetEnterprise}, {"UpdateEnterprise", testUpdateEnterprise}, - {"DeleteEnterprises", testDeleteEnterprises}, + {"DeleteAllEnterprises", testDeleteEnterprises}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -75,7 +75,7 @@ func testUpdateEnterprise(t *testing.T, ds *Datastore) { } func testDeleteEnterprises(t *testing.T, ds *Datastore) { - err := ds.DeleteEnterprises(testCtx()) + err := ds.DeleteAllEnterprises(testCtx()) require.NoError(t, err) err = ds.DeleteOtherEnterprises(testCtx(), 9999) require.NoError(t, err) @@ -108,7 +108,7 @@ func testDeleteEnterprises(t *testing.T, ds *Datastore) { _, err = ds.GetEnterpriseByID(testCtx(), tempEnterprise.ID) assert.True(t, fleet.IsNotFound(err)) - err = ds.DeleteEnterprises(testCtx()) + err = ds.DeleteAllEnterprises(testCtx()) require.NoError(t, err) _, err = ds.GetEnterpriseByID(testCtx(), enterprise.ID) assert.True(t, fleet.IsNotFound(err)) diff --git a/server/mdm/android/mysql/mysql_test.go b/server/mdm/android/mysql/testing_utils.go similarity index 71% rename from server/mdm/android/mysql/mysql_test.go rename to server/mdm/android/mysql/testing_utils.go index ab75c0546029..7017ecb8b29b 100644 --- a/server/mdm/android/mysql/mysql_test.go +++ b/server/mdm/android/mysql/testing_utils.go @@ -12,9 +12,11 @@ import ( "github.com/stretchr/testify/require" ) -// Android MySQL testing utilities. +// Android MySQL testing utilities. This file should contain VERY LITTLE code since it is also compiled into the production binary. +// Whenever possible, new code should go into a dedicated testing package (e.g. mdm/android/mysql/tests/testing_utils.go). // These utilities are used to create a MySQL Datastore for testing the Android MDM MySQL implementation. -// They are located in the same package as the implementation to prevent a circular dependency. +// They are located in the same package as the implementation to prevent a circular dependency. If put it in a different package, +// the circular dependency would be: mysql -> testing_utils -> mysql func CreateMySQLDS(t testing.TB) *Datastore { return createMySQLDSWithOptions(t, nil) @@ -22,14 +24,14 @@ func CreateMySQLDS(t testing.TB) *Datastore { func createMySQLDSWithOptions(t testing.TB, opts *testing_utils.DatastoreTestOptions) *Datastore { cleanTestName, opts := testing_utils.ProcessOptions(t, opts) - ds := initializeDatabase(t, cleanTestName, opts) + ds := InitializeDatabase(t, cleanTestName, opts) t.Cleanup(func() { Close(ds) }) return ds } -// initializeDatabase loads the dumped schema into a newly created database in MySQL. +// InitializeDatabase loads the dumped schema into a newly created database in MySQL. // This is much faster than running the full set of migrations on each test. -func initializeDatabase(t testing.TB, testName string, opts *testing_utils.DatastoreTestOptions) *Datastore { +func InitializeDatabase(t testing.TB, testName string, opts *testing_utils.DatastoreTestOptions) *Datastore { _, filename, _, _ := runtime.Caller(0) schemaPath := path.Join(path.Dir(filename), "schema.sql") testing_utils.LoadSchema(t, testName, opts, schemaPath) diff --git a/server/mdm/android/proxy.go b/server/mdm/android/proxy.go new file mode 100644 index 000000000000..8e9c61952283 --- /dev/null +++ b/server/mdm/android/proxy.go @@ -0,0 +1,22 @@ +package android + +import ( + "context" + + "google.golang.org/api/androidmanagement/v1" +) + +type Proxy interface { + SignupURLsCreate(callbackURL string) (*SignupDetails, error) + EnterprisesCreate(ctx context.Context, req ProxyEnterprisesCreateRequest) (string, string, error) + EnterprisesPoliciesPatch(enterpriseID string, policyName string, policy *androidmanagement.Policy) error + EnterprisesEnrollmentTokensCreate(enterpriseName string, token *androidmanagement.EnrollmentToken) (*androidmanagement.EnrollmentToken, error) + EnterpriseDelete(enterpriseID string) error +} + +type ProxyEnterprisesCreateRequest struct { + androidmanagement.Enterprise + EnterpriseToken string + SignupUrlName string + PubSubPushURL string +} diff --git a/server/mdm/android/service.go b/server/mdm/android/service.go index 929d926bfdfe..6583c4762acb 100644 --- a/server/mdm/android/service.go +++ b/server/mdm/android/service.go @@ -7,9 +7,29 @@ import ( type Service interface { EnterpriseSignup(ctx context.Context) (*SignupDetails, error) EnterpriseSignupCallback(ctx context.Context, enterpriseID uint, enterpriseToken string) error + GetEnterprise(ctx context.Context) (*Enterprise, error) DeleteEnterprise(ctx context.Context) error // CreateEnrollmentToken creates an enrollment token for a new Android device. CreateEnrollmentToken(ctx context.Context, enrollSecret string) (*EnrollmentToken, error) ProcessPubSubPush(ctx context.Context, token string, message *PubSubMessage) error } + +// ///////////////////////////////////////////// +// Android API request and response structs + +type DefaultResponse struct { + Err error `json:"error,omitempty"` +} + +func (r DefaultResponse) Error() error { return r.Err } + +type GetEnterpriseResponse struct { + EnterpriseID string `json:"android_enterprise_id"` + DefaultResponse +} + +type EnterpriseSignupResponse struct { + Url string `json:"android_enterprise_signup_url"` + DefaultResponse +} diff --git a/server/mdm/android/service/enterprises_test.go b/server/mdm/android/service/enterprises_test.go new file mode 100644 index 000000000000..9a99ed730c35 --- /dev/null +++ b/server/mdm/android/service/enterprises_test.go @@ -0,0 +1,135 @@ +package service + +import ( + "context" + "os" + "testing" + + "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mdm/android" + android_mock "github.com/fleetdm/fleet/v4/server/mdm/android/mock" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + kitlog "github.com/go-kit/log" + "github.com/stretchr/testify/require" +) + +func TestEnterprisesAuth(t *testing.T) { + proxy := android_mock.Proxy{} + proxy.InitCommonMocks() + logger := kitlog.NewLogfmtLogger(os.Stdout) + fleetDS := InitCommonDSMocks() + svc, err := NewServiceWithProxy(logger, fleetDS, &proxy) + require.NoError(t, err) + + testCases := []struct { + name string + user *fleet.User + shouldFailWrite bool + shouldFailRead bool + }{ + { + "global admin", + &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, + false, + false, + }, + { + "global maintainer", + &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}, + true, + true, + }, + { + "global gitops", + &fleet.User{GlobalRole: ptr.String(fleet.RoleGitOps)}, + true, + true, + }, + { + "global observer", + &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}, + true, + true, + }, + { + "global observer+", + &fleet.User{GlobalRole: ptr.String(fleet.RoleObserverPlus)}, + true, + true, + }, + { + "team admin", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}, + true, + true, + }, + { + "team maintainer", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}, + true, + true, + }, + { + "team observer", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}, + true, + true, + }, + { + "team observer+", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserverPlus}}}, + true, + true, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user}) + + _, err := svc.GetEnterprise(ctx) + checkAuthErr(t, tt.shouldFailRead, err) + + err = svc.DeleteEnterprise(ctx) + checkAuthErr(t, tt.shouldFailWrite, err) + + _, err = svc.EnterpriseSignup(ctx) + checkAuthErr(t, tt.shouldFailWrite, err) + }) + } + + t.Run("unauthorized", func(t *testing.T) { + err := svc.EnterpriseSignupCallback(context.Background(), 1, "token") + checkAuthErr(t, false, err) + }) +} + +func checkAuthErr(t *testing.T, shouldFail bool, err error) { + t.Helper() + if shouldFail { + require.Error(t, err) + var forbiddenError *authz.Forbidden + require.ErrorAs(t, err, &forbiddenError) + } else { + require.NoError(t, err) + } +} + +func InitCommonDSMocks() *mock.Store { + fleetDS := mock.Store{} + ds := android_mock.Datastore{} + ds.InitCommonMocks() + + fleetDS.GetAndroidDSFunc = func() android.Datastore { + return &ds + } + fleetDS.AppConfigFunc = func(_ context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + fleetDS.SetAndroidEnabledAndConfiguredFunc = func(_ context.Context, configured bool) error { + return nil + } + return &fleetDS +} diff --git a/server/mdm/android/service/handler.go b/server/mdm/android/service/handler.go index b38aa553c0f4..a5accf634069 100644 --- a/server/mdm/android/service/handler.go +++ b/server/mdm/android/service/handler.go @@ -18,20 +18,21 @@ const pubSubPushPath = "/api/v1/fleet/android_enterprise/pubsub" func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Service, opts []kithttp.ServerOption) { - // user-authenticated endpoints + // ////////////////////////////////////////// + // User-authenticated endpoints ue := newUserAuthenticatedEndpointer(fleetSvc, svc, opts, r, apiVersions()...) ue.GET("/api/_version_/fleet/android_enterprise/signup_url", enterpriseSignupEndpoint, nil) + ue.GET("/api/_version_/fleet/android_enterprise", getEnterpriseEndpoint, nil) ue.DELETE("/api/_version_/fleet/android_enterprise", deleteEnterpriseEndpoint, nil) - // unauthenticated endpoints - // They typically do one-time authentication by verifying that a valid secret token is provided with the request. + // ////////////////////////////////////////// + // Unauthenticated endpoints + // These endpoints should do custom one-time authentication by verifying that a valid secret token is provided with the request. ne := newNoAuthEndpointer(fleetSvc, svc, opts, r, apiVersions()...) - ne.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/connect", enterpriseSignupCallbackEndpoint, - enterpriseSignupCallbackRequest{}) - ne.GET("/api/_version_/fleet/android_enterprise/enrollment_token", enrollmentTokenEndpoint, - enrollmentTokenRequest{}) + ne.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/connect", enterpriseSignupCallbackEndpoint, enterpriseSignupCallbackRequest{}) + ne.GET("/api/_version_/fleet/android_enterprise/enrollment_token", enrollmentTokenEndpoint, enrollmentTokenRequest{}) ne.POST(pubSubPushPath, pubSubPushEndpoint, pubSubPushRequest{}) } diff --git a/server/mdm/android/service/proxy/proxy.go b/server/mdm/android/service/proxy/proxy.go index 18676126b88a..72eb0fbb5651 100644 --- a/server/mdm/android/service/proxy/proxy.go +++ b/server/mdm/android/service/proxy/proxy.go @@ -32,6 +32,9 @@ type Proxy struct { mgmt *androidmanagement.Service } +// Compile-time check to ensure that Proxy implements android.Proxy. +var _ android.Proxy = &Proxy{} + func NewProxy(ctx context.Context, logger kitlog.Logger) *Proxy { if androidServiceCredentials == "" { return nil @@ -74,28 +77,27 @@ func (p *Proxy) SignupURLsCreate(callbackURL string) (*android.SignupDetails, er }, nil } -func (p *Proxy) EnterprisesCreate(ctx context.Context, enabledNotificationTypes []string, enterpriseToken string, - signupUrlName string, pushURL string) (string, string, error) { +func (p *Proxy) EnterprisesCreate(ctx context.Context, req android.ProxyEnterprisesCreateRequest) (string, string, error) { if p == nil || p.mgmt == nil { return "", "", errors.New("android management service not initialized") } - topicName, err := p.createPubSubTopic(ctx, pushURL) + topicName, err := p.createPubSubTopic(ctx, req.PubSubPushURL) if err != nil { return "", "", fmt.Errorf("creating PubSub topic: %w", err) } enterprise, err := p.mgmt.Enterprises.Create(&androidmanagement.Enterprise{ - EnabledNotificationTypes: enabledNotificationTypes, + EnabledNotificationTypes: req.EnabledNotificationTypes, PubsubTopic: topicName, }). ProjectId(androidProjectID). - EnterpriseToken(enterpriseToken). - SignupUrlName(signupUrlName). + EnterpriseToken(req.EnterpriseToken). + SignupUrlName(req.SignupUrlName). Do() switch { case googleapi.IsNotModified(err): - return "", "", fmt.Errorf("android enterprise %s was already created", signupUrlName) + return "", "", fmt.Errorf("android enterprise %s was already created", req.SignupUrlName) case err != nil: return "", "", fmt.Errorf("creating enterprise: %w", err) } diff --git a/server/mdm/android/service/pubsub.go b/server/mdm/android/service/pubsub.go index 3383e352c927..7e710d203115 100644 --- a/server/mdm/android/service/pubsub.go +++ b/server/mdm/android/service/pubsub.go @@ -26,7 +26,7 @@ type pubSubPushRequest struct { func pubSubPushEndpoint(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer { req := request.(*pubSubPushRequest) err := svc.ProcessPubSubPush(ctx, req.Token, &req.PubSubMessage) - return defaultResponse{Err: err} + return android.DefaultResponse{Err: err} } func (svc *Service) ProcessPubSubPush(ctx context.Context, token string, message *android.PubSubMessage) error { diff --git a/server/mdm/android/service/service.go b/server/mdm/android/service/service.go index f2ab5b4f651a..e786f67c8320 100644 --- a/server/mdm/android/service/service.go +++ b/server/mdm/android/service/service.go @@ -24,43 +24,39 @@ type Service struct { authz *authz.Authorizer ds android.Datastore fleetDS fleet.Datastore - proxy *proxy.Proxy + proxy android.Proxy } func NewService( ctx context.Context, logger kitlog.Logger, fleetDS fleet.Datastore, +) (android.Service, error) { + prx := proxy.NewProxy(ctx, logger) + return NewServiceWithProxy(logger, fleetDS, prx) +} + +func NewServiceWithProxy( + logger kitlog.Logger, + fleetDS fleet.Datastore, + proxy android.Proxy, ) (android.Service, error) { authorizer, err := authz.NewAuthorizer() if err != nil { return nil, fmt.Errorf("new authorizer: %w", err) } - prx := proxy.NewProxy(ctx, logger) - return &Service{ logger: logger, authz: authorizer, ds: fleetDS.GetAndroidDS(), fleetDS: fleetDS, - proxy: prx, + proxy: proxy, }, nil } -type defaultResponse struct { - Err error `json:"error,omitempty"` -} - -func (r defaultResponse) Error() error { return r.Err } - -func newErrResponse(err error) defaultResponse { - return defaultResponse{Err: err} -} - -type androidEnterpriseSignupResponse struct { - Url string `json:"android_enterprise_signup_url"` - defaultResponse +func newErrResponse(err error) android.DefaultResponse { + return android.DefaultResponse{Err: err} } func enterpriseSignupEndpoint(ctx context.Context, _ interface{}, svc android.Service) fleet.Errorer { @@ -68,7 +64,7 @@ func enterpriseSignupEndpoint(ctx context.Context, _ interface{}, svc android.Se if err != nil { return newErrResponse(err) } - return androidEnterpriseSignupResponse{Url: result.Url} + return android.EnterpriseSignupResponse{Url: result.Url} } func (svc *Service) EnterpriseSignup(ctx context.Context) (*android.SignupDetails, error) { @@ -125,7 +121,7 @@ type enterpriseSignupCallbackRequest struct { func enterpriseSignupCallbackEndpoint(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer { req := request.(*enterpriseSignupCallbackRequest) err := svc.EnterpriseSignupCallback(ctx, req.ID, req.EnterpriseToken) - return defaultResponse{Err: err} + return android.DefaultResponse{Err: err} } func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enterpriseToken string) error { @@ -156,10 +152,19 @@ func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enter name, topicName, err := svc.proxy.EnterprisesCreate( ctx, - []string{android.PubSubEnrollment, android.PubSubStatusReport, android.PubSubCommand, android.PubSubUsageLogs}, - enterpriseToken, - enterprise.SignupName, - appConfig.ServerSettings.ServerURL+pubSubPushPath+"?token="+pubSubToken, + android.ProxyEnterprisesCreateRequest{ + Enterprise: androidmanagement.Enterprise{ + EnabledNotificationTypes: []string{ + android.PubSubEnrollment, + android.PubSubStatusReport, + android.PubSubCommand, + android.PubSubUsageLogs, + }, + }, + EnterpriseToken: enterpriseToken, + SignupUrlName: enterprise.SignupName, + PubSubPushURL: appConfig.ServerSettings.ServerURL + pubSubPushPath + "?token=" + pubSubToken, + }, ) if err != nil { return ctxerr.Wrap(ctx, err, "creating enterprise") @@ -222,9 +227,31 @@ func topicIDFromName(name string) (string, error) { return name[lastSlash+1:], nil } +func getEnterpriseEndpoint(ctx context.Context, _ interface{}, svc android.Service) fleet.Errorer { + enterprise, err := svc.GetEnterprise(ctx) + if err != nil { + return android.DefaultResponse{Err: err} + } + return android.GetEnterpriseResponse{EnterpriseID: enterprise.EnterpriseID} +} + +func (svc *Service) GetEnterprise(ctx context.Context) (*android.Enterprise, error) { + if err := svc.authz.Authorize(ctx, &android.Enterprise{}, fleet.ActionRead); err != nil { + return nil, err + } + enterprise, err := svc.ds.GetEnterprise(ctx) + switch { + case fleet.IsNotFound(err): + return nil, fleet.NewInvalidArgumentError("enterprise", "No enterprise found").WithStatus(http.StatusNotFound) + case err != nil: + return nil, ctxerr.Wrap(ctx, err, "getting enterprise") + } + return enterprise, nil +} + func deleteEnterpriseEndpoint(ctx context.Context, _ interface{}, svc android.Service) fleet.Errorer { err := svc.DeleteEnterprise(ctx) - return defaultResponse{Err: err} + return android.DefaultResponse{Err: err} } func (svc *Service) DeleteEnterprise(ctx context.Context) error { @@ -246,7 +273,7 @@ func (svc *Service) DeleteEnterprise(ctx context.Context) error { } } - err = svc.ds.DeleteEnterprises(ctx) + err = svc.ds.DeleteAllEnterprises(ctx) if err != nil { return ctxerr.Wrap(ctx, err, "deleting enterprises") } @@ -263,18 +290,18 @@ type enrollmentTokenRequest struct { EnrollSecret string `query:"enroll_secret"` } -type androidEnrollmentTokenResponse struct { +type enrollmentTokenResponse struct { *android.EnrollmentToken - defaultResponse + android.DefaultResponse } func enrollmentTokenEndpoint(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer { req := request.(*enrollmentTokenRequest) token, err := svc.CreateEnrollmentToken(ctx, req.EnrollSecret) if err != nil { - return defaultResponse{Err: err} + return android.DefaultResponse{Err: err} } - return androidEnrollmentTokenResponse{EnrollmentToken: token} + return enrollmentTokenResponse{EnrollmentToken: token} } func (svc *Service) CreateEnrollmentToken(ctx context.Context, enrollSecret string) (*android.EnrollmentToken, error) { diff --git a/server/mdm/android/tests/README.md b/server/mdm/android/tests/README.md new file mode 100644 index 000000000000..5a0a375e39d8 --- /dev/null +++ b/server/mdm/android/tests/README.md @@ -0,0 +1,5 @@ +This package contains API Android tests with the real Android service and Android MySQL database. + +We use testify Suite to run these tests. Since [testify Suite does not support parallel execution](https://github.com/stretchr/testify/issues/187), +we put each test in their own package/directory. This allows these tests to run in parallel because each package is a separate compile unit. If you +create a large test, please put it in a separate file within the same Suite/package. diff --git a/server/mdm/android/tests/enterprise/enterprise_test.go b/server/mdm/android/tests/enterprise/enterprise_test.go new file mode 100644 index 000000000000..d8793239a616 --- /dev/null +++ b/server/mdm/android/tests/enterprise/enterprise_test.go @@ -0,0 +1,52 @@ +package enterprise_test + +import ( + "net/http" + "testing" + + "github.com/fleetdm/fleet/v4/server/mdm/android" + "github.com/fleetdm/fleet/v4/server/mdm/android/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +func TestServiceEnterprise(t *testing.T) { + testingSuite := new(enterpriseTestSuite) + suite.Run(t, testingSuite) +} + +type enterpriseTestSuite struct { + tests.WithServer +} + +func (s *enterpriseTestSuite) SetupSuite() { + s.WithServer.SetupSuite(s.T(), "androidEnterpriseTestSuite") + s.Token = "bozo" +} + +func (s *enterpriseTestSuite) TearDownSuite() { + s.WithServer.TearDownSuite() +} + +func (s *enterpriseTestSuite) TestGetEnterprise() { + // Enterprise doesn't exist. + var resp android.GetEnterpriseResponse + s.DoJSON("GET", "/api/v1/fleet/android_enterprise", nil, http.StatusNotFound, &resp) + + // Create enterprise + var signupResp android.EnterpriseSignupResponse + s.DoJSON("GET", "/api/v1/fleet/android_enterprise/signup_url", nil, http.StatusOK, &signupResp) + assert.Equal(s.T(), tests.EnterpriseSignupURL, signupResp.Url) + s.T().Logf("callbackURL: %s", s.ProxyCallbackURL) + const enterpriseToken = "enterpriseToken" + s.DoJSON("GET", s.ProxyCallbackURL, nil, http.StatusOK, &resp, "enterpriseToken", enterpriseToken) + + // Now enterprise exists and we can retrieve it. + resp = android.GetEnterpriseResponse{} + s.DoJSON("GET", "/api/v1/fleet/android_enterprise", nil, http.StatusOK, &resp) + assert.Equal(s.T(), tests.EnterpriseID, resp.EnterpriseID) + + // Delete enterprise and make sure we can't find it. + s.Do("DELETE", "/api/v1/fleet/android_enterprise", nil, http.StatusOK) + s.DoJSON("GET", "/api/v1/fleet/android_enterprise", nil, http.StatusNotFound, &resp) +} diff --git a/server/mdm/android/tests/http.go b/server/mdm/android/tests/http.go new file mode 100644 index 000000000000..71a17b600981 --- /dev/null +++ b/server/mdm/android/tests/http.go @@ -0,0 +1,49 @@ +package tests + +import ( + "fmt" + "io" + "net/http" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/test/httptest" + "github.com/go-json-experiment/json" + "github.com/stretchr/testify/require" +) + +func (ts *WithServer) DoJSON(verb, path string, params interface{}, expectedStatusCode int, v interface{}, queryParams ...string) { + resp := ts.Do(verb, path, params, expectedStatusCode, queryParams...) + err := json.UnmarshalRead(resp.Body, v) + require.NoError(ts.T(), err) + if e, ok := v.(fleet.Errorer); ok { + require.NoError(ts.T(), e.Error()) + } +} + +func (ts *WithServer) Do(verb, path string, params interface{}, expectedStatusCode int, queryParams ...string) *http.Response { + j, err := json.Marshal(params) + require.NoError(ts.T(), err) + + resp := ts.DoRaw(verb, path, j, expectedStatusCode, queryParams...) + + ts.T().Cleanup(func() { + resp.Body.Close() + }) + return resp +} + +func (ts *WithServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int, queryParams ...string) *http.Response { + return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", ts.Token), + }, queryParams...) +} + +func (ts *WithServer) DoRawWithHeaders( + verb string, path string, rawBytes []byte, expectedStatusCode int, headers map[string]string, queryParams ...string, +) *http.Response { + return httptest.DoHTTPReq(ts.T(), decodeJSON, verb, rawBytes, ts.Server.URL+path, headers, expectedStatusCode, queryParams...) +} + +func decodeJSON(r io.Reader, v interface{}) error { + return json.UnmarshalRead(r, v) +} diff --git a/server/mdm/android/tests/testing_utils.go b/server/mdm/android/tests/testing_utils.go new file mode 100644 index 000000000000..acdf58378fe5 --- /dev/null +++ b/server/mdm/android/tests/testing_utils.go @@ -0,0 +1,157 @@ +package tests + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/fleetdm/fleet/v4/server/config" + "github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql/testing_utils" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mdm/android" + proxy_mock "github.com/fleetdm/fleet/v4/server/mdm/android/mock" + "github.com/fleetdm/fleet/v4/server/mdm/android/mysql" + "github.com/fleetdm/fleet/v4/server/mdm/android/service" + ds_mock "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/fleetdm/fleet/v4/server/service/middleware/auth" + "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" + "github.com/fleetdm/fleet/v4/server/service/middleware/log" + kithttp "github.com/go-kit/kit/transport/http" + kitlog "github.com/go-kit/log" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/api/androidmanagement/v1" +) + +const ( + EnterpriseSignupURL = "https://enterprise.google.com/signup/android/email?origin=android&thirdPartyToken=B4D779F1C4DD9A440" + EnterpriseID = "LC02k5wxw7" +) + +type WithServer struct { + suite.Suite + DS *mysql.Datastore + FleetDS ds_mock.Store + Server *httptest.Server + Token string + AppConfig fleet.AppConfig + Proxy proxy_mock.Proxy + ProxyCallbackURL string +} + +func (ts *WithServer) SetupSuite(t *testing.T, dbName string) { + ts.DS = CreateNamedMySQLDS(t, dbName) + ts.createCommonDSMocks() + + ts.Proxy = proxy_mock.Proxy{} + ts.createCommonProxyMocks(t) + + fleetSvc := mockService{} + logger := kitlog.NewLogfmtLogger(os.Stdout) + svc, err := service.NewServiceWithProxy(logger, &ts.FleetDS, &ts.Proxy) + require.NoError(t, err) + + ts.Server = runServerForTests(t, logger, &fleetSvc, svc) +} + +func (ts *WithServer) createCommonDSMocks() { + ts.FleetDS.GetAndroidDSFunc = func() android.Datastore { + return ts.DS + } + ts.FleetDS.AppConfigFunc = func(_ context.Context) (*fleet.AppConfig, error) { + return &ts.AppConfig, nil + } + ts.FleetDS.SetAndroidEnabledAndConfiguredFunc = func(_ context.Context, configured bool) error { + ts.AppConfig.MDM.AndroidEnabledAndConfigured = configured + return nil + } +} + +func (ts *WithServer) createCommonProxyMocks(t *testing.T) { + ts.Proxy.SignupURLsCreateFunc = func(callbackURL string) (*android.SignupDetails, error) { + ts.ProxyCallbackURL = callbackURL + return &android.SignupDetails{ + Url: EnterpriseSignupURL, + Name: "signupUrls/Cb08124d0999c464f", + }, nil + } + ts.Proxy.EnterprisesCreateFunc = func(ctx context.Context, req android.ProxyEnterprisesCreateRequest) (string, string, error) { + return EnterpriseID, "projects/android/topics/ae98ed130-5ce2-4ddb-a90a-191ec76976d5", nil + } + ts.Proxy.EnterprisesPoliciesPatchFunc = func(enterpriseID string, policyName string, policy *androidmanagement.Policy) error { + assert.Equal(t, EnterpriseID, enterpriseID) + return nil + } + ts.Proxy.EnterpriseDeleteFunc = func(enterpriseID string) error { + assert.Equal(t, EnterpriseID, enterpriseID) + return nil + } +} + +func (ts *WithServer) TearDownSuite() { + mysql.Close(ts.DS) +} + +type mockService struct { + mock.Mock + fleet.Service +} + +func (m *mockService) GetSessionByKey(ctx context.Context, sessionKey string) (*fleet.Session, error) { + return &fleet.Session{UserID: 1}, nil +} + +func (m *mockService) UserUnauthorized(ctx context.Context, userId uint) (*fleet.User, error) { + return &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, nil +} + +func runServerForTests(t *testing.T, logger kitlog.Logger, fleetSvc fleet.Service, androidSvc android.Service) *httptest.Server { + + fleetAPIOptions := []kithttp.ServerOption{ + kithttp.ServerBefore( + kithttp.PopulateRequestContext, + auth.SetRequestsContexts(fleetSvc), + ), + kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}), + kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), + kithttp.ServerAfter( + kithttp.SetContentType("application/json; charset=utf-8"), + log.LogRequestEnd(logger), + ), + } + + r := mux.NewRouter() + service.GetRoutes(fleetSvc, androidSvc)(r, fleetAPIOptions) + rootMux := http.NewServeMux() + rootMux.HandleFunc("/api/", r.ServeHTTP) + + server := httptest.NewUnstartedServer(rootMux) + serverConfig := config.ServerConfig{} + server.Config = serverConfig.DefaultHTTPServer(testCtx(), rootMux) + require.NotZero(t, server.Config.WriteTimeout) + server.Config.Handler = rootMux + server.Start() + t.Cleanup(func() { + server.Close() + }) + return server +} + +func testCtx() context.Context { + return context.Background() +} + +func CreateNamedMySQLDS(t *testing.T, name string) *mysql.Datastore { + if _, ok := os.LookupEnv("MYSQL_TEST"); !ok { + t.Skip("MySQL tests are disabled") + } + ds := mysql.InitializeDatabase(t, name, new(testing_utils.DatastoreTestOptions)) + t.Cleanup(func() { mysql.Close(ds) }) + return ds +} diff --git a/server/service/endpoint_utils_test.go b/server/service/endpoint_utils_test.go index 5831f7ddbfac..ec0009b73a37 100644 --- a/server/service/endpoint_utils_test.go +++ b/server/service/endpoint_utils_test.go @@ -16,6 +16,7 @@ import ( "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" + "github.com/fleetdm/fleet/v4/server/service/middleware/log" "github.com/go-kit/kit/endpoint" kithttp "github.com/go-kit/kit/transport/http" kitlog "github.com/go-kit/log" @@ -292,7 +293,7 @@ func TestEndpointer(t *testing.T) { kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), - logRequestEnd(kitlog.NewNopLogger()), + log.LogRequestEnd(kitlog.NewNopLogger()), checkLicenseExpiration(svc), ), } @@ -412,7 +413,7 @@ func TestEndpointerCustomMiddleware(t *testing.T) { kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), - logRequestEnd(kitlog.NewNopLogger()), + log.LogRequestEnd(kitlog.NewNopLogger()), checkLicenseExpiration(svc), ), } diff --git a/server/service/handler.go b/server/service/handler.go index f14eaa5a4af2..0d01ad9b9da2 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -11,7 +11,6 @@ import ( eeservice "github.com/fleetdm/fleet/v4/ee/server/service" "github.com/fleetdm/fleet/v4/server/config" - "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/contexts/publicip" "github.com/fleetdm/fleet/v4/server/fleet" apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" @@ -26,6 +25,7 @@ import ( scepserver "github.com/fleetdm/fleet/v4/server/mdm/scep/server" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" + "github.com/fleetdm/fleet/v4/server/service/middleware/log" "github.com/fleetdm/fleet/v4/server/service/middleware/mdmconfigured" "github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit" kithttp "github.com/go-kit/kit/transport/http" @@ -42,17 +42,6 @@ import ( microsoft_mdm "github.com/fleetdm/fleet/v4/server/mdm/microsoft" ) -func logRequestEnd(logger kitlog.Logger) func(context.Context, http.ResponseWriter) context.Context { - return func(ctx context.Context, w http.ResponseWriter) context.Context { - logCtx, ok := logging.FromContext(ctx) - if !ok { - return ctx - } - logCtx.Log(ctx, logger) - return ctx - } -} - func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.ResponseWriter) context.Context { return func(ctx context.Context, w http.ResponseWriter) context.Context { license, err := svc.License(ctx) @@ -103,7 +92,7 @@ func MakeHandler( kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), - logRequestEnd(logger), + log.LogRequestEnd(logger), checkLicenseExpiration(svc), ), } diff --git a/server/service/middleware/log/log.go b/server/service/middleware/log/log.go index f1c7e5c1e526..007e2fbb1c85 100644 --- a/server/service/middleware/log/log.go +++ b/server/service/middleware/log/log.go @@ -2,9 +2,11 @@ package log import ( "context" + "net/http" "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/go-kit/kit/endpoint" + kitlog "github.com/go-kit/log" ) // Logged wraps an endpoint and adds the error if the context supports it @@ -24,3 +26,14 @@ func Logged(next endpoint.Endpoint) endpoint.Endpoint { return res, nil } } + +func LogRequestEnd(logger kitlog.Logger) func(context.Context, http.ResponseWriter) context.Context { + return func(ctx context.Context, w http.ResponseWriter) context.Context { + logCtx, ok := logging.FromContext(ctx) + if !ok { + return ctx + } + logCtx.Log(ctx, logger) + return ctx + } +} diff --git a/server/service/testing_client.go b/server/service/testing_client.go index fea1a46c5e80..64d02f02be94 100644 --- a/server/service/testing_client.go +++ b/server/service/testing_client.go @@ -26,9 +26,9 @@ import ( "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/live_query/live_query_mock" "github.com/fleetdm/fleet/v4/server/pubsub" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/fleetdm/fleet/v4/server/sso" "github.com/fleetdm/fleet/v4/server/test" + fleet_httptest "github.com/fleetdm/fleet/v4/server/test/httptest" "github.com/ghodss/yaml" kitlog "github.com/go-kit/log" "github.com/jmoiron/sqlx" @@ -244,47 +244,11 @@ func (ts *withServer) Do(verb, path string, params interface{}, expectedStatusCo func (ts *withServer) DoRawWithHeaders( verb string, path string, rawBytes []byte, expectedStatusCode int, headers map[string]string, queryParams ...string, ) *http.Response { - t := ts.s.T() - - requestBody := io.NopCloser(bytes.NewBuffer(rawBytes)) - req, err := http.NewRequest(verb, ts.server.URL+path, requestBody) - require.NoError(t, err) - for key, val := range headers { - req.Header.Add(key, val) - } - - opts := []fleethttp.ClientOpt{} - if expectedStatusCode >= 300 && expectedStatusCode <= 399 { - opts = append(opts, fleethttp.WithFollowRedir(false)) - } - client := fleethttp.NewClient(opts...) - - if len(queryParams)%2 != 0 { - require.Fail(t, "need even number of params: key value") - } - if len(queryParams) > 0 { - q := req.URL.Query() - for i := 0; i < len(queryParams); i += 2 { - q.Add(queryParams[i], queryParams[i+1]) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := client.Do(req) - require.NoError(t, err) - - if resp.StatusCode != expectedStatusCode { - defer resp.Body.Close() - var je endpoint_utils.JsonError - err := json.NewDecoder(resp.Body).Decode(&je) - if err != nil { - t.Logf("Error trying to decode response body as Fleet jsonError: %s", err) - require.Equal(t, expectedStatusCode, resp.StatusCode, fmt.Sprintf("response: %+v", resp)) - } - require.Equal(t, expectedStatusCode, resp.StatusCode, fmt.Sprintf("Fleet jsonError: %+v", je)) - } + return fleet_httptest.DoHTTPReq(ts.s.T(), decodeJSON, verb, rawBytes, ts.server.URL+path, headers, expectedStatusCode, queryParams...) +} - return resp +func decodeJSON(r io.Reader, v interface{}) error { + return json.NewDecoder(r).Decode(v) } func (ts *withServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int, queryParams ...string) *http.Response { diff --git a/server/test/httptest/README.md b/server/test/httptest/README.md new file mode 100644 index 000000000000..28f7a61526c9 --- /dev/null +++ b/server/test/httptest/README.md @@ -0,0 +1,2 @@ +These HTTP test functions are in a separate package to prevent circular dependencies. +The circular dependency may be caused due to dependency on "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" diff --git a/server/test/httptest/http.go b/server/test/httptest/http.go new file mode 100644 index 000000000000..768121b0f519 --- /dev/null +++ b/server/test/httptest/http.go @@ -0,0 +1,55 @@ +package httptest + +import ( + "bytes" + "fmt" + "io" + "net/http" + "testing" + + "github.com/fleetdm/fleet/v4/pkg/fleethttp" + "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" + "github.com/stretchr/testify/require" +) + +func DoHTTPReq(t *testing.T, jsonDecoder func(r io.Reader, v interface{}) error, verb string, rawBytes []byte, urlPath string, + headers map[string]string, expectedStatusCode int, queryParams ...string) *http.Response { + requestBody := io.NopCloser(bytes.NewBuffer(rawBytes)) + req, err := http.NewRequest(verb, urlPath, requestBody) + require.NoError(t, err) + for key, val := range headers { + req.Header.Add(key, val) + } + + opts := []fleethttp.ClientOpt{} + if expectedStatusCode >= 300 && expectedStatusCode <= 399 { + opts = append(opts, fleethttp.WithFollowRedir(false)) + } + client := fleethttp.NewClient(opts...) + + if len(queryParams)%2 != 0 { + require.Fail(t, "need even number of params: key value") + } + if len(queryParams) > 0 { + q := req.URL.Query() + for i := 0; i < len(queryParams); i += 2 { + q.Add(queryParams[i], queryParams[i+1]) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := client.Do(req) + require.NoError(t, err) + + if resp.StatusCode != expectedStatusCode { + defer resp.Body.Close() + var je endpoint_utils.JsonError + err := jsonDecoder(resp.Body, &je) + if err != nil { + t.Logf("Error trying to decode response body as Fleet jsonError: %s", err) + require.Equal(t, expectedStatusCode, resp.StatusCode, fmt.Sprintf("response: %+v", resp)) + } + require.Equal(t, expectedStatusCode, resp.StatusCode, fmt.Sprintf("Fleet jsonError: %+v", je)) + } + return resp +}