Skip to content

Add GrassiaIIGeometric Distribution #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
269dd75
dist and rv init commit
ColtAllen Mar 29, 2025
b264161
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 11, 2025
d734c68
docstrings
ColtAllen Apr 15, 2025
71bd632
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Apr 15, 2025
48e93f3
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 15, 2025
93c4a60
unit tests
ColtAllen Apr 20, 2025
d2e72b5
alpha min value
ColtAllen Apr 20, 2025
8685005
revert alpha lim
ColtAllen Apr 21, 2025
026f182
small lam value tests
ColtAllen Apr 22, 2025
d12dd0b
ruff formatting
ColtAllen Apr 22, 2025
bcd9cac
TODOs
ColtAllen Apr 22, 2025
78be107
WIP add covar support to RV
ColtAllen Apr 22, 2025
f3ae359
Merge branch 'main' into grassia2geo-dist
ColtAllen Jun 20, 2025
8a30459
WIP time indexing
ColtAllen Jun 20, 2025
7c7afc8
WIP time indexing
ColtAllen Jun 20, 2025
fa9c1ec
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Jun 20, 2025
b957333
WIP symbolic indexing
ColtAllen Jun 20, 2025
d0c1d98
delete test_simple.py
ColtAllen Jun 20, 2025
264c55e
fix symbolic indexing errors
ColtAllen Jul 11, 2025
05e7c55
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 11, 2025
0fa3390
clean up cursor code
ColtAllen Jul 11, 2025
5baa6f7
warn for ndims deprecation
ColtAllen Jul 11, 2025
a715ec7
clean up comments and final TODO
ColtAllen Jul 11, 2025
f3c0f29
remove ndims deprecation and extraneous code
ColtAllen Jul 11, 2025
a232e4c
revert changes to irrelevant test
ColtAllen Jul 12, 2025
ffc059f
remove time_covariate_vector default args
ColtAllen Jul 12, 2025
1d41eb7
revert remaining changes in irrelevant tests
ColtAllen Jul 12, 2025
47ad523
remove test_sampling_consistency
ColtAllen Jul 12, 2025
5b77263
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
eb7222f
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
b34e3d8
make C_t external function, code cleanup
ColtAllen Jul 12, 2025
9803321
rng_fn cleanup
ColtAllen Jul 13, 2025
5ff6853
WIP test frameworks
ColtAllen Jul 13, 2025
63a0b10
inverse cdf
ColtAllen Jul 15, 2025
932a046
covariate pos constraint and WIP RV
ColtAllen Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pymc_extras/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,12 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
# Calculate exp(time_covariate_vector) for all samples
exp_time_covar = np.exp(
time_covariate_vector
).mean() # must average over time for correct broadcasting
).mean() # Approximation required to return a t-scalar from a covariate vector
lam_covar = lam * exp_time_covar

samples = np.ceil(rng.exponential(size=size) / lam_covar)
# Take uniform draws from the inverse CDF
u = rng.uniform(size=size)
samples = np.ceil(np.log(1 - u) / (-lam_covar))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about log(1 - u)? When I was writing in paper it seemed like it should just be log(u), or alternatively exponential and the - in lam_covar can be skipped

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather new to deriving inverse CDFs (and this particular derivation I did at 2am last night), but here's my general understanding:

cdf = 1 - S(t) = 1 - exp(- lam * C(t))

# solve for t
inv_cdf = u = 1 - exp(- lam * C(t))

1 - u = exp(- lam * C(t))

log(1-u) = -lam * t * C(t)

C(t) = sum(np.exp(time_covariate_vector)) # sum over t dim

# t is an index, so we must approximate for a solvable value
C(t) = t * exp(time_covariate_vector.mean())

log(1-u) = -lam * t * exp(time_covariate_vector.mean())

t = log(1-u) / (-lam * exp(time_covariate_vector.mean()))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t is a vector so I don't think that makes sense. I'm not sure you can even invert the CDF of a multivariate distribution, because it won't be 1-1 in general.

Copy link
Author

@ColtAllen ColtAllen Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t == len(C(t)). I'm not happy with the approximation, but it was the only way I could think of to use the inverse CDF.

The only alternative is t geometric draws for each sample covariate vector. To provide time context, the vector has to be aggregated in some way, be it sum, mean, or product. Might just have to start experimenting with PPCs in a notebook to see which agg option seems most viable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only alternative is t geometric draws for each sample covariate vector.

From your description of the situation that sounds the most reasonable


return samples

Expand Down Expand Up @@ -581,5 +583,5 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
# If t_idx exceeds length of time_covariate_vector, use last value
max_idx = pt.shape(time_covariate_vector)[0] - 1
safe_idx = pt.minimum(t_idx, max_idx)
covariate_value = time_covariate_vector[safe_idx]
return t * pt.exp(covariate_value)
covariate_value = time_covariate_vector[..., safe_idx]
return pt.exp(covariate_value).sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum(-1)? You should test with batched values to be sure you get the right things out, I also see you do pt.shape(...)[0] instead of pt.shape(...)[-1] above, so I think there's more stuff that's missing to actually work with batch dimensions

2 changes: 1 addition & 1 deletion tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_random_edge_cases(self):
# Test with small r and large alpha values
r_vals = [0.1, 0.5]
alpha_vals = [5.0, 10.0]
time_cov_vals = [0.0, 1.0]
time_cov_vals = [[0.0], [1.0]]

for r in r_vals:
for alpha in alpha_vals:
Expand Down