Skip to content

Commit 443fb68

Browse files
committed
refactor(test): provide means to validate metrics and observations
Some helpers are provided for introspecting metrics already (used in JWT cache tests). This change provides facilities to additionally validate emited Observation events. A new Spec module is also implemented, adding basic tests of schema cache reloading - their main goal is to excercise the new infrastructure.
1 parent 78f231c commit 443fb68

File tree

5 files changed

+123
-1
lines changed

5 files changed

+123
-1
lines changed

postgrest.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ test-suite spec
219219
Feature.ConcurrentSpec
220220
Feature.CorsSpec
221221
Feature.ExtraSearchPathSpec
222+
Feature.MetricsSpec
222223
Feature.NoSuperuserSpec
223224
Feature.ObservabilitySpec
224225
Feature.OpenApi.DisabledOpenApiSpec

src/PostgREST/AppState.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module PostgREST.AppState
1818
, init
1919
, initSockets
2020
, initWithPool
21+
, putConfig -- For tests TODO refactoring
2122
, putNextListenerDelay
2223
, putSchemaCache
2324
, putPgVersion

test/spec/Feature/MetricsSpec.hs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE FlexibleContexts #-}
3+
{-# LANGUAGE MonadComprehensions #-}
4+
{-# LANGUAGE TypeApplications #-}
5+
6+
module Feature.MetricsSpec where
7+
8+
import Network.Wai (Application)
9+
import qualified PostgREST.AppState as AppState
10+
import PostgREST.Config (AppConfig (configDbSchemas))
11+
import qualified PostgREST.Metrics as Metrics
12+
import PostgREST.Observation
13+
import Prometheus (getCounter, getVectorWith)
14+
import Protolude
15+
import SpecHelper
16+
import Test.Hspec (SpecWith, describe, it)
17+
18+
spec :: SpecWith (((AppState.AppState, Metrics.MetricsState), Chan Observation), Application)
19+
spec = describe "Server started with metrics enabled" $ do
20+
it "Should update pgrst_schema_cache_loads_total[SUCCESS]" $ do
21+
((appState, metrics), waitFor) <- prepareState
22+
23+
liftIO $ checkState' metrics [
24+
schemaCacheLoads "SUCCESS" (+1)
25+
] $ do
26+
AppState.schemaCacheLoader appState
27+
waitFor (1 * sec) "SchemaCacheLoadedObs" $ \x -> [ o | o@(SchemaCacheLoadedObs{}) <- pure x]
28+
29+
it "Should update pgrst_schema_cache_loads_total[ERROR]" $ do
30+
((appState, metrics), waitFor) <- prepareState
31+
32+
liftIO $ checkState' metrics [
33+
schemaCacheLoads "FAIL" (+1),
34+
schemaCacheLoads "SUCCESS" (+1)
35+
] $ do
36+
AppState.getConfig appState >>= \prev -> do
37+
AppState.putConfig appState $ prev { configDbSchemas = pure "bad_schema" }
38+
AppState.schemaCacheLoader appState
39+
waitFor (1 * sec) "SchemaCacheErrorObs" $ \x -> [ o | o@(SchemaCacheErrorObs{}) <- pure x]
40+
AppState.putConfig appState prev
41+
42+
-- wait up to 2 secs so that retry can happen
43+
waitFor (2 * sec) "SchemaCacheLoadedObs" $ \x -> [ o | o@(SchemaCacheLoadedObs{}) <- pure x]
44+
45+
where
46+
-- prometheus-client api to handle vectors is convoluted
47+
schemaCacheLoads label = expectField @"schemaCacheLoads" $
48+
fmap (sumSamples . findSamples label) . (`getVectorWith` getCounter)
49+
sumSamples = getSum . foldMap (Sum . round @Double @Int . snd)
50+
findSamples label = find ((== label) . fst)
51+
sec = 1000000

