Skip to content

Commit c5e88f4

Browse files
authored
feat(rls-transaction): add retry for read replica connections (#9064)
1 parent 5d4415d commit c5e88f4

File tree

5 files changed

+623
-21
lines changed

5 files changed

+623
-21
lines changed

.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ POSTGRES_DB=prowler_db
3535
# POSTGRES_REPLICA_USER=prowler
3636
# POSTGRES_REPLICA_PASSWORD=postgres
3737
# POSTGRES_REPLICA_DB=prowler_db
38+
# POSTGRES_REPLICA_MAX_ATTEMPTS=3
39+
# POSTGRES_REPLICA_RETRY_BASE_DELAY=0.5
3840

3941
# Celery-Prowler task settings
4042
TASK_RETRY_DELAY_SECONDS=0.1

api/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@ All notable changes to the **Prowler API** are documented in this file.
1313
- Support muting findings based on simple rules with custom reason [(#9051)](https://github.com/prowler-cloud/prowler/pull/9051)
1414
- Support C5 compliance framework for the GCP provider [(#9097)](https://github.com/prowler-cloud/prowler/pull/9097)
1515

16+
---
17+
1618
## [1.14.1] (Prowler 5.13.1)
1719

1820
### Fixed
1921
- `/api/v1/overviews/providers` collapses data by provider type so the UI receives a single aggregated record per cloud family even when multiple accounts exist [(#9053)](https://github.com/prowler-cloud/prowler/pull/9053)
22+
- Added retry logic to database transactions to handle Aurora read replica connection failures during scale-down events [(#9064)](https://github.com/prowler-cloud/prowler/pull/9064)
2023
- Security Hub integrations stop failing when they read relationships via the replica by allowing replica relations and saving updates through the primary [(#9080)](https://github.com/prowler-cloud/prowler/pull/9080)
2124

25+
---
26+
2227
## [1.14.0] (Prowler 5.13.0)
2328

2429
### Added

api/src/backend/api/db_utils.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
11
import re
22
import secrets
3+
import time
34
import uuid
45
from contextlib import contextmanager
56
from datetime import datetime, timedelta, timezone
67

8+
from celery.utils.log import get_task_logger
9+
from config.env import env
710
from django.conf import settings
811
from django.contrib.auth.models import BaseUserManager
9-
from django.db import DEFAULT_DB_ALIAS, connection, connections, models, transaction
12+
from django.db import (
13+
DEFAULT_DB_ALIAS,
14+
OperationalError,
15+
connection,
16+
connections,
17+
models,
18+
transaction,
19+
)
1020
from django_celery_beat.models import PeriodicTask
1121
from psycopg2 import connect as psycopg2_connect
1222
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
1323
from rest_framework_json_api.serializers import ValidationError
1424

15-
from api.db_router import get_read_db_alias, reset_read_db_alias, set_read_db_alias
25+
from api.db_router import (
26+
READ_REPLICA_ALIAS,
27+
get_read_db_alias,
28+
reset_read_db_alias,
29+
set_read_db_alias,
30+
)
31+
32+
logger = get_task_logger(__name__)
1633

1734
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
1835
DB_PASSWORD = (
@@ -28,6 +45,9 @@
2845
POSTGRES_TENANT_VAR = "api.tenant_id"
2946
POSTGRES_USER_VAR = "api.user_id"
3047

48+
REPLICA_MAX_ATTEMPTS = env.int("POSTGRES_REPLICA_MAX_ATTEMPTS", default=3)
49+
REPLICA_RETRY_BASE_DELAY = env.float("POSTGRES_REPLICA_RETRY_BASE_DELAY", default=0.5)
50+
3151
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
3252

3353

@@ -71,24 +91,51 @@ def rls_transaction(
7191
if db_alias not in connections:
7292
db_alias = DEFAULT_DB_ALIAS
7393

74-
router_token = None
75-
try:
76-
if db_alias != DEFAULT_DB_ALIAS:
77-
router_token = set_read_db_alias(db_alias)
78-
79-
with transaction.atomic(using=db_alias):
80-
conn = connections[db_alias]
81-
with conn.cursor() as cursor:
82-
try:
83-
# just in case the value is a UUID object
84-
uuid.UUID(str(value))
85-
except ValueError:
86-
raise ValidationError("Must be a valid UUID")
87-
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
88-
yield cursor
89-
finally:
90-
if router_token is not None:
91-
reset_read_db_alias(router_token)
94+
alias = db_alias
95+
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
96+
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1
97+
98+
for attempt in range(1, max_attempts + 1):
99+
router_token = None
100+
101+
# On final attempt, fallback to primary
102+
if attempt == max_attempts and is_replica:
103+
logger.warning(
104+
f"RLS transaction failed after {attempt - 1} attempts on replica, "
105+
f"falling back to primary DB"
106+
)
107+
alias = DEFAULT_DB_ALIAS
108+
109+
conn = connections[alias]
110+
try:
111+
if alias != DEFAULT_DB_ALIAS:
112+
router_token = set_read_db_alias(alias)
113+
114+
with transaction.atomic(using=alias):
115+
with conn.cursor() as cursor:
116+
try:
117+
# just in case the value is a UUID object
118+
uuid.UUID(str(value))
119+
except ValueError:
120+
raise ValidationError("Must be a valid UUID")
121+
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
122+
yield cursor
123+
return
124+
except OperationalError as e:
125+
# If on primary or max attempts reached, raise
126+
if not is_replica or attempt == max_attempts:
127+
raise
128+
129+
# Retry with exponential backoff
130+
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
131+
logger.info(
132+
f"RLS transaction failed on replica (attempt {attempt}/{max_attempts}), "
133+
f"retrying in {delay}s. Error: {e}"
134+
)
135+
time.sleep(delay)
136+
finally:
137+
if router_token is not None:
138+
reset_read_db_alias(router_token)
92139

93140

94141
class CustomUserManager(BaseUserManager):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests for rls_transaction retry and fallback logic."""
2+
3+
import pytest
4+
from django.db import DEFAULT_DB_ALIAS
5+
from rest_framework_json_api.serializers import ValidationError
6+
7+
from api.db_utils import rls_transaction
8+
9+
10+
@pytest.mark.django_db
11+
class TestRLSTransaction:
12+
"""Simple integration tests for rls_transaction using real DB."""
13+
14+
@pytest.fixture
15+
def tenant(self, tenants_fixture):
16+
return tenants_fixture[0]
17+
18+
def test_success_on_primary(self, tenant):
19+
"""Basic: transaction succeeds on primary database."""
20+
with rls_transaction(str(tenant.id), using=DEFAULT_DB_ALIAS) as cursor:
21+
cursor.execute("SELECT 1")
22+
result = cursor.fetchone()
23+
assert result == (1,)
24+
25+
def test_invalid_uuid_raises_validation_error(self):
26+
"""Invalid UUID raises ValidationError before DB operations."""
27+
with pytest.raises(ValidationError, match="Must be a valid UUID"):
28+
with rls_transaction("not-a-uuid", using=DEFAULT_DB_ALIAS):
29+
pass
30+
31+
def test_custom_parameter_name(self, tenant):
32+
"""Test custom RLS parameter name."""
33+
custom_param = "api.custom_id"
34+
with rls_transaction(
35+
str(tenant.id), parameter=custom_param, using=DEFAULT_DB_ALIAS
36+
) as cursor:
37+
cursor.execute("SELECT current_setting(%s, true)", [custom_param])
38+
result = cursor.fetchone()
39+
assert result == (str(tenant.id),)

0 commit comments

Comments
 (0)