Skip to content

Commit

Permalink
Lazy import of scipy.stats (#1268)
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored Mar 11, 2025
1 parent 00fea0e commit b3da2a4
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Literal

import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt
Expand All @@ -21,6 +20,11 @@
)


# Scipy.stats is considerably slow to import
# We import scipy.stats lazily inside `ScipyRandomVariable`
stats = None


try:
broadcast_shapes = np.broadcast_shapes
except AttributeError:
Expand Down Expand Up @@ -57,6 +61,9 @@ def rng_fn_scipy(cls, rng, *args, **kwargs):

@classmethod
def rng_fn(cls, *args, **kwargs):
global stats
if stats is None:
import scipy.stats as stats
size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs)

Expand Down

0 comments on commit b3da2a4

Please sign in to comment.