Skip to content

Add GrassiaIIGeometric Distribution #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
269dd75
dist and rv init commit
ColtAllen Mar 29, 2025
b264161
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 11, 2025
d734c68
docstrings
ColtAllen Apr 15, 2025
71bd632
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Apr 15, 2025
48e93f3
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 15, 2025
93c4a60
unit tests
ColtAllen Apr 20, 2025
d2e72b5
alpha min value
ColtAllen Apr 20, 2025
8685005
revert alpha lim
ColtAllen Apr 21, 2025
026f182
small lam value tests
ColtAllen Apr 22, 2025
d12dd0b
ruff formatting
ColtAllen Apr 22, 2025
bcd9cac
TODOs
ColtAllen Apr 22, 2025
78be107
WIP add covar support to RV
ColtAllen Apr 22, 2025
f3ae359
Merge branch 'main' into grassia2geo-dist
ColtAllen Jun 20, 2025
8a30459
WIP time indexing
ColtAllen Jun 20, 2025
7c7afc8
WIP time indexing
ColtAllen Jun 20, 2025
fa9c1ec
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Jun 20, 2025
b957333
WIP symbolic indexing
ColtAllen Jun 20, 2025
d0c1d98
delete test_simple.py
ColtAllen Jun 20, 2025
264c55e
fix symbolic indexing errors
ColtAllen Jul 11, 2025
05e7c55
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 11, 2025
0fa3390
clean up cursor code
ColtAllen Jul 11, 2025
5baa6f7
warn for ndims deprecation
ColtAllen Jul 11, 2025
a715ec7
clean up comments and final TODO
ColtAllen Jul 11, 2025
f3c0f29
remove ndims deprecation and extraneous code
ColtAllen Jul 11, 2025
a232e4c
revert changes to irrelevant test
ColtAllen Jul 12, 2025
ffc059f
remove time_covariate_vector default args
ColtAllen Jul 12, 2025
1d41eb7
revert remaining changes in irrelevant tests
ColtAllen Jul 12, 2025
47ad523
remove test_sampling_consistency
ColtAllen Jul 12, 2025
5b77263
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
eb7222f
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
b34e3d8
make C_t external function, code cleanup
ColtAllen Jul 12, 2025
9803321
rng_fn cleanup
ColtAllen Jul 13, 2025
5ff6853
WIP test frameworks
ColtAllen Jul 13, 2025
63a0b10
inverse cdf
ColtAllen Jul 15, 2025
932a046
covariate pos constraint and WIP RV
ColtAllen Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pymc_extras/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)

def logp(value, r, alpha, time_covariate_vector=None):
def logp(value, r, alpha, time_covariate_vector):
if time_covariate_vector is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like your logp doesn't handle ndim > 1 right? In that case raise NotImplementedError if value.ndim > 1 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would hierarchical models still be supported if this were the case?

time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
Expand Down Expand Up @@ -547,14 +547,12 @@ def C_t(t):
msg="r > 0, alpha > 0",
)

def logcdf(value, r, alpha, time_covariate_vector=None):
def logcdf(value, r, alpha, time_covariate_vector):
if time_covariate_vector is None:
time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)

def C_t(t):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved outside the function so that it can be reused by logp?

Copy link
Author

@ColtAllen ColtAllen Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how kosher this is, but due to how instantiation is handled in logp and logcdf, I had to move C_t outside the distribution class altogether to get it to work.

if t == 0:
return pt.constant(0.0)
if time_covariate_vector.ndim == 0:
return t
else:
Expand All @@ -576,7 +574,7 @@ def C_t(t):
msg="r > 0, alpha > 0",
)

def support_point(rv, size, r, alpha, time_covariate_vector=None):
def support_point(rv, size, r, alpha, time_covariate_vector):
"""Calculate a reasonable starting point for sampling.

For the GrassiaIIGeometric distribution, we use a point estimate based on
Expand Down