diff --git a/Dockerfile b/Dockerfile index 27e7b39ad..6c2e0cfa6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # # base installs required dependencies and runs go mod download to cache dependencies # -FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22-alpine AS base +FROM --platform=${BUILDPLATFORM} docker.io/golang:1.23-alpine AS base RUN apk --update --no-cache add bash build-base curl git # diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index b7cd88562..976a62725 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/element-hq/dendrite/clientapi" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/federationapi/statistics" "github.com/element-hq/dendrite/internal/httputil" @@ -446,7 +447,8 @@ func TestOutputAppserviceEvent(t *testing.T) { } usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: usrAPI} + clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, &userVerifier, caching.DisableMetrics) createAccessTokens(t, accessTokens, usrAPI, processCtx.Context(), routers) room := test.NewRoom(t, alice) @@ -537,7 +539,7 @@ func TestOutputAppserviceEvent(t *testing.T) { } // Start the syncAPI to have `/joined_members` available - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, caching.DisableMetrics) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics) // start the consumer appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 660b84a46..e23aad8bd 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -1,6 +1,6 @@ #syntax=docker/dockerfile:1.2 -FROM golang:1.22-bookworm as build +FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 8fc847650..c2af16495 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -8,7 +8,7 @@ # # Use these mounts to make use of this dockerfile: # COMPLEMENT_HOST_MOUNTS='/your/local/dendrite:/dendrite:ro;/your/go/path:/go:ro' -FROM golang:1.22-bookworm +FROM golang:1.23-bookworm RUN apt-get update && apt-get install -y sqlite3 ENV SERVER_NAME=localhost diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 0026842d8..48843eb08 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -1,6 +1,6 @@ #syntax=docker/dockerfile:1.2 -FROM golang:1.22-bookworm as build +FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index d3c5bcee0..02e89649d 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package clientapi import ( @@ -11,6 +16,8 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/userapi/types" + "github.com/element-hq/dendrite/federationapi" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/httputil" @@ -27,6 +34,7 @@ import ( "github.com/tidwall/gjson" capi "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" "github.com/element-hq/dendrite/userapi" @@ -48,7 +56,8 @@ func TestAdminCreateToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -199,7 +208,8 @@ func TestAdminListRegistrationTokens(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -317,7 +327,8 @@ func TestAdminGetRegistrationToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -418,7 +429,8 @@ func TestAdminDeleteRegistrationToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -512,7 +524,8 @@ func TestAdminUpdateRegistrationToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -697,7 +710,8 @@ func TestAdminResetPassword(t *testing.T) { // Needed for changing the password/login userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -794,7 +808,8 @@ func TestPurgeRoom(t *testing.T) { rsAPI.SetFederationAPI(fsAPI, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { @@ -802,7 +817,7 @@ func TestPurgeRoom(t *testing.T) { } // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -872,8 +887,10 @@ func TestAdminEvacuateRoom(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -976,8 +993,10 @@ func TestAdminEvacuateUser(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1059,8 +1078,10 @@ func TestAdminMarkAsStale(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1147,8 +1168,10 @@ func TestAdminQueryEventReports(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -1376,8 +1399,10 @@ func TestEventReportsGetDelete(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -1473,3 +1498,834 @@ func TestEventReportsGetDelete(t *testing.T) { }) }) } + +func TestAdminCheckUsernameAvailable(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: alice.AccountType, + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + testCases := []struct { + name string + accessToken string + userID string + wantOK bool + isAvailable bool + }{ + {name: "Missing auth", accessToken: "", wantOK: false, userID: alice.Localpart, isAvailable: false}, + {name: "Alice - user exists", accessToken: adminToken, wantOK: true, userID: alice.Localpart, isAvailable: false}, + {name: "Bob - user does not exist", accessToken: adminToken, wantOK: true, userID: "bob", isAvailable: true}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v1/username_available?username="+tc.userID) + if tc.accessToken != "" { + req.Header.Set("Authorization", "Bearer "+tc.accessToken) + } + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK || !tc.wantOK && rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + if tc.wantOK { + b := make(map[string]bool, 1) + _ = json.NewDecoder(rec.Body).Decode(&b) + available, ok := b["available"] + if !ok { + t.Fatal("'available' not found in body") + } + if available != tc.isAvailable { + t.Fatalf("expected 'available' to be %t, got %t instead", tc.isAvailable, available) + } + } + }) + } + }) +} + +func TestAdminUserDeviceRetrieveCreate(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice, bob} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Retrieve device", func(t *testing.T) { + var deviceRes uapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + }, &deviceRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices") + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + var body struct { + Total int `json:"total"` + Devices []struct { + DeviceID string `json:"device_id"` + } `json:"devices"` + } + _ = json.NewDecoder(rec.Body).Decode(&body) + if body.Total != 1 { + t.Errorf("expected 1 device, got %d", body.Total) + } + if len(body.Devices) != 1 { + t.Errorf("expected 1 device, got %d", len(body.Devices)) + } + }) + + t.Run("Create device", func(t *testing.T) { + reqBody := struct { + DeviceID string `json:"device_id"` + }{DeviceID: "devBob"} + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+bob.ID+"/devices", test.WithJSONBody(t, reqBody)) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusCreated { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusCreated, rec.Code, rec.Body.String()) + } + + var res uapi.QueryDevicesResponse + _ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: bob.ID}, &res) + if len(res.Devices) != 1 { + t.Errorf("expected 1 device, got %d", len(res.Devices)) + } + if res.Devices[0].ID != "devBob" { + t.Errorf("expected device to be devBob, got %s", res.Devices[0].ID) + } + }) + + }) +} + +func TestAdminUserDeviceDelete(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/anything") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Delete existing device", func(t *testing.T) { + var deviceRes uapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + }, &deviceRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/"+deviceRes.Device.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var rs uapi.QueryDevicesResponse + _ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs) + if len(rs.Devices) > 0 { + t.Errorf("expected 0 devices, got %d", len(rs.Devices)) + } + }) + + t.Run("Delete non-existing user's devices", func(t *testing.T) { + req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+bob.ID+"/devices/anything") + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + }) +} + +func TestAdminUserDevicesDelete(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + type payload struct { + Devices []string `json:"devices"` + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+alice.ID+"/delete_devices") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Delete existing user's devices", func(t *testing.T) { + var deviceRes uapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + }, &deviceRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest( + t, + http.MethodPost, + "/_synapse/admin/v2/users/"+alice.ID+"/delete_devices", + test.WithJSONBody(t, payload{Devices: []string{deviceRes.Device.ID}}), + ) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var rs uapi.QueryDevicesResponse + _ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs) + if len(rs.Devices) > 0 { + t.Errorf("expected 0 devices, got %d", len(rs.Devices)) + } + }) + + t.Run("Delete non-existing user's devices", func(t *testing.T) { + req := test.NewRequest( + t, + http.MethodPost, + "/_synapse/admin/v2/users/"+bob.ID+"/delete_devices", + test.WithJSONBody(t, payload{Devices: []string{"anyDevID"}}), + ) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + }) +} + +func TestAdminDeactivateAccount(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Deactivate existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var rs uapi.QueryAccountByLocalpartResponse + _ = userAPI.QueryAccountByLocalpart(ctx, &uapi.QueryAccountByLocalpartRequest{Localpart: alice.Localpart, ServerName: cfg.Global.ServerName}, &rs) + if !rs.Account.Deactivated { + t.Fatalf("expected account is deactivated") + } + }) + + t.Run("Deactivate non-existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+bob.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + }) +} + +func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+alice.ID+"/_allow_cross_signing_replacement_without_uia") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + for _, u := range []*test.User{alice} { + var userRes uapi.PerformAccountCreationResponse + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, &userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + _ = userAPI.KeyDatabase.StoreCrossSigningKeysForUser(ctx, alice.ID, types.CrossSigningKeyMap{ + fclient.CrossSigningKeyPurposeMaster: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + fclient.CrossSigningKeyPurposeSelfSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + fclient.CrossSigningKeyPurposeUserSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + }) + + } + + testCases := []struct { + Name string + User *test.User + Code int + }{ + {Name: "existing user", User: alice, Code: 200}, + {Name: "non-existing user", User: bob, Code: 404}, + } + + now := time.Now() + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+tc.User.ID+"/_allow_cross_signing_replacement_without_uia") + req.Header.Set("Authorization", "Bearer "+adminToken) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + + if rec.Code != tc.Code { + t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String()) + } + + if rec.Code == 200 { + buf := make(map[string]int64, 1) + _ = json.NewDecoder(rec.Body).Decode(&buf) + if ts := buf["updatable_without_uia_before_ms"]; ts <= now.UnixMilli() { + t.Fatalf("expected updatable_without_uia_before_ms is in future, got %d", ts) + } + } + }) + } + }) +} + +func TestAdminCreateOrModifyAccount(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + type threePID struct { + Medium string `json:"medium"` + Address string `json:"address"` + } + type adminCreateOrModifyAccountRequest struct { + DisplayName string `json:"displayname"` + AvatarURL string `json:"avatar_url"` + ThreePIDs []threePID `json:"threepids"` + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPut, "/_synapse/admin/v2/users/"+alice.ID) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + testCases := []struct { + Name string + User *test.User + Payload adminCreateOrModifyAccountRequest + Expected struct { + DisplayName, + AvatarURL string + ThreePIDs []string + } + Code int + }{ + { + Name: fmt.Sprintf("Modify user %s", alice.ID), + User: alice, + Payload: adminCreateOrModifyAccountRequest{ + DisplayName: "alice", + AvatarURL: "https://alice-avatar.example.com", + ThreePIDs: []threePID{ + { + Medium: "email", + Address: "alice@example.com", + }, + }, + }, + Expected: struct { + DisplayName, AvatarURL string + ThreePIDs []string + }{ + // In order to avoid any confusion and undesired behaviour, we do not change display name and avatar url if account already exists + DisplayName: alice.Localpart, + AvatarURL: "", + ThreePIDs: []string{"alice@example.com"}, + }, + Code: http.StatusOK, + }, + { + Name: fmt.Sprintf("Create user %s", bob.ID), + User: bob, + Payload: adminCreateOrModifyAccountRequest{ + DisplayName: "bob", + AvatarURL: "https://bob-avatar.example.com", + ThreePIDs: []threePID{ + { + Medium: "email", + Address: "bob@example.com", + }, + }, + }, + Expected: struct { + DisplayName, AvatarURL string + ThreePIDs []string + }{ + DisplayName: "bob", + AvatarURL: "https://bob-avatar.example.com", + ThreePIDs: []string{"bob@example.com"}, + }, + Code: http.StatusCreated, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req := test.NewRequest( + t, + http.MethodPut, + "/_synapse/admin/v2/users/"+tc.User.ID, + test.WithJSONBody(t, tc.Payload), + ) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != tc.Code { + t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String()) + } + + p, _ := userAPI.QueryProfile(ctx, tc.User.ID) + if p.DisplayName != tc.Expected.DisplayName { + t.Fatalf("expected display name %s, got %s", tc.Expected.DisplayName, p.DisplayName) + } + if p.AvatarURL != tc.Expected.AvatarURL { + t.Fatalf("expected avatar_url %s, got %s", tc.Expected.AvatarURL, p.AvatarURL) + } + var threePidRs uapi.QueryThreePIDsForLocalpartResponse + _ = userAPI.QueryThreePIDsForLocalpart( + ctx, + &uapi.QueryThreePIDsForLocalpartRequest{Localpart: tc.User.Localpart, ServerName: cfg.Global.ServerName}, + &threePidRs, + ) + if len(threePidRs.ThreePIDs) != 1 { + t.Fatalf("expected 1 3pid got %d", len(threePidRs.ThreePIDs)) + } + tp := threePidRs.ThreePIDs[0] + if tp.Medium != "email" { + t.Fatalf("expected 3pid medium email got %s", tp.Medium) + } + if tp.Address != tc.Payload.ThreePIDs[0].Address { + t.Fatalf("expected 3pid address %s got %s", tc.Expected.ThreePIDs[0], tp.Address) + } + }) + } + }) +} + +func TestAdminRetrieveAccount(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + testCase := []struct { + Name string + User *test.User + Code int + Body string + }{ + { + Name: "Retrieve existing account", + User: alice, + Code: http.StatusOK, + Body: fmt.Sprintf(`{"display_name":"%s","avatar_url":"","deactivated":false}`, alice.Localpart), + }, + { + Name: "Retrieve non-existing account", + User: bob, + Code: http.StatusNotFound, + Body: "", + }, + } + + for _, tc := range testCase { + t.Run("Retrieve existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+tc.User.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != tc.Code { + t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String()) + } + + if tc.Body != "" && tc.Body != rec.Body.String() { + t.Fatalf("expected body %s, got %s", tc.Body, rec.Body.String()) + } + }) + } + }) +} diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index c32ed0fae..4e3612ce1 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -16,8 +16,6 @@ import ( "strings" "github.com/element-hq/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/matrix-org/util" ) // OWASP recommends at least 128 bits of entropy for tokens: https://www.owasp.org/index.php/Insufficient_Session-ID_Length @@ -37,51 +35,6 @@ type AccountDatabase interface { GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) } -// VerifyUserFromRequest authenticates the HTTP request, -// on success returns Device of the requester. -// Finds local user or an application service user. -// Note: For an AS user, AS dummy device is returned. -// On failure returns an JSON error response which can be sent to the client. -func VerifyUserFromRequest( - req *http.Request, userAPI api.QueryAcccessTokenAPI, -) (*api.Device, *util.JSONResponse) { - // Try to find the Application Service user - token, err := ExtractAccessToken(req) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: spec.MissingToken(err.Error()), - } - } - var res api.QueryAccessTokenResponse - err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ - AccessToken: token, - AppServiceUserID: req.URL.Query().Get("user_id"), - }, &res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") - return nil, &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - if res.Err != "" { - if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison - return nil, &util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden(res.Err), - } - } - } - if res.Device == nil { - return nil, &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: spec.UnknownToken("Unknown token"), - } - } - return res.Device, nil -} - // GenerateAccessToken creates a new access token. Returns an error if failed to generate // random bytes. func GenerateAccessToken() (string, error) { diff --git a/clientapi/auth/default_user_verifier.go b/clientapi/auth/default_user_verifier.go new file mode 100644 index 000000000..e6ecaf23e --- /dev/null +++ b/clientapi/auth/default_user_verifier.go @@ -0,0 +1,65 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package auth + +import ( + "net/http" + "strings" + + "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" +) + +// DefaultUserVerifier implements UserVerifier interface +type DefaultUserVerifier struct { + UserAPI api.QueryAccessTokenAPI +} + +// VerifyUserFromRequest authenticates the HTTP request, +// on success returns Device of the requester. +// Finds local user or an application service user. +// Note: For an AS user, AS dummy device is returned. +// On failure returns an JSON error response which can be sent to the client. +func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { + ctx := req.Context() + util.GetLogger(ctx).Debug("Default VerifyUserFromRequest") + // Try to find the Application Service user + token, err := ExtractAccessToken(req) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.MissingToken(err.Error()), + } + } + var res api.QueryAccessTokenResponse + err = d.UserAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{ + AccessToken: token, + AppServiceUserID: req.URL.Query().Get("user_id"), + }, &res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed") + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if res.Err != "" { + if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison + return nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(res.Err), + } + } + } + if res.Device == nil { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken("Unknown token"), + } + } + return res.Device, nil +} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index dbf862ca6..1c3bc4711 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -36,7 +36,9 @@ func AddPublicRoutes( fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - extRoomsProvider api.ExtraPublicRoomsProvider, enableMetrics bool, + extRoomsProvider api.ExtraPublicRoomsProvider, + userVerifier httputil.UserVerifier, + enableMetrics bool, ) { js, natsClient := natsInstance.Prepare(processContext, &cfg.Global.JetStream) @@ -55,6 +57,7 @@ func AddPublicRoutes( cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, - extRoomsProvider, natsClient, enableMetrics, + extRoomsProvider, natsClient, + userVerifier, enableMetrics, ) } diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index ad2d4ad48..b844699da 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/element-hq/dendrite/appservice" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/clientapi/routing" "github.com/element-hq/dendrite/clientapi/threepid" @@ -127,9 +128,10 @@ func TestGetPutDevices(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -176,9 +178,10 @@ func TestDeleteDevice(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -281,9 +284,10 @@ func TestDeleteDevices(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -449,8 +453,9 @@ func TestSetDisplayname(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -561,8 +566,9 @@ func TestSetAvatarURL(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -638,8 +644,9 @@ func TestTyping(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) // Needed to create accounts userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -723,8 +730,9 @@ func TestMembership(t *testing.T) { // Needed to create accounts userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) rsAPI.SetUserAPI(userAPI) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -962,8 +970,9 @@ func TestCapabilities(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1010,9 +1019,10 @@ func TestTurnserver(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} //rsAPI.SetUserAPI(userAPI) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1109,8 +1119,9 @@ func Test3PID(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1285,9 +1296,10 @@ func TestPushRules(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -1672,9 +1684,10 @@ func TestKeys(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2134,9 +2147,10 @@ func TestKeyBackup(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2238,9 +2252,10 @@ func TestGetMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2301,9 +2316,10 @@ func TestCreateRoomInvite(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2376,9 +2392,10 @@ func TestReportEvent(t *testing.T) { if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 48e58209c..7cc03ebeb 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,7 +1,13 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package routing import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -20,16 +26,24 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/exp/constraints" + appserviceAPI "github.com/element-hq/dendrite/appservice/api" clientapi "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth/authtypes" + clienthttputil "github.com/element-hq/dendrite/clientapi/httputil" + "github.com/element-hq/dendrite/clientapi/userutil" "github.com/element-hq/dendrite/internal/httputil" roomserverAPI "github.com/element-hq/dendrite/roomserver/api" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/setup/jetstream" "github.com/element-hq/dendrite/userapi/api" userapi "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/shared" ) -var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") +var ( + validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + deviceDisplayName = "OIDC-native client" +) func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { @@ -496,6 +510,486 @@ func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverA } } +func AdminCheckUsernameAvailable( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + username := req.URL.Query().Get("username") + if username == "" { + return util.MessageResponse(http.StatusBadRequest, "Query parameter 'username' is missing or empty") + } + rq := userapi.QueryAccountAvailabilityRequest{Localpart: username, ServerName: cfg.Matrix.ServerName} + rs := userapi.QueryAccountAvailabilityResponse{} + if err := userAPI.QueryAccountAvailability(req.Context(), &rq, &rs); err != nil { + return util.ErrorResponse(err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]bool{"available": rs.Available}, + } +} + +func AdminUserDeviceRetrieveCreate( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID := vars["userID"] + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(err.Error()), + } + } + logger := util.GetLogger(req.Context()) + + switch req.Method { + case http.MethodPost: + var payload struct { + DeviceID string `json:"device_id"` + } + if resErr := clienthttputil.UnmarshalJSONRequest(req, &payload); resErr != nil { + return *resErr + } + + userDeviceExists := false + { + var rs api.QueryDevicesResponse + if err = userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { + logger.WithError(err).Error("QueryDevices") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if !rs.UserExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Given user ID does not exist"), + } + } + for i := range rs.Devices { + if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID { + userDeviceExists = true + break + } + } + } + + if !userDeviceExists { + var rs userapi.PerformDeviceCreationResponse + if err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + Localpart: local, + ServerName: domain, + DeviceID: &payload.DeviceID, + DeviceDisplayName: &deviceDisplayName, + IPAddr: "", + UserAgent: req.UserAgent(), + NoDeviceListUpdate: false, + FromRegistration: false, + }, &rs); err != nil { + logger.WithError(err).Error("PerformDeviceCreation") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + logger.WithError(err).Debug("PerformDeviceCreation succeeded") + } + return util.JSONResponse{ + Code: http.StatusCreated, + JSON: struct{}{}, + } + case http.MethodGet: + var res userapi.QueryDevicesResponse + if err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &res); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + + jsonDevices := make([]deviceJSON, 0, len(res.Devices)) + for i := range res.Devices { + d := &res.Devices[i] + jsonDevices = append(jsonDevices, deviceJSON{ + DeviceID: d.ID, + DisplayName: d.DisplayName, + LastSeenIP: d.LastSeenIP, + LastSeenTS: d.LastSeenTS, + }) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct { + Devices []deviceJSON `json:"devices"` + Total int `json:"total"` + }{ + Devices: jsonDevices, + Total: len(res.Devices), + }, + } + default: + return util.JSONResponse{ + Code: http.StatusMethodNotAllowed, + JSON: struct{}{}, + } + } +} + +func AdminUserDeviceDelete( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID := vars["userID"] + deviceID := vars["deviceID"] + logger := util.GetLogger(req.Context()) + + // XXX: we probably have to delete session from the sessions dict + // like we do in DeleteDeviceById. If so, we have to fi + var device *api.Device + { + var rs api.QueryDevicesResponse + if err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { + logger.WithError(err).Error("userAPI.QueryDevices failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if !rs.UserExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Given user ID does not exist"), + } + } + for i := range rs.Devices { + if d := rs.Devices[i]; d.ID == deviceID && d.UserID == userID { + device = &d + break + } + } + } + + if device != nil { + // XXX: this response struct can completely removed everywhere as it doesn't + // have any functional purpose + var res api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{device.ID}, + }, &res); err != nil { + logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func AdminUserDevicesDelete( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID := vars["userID"] + + if req.Body == nil { + return util.MessageResponse(http.StatusBadRequest, "body is required") + } + var payload struct { + Devices []string `json:"devices"` + } + + if err = json.NewDecoder(req.Body).Decode(&payload); err != nil { + logger.WithError(err).Error("unable to decode device deletion request") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + defer req.Body.Close() // nolint: errcheck + + { + // XXX: this response struct can completely removed everywhere as it doesn't + // have any functional purpose + var rs api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{ + UserID: userID, + DeviceIDs: payload.Devices, + }, &rs); err != nil { + logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func AdminDeactivateAccount( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID := vars["userID"] + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + + // TODO: "erase" field must also be processed here + // see https://github.com/element-hq/synapse/blob/develop/docs/admin_api/user_admin_api.md#deactivate-account + + var rs api.PerformAccountDeactivationResponse + if err := userAPI.PerformAccountDeactivation(req.Context(), &api.PerformAccountDeactivationRequest{ + Localpart: local, ServerName: domain, + }, &rs); err != nil { + logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func AdminAllowCrossSigningReplacementWithoutUIA( + req *http.Request, + userAPI userapi.ClientUserAPI, +) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userIDstr, ok := vars["userID"] + userID, err := spec.NewUserID(userIDstr, false) + if !ok || err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.MissingParam("User not found."), + } + } + + var rs api.QueryAccountByLocalpartResponse + err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{ + Localpart: userID.Local(), + ServerName: userID.Domain(), + }, &rs) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountByLocalpart") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } else if errors.Is(err, sql.ErrNoRows) { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("User not found."), + } + } + switch req.Method { + case http.MethodPost: + ts := sessions.allowCrossSigningKeysReplacement(userID.String()) + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]int64{"updatable_without_uia_before_ms": ts}, + } + default: + return util.JSONResponse{ + Code: http.StatusMethodNotAllowed, + JSON: spec.Unknown("Method not allowed."), + } + } + +} + +type adminCreateOrModifyAccountRequest struct { + DisplayName string `json:"displayname"` + AvatarURL string `json:"avatar_url"` + ThreePIDs []struct { + Medium string `json:"medium"` + Address string `json:"address"` + } `json:"threepids"` + // TODO: the following fields are not used by dendrite, but they are used in Synapse. + // Password string `json:"password"` + // LogoutDevices bool `json:"logout_devices"` + // ExternalIDs []struct{ + // AuthProvider string `json:"auth_provider"` + // ExternalID string `json:"external_id"` + // } `json:"external_ids"` + // Admin bool `json:"admin"` + // Deactivated bool `json:"deactivated"` + // Locked bool `json:"locked"` +} + +func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID := vars["userID"] + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(userID), + } + } + var r adminCreateOrModifyAccountRequest + if resErr := clienthttputil.UnmarshalJSONRequest(req, &r); resErr != nil { + logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr) + return *resErr + } + logger.Debugf("adminCreateOrModifyAccountRequest is: %#v", r) + statusCode := http.StatusOK + + // TODO: Ideally, the following commands should be executed in one transaction. + // can we propagate the tx object and pass it in context? + var res userapi.PerformAccountCreationResponse + err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: local, + ServerName: domain, + OnConflict: api.ConflictUpdate, + AvatarURL: r.AvatarURL, + DisplayName: r.DisplayName, + }, &res) + if err != nil { + logger.WithError(err).Error("userAPI.PerformAccountCreation") + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if res.AccountCreated { + statusCode = http.StatusCreated + } + + if l := len(r.ThreePIDs); l > 0 { + logger.Debugf("Trying to bulk save 3PID associations: %+v", r.ThreePIDs) + threePIDs := make([]authtypes.ThreePID, 0, len(r.ThreePIDs)) + for i := range r.ThreePIDs { + tpid := &r.ThreePIDs[i] + threePIDs = append(threePIDs, authtypes.ThreePID{Medium: tpid.Medium, Address: tpid.Address}) + } + err = userAPI.PerformBulkSaveThreePIDAssociation(req.Context(), &userapi.PerformBulkSaveThreePIDAssociationRequest{ + ThreePIDs: threePIDs, + Localpart: local, + ServerName: domain, + }, &struct{}{}) + if err == shared.Err3PIDInUse { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } else if err != nil { + logger.WithError(err).Error("userAPI.PerformSaveThreePIDAssociation") + return util.ErrorResponse(err) + } + } + + return util.JSONResponse{ + Code: statusCode, + JSON: nil, + } +} + +func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, ok := vars["userID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.MissingParam("Expecting user ID."), + } + } + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(err.Error()), + } + } + + body := struct { + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` + Deactivated bool `json:"deactivated"` + }{} + + var rs api.QueryAccountByLocalpartResponse + err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)), + } + } else if err != nil { + logger.WithError(err).Error("userAPI.QueryAccountByLocalpart") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + body.Deactivated = rs.Account.Deactivated + + profile, err := userAPI.QueryProfile(req.Context(), userID) + if err != nil { + if err == appserviceAPI.ErrProfileNotExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(err.Error()), + } + } + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + body.AvatarURL = profile.AvatarURL + body.DisplayName = profile.DisplayName + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: body, + } +} + // GetEventReports returns reported events for a given user/room. func GetEventReports( req *http.Request, diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 26e0014b5..a0f7f06e1 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -9,20 +9,21 @@ package routing import ( "context" "net/http" + "strings" "time" - "github.com/matrix-org/gomatrixserverlib/fclient" - "github.com/sirupsen/logrus" - "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/clientapi/httputil" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) +const CrossSigningResetStage = "org.matrix.cross_signing_reset" + type crossSigningRequest struct { api.PerformUploadDeviceKeysRequest Auth newPasswordAuth `json:"auth"` @@ -30,6 +31,7 @@ type crossSigningRequest struct { type UploadKeysAPI interface { QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) + QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) api.UploadDeviceKeysAPI } @@ -38,6 +40,7 @@ func UploadCrossSigningDeviceKeys( keyserverAPI UploadKeysAPI, device *api.Device, accountAPI auth.GetAccountByPassword, cfg *config.ClientAPI, ) util.JSONResponse { + logger := util.GetLogger(req.Context()) uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -55,78 +58,102 @@ func UploadCrossSigningDeviceKeys( }, &keyResp) if keyResp.Error != nil { - logrus.WithError(keyResp.Error).Error("Failed to query keys") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(keyResp.Error.Error()), - } + logger.WithError(keyResp.Error).Error("Failed to query keys") + return convertKeyError(keyResp.Error) } existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID] - requireUIA := false - if hasMasterKey { - // If we have a master key, check if any of the existing keys differ. If they do, - // we need to re-authenticate the user. - requireUIA = keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) - } - if requireUIA { - sessionID := uploadReq.Auth.Session - if sessionID == "" { - sessionID = util.RandomString(sessionIDLength) - } - if uploadReq.Auth.Type != authtypes.LoginTypePassword { + if hasMasterKey { + if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) { + // If we have a master key, check if any of the existing keys differ. If they don't + // we return 200 as keys are still valid and there's nothing to do. return util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: newUserInteractiveResponse( - sessionID, - []authtypes.Flow{ - { - Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, - }, - }, - nil, - ), + Code: http.StatusOK, + JSON: struct{}{}, } } - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountAPI, - Config: cfg, - } - if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { - return *authErr - } - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) - } - uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + // With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable. + if cfg.MSCs.MSC3861Enabled() { + masterKeyResp := api.QueryMasterKeysResponse{} + keyserverAPI.QueryMasterKeys(req.Context(), &api.QueryMasterKeysRequest{UserID: device.UserID}, &masterKeyResp) - if err := uploadRes.Error; err != nil { - switch { - case err.IsInvalidSignature: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidSignature(err.Error()), + if masterKeyResp.Error != nil { + logger.WithError(masterKeyResp.Error).Error("Failed to query master key") + return convertKeyError(masterKeyResp.Error) } - case err.IsMissingParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.MissingParam(err.Error()), + + requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) && masterKeyResp.Key != nil + if requireUIA { + url := "" + if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" { + url = strings.Join([]string{m.AccountManagementURL, "?action=", CrossSigningResetStage}, "") + } else { + url = m.Issuer + } + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + "dummy", + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{CrossSigningResetStage}, + }, + }, + map[string]interface{}{ + CrossSigningResetStage: map[string]string{ + "url": url, + }, + }, + strings.Join([]string{ + "To reset your end-to-end encryption cross-signing identity, you first need to approve it at", + url, + "and then try again.", + }, " "), + ), + } } - case err.IsInvalidParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidParam(err.Error()), + sessions.restrictCrossSigningKeysReplacement(device.UserID) + } else { + sessionID := uploadReq.Auth.Session + if sessionID == "" { + sessionID = util.RandomString(sessionIDLength) } - default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(err.Error()), + + if uploadReq.Auth.Type != authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + "", + ), + } } + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountAPI, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { + return *authErr + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) } } + uploadReq.UserID = device.UserID + keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + + if err := uploadRes.Error; err != nil { + return convertKeyError(err) + } + return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, @@ -160,28 +187,7 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) if err := uploadRes.Error; err != nil { - switch { - case err.IsInvalidSignature: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidSignature(err.Error()), - } - case err.IsMissingParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.MissingParam(err.Error()), - } - case err.IsInvalidParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidParam(err.Error()), - } - default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(err.Error()), - } - } + return convertKeyError(err) } return util.JSONResponse{ @@ -189,3 +195,28 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie JSON: struct{}{}, } } + +func convertKeyError(err *api.KeyError) util.JSONResponse { + switch { + case err.IsInvalidSignature: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidSignature(err.Error()), + } + case err.IsMissingParam: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.MissingParam(err.Error()), + } + case err.IsInvalidParam: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(err.Error()), + } + default: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(err.Error()), + } + } +} diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index 0ebb91e07..0db15ab92 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -19,20 +19,31 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) +// TODO: add more tests to cover cases related to MSC3861 + type mockKeyAPI struct { - t *testing.T - userResponses map[string]api.QueryKeysResponse + t *testing.T + queryKeysData map[string]api.QueryKeysResponse + queryMasterKeysData map[string]api.QueryMasterKeysResponse } func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { - res.MasterKeys = m.userResponses[req.UserID].MasterKeys - res.SelfSigningKeys = m.userResponses[req.UserID].SelfSigningKeys - res.UserSigningKeys = m.userResponses[req.UserID].UserSigningKeys + res.MasterKeys = m.queryKeysData[req.UserID].MasterKeys + res.SelfSigningKeys = m.queryKeysData[req.UserID].SelfSigningKeys + res.UserSigningKeys = m.queryKeysData[req.UserID].UserSigningKeys if m.t != nil { m.t.Logf("QueryKeys: %+v => %+v", req, res) } } +func (m mockKeyAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { + res.Key = m.queryMasterKeysData[req.UserID].Key + res.Error = m.queryMasterKeysData[req.UserID].Error + if m.t != nil { + m.t.Logf("QueryMasterKeys: %+v => %+v", req, res) + } +} + func (m mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Just a dummy upload which always succeeds } @@ -53,13 +64,19 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) { req.Header.Set("Content-Type", "application/json") keyserverAPI := &mockKeyAPI{ - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ + "@user:example.com": {}, + }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ "@user:example.com": {}, }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} - + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, res.Code) @@ -101,18 +118,30 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) { keyserverAPI := &mockKeyAPI{ t: t, - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": { MasterKeys: map[string]fclient.CrossSigningKey{ - "@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}}, + "@user:example.com": { + UserID: "@user:example.com", + Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}}, }, SelfSigningKeys: nil, UserSigningKeys: nil, }, }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ + "@user:example.com": { + Key: spec.Base64Bytes("key1"), + }, + }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusUnauthorized { @@ -132,8 +161,11 @@ func Test_UploadCrossSigningDeviceKeys_InvalidJSON(t *testing.T) { keyserverAPI := &mockKeyAPI{} device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} - + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusBadRequest { t.Fatalf("expected status %d, got %d", http.StatusBadRequest, res.Code) @@ -151,13 +183,22 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) { req.Header.Set("Content-Type", "application/json") keyserverAPI := &mockKeyAPI{ - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": { MasterKeys: map[string]fclient.CrossSigningKey{ - "@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}}, + "@user:example.com": { + UserID: "@user:example.com", + Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}, + }, }, }, }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ + "@user:example.com": { + Key: spec.Base64Bytes("different_key"), + }, + }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index b987e6f23..f88dc6690 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/httputil" @@ -50,9 +51,10 @@ func TestLogin(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) // Needed for /login userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, caching.DisableMetrics) + Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, &userVerifier, caching.DisableMetrics) // Create password password := util.RandomString(8) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 59d9594d6..6258155db 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -67,6 +67,7 @@ func Password( }, }, nil, + "", ), } } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index b75d38a62..922d5e901 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -172,24 +172,20 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, profileAPI userapi.ProfileAPI, + req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { - if userID != device.UserID { + if userID != device.UserID && device.AccountType != userapi.AccountTypeOIDCService { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("userID does not match the current user"), } } - var r eventutil.UserProfile - if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } - + logger := util.GetLogger(req.Context()) localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + logger.WithError(err).Error("gomatrixserverlib.SplitID failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, @@ -203,6 +199,28 @@ func SetDisplayName( } } + if device.AccountType == userapi.AccountTypeOIDCService { + // When a request is made on behalf of an OIDC provider service, the original device object refers + // to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue, + // we need to replace the admin's device with the user's device + var rs userapi.QueryDevicesResponse + err = userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if len(rs.Devices) > 0 { + device = &rs.Devices[0] + } + } + + var r eventutil.UserProfile + if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { + return *resErr + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -211,9 +229,9 @@ func SetDisplayName( } } - profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) + profile, changed, err := userAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") + logger.WithError(err).Error("profileAPI.SetDisplayName failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 5544dccd3..da43a6b01 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -66,11 +66,17 @@ type sessionsDict struct { // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. deleteSessionToDeviceID map[string]string + // crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating + // cross-signing keys without UIA. + crossSigningKeysReplacement map[string]*time.Timer } // defaultTimeout is the timeout used to clean up sessions const defaultTimeOut = time.Minute * 5 +// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA +const crossSigningKeysReplacementDuration = time.Minute * 10 + // getCompletedStages returns the completed stages for a session. func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { d.RLock() @@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) { } } +func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { + d.Lock() + defer d.Unlock() + ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli() + t, ok := d.crossSigningKeysReplacement[userID] + if ok { + t.Reset(crossSigningKeysReplacementDuration) + return ts + } + d.crossSigningKeysReplacement[userID] = time.AfterFunc( + crossSigningKeysReplacementDuration, + func() { + d.restrictCrossSigningKeysReplacement(userID) + }, + ) + return ts +} + +func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool { + d.RLock() + defer d.RUnlock() + _, ok := d.crossSigningKeysReplacement[userID] + return ok +} + +func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { + d.Lock() + defer d.Unlock() + t, ok := d.crossSigningKeysReplacement[userID] + if ok { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + delete(d.crossSigningKeysReplacement, userID) + } +} + func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - sessionCompletedResult: make(map[string]registerResponse), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), - deleteSessionToDeviceID: make(map[string]string), + sessions: make(map[string][]authtypes.LoginType), + sessionCompletedResult: make(map[string]registerResponse), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), + crossSigningKeysReplacement: make(map[string]*time.Timer), } } @@ -234,6 +281,7 @@ type userInteractiveResponse struct { Completed []authtypes.LoginType `json:"completed"` Params map[string]interface{} `json:"params"` Session string `json:"session"` + Msg string `json:"msg,omitempty"` } // newUserInteractiveResponse will return a struct to be sent back to the client @@ -242,9 +290,10 @@ func newUserInteractiveResponse( sessionID string, fs []authtypes.Flow, params map[string]interface{}, + msg string, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.getCompletedStages(sessionID), params, sessionID, + fs, sessions.getCompletedStages(sessionID), params, sessionID, msg, } } @@ -817,7 +866,7 @@ func checkAndCompleteFlow( return util.JSONResponse{ Code: http.StatusUnauthorized, JSON: newUserInteractiveResponse(sessionID, - cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params), + cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params, ""), } } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 71cc0ca67..8529d7c59 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { assert.Equal(t, expectedDisplayName, profile.DisplayName) }) } + +func TestCrossSigningKeysReplacement(t *testing.T) { + userID := "@user:example.com" + + t.Run("Can add new session", func(t *testing.T) { + s := newSessionsDict() + assert.Empty(t, s.crossSigningKeysReplacement) + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.Contains(t, s.crossSigningKeysReplacement, userID) + }) + + t.Run("Can check if session exists or not", func(t *testing.T) { + s := newSessionsDict() + t.Run("exists", func(t *testing.T) { + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID)) + }) + + t.Run("not exists", func(t *testing.T) { + assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com")) + }) + }) + + t.Run("Can deactivate session", func(t *testing.T) { + s := newSessionsDict() + assert.Empty(t, s.crossSigningKeysReplacement) + t.Run("not exists", func(t *testing.T) { + s.restrictCrossSigningKeysReplacement("@random:test.com") + assert.Empty(t, s.crossSigningKeysReplacement) + }) + + t.Run("exists", func(t *testing.T) { + s.allowCrossSigningKeysReplacement(userID) + s.restrictCrossSigningKeysReplacement(userID) + assert.Empty(t, s.crossSigningKeysReplacement) + }) + }) + + t.Run("Can erase expired sessions", func(t *testing.T) { + s := newSessionsDict() + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID)) + timer := s.crossSigningKeysReplacement[userID] + + // pretending the timer is expired + timer.Reset(time.Millisecond) + time.Sleep(time.Millisecond * 500) + + assert.Empty(t, s.crossSigningKeysReplacement) + }) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f0aa087db..15a5addfb 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -67,7 +67,9 @@ func Setup( transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - natsClient *nats.Conn, enableMetrics bool, + natsClient *nats.Conn, + userVerifier httputil.UserVerifier, + enableMetrics bool, ) { cfg := &dendriteCfg.ClientAPI mscCfg := &dendriteCfg.MSCs @@ -171,19 +173,19 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } dendriteAdminRouter.Handle("/admin/registrationTokens/new", - httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_registration_tokens_new", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminCreateNewRegistrationToken(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/registrationTokens", - httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_list_registration_tokens", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminListRegistrationTokens(req, cfg, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", - httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_get_registration_token", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { switch req.Method { case http.MethodGet: return AdminGetRegistrationToken(req, cfg, userAPI) @@ -202,43 +204,43 @@ func Setup( ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", - httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_evacuate_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", - httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_evacuate_user", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateUser(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", - httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_purge_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminPurgeRoom(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", - httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_reset_password", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}", - httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_download_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminDownloadState(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/fulltext/reindex", - httputil.MakeAdminAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_fultext_reindex", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminReindex(req, cfg, device, natsClient) }), ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", - httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_refresh_devices", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminMarkAsStale(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -252,7 +254,7 @@ func Setup( } synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", - httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_server_notice", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -273,7 +275,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/send_server_notice", - httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_server_notice", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -301,12 +303,12 @@ func Setup( unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() v3mux.Handle("/createRoom", - httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("createRoom", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/join/{roomIDOrAlias}", - httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -328,9 +330,102 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) + if mscCfg.MSC3861Enabled() { + m := mscCfg.MSC3861 + + unstableMux.Handle("/org.matrix.msc2965/auth_issuer", + httputil.MakeExternalAPI("auth_issuer", func(r *http.Request) util.JSONResponse { + return util.JSONResponse{Code: http.StatusOK, JSON: map[string]string{ + "issuer": m.Issuer, + }} + })).Methods(http.MethodGet) + synapseAdminRouter.Handle("/admin/v1/username_available", + httputil.MakeServiceAdminAPI("admin_username_available", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminCheckUsernameAvailable(r, userAPI, cfg) + })).Methods(http.MethodGet) + synapseAdminRouter.Handle("/admin/v1/deactivate/{userID}", + httputil.MakeServiceAdminAPI("admin_deactivate_user", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminDeactivateAccount(r, userAPI, cfg) + })).Methods(http.MethodPost) + synapseAdminRouter.Handle("/admin/v2/users/{userID}", + httputil.MakeServiceAdminAPI("admin_manage_user", m.AdminToken, func(r *http.Request) util.JSONResponse { + switch r.Method { + case http.MethodGet: + return AdminRetrieveAccount(r, cfg, userAPI) + case http.MethodPut: + return AdminCreateOrModifyAccount(r, userAPI, cfg) + default: + return util.JSONResponse{Code: http.StatusMethodNotAllowed, JSON: nil} + } + })).Methods(http.MethodPut, http.MethodGet) + synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", + httputil.MakeServiceAdminAPI("admin_create_retrieve_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminUserDeviceRetrieveCreate(r, userAPI, cfg) + })).Methods(http.MethodPost, http.MethodGet) + synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices/{deviceID}", + httputil.MakeServiceAdminAPI("admin_delete_user_device", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminUserDeviceDelete(r, userAPI, cfg) + })).Methods(http.MethodDelete) + synapseAdminRouter.Handle("/admin/v2/users/{userID}/delete_devices", + httputil.MakeServiceAdminAPI("admin_delete_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminUserDevicesDelete(r, userAPI, cfg) + })).Methods(http.MethodPost) + synapseAdminRouter.Handle("/admin/v1/users/{userID}/_allow_cross_signing_replacement_without_uia", + httputil.MakeServiceAdminAPI("admin_allow_cross_signing_replacement_without_uia", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminAllowCrossSigningReplacementWithoutUIA(r, userAPI) + })).Methods(http.MethodPost) + } else { + // If msc3861 is enabled, these endpoints are either redundant or replaced by Matrix Auth Service (MAS) + // Once we migrate to MAS completely, these endpoints should be removed + + v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimits.Limit(req, nil); r != nil { + return *r + } + return Register(req, userAPI, cfg) + })).Methods(http.MethodPost, http.MethodOptions) + + v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + if r := rateLimits.Limit(req, nil); r != nil { + return *r + } + return RegisterAvailable(req, cfg, userAPI) + })).Methods(http.MethodGet, http.MethodOptions) + + // Stub endpoints required by Element + + v3mux.Handle("/login", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + if r := rateLimits.Limit(req, nil); r != nil { + return *r + } + return Login(req, userAPI, cfg) + }), + ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + + v3mux.Handle("/auth/{authType}/fallback/web", + httputil.MakeHTTPAPI("auth_fallback", userVerifier, enableMetrics, func(w http.ResponseWriter, req *http.Request) { + vars := mux.Vars(req) + AuthFallback(w, req, vars["authType"], cfg) + }), + ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + + v3mux.Handle("/logout", + httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return Logout(req, userAPI, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + v3mux.Handle("/logout/all", + httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return LogoutAll(req, userAPI, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + } + if mscCfg.Enabled("msc2753") { v3mux.Handle("/peek/{roomIDOrAlias}", - httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Peek, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -345,12 +440,12 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) } v3mux.Handle("/joined_rooms", - httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("joined_rooms", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", - httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -372,7 +467,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -386,7 +481,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unpeek", - httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("unpeek", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -397,7 +492,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/ban", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -406,7 +501,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/invite", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -418,7 +513,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/kick", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -427,7 +522,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unban", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -436,7 +531,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -445,7 +540,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -456,7 +551,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -464,7 +559,7 @@ func Setup( return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -472,7 +567,7 @@ func Setup( return GetAliases(req, rsAPI, device, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -483,7 +578,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -493,7 +588,7 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -505,7 +600,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -519,7 +614,7 @@ func Setup( // TODO: clear based on some criteria roomHierarchyPaginationCache := NewRoomHierarchyPaginationCache() v1mux.Handle("/rooms/{roomID}/hierarchy", - httputil.MakeAuthAPI("spaces", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("spaces", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -528,20 +623,6 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - if r := rateLimits.Limit(req, nil); r != nil { - return *r - } - return Register(req, userAPI, cfg) - })).Methods(http.MethodPost, http.MethodOptions) - - v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { - if r := rateLimits.Limit(req, nil); r != nil { - return *r - } - return RegisterAvailable(req, cfg, userAPI) - })).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -553,7 +634,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/directory/room/{roomAlias}", - httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -563,7 +644,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/directory/room/{roomAlias}", - httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -582,7 +663,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/directory/list/room/{roomID}", - httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_list", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -591,7 +672,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}", - httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_list", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -602,7 +683,7 @@ func Setup( // Undocumented endpoint v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}", - httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_list", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -617,20 +698,8 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - v3mux.Handle("/logout", - httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return Logout(req, userAPI, device) - }), - ).Methods(http.MethodPost, http.MethodOptions) - - v3mux.Handle("/logout/all", - httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return LogoutAll(req, userAPI, device) - }), - ).Methods(http.MethodPost, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/typing/{userID}", - httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_typing", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -642,7 +711,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/redact/{eventID}", - httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_redact", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -651,7 +720,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", - httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_redact", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -662,7 +731,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/sendToDevice/{eventType}/{txnID}", - httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_to_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -676,7 +745,7 @@ func Setup( // rather than r0. It's an exact duplicate of the above handler. // TODO: Remove this if/when sytest is fixed! unstableMux.Handle("/sendToDevice/{eventType}/{txnID}", - httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_to_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -687,7 +756,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/account/whoami", - httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("whoami", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -696,7 +765,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", - httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("password", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -705,7 +774,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/account/deactivate", - httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("deactivate", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -713,28 +782,10 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // Stub endpoints required by Element - - v3mux.Handle("/login", - httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - if r := rateLimits.Limit(req, nil); r != nil { - return *r - } - return Login(req, userAPI, cfg) - }), - ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - - v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTTPAPI("auth_fallback", userAPI, enableMetrics, func(w http.ResponseWriter, req *http.Request) { - vars := mux.Vars(req) - AuthFallback(w, req, vars["authType"], cfg) - }), - ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - // Push rules v3mux.Handle("/pushrules", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("missing trailing slash"), @@ -743,13 +794,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAllPushRules(req.Context(), device, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("scope, kind and rule ID must be specified"), @@ -758,7 +809,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -768,7 +819,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("missing trailing slash after scope"), @@ -777,7 +828,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope:[^/]+/?}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("kind and rule ID must be specified"), @@ -786,7 +837,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/{kind}/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -796,7 +847,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("missing trailing slash after kind"), @@ -805,7 +856,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("rule ID must be specified"), @@ -814,7 +865,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -824,7 +875,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -838,7 +889,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -848,7 +899,7 @@ func Setup( ).Methods(http.MethodDelete) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -858,7 +909,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -890,7 +941,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/profile/{userID}/avatar_url", - httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("profile_avatar_url", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -915,7 +966,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/profile/{userID}/displayname", - httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("profile_displayname", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -932,19 +983,19 @@ func Setup( threePIDClient := base.CreateClient(dendriteCfg, nil) // TODO: Move this somewhere else, e.g. pass in as parameter v3mux.Handle("/account/3pid", - httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/3pid", - httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, userAPI, device, cfg, threePIDClient) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/account/3pid/delete", - httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Forget3PID(req, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -956,7 +1007,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/voip/turnServer", - httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("turn_server", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -965,13 +1016,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocols", - httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_protocols", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Protocols(req, asAPI, device, "") }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocol/{protocolID}", - httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_protocols", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -981,7 +1032,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user/{protocolID}", - httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_user", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -991,13 +1042,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user", - httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_user", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return User(req, asAPI, device, "", req.URL.Query()) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location/{protocolID}", - httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_location", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1007,7 +1058,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location", - httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_location", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Location(req, asAPI, device, "", req.URL.Query()) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) @@ -1023,7 +1074,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1033,7 +1084,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1043,7 +1094,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1053,7 +1104,7 @@ func Setup( ).Methods(http.MethodGet) v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1063,7 +1114,7 @@ func Setup( ).Methods(http.MethodGet) v3mux.Handle("/admin/whois/{userID}", - httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("admin_whois", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1073,7 +1124,7 @@ func Setup( ).Methods(http.MethodGet) v3mux.Handle("/user/{userID}/openid/request_token", - httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("openid_request_token", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1086,7 +1137,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/user_directory/search", - httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("userdirectory_search", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1112,7 +1163,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/read_markers", - httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_read_markers", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1125,7 +1176,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/forget", - httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_forget", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1138,7 +1189,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/upgrade", - httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_upgrade", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1148,13 +1199,13 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/devices", - httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_devices", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", - httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1164,7 +1215,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", - httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("device_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1174,7 +1225,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", - httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1184,25 +1235,25 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) v3mux.Handle("/delete_devices", - httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_devices", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return DeleteDevices(req, userInteractiveAuth, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/notifications", - httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_notifications", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetNotifications(req, device, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushers", - httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_pushers", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetPushers(req, device, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushers/set", - httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("set_pushers", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1212,7 +1263,7 @@ func Setup( // Stub implementations for sytest v3mux.Handle("/events", - httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("events", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", @@ -1222,7 +1273,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/initialSync", - httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("initial_sync", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} @@ -1230,7 +1281,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", - httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_tags", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1240,7 +1291,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("put_tag", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1250,7 +1301,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_tag", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1260,7 +1311,7 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) v3mux.Handle("/capabilities", - httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("capabilities", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1270,7 +1321,7 @@ func Setup( // Key Backup Versions (Metadata) - getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1278,11 +1329,11 @@ func Setup( return KeyBackupVersion(req, userAPI, device, vars["version"]) }) - getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return KeyBackupVersion(req, userAPI, device, "") }) - putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1290,7 +1341,7 @@ func Setup( return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) }) - deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1298,7 +1349,7 @@ func Setup( return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) }) - postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateKeyBackupVersion(req, userAPI, device) }) @@ -1317,7 +1368,7 @@ func Setup( // Inserting E2E Backup Keys // Bulk room and session - putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { version := req.URL.Query().Get("version") if version == "" { return util.JSONResponse{ @@ -1334,7 +1385,7 @@ func Setup( }) // Single room bulk session - putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1366,7 +1417,7 @@ func Setup( }) // Single room, single session - putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1408,11 +1459,11 @@ func Setup( // Querying E2E Backup Keys - getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") }) - getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1420,7 +1471,7 @@ func Setup( return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") }) - getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1440,11 +1491,11 @@ func Setup( // Cross-signing device keys - postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceKeys(req, userAPI, device, userAPI.QueryAccountByPassword, cfg) }) - postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceSignatures(req, userAPI, device) }, httputil.WithAllowGuests()) @@ -1456,27 +1507,27 @@ func Setup( // Supplying a device ID is deprecated. v3mux.Handle("/keys/upload/{deviceID}", - httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_upload", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", - httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_upload", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", - httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_query", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", - httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_claim", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, userAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", - httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1489,7 +1540,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/presence/{userId}/status", - httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("set_presence", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1498,7 +1549,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/presence/{userId}/status", - httputil.MakeAuthAPI("get_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_presence", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1508,7 +1559,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/joined_members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_members", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1518,7 +1569,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/report/{eventID}", - httputil.MakeAuthAPI("report_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("report_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1528,7 +1579,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/event_reports", - httputil.MakeAdminAPI("admin_report_events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_report_events", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { from := parseUint64OrDefault(req.URL.Query().Get("from"), 0) limit := parseUint64OrDefault(req.URL.Query().Get("limit"), 100) dir := req.URL.Query().Get("dir") @@ -1542,7 +1593,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/event_reports/{reportID}", - httputil.MakeAdminAPI("admin_report_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_report_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1552,7 +1603,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/event_reports/{reportID}", - httputil.MakeAdminAPI("admin_report_event_delete", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_report_event_delete", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index e1ac179a7..519d5e47d 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -55,7 +55,7 @@ var latest, _ = semver.NewVersion("v6.6.6") // Dummy version, used as "HEAD" // due to the error: // When using COPY with more than one source file, the destination must be a directory and end with a / // We need to run a postgres anyway, so use the dockerfile associated with Complement instead. -const DockerfilePostgreSQL = `FROM golang:1.22-bookworm as build +const DockerfilePostgreSQL = `FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build ARG BINARY @@ -99,7 +99,7 @@ ENV BINARY=dendrite EXPOSE 8008 8448 CMD /build/run_dendrite.sh` -const DockerfileSQLite = `FROM golang:1.22-bookworm as build +const DockerfileSQLite = `FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build ARG BINARY diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 2afdc33f1..bfa17051c 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -303,8 +303,27 @@ media_api: # Configuration for enabling experimental MSCs on this homeserver. mscs: mscs: + # - msc3861 # (Next-gen auth, see https://github.com/matrix-org/matrix-doc/pull/3861) # - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) + # This block has no effect if the feature is not activated in the list above + # msc3861: + # # OIDC issuer advertised by the service. + # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#http + # issuer: "https://mas.example.com/" + + # # Credentials used for authenticating requests coming from dendrite to auth service. + # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#clients + # client_id: 01JFNM9MCHKV6A7A0C0RBHMYC0 + # client_secret: c85731184ac8f9aea76cf48146046b454473ca667a0cd1fd52a43034a0662eed + + # # The service token used for authenticating requests coming from auth service to dendrite. + # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#matrix + # admin_token: ttJORW9oV4Wf4DJ63GdZEYekE2KElP4g + + # # URL of the account page on the auth service side + # account_management_url: "https://mas.example.com/account" + # Configuration for the Sync API. sync_api: # This option controls which HTTP header to inspect to find the real remote IP diff --git a/federationapi/storage/storage_wasm.go b/federationapi/storage/storage_wasm.go index 9f630f37d..10ed7d2a1 100644 --- a/federationapi/storage/storage_wasm.go +++ b/federationapi/storage/storage_wasm.go @@ -14,7 +14,7 @@ import ( "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database diff --git a/go.mod b/go.mod index b1fe33b77..37bde3b26 100644 --- a/go.mod +++ b/go.mod @@ -156,6 +156,6 @@ require ( nhooyr.io/websocket v1.8.7 // indirect ) -go 1.22 +go 1.23 toolchain go1.23.2 diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index d32557679..65a2db2e0 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -58,17 +58,26 @@ func WithAuth() AuthAPIOption { } } +// UserVerifier verifies users by their access tokens. Currently, there are two interface implementations: +// DefaultUserVerifier and MSC3861UserVerifier. The first one checks if the token exists in the server's database, +// whereas the latter passes the token for verification to MAS and acts in accordance with MAS's response. +type UserVerifier interface { + // VerifyUserFromRequest authenticates the HTTP request, + // on success returns Device of the requester. + VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) +} + // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( - metricsName string, userAPI userapi.QueryAcccessTokenAPI, + metricsName string, userVerifier UserVerifier, f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPIOption, ) http.Handler { h := func(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) - device, err := auth.VerifyUserFromRequest(req, userAPI) + device, err := userVerifier.VerifyUserFromRequest(req) if err != nil { - logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code) + logger.Debugf("VerifyUserFromRequest %s -> HTTP %d: JSON %+v", req.RemoteAddr, err.Code, err.JSON) return *err } // add the user ID to the logger @@ -122,11 +131,11 @@ func MakeAuthAPI( // MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be // completed by a user that is a server administrator. func MakeAdminAPI( - metricsName string, userAPI userapi.QueryAcccessTokenAPI, + metricsName string, userVerifier UserVerifier, f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { - return MakeAuthAPI(metricsName, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - if device.AccountType != userapi.AccountTypeAdmin { + return MakeAuthAPI(metricsName, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if device == nil || device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("This API can only be used by admin users."), @@ -136,6 +145,38 @@ func MakeAdminAPI( }) } +// MakeServiceAdminAPI is a wrapper around MakeExternalAPI which enforces that the request can only be +// completed by a trusted service e.g. Matrix Auth Service (MAS). +func MakeServiceAdminAPI( + metricsName, serviceToken string, + f func(*http.Request) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + logger := util.GetLogger(req.Context()) + token, err := auth.ExtractAccessToken(req) + + if err != nil { + logger.Debugf("ExtractAccessToken %s -> HTTP %d", req.RemoteAddr, http.StatusUnauthorized) + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.MissingToken(err.Error()), + } + } + if token != serviceToken { + logger.Debugf("Invalid service token '%s'", token) + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.UnknownToken(token), + } + } + // add the service addr to the logger + logger = logger.WithField("service_useragent", req.UserAgent()) + req = req.WithContext(util.ContextWithLogger(req.Context(), logger)) + return f(req) + } + return MakeExternalAPI(metricsName, h) +} + // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // This is used for APIs that are called from the internet. func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { @@ -200,7 +241,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTTPAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler { +func MakeHTTPAPI(metricsName string, userVerifier UserVerifier, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { if req.Method == http.MethodOptions { util.SetCORSHeaders(w) @@ -220,7 +261,7 @@ func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enabl if opts.WithAuth { logger := util.GetLogger(req.Context()) - _, jsonErr := auth.VerifyUserFromRequest(req, userAPI) + _, jsonErr := userVerifier.VerifyUserFromRequest(req) if jsonErr != nil { w.WriteHeader(jsonErr.Code) if err := json.NewEncoder(w).Encode(jsonErr.JSON); err != nil { diff --git a/internal/httputil/httpapi_test.go b/internal/httputil/httpapi_test.go index 23797a5ea..c9dd933cf 100644 --- a/internal/httputil/httpapi_test.go +++ b/internal/httputil/httpapi_test.go @@ -10,6 +10,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/matrix-org/util" ) func TestWrapHandlerInBasicAuth(t *testing.T) { @@ -99,3 +101,68 @@ func TestWrapHandlerInBasicAuth(t *testing.T) { }) } } + +func TestMakeServiceAdminAPI(t *testing.T) { + serviceToken := "valid_secret_token" + type args struct { + f func(*http.Request) util.JSONResponse + serviceToken string + } + + f := func(*http.Request) util.JSONResponse { + return util.JSONResponse{Code: http.StatusOK} + } + + tests := []struct { + name string + args args + want int + reqAuth bool + }{ + { + name: "service token valid", + args: args{ + f: f, + serviceToken: serviceToken, + }, + want: http.StatusOK, + reqAuth: true, + }, + { + name: "service token invalid", + args: args{ + f: f, + serviceToken: "invalid_service_token", + }, + want: http.StatusForbidden, + reqAuth: true, + }, + { + name: "service token is missing", + args: args{ + f: f, + serviceToken: "", + }, + want: http.StatusUnauthorized, + reqAuth: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := MakeServiceAdminAPI("metrics", serviceToken, tt.args.f) + + req := httptest.NewRequest("GET", "http://localhost/admin/v1/username_available", nil) + if tt.reqAuth { + req.Header.Add("Authorization", "Bearer "+tt.args.serviceToken) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + resp := w.Result() + + if resp.StatusCode != tt.want { + t.Errorf("Expected status code %d, got %d", resp.StatusCode, tt.want) + } + }) + } +} diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index 307009323..d8955bdd6 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -14,7 +14,6 @@ import ( "github.com/element-hq/dendrite/mediaapi/routing" "github.com/element-hq/dendrite/mediaapi/storage" "github.com/element-hq/dendrite/setup/config" - userapi "github.com/element-hq/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" ) @@ -24,10 +23,10 @@ func AddPublicRoutes( routers httputil.Routers, cm *sqlutil.Connections, cfg *config.Dendrite, - userAPI userapi.MediaUserAPI, client *fclient.Client, fedClient fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, + userVerifier httputil.UserVerifier, ) { mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) if err != nil { @@ -35,6 +34,6 @@ func AddPublicRoutes( } routing.Setup( - routers, cfg, mediaDB, userAPI, client, fedClient, keyRing, + routers, cfg, mediaDB, client, fedClient, keyRing, userVerifier, ) } diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 45da8eba6..3d198f0d0 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -42,10 +42,10 @@ func Setup( routers httputil.Routers, cfg *config.Dendrite, db storage.Database, - userAPI userapi.MediaUserAPI, client *fclient.Client, federationClient fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, + userVerifier httputil.UserVerifier, ) { rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) @@ -58,7 +58,7 @@ func Setup( } uploadHandler := httputil.MakeAuthAPI( - "upload", userAPI, + "upload", userVerifier, func(req *http.Request, dev *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, dev); r != nil { return *r @@ -67,7 +67,7 @@ func Setup( }, ) - configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + configHandler := httputil.MakeAuthAPI("config", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -97,13 +97,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) // v1 client endpoints requiring auth - downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()) + downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()) v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/thumbnail/{serverName}/{mediaId}", - httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), + httputil.MakeHTTPAPI("thumbnail", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), ).Methods(http.MethodGet, http.MethodOptions) // same, but for federation diff --git a/relayapi/storage/storage_wasm.go b/relayapi/storage/storage_wasm.go index 86ba972a9..69f4fa174 100644 --- a/relayapi/storage/storage_wasm.go +++ b/relayapi/storage/storage_wasm.go @@ -13,7 +13,7 @@ import ( "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/relayapi/storage/sqlite3" "github.com/element-hq/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 48911d2bb..01d4b47dd 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/federationapi/statistics" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/eventutil" @@ -267,7 +268,8 @@ func TestPurgeRoom(t *testing.T) { rsAPI.SetFederationAPI(fsAPI, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, fsAPI.IsBlacklistedOrBackingOff) - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics) // Create the room if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 85dfe0beb..3e683059c 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -74,34 +74,45 @@ func (c *ClientAPI) Defaults(opts DefaultOpts) { func (c *ClientAPI) Verify(configErrs *ConfigErrors) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) - if c.RecaptchaEnabled { - if c.RecaptchaSiteVerifyAPI == "" { - c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" - } - if c.RecaptchaApiJsUrl == "" { - c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" + + if c.MSCs.MSC3861Enabled() { + if !c.RegistrationDisabled || c.RecaptchaEnabled { + configErrs.Add( + "You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC. " + + "As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled. " + + "You need to disable registration (client_api.registration_disabled) and recapthca (client_api.enable_registration_captcha) options to proceed.", + ) } - if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha-response" + } else { + if c.RecaptchaEnabled { + if c.RecaptchaSiteVerifyAPI == "" { + c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" + } + if c.RecaptchaApiJsUrl == "" { + c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" + } + if c.RecaptchaFormField == "" { + c.RecaptchaFormField = "g-recaptcha-response" + } + if c.RecaptchaSitekeyClass == "" { + c.RecaptchaSitekeyClass = "g-recaptcha" + } + checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) + checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) + checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) + checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } - if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha" + // Ensure there is any spam counter measure when enabling registration + if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled { + configErrs.Add( + "You have tried to enable open registration without any secondary verification methods " + + "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + + "increasing the risk that your server will be used to send spam or abuse, and may result in " + + "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + + "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + + "should set the registration_disabled option in your Dendrite config.", + ) } - checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) - checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) - checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) - checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) - } - // Ensure there is any spam counter measure when enabling registration - if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled { - configErrs.Add( - "You have tried to enable open registration without any secondary verification methods " + - "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + - "increasing the risk that your server will be used to send spam or abuse, and may result in " + - "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + - "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + - "should set the registration_disabled option in your Dendrite config.", - ) } } diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index ce491cd72..fb0c547fe 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -4,11 +4,16 @@ type MSCs struct { Matrix *Global `yaml:"-"` // The MSCs to enable. Supported MSCs include: + // 'msc3861': Delegate auth to an OIDC provider - https://github.com/matrix-org/matrix-spec-proposals/pull/3861 // 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444 // 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753 // 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836 MSCs []string `yaml:"mscs"` + // MSC3861 contains config related to the experimental feature MSC3861. + // It takes effect only if 'msc3861' is included in 'MSCs' array. + MSC3861 *MSC3861 `yaml:"msc3861,omitempty"` + Database DatabaseOptions `yaml:"database,omitempty"` } @@ -34,4 +39,27 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) } + if m := c.MSC3861; m != nil && c.MSC3861Enabled() { + m.Verify(configErrs) + } +} + +func (c *MSCs) MSC3861Enabled() bool { + return c.Enabled("msc3861") && c.MSC3861 != nil +} + +type MSC3861 struct { + Issuer string `yaml:"issuer"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + AdminToken string `yaml:"admin_token"` + AccountManagementURL string `yaml:"account_management_url"` +} + +func (m *MSC3861) Verify(configErrs *ConfigErrors) { + checkNotEmpty(configErrs, "mscs.msc3861.issuer", string(m.Issuer)) + checkNotEmpty(configErrs, "mscs.msc3861.client_id", string(m.ClientID)) + checkNotEmpty(configErrs, "mscs.msc3861.client_secret", string(m.ClientSecret)) + checkNotEmpty(configErrs, "mscs.msc3861.admin_token", string(m.AdminToken)) + checkNotEmpty(configErrs, "mscs.msc3861.account_management_url", string(m.AccountManagementURL)) } diff --git a/setup/monolith.go b/setup/monolith.go index 36d6794d6..8d8fadc90 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -7,9 +7,12 @@ package setup import ( + "net/http" + appserviceAPI "github.com/element-hq/dendrite/appservice/api" "github.com/element-hq/dendrite/clientapi" "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/federationapi" federationAPI "github.com/element-hq/dendrite/federationapi/api" "github.com/element-hq/dendrite/internal/caching" @@ -27,6 +30,7 @@ import ( userapi "github.com/element-hq/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/util" ) // Monolith represents an instantiation of all dependencies required to build @@ -46,6 +50,8 @@ type Monolith struct { // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider ExtUserDirectoryProvider userapi.QuerySearchProfilesAPI + + UserVerifierProvider *UserVerifierProvider } // AddAllPublicRoutes attaches all public paths to the given router @@ -58,6 +64,10 @@ func (m *Monolith) AddAllPublicRoutes( caches *caching.Caches, enableMetrics bool, ) { + if m.UserVerifierProvider == nil { + m.UserVerifierProvider = NewUserVerifierProvider(&auth.DefaultUserVerifier{UserAPI: m.UserAPI}) + } + userDirectoryProvider := m.ExtUserDirectoryProvider if userDirectoryProvider == nil { userDirectoryProvider = m.UserAPI @@ -65,15 +75,29 @@ func (m *Monolith) AddAllPublicRoutes( clientapi.AddPublicRoutes( processCtx, routers, cfg, natsInstance, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), m.FederationAPI, m.UserAPI, userDirectoryProvider, - m.ExtPublicRoomsProvider, enableMetrics, + m.ExtPublicRoomsProvider, m.UserVerifierProvider, enableMetrics, ) federationapi.AddPublicRoutes( processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics, ) - mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing) - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) + mediaapi.AddPublicRoutes(routers, cm, cfg, m.Client, m.FedClient, m.KeyRing, m.UserVerifierProvider) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, m.UserVerifierProvider, enableMetrics) if m.RelayAPI != nil { relayapi.AddPublicRoutes(routers, cfg, m.KeyRing, m.RelayAPI) } } + +type UserVerifierProvider struct { + UserVerifier httputil.UserVerifier +} + +func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { + return u.UserVerifier.VerifyUserFromRequest(req) +} + +func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider { + return &UserVerifierProvider{ + UserVerifier: userVerifier, + } +} diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 4322e8a2b..847e836ab 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -98,7 +98,7 @@ func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsRespons // Enable this MSC func Enable( cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, - userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, + userVerifier httputil.UserVerifier, keyRing gomatrixserverlib.JSONVerifier, ) error { db, err := NewDatabase(cm, &cfg.MSCs.Database) if err != nil { @@ -124,7 +124,7 @@ func Enable( }) routers.Client.Handle("/unstable/event_relationships", - httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)), + httputil.MakeAuthAPI("eventRelationships", userVerifier, eventRelationshipHandler(db, rsAPI, fsAPI)), ).Methods(http.MethodPost, http.MethodOptions) routers.Federation.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 5b85e6707..024f175c7 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/setup/process" "github.com/element-hq/dendrite/syncapi/synctypes" "github.com/gorilla/mux" @@ -571,7 +572,8 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve processCtx := process.NewProcessContext() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) routers := httputil.NewRouters() - err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, userAPI, nil) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, &userVerifier, nil) if err != nil { t.Fatalf("failed to enable MSC2836: %s", err) } diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go new file mode 100644 index 000000000..9e3d00123 --- /dev/null +++ b/setup/mscs/msc3861/msc3861.go @@ -0,0 +1,25 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package msc3861 + +import ( + "github.com/element-hq/dendrite/setup" + "github.com/matrix-org/gomatrixserverlib/fclient" +) + +func Enable(m *setup.Monolith) error { + client := fclient.NewClient() + userVerifier, err := newMSC3861UserVerifier( + m.UserAPI, m.Config.Global.ServerName, + m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled, + client, + ) + if err != nil { + return err + } + m.UserVerifierProvider.UserVerifier = userVerifier + return nil +} diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go new file mode 100644 index 000000000..0de9b7fc9 --- /dev/null +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -0,0 +1,451 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package msc3861 + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/element-hq/dendrite/clientapi/auth" + "github.com/element-hq/dendrite/setup/config" + "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" +) + +const externalAuthProvider string = "oauth-delegated" + +// Scopes as defined by MSC2967 +// https://github.com/matrix-org/matrix-spec-proposals/pull/2967 +const ( + scopeMatrixAPI string = "urn:matrix:org.matrix.msc2967.client:api:*" + scopeMatrixGuest string = "urn:matrix:org.matrix.msc2967.client:api:guest" + scopeMatrixDevicePrefix string = "urn:matrix:org.matrix.msc2967.client:device:" +) + +type errCode string + +const ( + codeIntrospectionNot2xx errCode = "introspectionIsNot2xx" + codeInvalidClientToken errCode = "invalidClientToken" + codeAuthError errCode = "authError" + codeMxidError errCode = "mxidError" + codeOpenidConfigEndpointNon2xx errCode = "openidConfigEndpointNon2xx" + codeOpenidConfigDecodingFailed errCode = "openidConfigDecodingFailed" +) + +// MSC3861UserVerifier implements UserVerifier interface +type MSC3861UserVerifier struct { + userAPI api.UserInternalAPI + serverName spec.ServerName + cfg *config.MSC3861 + httpClient *fclient.Client + openIdConfig *OpenIDConfiguration + allowGuest bool +} + +func newMSC3861UserVerifier( + userAPI api.UserInternalAPI, + serverName spec.ServerName, + cfg *config.MSC3861, + allowGuest bool, + client *fclient.Client, +) (*MSC3861UserVerifier, error) { + if cfg == nil { + return nil, errors.New("unable to create MSC3861UserVerifier object as 'cfg' param is nil") + } + + if client == nil { + return nil, errors.New("unable to create MSC3861UserVerifier object as 'client' param is nil") + } + + openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer) + if err != nil { + return nil, err + } + + return &MSC3861UserVerifier{ + userAPI: userAPI, + serverName: serverName, + cfg: cfg, + openIdConfig: openIdConfig, + allowGuest: allowGuest, + httpClient: client, + }, nil +} + +type mscError struct { + Code errCode + Msg string +} + +func (r *mscError) Error() string { + return fmt.Sprintf("%s: %s", r.Code, r.Msg) +} + +// VerifyUserFromRequest authenticates the HTTP request, on success returns Device of the requester. +func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { + util.GetLogger(req.Context()).Debug("MSC3861.VerifyUserFromRequest") + // Try to find the Application Service user + token, err := auth.ExtractAccessToken(req) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.MissingToken(err.Error()), + } + } + // TODO: try to get appservice user first. See https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/msc3861_delegated.py#L273 + userData, err := m.getUserByAccessToken(req.Context(), token) + if err != nil { + switch e := err.(type) { + case (*mscError): + switch e.Code { + case codeIntrospectionNot2xx, codeOpenidConfigDecodingFailed, codeOpenidConfigEndpointNon2xx: + return nil, &util.JSONResponse{ + Code: http.StatusServiceUnavailable, + JSON: spec.Unknown(e.Error()), + } + case codeInvalidClientToken: + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken(e.Error()), + } + case codeAuthError, codeMxidError: + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(e.Error()), + } + default: + r := util.ErrorResponse(err) + return nil, &r + } + default: + r := util.ErrorResponse(err) + return nil, &r + } + } + + // Do not record requests from MAS using the virtual `__oidc_admin` user. + if token != m.cfg.AdminToken { + // XXX: not sure which exact data we should record here. See the link for reference + // https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/base.py#L365 + } + + if !m.allowGuest && userData.IsGuest { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.Forbidden(strings.Join([]string{"Insufficient scope: ", scopeMatrixAPI}, "")), + } + } + + return userData.Device, nil +} + +type requester struct { + Device *api.Device + UserID *spec.UserID + Scope []string + IsGuest bool +} + +// nolint: gocyclo +func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token string) (*requester, error) { + var userID *spec.UserID + logger := util.GetLogger(ctx) + + if adminToken := m.cfg.AdminToken; adminToken != "" && token == adminToken { + // XXX: This is a temporary solution so that the admin API can be called by + // the OIDC provider. This will be removed once we have OIDC client + // credentials grant support in matrix-authentication-service. + // XXX: that user doesn't exist and won't be provisioned. + adminUser, err := createUserID("__oidc_admin", m.serverName) + if err != nil { + return nil, err + } + return &requester{ + UserID: adminUser, + Scope: []string{"urn:synapse:admin:*"}, + Device: &api.Device{UserID: adminUser.Local(), AccountType: api.AccountTypeOIDCService}, + }, nil + } + + introspectionResult, err := m.introspectToken(ctx, token) + if err != nil { + logger.WithError(err).Error("MSC3861UserVerifier:introspectToken") + return nil, err + } + + if !introspectionResult.Active { + return nil, &mscError{Code: codeInvalidClientToken, Msg: "Token is not active"} + } + + scopes := introspectionResult.Scopes() + hasUserScope, hasGuestScope := slices.Contains(scopes, scopeMatrixAPI), slices.Contains(scopes, scopeMatrixGuest) + if !hasUserScope && !hasGuestScope { + return nil, &mscError{Code: codeInvalidClientToken, Msg: "No scope in token granting user rights"} + } + + sub := introspectionResult.Sub + if sub == "" { + return nil, &mscError{Code: codeInvalidClientToken, Msg: "Invalid sub claim in the introspection result"} + } + + localpart := "" + { + var rs api.QueryLocalpartExternalIDResponse + if err = m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, &api.QueryLocalpartExternalIDRequest{ + ExternalID: sub, + AuthProvider: externalAuthProvider, + }, &rs); err != nil && err != sql.ErrNoRows { + return nil, err + } + if l := rs.LocalpartExternalID; l != nil { + localpart = l.Localpart + } + } + + if localpart == "" { + // If we could not find a user via the external_id, it either does not exist, + // or the external_id was never recorded + username := introspectionResult.Username + if username == "" { + return nil, &mscError{Code: codeAuthError, Msg: "Invalid username claim in the introspection result"} + } + userID, err = createUserID(username, m.serverName) + if err != nil { + logger.WithError(err).Error("getUserByAccessToken:createUserID") + return nil, err + } + + // First try to find a user from the username claim + var account *api.Account + { + var rs api.QueryAccountByLocalpartResponse + err = m.userAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{Localpart: userID.Local(), ServerName: userID.Domain()}, &rs) + if err != nil && err != sql.ErrNoRows { + logger.WithError(err).Error("QueryAccountByLocalpart") + return nil, err + } + account = rs.Account + } + + if account == nil { + // If the user does not exist, we should create it on the fly + var rs api.PerformAccountCreationResponse + if err = m.userAPI.PerformAccountCreation(ctx, &api.PerformAccountCreationRequest{ + AccountType: api.AccountTypeUser, + Localpart: userID.Local(), + ServerName: userID.Domain(), + }, &rs); err != nil { + logger.WithError(err).Error("PerformAccountCreation") + return nil, err + } + } + + if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, &api.PerformLocalpartExternalUserIDCreationRequest{ + Localpart: userID.Local(), + ExternalID: sub, + AuthProvider: externalAuthProvider, + }); err != nil { + logger.WithError(err).Error("PerformLocalpartExternalUserIDCreation") + return nil, err + } + + localpart = userID.Local() + } + + if userID == nil { + userID, err = createUserID(localpart, m.serverName) + if err != nil { + logger.WithError(err).Error("getUserByAccessToken:createUserID") + return nil, err + } + } + + deviceIDs := make([]string, 0, 1) + for i := range scopes { + if s := scopes[i]; strings.HasPrefix(s, scopeMatrixDevicePrefix) { + deviceIDs = append(deviceIDs, s[len(scopeMatrixDevicePrefix):]) + } + } + + if len(deviceIDs) != 1 { + logger.Errorf("Invalid device IDs in scope: %+v", deviceIDs) + return nil, &mscError{Code: codeAuthError, Msg: "Invalid device IDs in scope"} + } + + var device *api.Device + + deviceID := deviceIDs[0] + if len(deviceID) > 255 || len(deviceID) < 1 { + return nil, &mscError{ + Code: codeAuthError, + Msg: strings.Join([]string{"Invalid device ID in scope: ", deviceID}, ""), + } + } + + userDeviceExists := false + { + var rs api.QueryDevicesResponse + err := m.userAPI.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: userID.String()}, &rs) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + + for i := range rs.Devices { + if d := &rs.Devices[i]; d.ID == deviceID { + userDeviceExists = true + device = d + break + } + } + } + if !userDeviceExists { + var rs api.PerformDeviceCreationResponse + deviceDisplayName := "OIDC-native client" + if err := m.userAPI.PerformDeviceCreation(ctx, &api.PerformDeviceCreationRequest{ + Localpart: localpart, + ServerName: m.serverName, + AccessToken: "", + DeviceID: &deviceID, + DeviceDisplayName: &deviceDisplayName, + // TODO: Cannot add IPAddr and Useragent values here. Should we care about it here? + }, &rs); err != nil { + logger.WithError(err).Error("PerformDeviceCreation") + return nil, err + } + device = rs.Device + logger.Debugf("PerformDeviceCreationResponse is: %+v", rs) + } + + return &requester{ + Device: device, + UserID: userID, + Scope: scopes, + IsGuest: hasGuestScope && !hasUserScope, + }, nil +} + +func createUserID(local string, serverName spec.ServerName) (*spec.UserID, error) { + userID, err := spec.NewUserID(strings.Join([]string{"@", local, ":", string(serverName)}, ""), false) + if err != nil { + return nil, &mscError{Code: codeMxidError, Msg: err.Error()} + } + return userID, nil +} + +func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) (*introspectionResponse, error) { + formBody := url.Values{"token": []string{token}} + encoded := formBody.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.openIdConfig.IntrospectionEndpoint, strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret) + + resp, err := m.httpClient.DoHTTPRequest(ctx, req) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint: errcheck + + if c := resp.StatusCode; c/100 != 2 { + return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) + } + var ir introspectionResponse + if err := json.NewDecoder(resp.Body).Decode(&ir); err != nil { + return nil, err + } + return &ir, nil +} + +type OpenIDConfiguration struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKsURI string `json:"jwks_uri"` + RegistrationEndpoint string `json:"registration_endpoint"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + ResponseModesSupported []string `json:"response_modes_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + TokenEndpointAuthSigningAlgCaluesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported"` + RevocationEnpoint string `json:"revocation_endpoint"` + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported"` + RevocationEndpointAuthSigningAlgValues []string `json:"revocation_endpoint_auth_signing_alg_values_supported"` + IntrospectionEndpoint string `json:"introspection_endpoint"` + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"` + IntrospectionEndpointAuthSigningAlgValues []string `json:"introspection_endpoint_auth_signing_alg_values_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported"` + DisplayValuesSupported []string `json:"display_values_supported"` + ClaimTypesSupported []string `json:"claim_types_supported"` + ClaimsSupported []string `json:"claims_supported"` + ClaimsParameterSupported bool `json:"claims_parameter_supported"` + RequestParameterSupported bool `json:"request_parameter_supported"` + RequestURIParameterSupported bool `json:"request_uri_parameter_supported"` + PromptValuesSupported []string `json:"prompt_values_supported"` + DeviceAuthorizaEndpoint string `json:"device_authorization_endpoint"` + AccountManagementURI string `json:"account_management_uri"` + AccountManagementActionsSupported []string `json:"account_management_actions_supported"` +} + +func fetchOpenIDConfiguration(httpClient *fclient.Client, authHostURL string) (*OpenIDConfiguration, error) { + u, err := url.Parse(authHostURL) + if err != nil { + return nil, err + } + u = u.JoinPath(".well-known/openid-configuration") + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + resp, err := httpClient.DoHTTPRequest(context.Background(), req) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint: errcheck + if resp.StatusCode != http.StatusOK { + return nil, &mscError{Code: codeOpenidConfigEndpointNon2xx, Msg: ".well-known/openid-configuration endpoint returned non-200 response"} + } + var oic OpenIDConfiguration + if err := json.NewDecoder(resp.Body).Decode(&oic); err != nil { + return nil, &mscError{Code: codeOpenidConfigDecodingFailed, Msg: err.Error()} + } + return &oic, nil +} + +// introspectionResponse as described in the RFC https://datatracker.ietf.org/doc/html/rfc7662#section-2.2 +type introspectionResponse struct { + Active bool `json:"active"` // required + Scope string `json:"scope"` // optional + Username string `json:"username"` // optional + TokenType string `json:"token_type"` // optional + Exp *int64 `json:"exp"` // optional + Iat *int64 `json:"iat"` // optional + Nfb *int64 `json:"nfb"` // optional + Sub string `json:"sub"` // optional + Jti string `json:"jti"` // optional + Aud string `json:"aud"` // optional + Iss string `json:"iss"` // optional +} + +func (i *introspectionResponse) Scopes() []string { + return strings.Split(i.Scope, " ") +} diff --git a/setup/mscs/msc3861/msc3861_user_verifier_test.go b/setup/mscs/msc3861/msc3861_user_verifier_test.go new file mode 100644 index 000000000..fd1d22a92 --- /dev/null +++ b/setup/mscs/msc3861/msc3861_user_verifier_test.go @@ -0,0 +1,234 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package msc3861 + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "errors" + + "github.com/element-hq/dendrite/federationapi/statistics" + "github.com/element-hq/dendrite/internal/caching" + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/roomserver" + "github.com/element-hq/dendrite/setup/config" + "github.com/element-hq/dendrite/setup/jetstream" + "github.com/element-hq/dendrite/test" + "github.com/element-hq/dendrite/test/testrig" + "github.com/element-hq/dendrite/userapi" + uapi "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + +type roundTripper struct { + roundTrip func(request *http.Request) (*http.Response, error) +} + +func (rt *roundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + return rt.roundTrip(request) +} + +func TestVerifyUserFromRequest(t *testing.T) { + aliceUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bobUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + + roundTrip := func(request *http.Request) (*http.Response, error) { + var ( + respBody string + statusCode int + ) + + switch request.URL.String() { + case "https://mas.example.com/.well-known/openid-configuration": + respBody = `{"introspection_endpoint": "https://mas.example.com/oauth2/introspect"}` + statusCode = http.StatusOK + case "https://mas.example.com/oauth2/introspect": + _ = request.ParseForm() + + switch request.Form.Get("token") { + case "validTokenUserExistsTokenActive": + statusCode = http.StatusOK + resp := introspectionResponse{ + Active: true, + Scope: "urn:matrix:org.matrix.msc2967.client:device:devAlice urn:matrix:org.matrix.msc2967.client:api:*", + Sub: "111111111111111111", + Username: aliceUser.Localpart, + } + b, _ := json.Marshal(resp) + respBody = string(b) + case "validTokenUserDoesNotExistTokenActive": + statusCode = http.StatusOK + resp := introspectionResponse{ + Active: true, + Scope: "urn:matrix:org.matrix.msc2967.client:device:devBob urn:matrix:org.matrix.msc2967.client:api:*", + Sub: "222222222222222222", + Username: bobUser.Localpart, + } + b, _ := json.Marshal(resp) + respBody = string(b) + case "validTokenUserExistsTokenInactive": + statusCode = http.StatusOK + resp := introspectionResponse{Active: false} + b, _ := json.Marshal(resp) + respBody = string(b) + default: + return nil, errors.New("Request URL not supported by stub") + } + } + + respReader := io.NopCloser(strings.NewReader(respBody)) + resp := http.Response{ + StatusCode: statusCode, + Body: respReader, + ContentLength: int64(len(respBody)), + Header: map[string][]string{"Content-Type": {"application/json"}}, + } + return &resp, nil + } + + httpClient := fclient.NewClient( + fclient.WithTransport(&roundTripper{roundTrip: roundTrip}), + ) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{ + Issuer: "https://mas.example.com", + } + cfg.ClientAPI.RateLimiting.Enabled = false + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed for /login + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier, err := newMSC3861UserVerifier( + userAPI, + cfg.Global.ServerName, + cfg.MSCs.MSC3861, + false, + httpClient, + ) + if err != nil { + t.Fatal(err.Error()) + } + u, _ := url.Parse("https://example.com/something") + + t.Run("existing user and active token", func(t *testing.T) { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', aliceUser.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: aliceUser.AccountType, + Localpart: localpart, + ServerName: serverName, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + if !userRes.AccountCreated { + t.Fatalf("account not created") + } + httpReq := http.Request{ + URL: u, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer validTokenUserExistsTokenActive"}, + }, + } + device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq) + if jsonResp != nil { + t.Fatalf("JSONResponse is not expected: %+v", jsonResp) + } + deviceRes := uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{ + UserID: aliceUser.ID, + }, &deviceRes); err != nil { + t.Errorf("failed to query user devices") + } + if !deviceRes.UserExists { + t.Fatalf("user does not exist") + } + if l := len(deviceRes.Devices); l != 1 { + t.Fatalf("Incorrect number of user devices. Got %d, want 1", l) + } + if device.ID != deviceRes.Devices[0].ID { + t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID) + } + }) + + t.Run("inactive token", func(t *testing.T) { + httpReq := http.Request{ + URL: u, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer validTokenUserExistsTokenInactive"}, + }, + } + device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq) + if jsonResp == nil { + t.Fatal("JSONResponse is expected to be nil") + } + if device != nil { + t.Fatalf("Device is not nil: %+v", device) + } + if jsonResp.Code != http.StatusUnauthorized { + t.Fatalf("Incorrect status code: want=401, got=%d", jsonResp.Code) + } + mErr, _ := jsonResp.JSON.(spec.MatrixError) + if mErr.ErrCode != spec.ErrorUnknownToken { + t.Fatalf("Unexpected error code: want=%s, got=%s", spec.ErrorUnknownToken, mErr.ErrCode) + } + }) + + t.Run("non-existing user", func(t *testing.T) { + httpReq := http.Request{ + URL: u, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer validTokenUserDoesNotExistTokenActive"}, + }, + } + device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq) + if jsonResp != nil { + t.Fatalf("JSONResponse is not expected: %+v", jsonResp) + } + deviceRes := uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{ + UserID: bobUser.ID, + }, &deviceRes); err != nil { + t.Errorf("failed to query user devices") + } + if !deviceRes.UserExists { + t.Fatalf("user does not exist") + } + if l := len(deviceRes.Devices); l != 1 { + t.Fatalf("Incorrect number of user devices. Got %d, want 1", l) + } + if device.ID != deviceRes.Devices[0].ID { + t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID) + } + }) + }) +} diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index fc360b5d8..3881b8e0c 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -16,6 +16,7 @@ import ( "github.com/element-hq/dendrite/setup" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/setup/mscs/msc2836" + "github.com/element-hq/dendrite/setup/mscs/msc3861" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -34,9 +35,11 @@ func Enable(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Rout func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error { switch msc { case "msc2836": - return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) + return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserVerifierProvider, monolith.KeyRing) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi + case "msc3861": + return msc3861.Enable(monolith) default: logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index dcc78c859..484736988 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -36,16 +36,17 @@ func Setup( lazyLoadCache caching.LazyLoadCache, fts fulltext.Indexer, rateLimits *httputil.RateLimits, + userVerifier httputil.UserVerifier, ) { v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() // TODO: Add AS support for all handlers below. - v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -58,7 +59,7 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/event/{eventID}", - httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_get_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -68,7 +69,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/filter", - httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("put_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -78,7 +79,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/user/{userId}/filter/{filterId}", - httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -87,12 +88,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", - httputil.MakeAuthAPI("context", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("context", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -108,7 +109,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", - httputil.MakeAuthAPI("relations", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relations", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -122,7 +123,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", - httputil.MakeAuthAPI("relation_type", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -136,7 +137,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", - httputil.MakeAuthAPI("relation_type_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -150,7 +151,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/search", - httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("search", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if !cfg.Fulltext.Enabled { return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -173,7 +174,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_members", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 2b1dc9958..a45173dbe 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -42,6 +42,7 @@ func AddPublicRoutes( userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, caches caching.LazyLoadCache, + userVerifier httputil.UserVerifier, enableMetrics bool, ) { js, natsClient := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream) @@ -149,5 +150,6 @@ func AddPublicRoutes( routers.Client, requestPool, syncDB, userAPI, rsAPI, &dendriteCfg.SyncAPI, caches, fts, rateLimits, + userVerifier, ) } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 4e1fa7dfb..f6c0c898a 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -119,6 +120,20 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe return nil } +type mockUserVerifier struct { + accessTokenToDeviceAndResponse map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + } +} + +func (u *mockUserVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { + if pair, ok := u.accessTokenToDeviceAndResponse[req.URL.Query().Get("access_token")]; ok { + return pair.Device, pair.Response + } + return nil, nil +} + func TestSyncAPIAccessTokens(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { testSyncAccessTokens(t, dbType) @@ -146,12 +161,16 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + uv := &mockUserVerifier{} + + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, uv, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { name string req *http.Request + device *userapi.Device + response *util.JSONResponse wantCode int wantJoinedRooms []string }{ @@ -160,6 +179,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { req: test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "timeout": "0", })), + device: nil, + response: &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken("Unknown token"), + }, wantCode: 401, }, { @@ -168,6 +192,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { "access_token": "foo", "timeout": "0", })), + device: nil, + response: &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken("Unknown token"), + }, wantCode: 401, }, { @@ -176,11 +205,25 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { "access_token": alice.AccessToken, "timeout": "0", })), + device: &alice, + response: nil, wantCode: 200, wantJoinedRooms: []string{room.ID}, }, } + uv.accessTokenToDeviceAndResponse = make(map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }, len(testCases)) + for _, tc := range testCases { + + uv.accessTokenToDeviceAndResponse[tc.req.URL.Query().Get("access_token")] = struct { + Device *userapi.Device + Response *util.JSONResponse + }{Device: tc.device, Response: tc.response} + } + syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool { // wait for the last sent eventID to come down sync path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) @@ -241,12 +284,20 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } defer close() jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -399,7 +450,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, cfg, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, nil, caching.DisableMetrics) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -487,7 +538,15 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics) + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics) w := httptest.NewRecorder() routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -609,7 +668,16 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // Use the actual internal roomserver API rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics) + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + aliceDev.AccessToken: {Device: &aliceDev, Response: nil}, + bobDev.AccessToken: {Device: &bobDev, Response: nil}, + }, + } + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -878,8 +946,17 @@ func TestGetMembership(t *testing.T) { // Use an actual roomserver for this rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + aliceDev.AccessToken: {Device: &aliceDev, Response: nil}, + bobDev.AccessToken: {Device: &bobDev, Response: nil}, + }, + } - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -946,10 +1023,18 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() natsInstance := jetstream.NATSInstance{} + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -1172,7 +1257,16 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, caching.DisableMetrics) + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } + + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, &uv, caching.DisableMetrics) room := test.NewRoom(t, user) @@ -1351,9 +1445,17 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) + uv := mockUserVerifier{ + accessTokenToDeviceAndResponse: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } room := test.NewRoom(t, user) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics) if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) @@ -1416,6 +1518,7 @@ func searchRequest(t *testing.T, router *mux.Router, accessToken, searchTerm str assert.NoError(t, err) return body } + func syncUntil(t *testing.T, routers httputil.Routers, accessToken string, skip bool, diff --git a/userapi/api/api.go b/userapi/api/api.go index 264821296..9b1319986 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -31,7 +31,8 @@ type UserInternalAPI interface { FederationUserAPI QuerySearchProfilesAPI // used by p2p demos - QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) + QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *QueryLocalpartExternalIDRequest, res *QueryLocalpartExternalIDResponse) (err error) + PerformLocalpartExternalUserIDCreation(ctx context.Context, req *PerformLocalpartExternalUserIDCreationRequest) (err error) } // api functions required by the appservice api @@ -47,7 +48,7 @@ type RoomserverUserAPI interface { // api functions required by the media api type MediaUserAPI interface { - QueryAcccessTokenAPI + QueryAccessTokenAPI } // api functions required by the federation api @@ -64,7 +65,7 @@ type FederationUserAPI interface { // api functions required by the sync api type SyncUserAPI interface { - QueryAcccessTokenAPI + QueryAccessTokenAPI SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error @@ -75,7 +76,7 @@ type SyncUserAPI interface { // api functions required by the client api type ClientUserAPI interface { - QueryAcccessTokenAPI + QueryAccessTokenAPI LoginTokenInternalAPI UserLoginAPI ClientKeyAPI @@ -87,6 +88,7 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) @@ -109,6 +111,7 @@ type ClientUserAPI interface { QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error + PerformBulkSaveThreePIDAssociation(ctx context.Context, req *PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error } type KeyBackupAPI interface { @@ -130,7 +133,7 @@ type QuerySearchProfilesAPI interface { } // common function for creating authenticated endpoints (used in client/media/sync api) -type QueryAcccessTokenAPI interface { +type QueryAccessTokenAPI interface { QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error } @@ -316,6 +319,9 @@ type PerformAccountCreationRequest struct { Localpart string // Required: The localpart for this account. Ignored if account type is guest. ServerName spec.ServerName // optional: if not specified, default server name used instead + DisplayName string // optional: this is populated only by MAS. In the legacy flow it's not used + AvatarURL string // optional: this is populated only by MAS. In the legacy flow it's not used + AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account OnConflict Conflict @@ -455,6 +461,7 @@ type Account struct { ServerName spec.ServerName AppServiceID string AccountType AccountType + Deactivated bool // TODO: Associations (e.g. with application services) } @@ -471,6 +478,14 @@ type OpenIDTokenAttributes struct { ExpiresAtMS int64 } +// LocalpartExternalID represents a connection between Matrix account and OpenID Connect provider +type LocalpartExternalID struct { + Localpart string + ExternalID string + AuthProvider string + CreatedTS int64 +} + // UserInfo is for returning information about the user an OpenID token was issued for type UserInfo struct { Sub string // The Matrix user's ID who generated the token @@ -514,6 +529,8 @@ const ( AccountTypeAdmin AccountType = 3 // AccountTypeAppService indicates this is an appservice account AccountTypeAppService AccountType = 4 + // AccountTypeOIDC indicates this is an account belonging to Matrix Authentication Service (MAS) + AccountTypeOIDCService AccountType = 5 ) type QueryPushersRequest struct { @@ -636,6 +653,12 @@ type PerformSaveThreePIDAssociationRequest struct { Medium string } +type PerformBulkSaveThreePIDAssociationRequest struct { + ThreePIDs []authtypes.ThreePID + Localpart string + ServerName spec.ServerName +} + type QueryAccountByLocalpartRequest struct { Localpart string ServerName spec.ServerName @@ -645,10 +668,26 @@ type QueryAccountByLocalpartResponse struct { Account *Account } +type QueryLocalpartExternalIDRequest struct { + ExternalID string + AuthProvider string +} + +type QueryLocalpartExternalIDResponse struct { + LocalpartExternalID *LocalpartExternalID +} + +type PerformLocalpartExternalUserIDCreationRequest struct { + Localpart string + ExternalID string + AuthProvider string +} + // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QueryMasterKeys(ctx context.Context, req *QueryMasterKeysRequest, res *QueryMasterKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) @@ -910,6 +949,16 @@ type QueryKeysResponse struct { Error *KeyError } +type QueryMasterKeysRequest struct { + UserID string +} + +type QueryMasterKeysResponse struct { + Key spec.Base64Bytes + // Set if there was a fatal error processing this query + Error *KeyError +} + type QueryKeyChangesRequest struct { // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index fe5d9f7d9..7b245a691 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -455,10 +455,11 @@ func (a *UserInternalAPI) processOtherSignatures( func (a *UserInternalAPI) crossSigningKeysFromDatabase( ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ) { + logger := logrus.WithContext(ctx) for targetUserID := range req.UserToDevices { keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil { - logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) + logger.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) continue } @@ -471,7 +472,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) if err != nil && err != sql.ErrNoRows { - logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) + logger.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) continue } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 6cb11bcd2..eb7597ab9 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -234,6 +234,19 @@ func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *a return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) } +func (a *UserInternalAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { + crossSigningKeyMap, err := a.KeyDatabase.CrossSigningKeysDataForUserAndKeyType(ctx, req.UserID, fclient.CrossSigningKeyPurposeMaster) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query user cross signing master keys: %s", err), + } + return + } + if key, ok := crossSigningKeyMap[fclient.CrossSigningKeyPurposeMaster]; ok { + res.Key = key + } +} + // nolint:gocyclo func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { var respMu sync.Mutex @@ -272,7 +285,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque DeviceIDs: dids, }, &queryRes) if err != nil { - util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + util.GetLogger(ctx).WithError(err).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") } if res.DeviceKeys[userID] == nil { diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 666e75f93..2b500c95d 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -7,6 +7,7 @@ package internal import ( + "cmp" "context" "database/sql" "encoding/json" @@ -247,10 +248,17 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + displayName := cmp.Or(req.DisplayName, req.Localpart) + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, displayName); err != nil { return fmt.Errorf("a.DB.SetDisplayName: %w", err) } + if req.AvatarURL != "" { + if _, _, err := a.DB.SetAvatarURL(ctx, req.Localpart, serverName, req.AvatarURL); err != nil { + return fmt.Errorf("a.DB.SetAvatarURL: %w", err) + } + } + postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true @@ -298,6 +306,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") + // TODO: Since we have deleted access_token's unique constraint from the db, + // we probably should check its uniqueness if msc3861 is disabled dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err @@ -594,6 +604,15 @@ func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api. return } +func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, req *api.PerformLocalpartExternalUserIDCreationRequest) (err error) { + return a.DB.CreateLocalpartExternalID(ctx, req.Localpart, req.ExternalID, req.AuthProvider) +} + +func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *api.QueryLocalpartExternalIDRequest, res *api.QueryLocalpartExternalIDResponse) (err error) { + res.LocalpartExternalID, err = a.DB.GetLocalpartForExternalID(ctx, req.ExternalID, req.AuthProvider) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { @@ -970,4 +989,8 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } +func (a *UserInternalAPI) PerformBulkSaveThreePIDAssociation(ctx context.Context, req *api.PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error { + return a.DB.BulkSaveThreePIDAssociation(ctx, req.ThreePIDs, req.Localpart, req.ServerName) +} + const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 2a46a7fd7..3cf7e7659 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -119,6 +119,7 @@ type Pusher interface { type ThreePID interface { SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error) + BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) @@ -134,6 +135,12 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } +type LocalpartExternalID interface { + CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error + GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) + DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error +} + type UserDatabase interface { Account AccountData @@ -147,6 +154,7 @@ type UserDatabase interface { Statistics ThreePID RegistrationTokens + LocalpartExternalID } type KeyChangeDatabase interface { @@ -219,6 +227,7 @@ type KeyDatabase interface { CrossSigningKeysForUser(ctx context.Context, userID string) (map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey, error) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 489017fb9..19c16230b 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -55,7 +55,7 @@ const deactivateAccountSQL = "" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" + "SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" @@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount( localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { - createdTimeMS := time.Now().UnixNano() / 1000000 + createdTimeMS := spec.AsTimestamp(time.Now()) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error @@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount( ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, + Deactivated: false, }, nil } @@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index a8566e69b..f05f7845a 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -32,15 +32,20 @@ const selectCrossSigningKeysForUserSQL = "" + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" +const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1 AND key_type = $2" + const upsertCrossSigningKeysForUserSQL = "" + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -51,8 +56,14 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro if err != nil { return nil, err } + m := sqlutil.NewMigrator(db) + err = m.Up(context.Background()) + if err != nil { + return nil, err + } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, }.Prepare(db) } @@ -82,6 +93,35 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( return } +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, +) (r types.CrossSigningKeyMap, err error) { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return nil, fmt.Errorf("unknown key purpose %q", keyType) + } + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserAndKeyTypeStmt).QueryContext(ctx, userID, keyTypeInt) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectCrossSigningKeysForUserAndKeyType: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData spec.Base64Bytes + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + err = rows.Err() + return +} + func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { diff --git a/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go b/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go new file mode 100644 index 000000000..e88423361 --- /dev/null +++ b/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go @@ -0,0 +1,30 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` +ALTER TABLE userapi_devices DROP CONSTRAINT userapi_devices_pkey;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE userapi_devices ADD CONSTRAINT userapi_devices_pkey PRIMARY KEY (access_token);`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index b5feea07f..e76452447 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -116,10 +116,16 @@ func NewPostgresDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Dev return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: add last_seen_ts", - Up: deltas.UpLastSeenTSIP, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "userapi: add last_seen_ts", + Up: deltas.UpLastSeenTSIP, + }, + sqlutil.Migration{ + Version: "userapi: drop primary key constraint", + Up: deltas.UpDropPrimaryKeyConstraint, + }, + ) err = m.Up(context.Background()) if err != nil { return nil, err diff --git a/userapi/storage/postgres/localpart_external_ids_table.go b/userapi/storage/postgres/localpart_external_ids_table.go new file mode 100644 index 000000000..bc2adac20 --- /dev/null +++ b/userapi/storage/postgres/localpart_external_ids_table.go @@ -0,0 +1,102 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +const localpartExternalIDsSchema = ` +-- Stores data about connections between accounts and third-party auth providers +CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids ( + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- The external ID + external_id TEXT NOT NULL, + -- Auth provider ID (see OIDCProvider.IDPID) + auth_provider TEXT NOT NULL, + -- When this connection was created, as a unix timestamp. + created_ts BIGINT NOT NULL, + + CONSTRAINT userapi_localpart_external_ids_external_id_auth_provider_unique UNIQUE(external_id, auth_provider), + CONSTRAINT userapi_localpart_external_ids_localpart_external_id_auth_provider_unique UNIQUE(localpart, external_id, auth_provider) +); + +-- This index allows efficient lookup of the local user by the external ID +CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider); +` + +const insertUserExternalIDSQL = "" + + "INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)" + +const selectUserExternalIDSQL = "" + + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +const deleteUserExternalIDSQL = "" + + "DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +type localpartExternalIDStatements struct { + db *sql.DB + insertUserExternalIDStmt *sql.Stmt + selectUserExternalIDStmt *sql.Stmt + deleteUserExternalIDStmt *sql.Stmt +} + +func NewPostgresLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) { + s := &localpartExternalIDStatements{ + db: db, + } + _, err := db.Exec(localpartExternalIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertUserExternalIDStmt, insertUserExternalIDSQL}, + {&s.selectUserExternalIDStmt, selectUserExternalIDSQL}, + {&s.deleteUserExternalIDStmt, deleteUserExternalIDSQL}, + }.Prepare(db) +} + +// Select selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { + ret := api.LocalpartExternalID{ + ExternalID: externalID, + AuthProvider: authProvider, + } + err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( + &ret.Localpart, &ret.CreatedTS, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + log.WithError(err).Error("Unable to retrieve localpart from the db") + return nil, err + } + + return &ret, nil +} + +// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) + return err +} + +// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, externalID, authProvider) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index c7fb9d29b..eff12a64a 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -97,6 +97,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + localpartExternalIDsTable, err := NewPostgresLocalpartExternalIDsTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteLocalpartExternalIDsTable: %w", err) + } m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ @@ -123,6 +127,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties Notifications: notificationsTable, RegistrationTokens: registationTokensTable, Stats: statsTable, + LocalpartExternalIDs: localpartExternalIDsTable, ServerName: serverName, DB: db, Writer: writer, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 44ace733e..834c76488 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -49,6 +49,7 @@ type Database struct { Notifications tables.NotificationTable Pushers tables.PusherTable Stats tables.StatsTable + LocalpartExternalIDs tables.LocalpartExternalIDsTable LoginTokenLifetime time.Duration ServerName spec.ServerName BcryptCost int @@ -352,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation( }) } +// BulkSaveThreePIDAssociation recreates 3PIDs for a user. +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) + if err != nil { + return err + } + for _, t := range oldThreePIDs { + if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil { + return err + } + } + for _, t := range threePIDs { + // if 3PID is associated with another user, return Err3PIDInUse + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( + ctx, txn, t.Address, t.Medium, + ) + if err != nil { + return err + } + + if len(user) > 0 && user != localpart { + return Err3PIDInUse + } + + if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil { + return err + } + } + return nil + }) +} + // RemoveThreePIDAssociation removes the association involving a given third-party // identifier. // If no association exists involving this third-party identifier, returns nothing. @@ -870,6 +906,18 @@ func (d *Database) UpsertPusher( }) } +func (d *Database) CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error { + return d.LocalpartExternalIDs.Insert(ctx, nil, localpart, externalID, authProvider) +} + +func (d *Database) GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) { + return d.LocalpartExternalIDs.Select(ctx, nil, externalID, authProvider) +} + +func (d *Database) DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error { + return d.LocalpartExternalIDs.Delete(ctx, nil, externalID, authProvider) +} + // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( ctx context.Context, localpart string, serverName spec.ServerName, @@ -1124,6 +1172,11 @@ func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID st return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } +// CrossSigningKeysForUserAndKeyType returns the latest known cross-signing keys for a user and key type, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUserAndKeyType(ctx, nil, userID, keyType) +} + // CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) @@ -1132,8 +1185,8 @@ func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserI // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + for keyType, key := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, key); err != nil { return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) } } @@ -1141,7 +1194,7 @@ func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID s }) } -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/device. func (d *KeyDatabase) StoreCrossSigningSigsForTarget( ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 66cc7c060..1090ec3ed 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -54,7 +54,7 @@ const deactivateAccountSQL = "" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" + "SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" @@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { - createdTimeMS := time.Now().UnixNano() / 1000000 + createdTimeMS := spec.AsTimestamp(time.Now()) stmt := s.insertAccountStmt var err error @@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount( ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, + Deactivated: false, }, nil } @@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index dd8923d30..c57ffd398 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -32,14 +32,19 @@ const selectCrossSigningKeysForUserSQL = "" + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" +const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1 AND key_type = $2" + const upsertCrossSigningKeysForUserSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -50,8 +55,14 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) if err != nil { return nil, err } + m := sqlutil.NewMigrator(db) + err = m.Up(context.Background()) + if err != nil { + return nil, err + } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, }.Prepare(db) } @@ -81,9 +92,37 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( return } +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, +) (r types.CrossSigningKeyMap, err error) { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return nil, fmt.Errorf("unknown key purpose %q", keyType) + } + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserAndKeyTypeStmt).QueryContext(ctx, userID, keyTypeInt) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectCrossSigningKeysForUserAndKeyType: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData spec.Base64Bytes + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + err = rows.Err() + return +} + func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, -) error { + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { return fmt.Errorf("unknown key purpose %q", keyType) diff --git a/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go new file mode 100644 index 000000000..d66053530 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go @@ -0,0 +1,70 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; + CREATE TABLE userapi_devices ( + access_token TEXT, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + server_name TEXT NOT NULL, + created_ts BIGINT, + display_name TEXT, + last_seen_ts BIGINT, + ip TEXT, + user_agent TEXT, + UNIQUE (localpart, device_id) + ); + INSERT + INTO userapi_devices ( + access_token, session_id, device_id, localpart, server_name, created_ts, display_name, last_seen_ts, ip, user_agent + ) SELECT + access_token, session_id, device_id, localpart, server_name, created_ts, display_name, created_ts, '', '' + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` +ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; +CREATE TABLE userapi_devices ( + access_token TEXT PRIMARY KEY, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + server_name TEXT NOT NULL, + created_ts BIGINT, + display_name TEXT, + last_seen_ts BIGINT, + ip TEXT, + user_agent TEXT, + UNIQUE (localpart, device_id) + ); + INSERT + INTO userapi_devices ( + access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent + ) SELECT + access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index d5d1fed3d..2eb88109a 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -102,10 +102,16 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Devic return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: add last_seen_ts", - Up: deltas.UpLastSeenTSIP, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "userapi: add last_seen_ts", + Up: deltas.UpLastSeenTSIP, + }, + sqlutil.Migration{ + Version: "userapi: drop primary key constraint", + Up: deltas.UpDropPrimaryKeyConstraint, + }, + ) if err = m.Up(context.Background()); err != nil { return nil, err } diff --git a/userapi/storage/sqlite3/localpart_external_ids_table.go b/userapi/storage/sqlite3/localpart_external_ids_table.go new file mode 100644 index 000000000..acbd5a7e9 --- /dev/null +++ b/userapi/storage/sqlite3/localpart_external_ids_table.go @@ -0,0 +1,102 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +const localpartExternalIDsSchema = ` +-- Stores data about connections between accounts and third-party auth providers +CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids ( + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- The external ID + external_id TEXT NOT NULL, + -- Auth provider ID (see OIDCProvider.IDPID) + auth_provider TEXT NOT NULL, + -- When this connection was created, as a unix timestamp. + created_ts BIGINT NOT NULL, + + UNIQUE(external_id, auth_provider), + UNIQUE(localpart, external_id, auth_provider) +); + +-- This index allows efficient lookup of the local user by the external ID +CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider); +` + +const insertLocalpartExternalIDSQL = "" + + "INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)" + +const selectLocalpartExternalIDSQL = "" + + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +const deleteLocalpartExternalIDSQL = "" + + "DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +type localpartExternalIDStatements struct { + db *sql.DB + insertUserExternalIDStmt *sql.Stmt + selectUserExternalIDStmt *sql.Stmt + deleteUserExternalIDStmt *sql.Stmt +} + +func NewSQLiteLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) { + s := &localpartExternalIDStatements{ + db: db, + } + _, err := db.Exec(localpartExternalIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertUserExternalIDStmt, insertLocalpartExternalIDSQL}, + {&s.selectUserExternalIDStmt, selectLocalpartExternalIDSQL}, + {&s.deleteUserExternalIDStmt, deleteLocalpartExternalIDSQL}, + }.Prepare(db) +} + +// Select selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { + ret := api.LocalpartExternalID{ + ExternalID: externalID, + AuthProvider: authProvider, + } + err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( + &ret.Localpart, &ret.CreatedTS, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + log.WithError(err).Error("Unable to retrieve localpart from the db") + return nil, err + } + + return &ret, nil +} + +// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) + return err +} + +// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, externalID, authProvider) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 6d906191f..80ecaf83c 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -94,6 +94,10 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + localpartExternalIDsTable, err := NewSQLiteLocalpartExternalIDsTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteUserExternalIDsTable: %w", err) + } m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ @@ -119,6 +123,7 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert Pushers: pusherTable, Notifications: notificationsTable, Stats: statsTable, + LocalpartExternalIDs: localpartExternalIDsTable, ServerName: serverName, DB: db, Writer: writer, diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 3a2afdf06..309aecd66 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -14,7 +14,7 @@ import ( "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/userapi/storage/sqlite3" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func NewUserDatabase( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 44f31a5c5..434702761 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -127,6 +127,12 @@ type StatsTable interface { UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } +type LocalpartExternalIDsTable interface { + Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) + Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error + Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error +} + type NotificationFilter uint32 const ( @@ -192,6 +198,7 @@ type StaleDeviceLists interface { type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + SelectCrossSigningKeysForUserAndKeyType(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose) (r types.CrossSigningKeyMap, err error) UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 6e33ced01..3ff6adfb3 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -445,8 +445,6 @@ func TestAccountData(t *testing.T) { func TestDevices(t *testing.T) { ctx := context.Background() - dupeAccessToken := util.RandomString(8) - displayName := "testing" creationTests := []struct { @@ -468,15 +466,6 @@ func TestDevices(t *testing.T) { name: "explicit local user", inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, }, - { - name: "dupe token - ok", - inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, - }, - { - name: "dupe token - not ok", - inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, - wantErr: true, - }, { name: "test3 second device", // used to test deletion later inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},