Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seeking guidance for landing spot of scipy.stats.levy_stable in Jax #21054

Open
tjhunter opened this issue May 3, 2024 · 0 comments
Open

Seeking guidance for landing spot of scipy.stats.levy_stable in Jax #21054

tjhunter opened this issue May 3, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@tjhunter
Copy link

tjhunter commented May 3, 2024

Hello,
thank you again for having released this excellent framework. I have implemented the Lévy alpha-stable distribution in JAX (levy-stable-jax). I would like some guidance / alignment before a potential PR:

  • would this implementation (or part of it) fit in Jax?
  • if not, what is the best way to insert a reference to it in the doc?

The alpha-stable distribution is one of the standard 1-dimensional distributions available in scipy (link), and it has many appealing properties for modeling heavy-tailed data such as stock markets. It is currently not in JAX. It is challenging to implement correctly, so all implementations come in two flavors:

  • exact, quadrature-based code such as in scipy or R. It is slow and hard to differentiate.
  • approximate, interpolation-based such as in levy-stable-jax. It is very fast, easy to differentiate, but must include tabulated values (ex: pylevy, Nolan's reference STABLE program)

levy-stable-jax 's implementation has a thorough test suite, is is vectorized and differentiable across all 4 parameters of the distribution. It is especially tuned for fast maximum likelihood estimation, making it quite in line with Jax's general audience. However, because of the tabulated values, the package is at least 5MB, which is probably too much for adding to JAX itself. It could probably be reduced to ~ 500kB, but I would only consider this extra work if there is a reasonable likelihood that this distribution would land in JAX.

In general, I do not think that levy_stable with exact algorithms would likely come to JAX: all known formulas for the pdf rely on truncating infinite series with quadratures. Any implementation would at the very least make jax depend on quadax (and the ability to differentiate across all parameters for all values would still be an open question).

@tjhunter tjhunter added the enhancement New feature or request label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant