@@ -28,15 +28,65 @@ class Distribution(abc.ABC):
2828 If a float, the distribution is transformed to its corresponding
2929 log distribution with the given base (e.g., Normal -> Log10Normal).
3030 If ``False``, no transformation is applied.
31+ :param trunc: The truncation points (lower, upper) of the distribution
32+ or ``None`` if the distribution is not truncated.
3133 """
3234
33- def __init__ (self , log : bool | float = False ):
35+ def __init__ (
36+ self , * , log : bool | float = False , trunc : tuple [float , float ] = None
37+ ):
3438 if log is True :
3539 log = np .exp (1 )
40+
41+ if trunc == (- np .inf , np .inf ):
42+ trunc = None
43+
44+ if trunc is not None and trunc [0 ] > trunc [1 ]:
45+ raise ValueError (
46+ "The lower truncation limit must be smaller "
47+ "than the upper truncation limit."
48+ )
49+
3650 self ._logbase = log
51+ self ._trunc = trunc
52+
53+ self ._cd_low = None
54+ self ._cd_high = None
55+ self ._truncation_normalizer = 1
56+
57+ if self ._trunc is not None :
58+ try :
59+ # the cumulative density of the transformed distribution at the
60+ # truncation limits
61+ self ._cd_low = self ._cdf_transformed_untruncated (
62+ self .trunc_low
63+ )
64+ self ._cd_high = self ._cdf_transformed_untruncated (
65+ self .trunc_high
66+ )
67+ # normalization factor for the PDF of the transformed
68+ # distribution to account for truncation
69+ self ._truncation_normalizer = 1 / (
70+ self ._cd_high - self ._cd_low
71+ )
72+ except NotImplementedError :
73+ pass
74+
75+ @property
76+ def trunc_low (self ) -> float :
77+ """The lower truncation limit of the transformed distribution."""
78+ return self ._trunc [0 ] if self ._trunc else - np .inf
79+
80+ @property
81+ def trunc_high (self ) -> float :
82+ """The upper truncation limit of the transformed distribution."""
83+ return self ._trunc [1 ] if self ._trunc else np .inf
3784
38- def _undo_log (self , x : np .ndarray | float ) -> np .ndarray | float :
39- """Undo the log transformation.
85+ def _exp (self , x : np .ndarray | float ) -> np .ndarray | float :
86+ """Exponentiate / undo the log transformation according.
87+
88+ Exponentiate if a log transformation is applied to the distribution.
89+ Otherwise, return the input.
4090
4191 :param x: The sample to transform.
4292 :return: The transformed sample
@@ -45,9 +95,12 @@ def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
4595 return x
4696 return self ._logbase ** x
4797
48- def _apply_log (self , x : np .ndarray | float ) -> np .ndarray | float :
98+ def _log (self , x : np .ndarray | float ) -> np .ndarray | float :
4999 """Apply the log transformation.
50100
101+ Compute the log of x with the specified base if a log transformation
102+ is applied to the distribution. Otherwise, return the input.
103+
51104 :param x: The value to transform.
52105 :return: The transformed value.
53106 """
@@ -61,12 +114,17 @@ def sample(self, shape=None) -> np.ndarray:
61114 :param shape: The shape of the sample.
62115 :return: A sample from the distribution.
63116 """
64- sample = self ._sample (shape )
65- return self ._undo_log (sample )
117+ sample = (
118+ self ._exp (self ._sample (shape ))
119+ if self ._trunc is None
120+ else self ._inverse_transform_sample (shape )
121+ )
122+
123+ return sample
66124
67125 @abc .abstractmethod
68126 def _sample (self , shape = None ) -> np .ndarray :
69- """Sample from the underlying distribution.
127+ """Sample from the underlying distribution, accounting for truncation .
70128
71129 :param shape: The shape of the sample.
72130 :return: A sample from the underlying distribution,
@@ -85,7 +143,11 @@ def pdf(self, x):
85143 chain_rule_factor = (
86144 (1 / (x * np .log (self ._logbase ))) if self ._logbase else 1
87145 )
88- return self ._pdf (self ._apply_log (x )) * chain_rule_factor
146+ return (
147+ self ._pdf (self ._log (x ))
148+ * chain_rule_factor
149+ * self ._truncation_normalizer
150+ )
89151
90152 @abc .abstractmethod
91153 def _pdf (self , x ):
@@ -104,13 +166,71 @@ def logbase(self) -> bool | float:
104166 """
105167 return self ._logbase
106168
169+ def cdf (self , x ):
170+ """Cumulative distribution function at x.
171+
172+ :param x: The value at which to evaluate the CDF.
173+ :return: The value of the CDF at ``x``.
174+ """
175+ return self ._cdf_transformed_untruncated (x ) - self ._cd_low
176+
177+ def _cdf_transformed_untruncated (self , x ):
178+ """Cumulative distribution function of the transformed, but untruncated
179+ distribution at x.
180+
181+ :param x: The value at which to evaluate the CDF.
182+ :return: The value of the CDF at ``x``.
183+ """
184+ return self ._cdf_untransformed_untruncated (self ._log (x ))
185+
186+ def _cdf_untransformed_untruncated (self , x ):
187+ """Cumulative distribution function of the underlying
188+ (untransformed, untruncated) distribution at x.
189+
190+ :param x: The value at which to evaluate the CDF.
191+ :return: The value of the CDF at ``x``.
192+ """
193+ raise NotImplementedError
194+
195+ def _ppf_untransformed_untruncated (self , q ):
196+ """Percent point function of the underlying
197+ (untransformed, untruncated) distribution at q.
198+
199+ :param q: The quantile at which to evaluate the PPF.
200+ :return: The value of the PPF at ``q``.
201+ """
202+ raise NotImplementedError
203+
204+ def _ppf_transformed_untruncated (self , q ):
205+ """Percent point function of the transformed, but untruncated
206+ distribution at q.
207+
208+ :param q: The quantile at which to evaluate the PPF.
209+ :return: The value of the PPF at ``q``.
210+ """
211+ return self ._exp (self ._ppf_untransformed_untruncated (q ))
212+
213+ def _inverse_transform_sample (self , shape ):
214+ """Generate an inverse transform sample from the transformed and
215+ truncated distribution.
216+
217+ :param shape: The shape of the sample.
218+ :return: The sample.
219+ """
220+ uniform_sample = np .random .uniform (
221+ low = self ._cd_low , high = self ._cd_high , size = shape
222+ )
223+ return self ._ppf_transformed_untruncated (uniform_sample )
224+
107225
108226class Normal (Distribution ):
109227 """A (log-)normal distribution.
110228
111229 :param loc: The location parameter of the distribution.
112230 :param scale: The scale parameter of the distribution.
113- :param truncation: The truncation limits of the distribution.
231+ :param trunc: The truncation limits of the distribution.
232+ ``None`` if the distribution is not truncated. The truncation limits
233+ are the truncation limits of the transformed distribution.
114234 :param log: If ``True``, the distribution is transformed to a log-normal
115235 distribution. If a float, the distribution is transformed to a
116236 log-normal distribution with the given base.
@@ -124,19 +244,15 @@ def __init__(
124244 self ,
125245 loc : float ,
126246 scale : float ,
127- truncation : tuple [float , float ] | None = None ,
247+ trunc : tuple [float , float ] | None = None ,
128248 log : bool | float = False ,
129249 ):
130- super ().__init__ (log = log )
131250 self ._loc = loc
132251 self ._scale = scale
133- self ._truncation = truncation
134-
135- if truncation is not None :
136- raise NotImplementedError ("Truncation is not yet implemented." )
252+ super ().__init__ (log = log , trunc = trunc )
137253
138254 def __repr__ (self ):
139- trunc = f", truncation ={ self ._truncation } " if self ._truncation else ""
255+ trunc = f", trunc ={ self ._trunc } " if self ._trunc else ""
140256 log = f", log={ self ._logbase } " if self ._logbase else ""
141257 return f"Normal(loc={ self ._loc } , scale={ self ._scale } { trunc } { log } )"
142258
@@ -146,6 +262,12 @@ def _sample(self, shape=None):
146262 def _pdf (self , x ):
147263 return norm .pdf (x , loc = self ._loc , scale = self ._scale )
148264
265+ def _cdf_untransformed_untruncated (self , x ):
266+ return norm .cdf (x , loc = self ._loc , scale = self ._scale )
267+
268+ def _ppf_untransformed_untruncated (self , q ):
269+ return norm .ppf (q , loc = self ._loc , scale = self ._scale )
270+
149271 @property
150272 def loc (self ):
151273 """The location parameter of the underlying distribution."""
@@ -177,9 +299,9 @@ def __init__(
177299 * ,
178300 log : bool | float = False ,
179301 ):
180- super ().__init__ (log = log )
181302 self ._low = low
182303 self ._high = high
304+ super ().__init__ (log = log )
183305
184306 def __repr__ (self ):
185307 log = f", log={ self ._logbase } " if self ._logbase else ""
@@ -191,13 +313,21 @@ def _sample(self, shape=None):
191313 def _pdf (self , x ):
192314 return uniform .pdf (x , loc = self ._low , scale = self ._high - self ._low )
193315
316+ def _cdf_untransformed_untruncated (self , x ):
317+ return uniform .cdf (x , loc = self ._low , scale = self ._high - self ._low )
318+
319+ def _ppf_untransformed_untruncated (self , q ):
320+ return uniform .ppf (q , loc = self ._low , scale = self ._high - self ._low )
321+
194322
195323class Laplace (Distribution ):
196324 """A (log-)Laplace distribution.
197325
198326 :param loc: The location parameter of the distribution.
199327 :param scale: The scale parameter of the distribution.
200- :param truncation: The truncation limits of the distribution.
328+ :param trunc: The truncation limits of the distribution.
329+ ``None`` if the distribution is not truncated. The truncation limits
330+ are the truncation limits of the transformed distribution.
201331 :param log: If ``True``, the distribution is transformed to a log-Laplace
202332 distribution. If a float, the distribution is transformed to a
203333 log-Laplace distribution with the given base.
@@ -211,18 +341,15 @@ def __init__(
211341 self ,
212342 loc : float ,
213343 scale : float ,
214- truncation : tuple [float , float ] | None = None ,
344+ trunc : tuple [float , float ] | None = None ,
215345 log : bool | float = False ,
216346 ):
217- super ().__init__ (log = log )
218347 self ._loc = loc
219348 self ._scale = scale
220- self ._truncation = truncation
221- if truncation is not None :
222- raise NotImplementedError ("Truncation is not yet implemented." )
349+ super ().__init__ (log = log , trunc = trunc )
223350
224351 def __repr__ (self ):
225- trunc = f", truncation ={ self ._truncation } " if self ._truncation else ""
352+ trunc = f", trunc ={ self ._trunc } " if self ._trunc else ""
226353 log = f", log={ self ._logbase } " if self ._logbase else ""
227354 return f"Laplace(loc={ self ._loc } , scale={ self ._scale } { trunc } { log } )"
228355
@@ -232,6 +359,12 @@ def _sample(self, shape=None):
232359 def _pdf (self , x ):
233360 return laplace .pdf (x , loc = self ._loc , scale = self ._scale )
234361
362+ def _cdf_untransformed_untruncated (self , x ):
363+ return laplace .cdf (x , loc = self ._loc , scale = self ._scale )
364+
365+ def _ppf_untransformed_untruncated (self , q ):
366+ return laplace .ppf (q , loc = self ._loc , scale = self ._scale )
367+
235368 @property
236369 def loc (self ):
237370 """The location parameter of the underlying distribution."""
0 commit comments