|
1 | 1 | import re |
2 | 2 | import secrets |
| 3 | +import time |
3 | 4 | import uuid |
4 | 5 | from contextlib import contextmanager |
5 | 6 | from datetime import datetime, timedelta, timezone |
6 | 7 |
|
| 8 | +from celery.utils.log import get_task_logger |
| 9 | +from config.env import env |
7 | 10 | from django.conf import settings |
8 | 11 | from django.contrib.auth.models import BaseUserManager |
9 | | -from django.db import DEFAULT_DB_ALIAS, connection, connections, models, transaction |
| 12 | +from django.db import ( |
| 13 | + DEFAULT_DB_ALIAS, |
| 14 | + OperationalError, |
| 15 | + connection, |
| 16 | + connections, |
| 17 | + models, |
| 18 | + transaction, |
| 19 | +) |
10 | 20 | from django_celery_beat.models import PeriodicTask |
11 | 21 | from psycopg2 import connect as psycopg2_connect |
12 | 22 | from psycopg2.extensions import AsIs, new_type, register_adapter, register_type |
13 | 23 | from rest_framework_json_api.serializers import ValidationError |
14 | 24 |
|
15 | | -from api.db_router import get_read_db_alias, reset_read_db_alias, set_read_db_alias |
| 25 | +from api.db_router import ( |
| 26 | + READ_REPLICA_ALIAS, |
| 27 | + get_read_db_alias, |
| 28 | + reset_read_db_alias, |
| 29 | + set_read_db_alias, |
| 30 | +) |
| 31 | + |
| 32 | +logger = get_task_logger(__name__) |
16 | 33 |
|
17 | 34 | DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test" |
18 | 35 | DB_PASSWORD = ( |
|
28 | 45 | POSTGRES_TENANT_VAR = "api.tenant_id" |
29 | 46 | POSTGRES_USER_VAR = "api.user_id" |
30 | 47 |
|
| 48 | +REPLICA_MAX_ATTEMPTS = env.int("POSTGRES_REPLICA_MAX_ATTEMPTS", default=3) |
| 49 | +REPLICA_RETRY_BASE_DELAY = env.float("POSTGRES_REPLICA_RETRY_BASE_DELAY", default=0.5) |
| 50 | + |
31 | 51 | SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);" |
32 | 52 |
|
33 | 53 |
|
@@ -71,24 +91,51 @@ def rls_transaction( |
71 | 91 | if db_alias not in connections: |
72 | 92 | db_alias = DEFAULT_DB_ALIAS |
73 | 93 |
|
74 | | - router_token = None |
75 | | - try: |
76 | | - if db_alias != DEFAULT_DB_ALIAS: |
77 | | - router_token = set_read_db_alias(db_alias) |
78 | | - |
79 | | - with transaction.atomic(using=db_alias): |
80 | | - conn = connections[db_alias] |
81 | | - with conn.cursor() as cursor: |
82 | | - try: |
83 | | - # just in case the value is a UUID object |
84 | | - uuid.UUID(str(value)) |
85 | | - except ValueError: |
86 | | - raise ValidationError("Must be a valid UUID") |
87 | | - cursor.execute(SET_CONFIG_QUERY, [parameter, value]) |
88 | | - yield cursor |
89 | | - finally: |
90 | | - if router_token is not None: |
91 | | - reset_read_db_alias(router_token) |
| 94 | + alias = db_alias |
| 95 | + is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS |
| 96 | + max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1 |
| 97 | + |
| 98 | + for attempt in range(1, max_attempts + 1): |
| 99 | + router_token = None |
| 100 | + |
| 101 | + # On final attempt, fallback to primary |
| 102 | + if attempt == max_attempts and is_replica: |
| 103 | + logger.warning( |
| 104 | + f"RLS transaction failed after {attempt - 1} attempts on replica, " |
| 105 | + f"falling back to primary DB" |
| 106 | + ) |
| 107 | + alias = DEFAULT_DB_ALIAS |
| 108 | + |
| 109 | + conn = connections[alias] |
| 110 | + try: |
| 111 | + if alias != DEFAULT_DB_ALIAS: |
| 112 | + router_token = set_read_db_alias(alias) |
| 113 | + |
| 114 | + with transaction.atomic(using=alias): |
| 115 | + with conn.cursor() as cursor: |
| 116 | + try: |
| 117 | + # just in case the value is a UUID object |
| 118 | + uuid.UUID(str(value)) |
| 119 | + except ValueError: |
| 120 | + raise ValidationError("Must be a valid UUID") |
| 121 | + cursor.execute(SET_CONFIG_QUERY, [parameter, value]) |
| 122 | + yield cursor |
| 123 | + return |
| 124 | + except OperationalError as e: |
| 125 | + # If on primary or max attempts reached, raise |
| 126 | + if not is_replica or attempt == max_attempts: |
| 127 | + raise |
| 128 | + |
| 129 | + # Retry with exponential backoff |
| 130 | + delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1)) |
| 131 | + logger.info( |
| 132 | + f"RLS transaction failed on replica (attempt {attempt}/{max_attempts}), " |
| 133 | + f"retrying in {delay}s. Error: {e}" |
| 134 | + ) |
| 135 | + time.sleep(delay) |
| 136 | + finally: |
| 137 | + if router_token is not None: |
| 138 | + reset_read_db_alias(router_token) |
92 | 139 |
|
93 | 140 |
|
94 | 141 | class CustomUserManager(BaseUserManager): |
|
0 commit comments