test/spec/Main.hs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import qualified Feature.Auth.NoJwtSecretSpec
2929
import qualified Feature.ConcurrentSpec
3030
import qualified Feature.CorsSpec
3131
import qualified Feature.ExtraSearchPathSpec
32+
import qualified Feature.MetricsSpec
3233
import qualified Feature.NoSuperuserSpec
3334
import qualified Feature.ObservabilitySpec
3435
import qualified Feature.OpenApi.DisabledOpenApiSpec
@@ -68,16 +69,26 @@ import qualified Feature.Query.UpdateSpec
6869
import qualified Feature.Query.UpsertSpec
6970
import qualified Feature.RollbackSpec
7071
import qualified Feature.RpcPreRequestGucsSpec
72+
import PostgREST.Observation (Observation (HasqlPoolObs))
7173

7274

7375
main :: IO ()
7476
main = do
77+
poolChan <- newChan
78+
-- make sure poolChan is not growing indefinitely
79+
-- start a thread that drains the channel
80+
-- this is necessary because test cases operate on
81+
-- copies so poolChan is never read from
82+
void $ forkIO $ forever $ readChan poolChan
83+
metricsState <- Metrics.init (configDbPoolSize testCfg)
7584
pool <- P.acquire $ P.settings
7685
[ P.size 3
7786
, P.acquisitionTimeout 10
7887
, P.agingTimeout 60
7988
, P.idlenessTimeout 60
8089
, P.staticConnectionSettings (toUtf8 $ configDbUri testCfg)
90+
-- make sure metrics are updated and pool observations published to poolChan
91+
, P.observationHandler $ (writeChan poolChan <> Metrics.observationMetrics metricsState) . HasqlPoolObs
8192
]
8293

8394
actualPgVersion <- either (panic . show) id <$> P.use pool (queryPgVersion False)
@@ -86,7 +97,6 @@ main = do
8697
baseSchemaCache <- loadSCache pool testCfg
8798
sockets <- AppState.initSockets testCfg
8899
loggerState <- Logger.init
89-
metricsState <- Metrics.init (configDbPoolSize testCfg)
90100

91101
let
92102
initApp sCache st config = do
@@ -95,6 +105,14 @@ main = do
95105
AppState.putSchemaCache appState (Just sCache)
96106
return (st, postgrest (configLogLevel config) appState (pure ()))
97107

108+
initObservationsApp sCache config = do
109+
-- duplicate poolChan as a starting point
110+
obsChan <- dupChan poolChan
111+
appState <- AppState.initWithPool sockets pool config loggerState metricsState (Metrics.observationMetrics metricsState <> writeChan obsChan)
112+
AppState.putPgVersion appState actualPgVersion
113+
AppState.putSchemaCache appState (Just sCache)
114+
return (((appState, metricsState), obsChan), postgrest (configLogLevel config) appState (pure ()))
115+
98116
-- For tests that run with the same schema cache
99117
app = initApp baseSchemaCache ()
100118

@@ -123,6 +141,7 @@ main = do
123141
obsApp = app testObservabilityCfg
124142
serverTiming = app testCfgServerTiming
125143
aggregatesEnabled = app testCfgAggregatesEnabled
144+
observationsApp = initObservationsApp baseSchemaCache testCfg
126145

127146
extraSearchPathApp = appDbs testCfgExtraSearchPath
128147
unicodeApp = appDbs testUnicodeCfg
@@ -278,6 +297,9 @@ main = do
278297
before (initApp baseSchemaCache metricsState testCfgJwtCache) $
279298
describe "Feature.Auth.JwtCacheSpec" Feature.Auth.JwtCacheSpec.spec
280299

300+
before observationsApp $
301+
describe "Feature.MetricsSpec" Feature.MetricsSpec.spec
302+
281303
where
282304
loadSCache pool conf =
283305
either (panic.show) id <$> P.use pool (HT.transaction HT.ReadCommitted HT.Read $ querySchemaCache conf)

