Skip to content

Commit

Permalink
REF/ENH: Cache control for CompositeCurves and id on Objects (#570)
Browse files Browse the repository at this point in the history
Co-authored-by: JHM Darbyshire (M1) <[email protected]>
  • Loading branch information
attack68 and attack68 authored Dec 25, 2024
1 parent 423e0e7 commit cf19c28
Show file tree
Hide file tree
Showing 9 changed files with 388 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dist/
downloads/
eggs/
.eggs/
.asv/
lib/
lib64/
parts/
Expand Down
5 changes: 5 additions & 0 deletions docs/source/i_whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ email contact, see `rateslib <https://rateslib.com>`_.
introduced in v1.3.0. This should not be noticeable on round trips, i.e. using
``from_json`` on the output from ``to_json``.
(`552 <https://github.com/attack68/rateslib/pull/552>`_)
* - Refactor
- Internal ``_cache_id`` management is introduced to mutable objects such as *Curves*,
*FXRates* and *FXForwards* to allow auto-mutate detection of associated objects and ensure
consistent method results.
(`570 <https://github.com/attack68/rateslib/pull/570>`_)

1.6.0 (30th November 2024)
****************************
Expand Down
69 changes: 63 additions & 6 deletions python/rateslib/curves/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rateslib.calendars import CalInput, add_tenor, dcf
from rateslib.calendars.dcfs import _DCF1d
from rateslib.calendars.rs import CalTypes, get_calendar
from rateslib.default import NoInput, PlotOutput, _drb, plot
from rateslib.default import NoInput, PlotOutput, _drb, _validate_caches, plot
from rateslib.dual import ( # type: ignore[attr-defined]
Arr1dF64,
Arr1dObj,
Expand Down Expand Up @@ -254,6 +254,10 @@ def __init__( # type: ignore[no-untyped-def]

self._set_ad_order(order=ad)

@property
def _cache_id(self) -> int:
return self._cache_id_store

def __eq__(self, other: Any) -> bool:
"""Test two curves are identical"""
if type(self) is not type(other):
Expand Down Expand Up @@ -381,8 +385,7 @@ def __getitem__(self, date: datetime) -> DualTypes:
# self.spline cannot be None becuase self.t is given and it has been calibrated
val = self._op_exp(self.spline.ppev_single(date_posix)) # type: ignore[union-attr]

self._maybe_add_to_cache(date, val)
return val
return self._cached_value(date, val)

# Licence: Creative Commons - Attribution-NonCommercial-NoDerivatives 4.0 International
# Commercial use of this code, and/or copying and redistribution is prohibited.
Expand Down Expand Up @@ -589,10 +592,12 @@ def clear_cache(self) -> None:
Alternatively the curve caching as a feature can be set to *False* in ``defaults``.
"""
self._cache: dict[datetime, DualTypes] = dict()
self._cache_id_store: int = hash(uuid4())

def _maybe_add_to_cache(self, date: datetime, val: DualTypes) -> None:
def _cached_value(self, date: datetime, val: DualTypes) -> DualTypes:
if defaults.curve_caching:
self._cache[date] = val
return val

def csolve(self) -> None:
"""
Expand Down Expand Up @@ -2257,6 +2262,7 @@ def __init__(

# validate
self._validate_curve_collection()
self._clear_cache()

def _validate_curve_collection(self) -> None:
"""Perform checks to ensure CompositeCurve can exist"""
Expand Down Expand Up @@ -2303,6 +2309,7 @@ def _check_init_attribute(self, attr: str) -> None:
f"Cannot composite curves with different attributes, got for '{attr}': {attrs},",
)

@_validate_caches
def rate( # type: ignore[override]
self,
effective: datetime,
Expand Down Expand Up @@ -2385,7 +2392,10 @@ def rate( # type: ignore[override]

return _

@_validate_caches
def __getitem__(self, date: datetime) -> DualTypes:
if defaults.curve_caching and date in self._cache:
return self._cache[date]
if self._base_type == "dfs":
# will return a composited discount factor
if date == self.curves[0].node_dates[0]:
Expand All @@ -2400,20 +2410,21 @@ def __getitem__(self, date: datetime) -> DualTypes:
avg_rate = ((1.0 / curve[date]) ** (1.0 / days) - 1) / d
total_rate += avg_rate
_: DualTypes = 1.0 / (1 + total_rate * d) ** days
return _
return self._cached_value(date, _)

elif self._base_type == "values":
# will return a composited rate
_ = 0.0
for curve in self.curves:
_ += curve[date]
return _
return self._cached_value(date, _)

else:
raise TypeError(
f"Base curve type is unrecognised: {self._base_type}",
) # pragma: no cover

@_validate_caches
def shift(
self,
spread: DualTypes,
Expand Down Expand Up @@ -2460,6 +2471,7 @@ def shift(
_.collateral = _drb(None, collateral)
return _

@_validate_caches
def translate(self, start: datetime, t: bool = False) -> CompositeCurve:
"""
Create a new curve with an initial node date moved forward keeping all else
Expand All @@ -2482,8 +2494,10 @@ def translate(self, start: datetime, t: bool = False) -> CompositeCurve:
-------
CompositeCurve
"""
# cache check unnecessary since translate is constructed from up-to-date objects directly
return CompositeCurve(curves=[curve.translate(start, t) for curve in self.curves])

@_validate_caches
def roll(self, tenor: datetime | str) -> CompositeCurve:
"""
Create a new curve with its shape translated in time
Expand All @@ -2506,8 +2520,10 @@ def roll(self, tenor: datetime | str) -> CompositeCurve:
-------
CompositeCurve
"""
# cache check unnecessary since roll is constructed from up-to-date objects directly
return CompositeCurve(curves=[curve.roll(tenor) for curve in self.curves])

@_validate_caches
def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes:
"""
Calculate the accrued value of the index from the ``index_base``, which is taken
Expand All @@ -2520,6 +2536,32 @@ def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes
def _get_node_vector(self) -> Arr1dObj | Arr1dF64:
raise NotImplementedError("Instances of CompositeCurve do not have solvable variables.")

@property
def _cache_id_associate(self) -> int:
return hash(sum(curve._cache_id for curve in self.curves))

def _clear_cache(self) -> None:
"""
Clear the cache of values on a *CompositeCurve* type.
Returns
-------
None
Notes
-----
This method is called automatically when any of the composited curves
are detected to have been mutated, via their ``_cache_id``, which therefore
invalidates the cache on a composite curve.
"""
self._cache: dict[datetime, DualTypes] = dict()
self._cache_id_store = self._cache_id_associate

def _validate_cache(self) -> None:
if self._cache_id != self._cache_id_associate:
# If any of the associated curves have been mutated then the cache is invalidated
self._clear_cache()


class MultiCsaCurve(CompositeCurve):
"""
Expand Down Expand Up @@ -2561,6 +2603,7 @@ def __init__(
self.multi_csa_max_step = min(1825, multi_csa_max_step)
super().__init__(curves, id)

@_validate_caches
def rate( # type: ignore[override]
self,
effective: datetime,
Expand Down Expand Up @@ -2603,6 +2646,7 @@ def rate( # type: ignore[override]
_: DualTypes = (df_num / df_den - 1) * 100 / (d * n)
return _

@_validate_caches
def __getitem__(self, date: datetime) -> DualTypes:
# will return a composited discount factor
if date == self.curves[0].node_dates[0]:
Expand Down Expand Up @@ -2647,6 +2691,8 @@ def _get_step(step: int) -> int:
_ *= min_ratio
return _

@_validate_caches
# unnecessary because up-to-date objects are referred to directly
def translate(self, start: datetime, t: bool = False) -> MultiCsaCurve:
"""
Create a new curve with an initial node date moved forward keeping all else
Expand Down Expand Up @@ -2675,6 +2721,8 @@ def translate(self, start: datetime, t: bool = False) -> MultiCsaCurve:
multi_csa_min_step=self.multi_csa_min_step,
)

@_validate_caches
# unnecessary because up-to-date objects are referred to directly
def roll(self, tenor: datetime | str) -> MultiCsaCurve:
"""
Create a new curve with its shape translated in time
Expand Down Expand Up @@ -2703,6 +2751,7 @@ def roll(self, tenor: datetime | str) -> MultiCsaCurve:
multi_csa_min_step=self.multi_csa_min_step,
)

@_validate_caches
def shift(
self,
spread: DualTypes,
Expand Down Expand Up @@ -2875,13 +2924,21 @@ def __init__(
self.node_dates = [self.fx_forwards.immediate, self.terminal]

def __getitem__(self, date: datetime) -> DualTypes:
self.fx_forwards._validate_cache() # manually handle cache check

_1: DualTypes = self.fx_forwards._rate_with_path(self.pair, date, path=self.path)[0]
_2: DualTypes = self.fx_forwards.fx_rates_immediate._fx_array_el(
self.cash_idx, self.coll_idx
)
_3: DualTypes = self.fx_forwards.fx_curves[self.coll_pair][date]
return _1 / _2 * _3

@property
def _cache_id(self) -> int:
# the state of the cache for a ProxyCurve is fully dependent on the state of the
# cache of its contained FXForwards object, which is what derives its calculations.
return self.fx_forwards._cache_id

def to_json(self) -> str: # pragma: no cover # type: ignore
"""
Not implemented for :class:`~rateslib.fx.ProxyCurve` s.
Expand Down
22 changes: 21 additions & 1 deletion python/rateslib/default.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import os
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Any, ParamSpec, TypeVar

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -390,3 +391,22 @@ def _drb(default: Any, possible_ni: Any | NoInput) -> Any:
def _make_py_json(json: str, class_name: str) -> str:
"""Modifies the output JSON output for Rust structs wrapped by Python classes."""
return '{"Py":' + json + "}"


P = ParamSpec("P")
R = TypeVar("R")


def _validate_caches(func: Callable[P, R]) -> Callable[P, R]:
"""
Add a decorator to a class instance method to first validate the cache before performing
additional operations. If a change is detected the implemented `validate_cache` function
is responsible for resetting the cache and updating any `cache_id`s.
"""

def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
self = args[0]
self._validate_cache() # type: ignore[attr-defined]
return func(*args, **kwargs)

return wrapper
Loading

0 comments on commit cf19c28

Please sign in to comment.