Skip to content

Commit 7963c8f

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Add a MemoryRegulator for dynamic memory limit adjustment.
PiperOrigin-RevId: 888856914
1 parent 09d2982 commit 7963c8f

File tree

2 files changed

+530
-0
lines changed

2 files changed

+530
-0
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Manages memory limits for checkpoint operations via PID control and profiling."""
16+
17+
from __future__ import annotations
18+
19+
import abc
20+
import dataclasses
21+
from typing import Optional
22+
23+
# CONSTANT
24+
_BYTES_TO_GIB = 1024.0**3
25+
26+
27+
class MemoryProfiler(abc.ABC):
28+
"""A memory profiler providing feedback for memory regulation."""
29+
30+
def __init__(self):
31+
self._peak_usage_bytes = 0
32+
33+
@property
34+
def peak_usage_bytes(self) -> int:
35+
return self._peak_usage_bytes
36+
37+
@abc.abstractmethod
38+
def profiler_start(self) -> None:
39+
"""Starts the memory profiler."""
40+
raise NotImplementedError
41+
42+
@abc.abstractmethod
43+
def profiler_end(self) -> None:
44+
"""Stops the profiler."""
45+
raise NotImplementedError
46+
47+
@property
48+
def peak_usage_gib(self) -> float:
49+
"""Returns peak memory usage in GiB."""
50+
return self.peak_usage_bytes / _BYTES_TO_GIB
51+
52+
@abc.abstractmethod
53+
def get_prev_blocking_time_sec(self) -> float:
54+
"""Returns the previous iteration's blocking time in seconds."""
55+
raise NotImplementedError
56+
57+
@abc.abstractmethod
58+
def get_expected_surge_gib(self) -> float:
59+
"""Returns the expected memory surge for the next iteration in GiB."""
60+
raise NotImplementedError
61+
62+
63+
_profiler: Optional[MemoryProfiler] = None
64+
65+
66+
def register_memory_profiler(profiler: Optional[MemoryProfiler]) -> None:
67+
global _profiler
68+
_profiler = profiler
69+
70+
71+
def profiler_start() -> None:
72+
if _profiler:
73+
_profiler.profiler_start()
74+
75+
76+
def profiler_end() -> None:
77+
if _profiler:
78+
_profiler.profiler_end()
79+
80+
81+
def profiler_peak_usage_gib() -> float:
82+
if _profiler:
83+
return _profiler.peak_usage_gib
84+
return 0.0
85+
86+
87+
def get_prev_blocking_time_sec() -> float:
88+
if _profiler:
89+
return _profiler.get_prev_blocking_time_sec()
90+
return 0.0
91+
92+
93+
def get_expected_surge_gib() -> float:
94+
if _profiler:
95+
return _profiler.get_expected_surge_gib()
96+
return 0.0
97+
98+
99+
@dataclasses.dataclass
100+
class MemoryRegulator:
101+
"""Regulates maximum concurrent memory usage using a PID controller based on peak memory usage feedback.
102+
103+
For setting up the coefficients, we have the following guidelines:
104+
105+
| Coefficient | Suggested Range | Justification |
106+
| :--- | :--- | :--- |
107+
| kp | 0.30 - 0.60 | A moderate Kp (~0.4) safely scales the limit (e.g.,
108+
opening by 4 GiB for a 10 GiB gap). |
109+
| ki | <= 0.08 | Must be kept low to mitigate integral windup. |
110+
| kd | 0.10 - 0.30 | Acts as a brake against the rate of growth. A higher Kd
111+
(e.g., 0.2) can overpower Kp during rapid memory spikes, capping the limit to
112+
ensure a soft landing at the target. |
113+
114+
Attributes:
115+
max_memory_limit_gib: The maximum host memory limit in GiB
116+
target_ratio: The target ratio of host memory limit to use for peak memory
117+
min_memory_limit_gib: The minimum memory limit in GiB allowed for regulation
118+
kp: Proportional coefficient
119+
ki: Integral coefficient
120+
kd: Derivative coefficient
121+
integral: Integral term accumulated over time
122+
prev_error: Error term from the previous step
123+
integral_windup_limit: Upper and lower bounds for the integral term to
124+
prevent windup
125+
"""
126+
127+
max_memory_limit_gib: float
128+
target_ratio: float = 0.80
129+
min_memory_limit_gib: float = 10.0
130+
kp: float = 0.4
131+
ki: float = 0.05
132+
kd: float = 0.1
133+
134+
integral: float = dataclasses.field(init=False)
135+
prev_error: float = dataclasses.field(init=False)
136+
_prev_expected_surge_gib: float = dataclasses.field(init=False)
137+
integral_windup_limit: float = dataclasses.field(init=False)
138+
139+
def __post_init__(self):
140+
"""Post-initialization validation and field setup."""
141+
self.integral = 0.0
142+
self.prev_error = 0.0
143+
self._prev_expected_surge_gib = 0.0
144+
self.integral_windup_limit = 50.0
145+
146+
if self.max_memory_limit_gib <= 0:
147+
raise ValueError(
148+
'max_memory_limit_gib must be positive, got'
149+
f' {self.max_memory_limit_gib}'
150+
)
151+
if self.min_memory_limit_gib <= 0:
152+
raise ValueError(
153+
'min_memory_limit_gib must be positive, got'
154+
f' {self.min_memory_limit_gib}'
155+
)
156+
if (
157+
self.min_memory_limit_gib
158+
>= self.max_memory_limit_gib * self.target_ratio
159+
):
160+
raise ValueError(
161+
'min_memory_limit_gib must be less than target memory ('
162+
f'{self.max_memory_limit_gib * self.target_ratio} GiB)'
163+
)
164+
165+
def get_next_memory_limit(
166+
self,
167+
current_limit_gib: float,
168+
peak_memory_usage_gib: float,
169+
blocking_time_sec: float, # pylint: disable=unused-argument
170+
expected_surge_gib: float = 0.0,
171+
) -> float:
172+
"""Calculates the next memory limit using PID control and expected surge data.
173+
174+
The PID controller adjusts the memory limit based on feedback from
175+
`peak_memory_usage_gib` to guide usage towards
176+
`max_memory_limit_gib * target_ratio`.
177+
178+
If `expected_surge_gib` is positive, it signals an anticipated temporary
179+
increase in memory consumption. The regulator preemptively reduces the
180+
memory limit by this amount to create headroom and prevent potential OOMs.
181+
During such a surge, PID integral and error history are not updated, and
182+
the PID controller is prevented from increasing the limit. When the
183+
surge passes, `expected_surge_gib` should be reset to 0, and the memory
184+
limit will be restored.
185+
186+
Args:
187+
current_limit_gib: The current memory limit in GiB.
188+
peak_memory_usage_gib: The peak memory usage observed in GiB since the
189+
last adjustment.
190+
blocking_time_sec: The time in seconds that consumers were blocked waiting
191+
for memory in the last interval. Currently unused.
192+
expected_surge_gib: The anticipated memory surge in GiB. If 0, no surge
193+
is expected.
194+
195+
Returns:
196+
The calculated memory limit for the next interval in GiB.
197+
"""
198+
effective_host_limit = self.max_memory_limit_gib
199+
target_mem_gib = effective_host_limit * self.target_ratio
200+
201+
error_gib = target_mem_gib - peak_memory_usage_gib
202+
max_error_gib = effective_host_limit - peak_memory_usage_gib
203+
204+
# --- STANDARD PID MATH ---
205+
p_term = self.kp * error_gib
206+
207+
i_term = self.ki * self.integral
208+
d_term = self.kd * (error_gib - self.prev_error)
209+
210+
# Update history ONLY if not in an expected surge.
211+
# This preserves history for when the surge ends.
212+
if expected_surge_gib == 0:
213+
self.integral += error_gib
214+
self.integral = max(
215+
-self.integral_windup_limit,
216+
min(self.integral_windup_limit, self.integral),
217+
)
218+
self.prev_error = error_gib
219+
220+
base_adjustment = p_term + i_term + d_term
221+
222+
# --- CUSTOM LOGIC ---
223+
if max_error_gib < 0:
224+
# Prioritize memory space.
225+
# Force a reduction if we are over the hard limit, even if
226+
# the PID controller suggests an increase (e.g. due to recovery).
227+
# We take the more aggressive of either the PID drop or the raw overflow.
228+
adjustment = min(max_error_gib, base_adjustment)
229+
else:
230+
adjustment = base_adjustment
231+
232+
# Apply and clamp to hardware limits
233+
# Bypass PID adjustment if in an active surge to honor the manual drop.
234+
if expected_surge_gib > 0:
235+
# If in a surge, we allow the PID to reduce the limit further (negative
236+
# adjustment) but not increase it. This prevents "double counting"
237+
# the surge headroom while still allowing throttling.
238+
adjustment = min(0.0, adjustment)
239+
240+
# Surge delta handles immediate jump down and up.
241+
surge_delta = expected_surge_gib - self._prev_expected_surge_gib
242+
self._prev_expected_surge_gib = expected_surge_gib
243+
244+
new_limit_gib = current_limit_gib + adjustment - surge_delta
245+
new_limit_gib = max(
246+
self.min_memory_limit_gib, min(self.max_memory_limit_gib, new_limit_gib)
247+
)
248+
249+
return new_limit_gib

0 commit comments

Comments
 (0)