test/spec/SpecHelper.hs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{-# LANGUAGE AllowAmbiguousTypes #-}
2+
{-# LANGUAGE DeriveAnyClass #-}
23
{-# LANGUAGE ExistentialQuantification #-}
34
{-# LANGUAGE FlexibleContexts #-}
45
{-# LANGUAGE RankNTypes #-}
@@ -41,10 +42,13 @@ import PostgREST.Config (AppConfig (..),
4142
LogLevel (..),
4243
OpenAPIMode (..),
4344
Verbosity (..), parseSecret)
45+
import PostgREST.Observation (Observation,
46+
observationMessage)
4447
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..))
4548
import Prometheus (Counter, getCounter)
4649
import Protolude hiding (get, toS)
4750
import Protolude.Conv (toS)
51+
import System.Timeout (timeout)
4852
import Test.Hspec.Expectations.Contrib (annotate)
4953

5054
filterAndMatchCT :: BS.ByteString -> MatchHeader
@@ -381,3 +385,46 @@ expectCounter :: forall s st m. (KnownSymbol s, HasField s st Counter, MonadIO m
381385
expectCounter = expectField @s intCounter
382386
where
383387
intCounter = ((round @Double @Int) <$>) . getCounter
388+
389+
data TimeoutException = TimeoutException deriving (Show, Exception)
390+
391+
accumulateUntilTimeout :: Int -> (s -> a -> s) -> s -> IO a -> IO s
392+
accumulateUntilTimeout t f start act = do
393+
tid <- myThreadId
394+
-- mask to make sure TimeoutException is not thrown before starting the loop
395+
mask $ \unmask -> do
396+
-- start timeout thread unmasking exceptions
397+
ttid <- forkIOWithUnmask ($ (threadDelay t *> throwTo tid TimeoutException))
398+
-- unmask effect
399+
unmask (fix (\loop accum -> (act >>= loop . f accum) `onTimeout` pure accum) start)
400+
-- make sure we catch timeout if happens before entering the loop
401+
`onTimeout` pure start
402+
-- make sure timer thread is killed on other exceptions
403+
-- so that it won't throw TimeoutException later
404+
`onException` killThread ttid
405+
where
406+
onTimeout m a = m `catch` \TimeoutException -> a
407+
408+
409+
prepareState :: HasCallStack => Traversable f => WaiSession (f (Chan Observation)) (f (Int -> String -> (Observation -> Maybe a) -> IO ()))
410+
prepareState = getState >>= traverse (liftA2 (<$>) waitFor (liftIO . dupChan))
411+
where
412+
-- read messages from copy chan and once condition is met drain original to the same point
413+
-- upon timeout report error and messages remaining in the original chan
414+
-- that way we report messages since last successful read
415+
waitFor orig copy t msg f =
416+
timeout t (readUntil copy *> readUntil orig) >>= maybe failTimeout mempty
417+
where
418+
failTimeout =
419+
takeUntilTimeout decisecond (readChan orig)
420+
>>= expectationFailure
421+
. ("Timeout waiting for " <> msg <> " at " <> loc <> ". Remaining observations:\n" ++)
422+
. foldMap ((++ "\n") . show . observationMessage)
423+
readUntil = void . untilM (pure . isJust . f) . readChan
424+
loc = fromMaybe "(unknown)" . head $ (prettySrcLoc . snd <$> getCallStack callStack)
425+
-- execute effectful computation until result meets provided condition
426+
untilM cond m = fix $ \loop -> m >>= \a -> ifM (cond a) (pure a) loop
427+
-- duplicate the provided channel and construct wairFor function binding both channels
428+
-- accumulate effecful computation results into a list for specified time
429+
takeUntilTimeout t = fmap reverse . accumulateUntilTimeout t (flip (:)) []
430+
decisecond = 100000

0 commit comments

Comments
 (0)