Skip to content

Commit

Permalink
Fix federationclient whitelist checks, improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
enaix committed Jan 17, 2025
1 parent 94deed7 commit b5ea31f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
7 changes: 4 additions & 3 deletions federationapi/internal/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ func NewFederationInternalAPI(
}
}

// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server
// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled (we can connect to any server)
func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool {
stats := a.statistics.ForServer(s)
return stats.Whitelisted() || !a.cfg.EnableWhitelist
// Thread-safe, since DB access is performed in mutex and stats.Whitelisted is constant
stats := a.statistics.ForServer(s) // Calls mutex if the stats do not exist yet
return !a.cfg.EnableWhitelist || stats.Whitelisted() // Lazy eval
}

func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {
Expand Down
50 changes: 28 additions & 22 deletions federationapi/internal/federationclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const defaultTimeout = time.Second * 30
func (a *FederationInternalAPI) MakeJoin(
ctx context.Context, origin, s spec.ServerName, roomID, userID string,
) (res gomatrixserverlib.MakeJoinResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespMakeJoin{}, nil
} // Is thread-safe, so we can omit ctx call
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID)
Expand All @@ -29,6 +32,9 @@ func (a *FederationInternalAPI) MakeJoin(
func (a *FederationInternalAPI) SendJoin(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (res gomatrixserverlib.SendJoinResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespSendJoin{}, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel()
ires, err := a.federation.SendJoin(ctx, origin, s, event)
Expand All @@ -42,11 +48,11 @@ func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, origin, s spec.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res fclient.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespEventAuth{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
})
Expand All @@ -59,11 +65,11 @@ func (a *FederationInternalAPI) GetEventAuth(
func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, origin, s spec.ServerName, userID string,
) (fclient.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespUserDevices{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, origin, s, userID)
})
Expand All @@ -76,11 +82,11 @@ func (a *FederationInternalAPI) GetUserDevices(
func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string,
) (fclient.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespClaimKeys{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
})
Expand Down Expand Up @@ -108,11 +114,11 @@ func (a *FederationInternalAPI) QueryKeys(
func (a *FederationInternalAPI) Backfill(
ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
})
Expand All @@ -125,11 +131,11 @@ func (a *FederationInternalAPI) Backfill(
func (a *FederationInternalAPI) LookupState(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.StateResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespState{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
})
Expand All @@ -143,11 +149,11 @@ func (a *FederationInternalAPI) LookupState(
func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string,
) (res gomatrixserverlib.StateIDResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespStateIDs{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
})
Expand All @@ -161,11 +167,11 @@ func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, origin, s spec.ServerName, roomID string,
missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespMissingEvents{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
})
Expand All @@ -178,11 +184,11 @@ func (a *FederationInternalAPI) LookupMissingEvents(
func (a *FederationInternalAPI) GetEvent(
ctx context.Context, origin, s spec.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil
}
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, origin, s, eventID)
})
Expand All @@ -195,11 +201,11 @@ func (a *FederationInternalAPI) GetEvent(
func (a *FederationInternalAPI) LookupServerKeys(
ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return []gomatrixserverlib.ServerKeys{}, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests)
})
Expand All @@ -213,11 +219,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return res, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
})
Expand All @@ -230,11 +236,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
func (a *FederationInternalAPI) RoomHierarchies(
ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool,
) (res fclient.RoomHierarchyResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
if !a.IsWhitelistedOrAny(s) {
return res, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
})
Expand Down

0 comments on commit b5ea31f

Please sign in to comment.