Skip to content

Commit

Permalink
ENH: Allow pickle with NumPy Generator
Browse files Browse the repository at this point in the history
  • Loading branch information
bashtage committed Jul 19, 2022
1 parent 2aa1ddb commit 8846a2c
Show file tree
Hide file tree
Showing 37 changed files with 168 additions and 112 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ The RNGs include:
## Status

* Builds and passes all tests on:
* Linux 32/64 bit, Python 2.7, 3.5, 3.6, 3.7
* Linux (ARM/ARM64), Python 3.7
* OSX 64-bit, Python 2.7, 3.5, 3.6, 3.7
* Windows 32/64 bit, Python 2.7, 3.5, 3.6, 3.7
* Linux 32/64 bit, Python 3.7, 3.8, 3.9, 3.10
* Linux (ARM/ARM64), Python 3.8
* OSX 64-bit, Python 3.9
* Windows 32/64 bit, Python 3.7, 3.8, 3.9, 3.10
* FreeBSD 64-bit

## Version
Expand Down
6 changes: 4 additions & 2 deletions ci/azure/azure_template_posix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ jobs:
python.version: '3.8'
coverage: true
NUMPY: 1.17.0
python37_latest:
python.version: '3.7'
python38_latest:
python.version: '3.8'
python39_latest:
python.version: '3.9'
python310_latest:
python.version: '3.10'
python36_latest:
python.version: '3.6'
python38_mid_conda:
python.version: '3.8'
use.conda: true
Expand Down
4 changes: 4 additions & 0 deletions ci/azure/azure_template_windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ jobs:
vmImage: ${{ parameters.vmImage }}
strategy:
matrix:
python37_win_latest:
python.version: '3.7'
python38_win_latest:
python.version: '3.8'
python39_win_latest:
python.version: '3.9'
python310_win_latest:
python.version: '3.10'
maxParallel: 10

steps:
Expand Down
10 changes: 10 additions & 0 deletions doc/source/change-log.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ Change Log
You should be using :class:`numpy.random.Generator` or
:class:`numpy.random.RandomState` which are maintained.

v1.23.1
=======
- Registered the bit generators included in ``randomgen`` with NumPy
so that NumPy :class:`~numpy.random.Generator` instances can be pickled
and unpickled when using a ``randomstate`` bit generator.
- Changed the canonical name of the bit generators to be their fully qualified
name. For example, :class:`~randomgen.pcg64.PCG64` is not named ``"randomgen.pcg64.PCG64"``
instead of ``"PCG64"``. This was done to avoid ambiguity with NumPy's supplied
bit generators with the same name.

v1.23.0
=======
- Removed ``Generator`` and ``RandomState``.
Expand Down
2 changes: 2 additions & 0 deletions randomgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from typing import List, Union

from randomgen._register import BitGenerators
from randomgen.aes import AESCounter
from randomgen.chacha import ChaCha
from randomgen.dsfmt import DSFMT
Expand Down Expand Up @@ -37,6 +38,7 @@

__all__ = [
"AESCounter",
"BitGenerators",
"ChaCha",
"DSFMT",
"EFIIX64",
Expand Down
55 changes: 5 additions & 50 deletions randomgen/_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from randomgen.common import BitGenerator
from randomgen.dsfmt import DSFMT
from randomgen.efiix64 import EFIIX64
from randomgen.generator import ExtendedGenerator, Generator
from randomgen.generator import ExtendedGenerator
from randomgen.hc128 import HC128
from randomgen.jsf import JSF
from randomgen.lxm import LXM
from randomgen.mt64 import MT64
from randomgen.mt19937 import MT19937
from randomgen.mtrand import RandomState
from randomgen.pcg32 import PCG32
from randomgen.pcg64 import PCG64, PCG64DXSM, LCG128Mix
from randomgen.philox import Philox
Expand Down Expand Up @@ -54,6 +53,10 @@
"RDRAND": RDRAND,
}

# Assign the fully qualified name for future proofness
for value in list(BitGenerators.values()):
BitGenerators[f"{value.__module__}.{value.__name__}"] = value


def _get_bitgenerator(bit_generator_name: str) -> Type[BitGenerator]:
"""
Expand All @@ -75,29 +78,6 @@ def _decode(name: Union[str, bytes]) -> str:
return name.decode("ascii")


def __generator_ctor(bit_generator_name: Union[bytes, str] = "MT19937") -> Generator:
"""
Pickling helper function that returns a Generator object
Parameters
----------
bit_generator_name: str
String containing the core BitGenerator
Returns
-------
rg: Generator
Generator using the named core BitGenerator
"""
bit_generator_name = _decode(bit_generator_name)
assert isinstance(bit_generator_name, str)
bit_generator = _get_bitgenerator(bit_generator_name)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
bit_gen = bit_generator()
return Generator(bit_gen)


def __extended_generator_ctor(
bit_generator_name: Union[str, bytes] = "MT19937"
) -> ExtendedGenerator:
Expand Down Expand Up @@ -146,28 +126,3 @@ def __bit_generator_ctor(
warnings.filterwarnings("ignore", category=FutureWarning)
bit_gen = bit_generator()
return bit_gen


def __randomstate_ctor(
bit_generator_name: Union[str, bytes] = "MT19937"
) -> RandomState:
"""
Pickling helper function that returns a legacy RandomState-like object
Parameters
----------
bit_generator_name: str
String containing the core BitGenerator
Returns
-------
rs: RandomState
Legacy RandomState using the named core BitGenerator
"""
bit_generator_name = _decode(bit_generator_name)
assert isinstance(bit_generator_name, str)
bit_generator = _get_bitgenerator(bit_generator_name)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
bit_gen = bit_generator()
return RandomState(bit_gen)
62 changes: 62 additions & 0 deletions randomgen/_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from numpy.random._pickle import BitGenerators

from randomgen.aes import AESCounter
from randomgen.chacha import ChaCha
from randomgen.dsfmt import DSFMT
from randomgen.efiix64 import EFIIX64
from randomgen.hc128 import HC128
from randomgen.jsf import JSF
from randomgen.lxm import LXM
from randomgen.mt64 import MT64
from randomgen.mt19937 import MT19937
from randomgen.pcg32 import PCG32
from randomgen.pcg64 import PCG64, PCG64DXSM, LCG128Mix
from randomgen.philox import Philox
from randomgen.rdrand import RDRAND
from randomgen.romu import Romu
from randomgen.sfc import SFC64
from randomgen.sfmt import SFMT
from randomgen.speck128 import SPECK128
from randomgen.threefry import ThreeFry
from randomgen.wrapper import UserBitGenerator
from randomgen.xoroshiro128 import Xoroshiro128
from randomgen.xorshift1024 import Xorshift1024
from randomgen.xoshiro256 import Xoshiro256
from randomgen.xoshiro512 import Xoshiro512

bit_generators = [
AESCounter,
ChaCha,
DSFMT,
EFIIX64,
HC128,
JSF,
LXM,
MT19937,
MT64,
PCG32,
PCG64,
PCG64DXSM,
LCG128Mix,
Philox,
RDRAND,
Romu,
SFC64,
SFMT,
SPECK128,
ThreeFry,
UserBitGenerator,
Xoroshiro128,
Xorshift1024,
Xoshiro256,
Xoshiro512,
]

for bitgen in bit_generators:
key = f"{bitgen.__name__}"
if key not in BitGenerators:
BitGenerators[key] = bitgen
full_key = f"{bitgen.__module__}.{bitgen.__name__}"
BitGenerators[full_key] = bitgen

__all__ = ["BitGenerators"]
4 changes: 2 additions & 2 deletions randomgen/aes.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ cdef class AESCounter(BitGenerator):
for i in range(16 * 4):
state[i] = self.rng_state.state[i]
offset = self.rng_state.offset
return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"s": {"state": state, "seed": seed, "counter": counter,
"offset": offset},
"has_uint32": self.rng_state.has_uint32,
Expand All @@ -290,7 +290,7 @@ cdef class AESCounter(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))
state =value["s"]["state"]
Expand Down
4 changes: 2 additions & 2 deletions randomgen/chacha.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ cdef class ChaCha(BitGenerator):
for i in range(2):
ctr[i] = self.rng_state.ctr[i]

return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"state": {"block": block, "keysetup": keysetup, "ctr": ctr,
"rounds": self.rng_state.rounds}}

Expand All @@ -288,7 +288,7 @@ cdef class ChaCha(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))

Expand Down
3 changes: 3 additions & 0 deletions randomgen/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,6 @@ cdef inline void compute_complex(double *rv_r, double *rv_i, double loc_r,

rv_i[0] = loc_i + scale_i * (rho * rv_r[0] + scale_c * rv_i[0])
rv_r[0] = loc_r + scale_r * rv_r[0]


cdef object fully_qualified_name(instance)
18 changes: 18 additions & 0 deletions randomgen/common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1332,3 +1332,21 @@ cdef object cont_f(void *func, bitgen_t *state, object size, object lock,
return randoms
else:
return out


cdef object fully_qualified_name(instance):
"""
Return the module and class name

Parameters
----------
instance
A bit generator instance

Returns
-------
str
The fully qualified name
"""
typ = type(instance)
return f"{typ.__module__}.{typ.__name__}"
4 changes: 2 additions & 2 deletions randomgen/dsfmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ cdef class DSFMT(BitGenerator):
buffered_uniforms = np.empty(DSFMT_N64, dtype=np.double)
for i in range(DSFMT_N64):
buffered_uniforms[i] = self.rng_state.buffered_uniforms[i]
return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"state": {"state": np.asarray(state),
"idx": self.rng_state.state.idx},
"buffer_loc": self.rng_state.buffer_loc,
Expand All @@ -302,7 +302,7 @@ cdef class DSFMT(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))
state = check_state_array(value["state"]["state"], 2*DSFMT_N_PLUS_1,
Expand Down
4 changes: 2 additions & 2 deletions randomgen/efiix64.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ cdef class EFIIX64(BitGenerator):
"a": self.rng_state.a,
"b": self.rng_state.b,
"c": self.rng_state.c }
return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"state": state,
"has_uint32": self.rng_state.has_uint32,
"uinteger": self.rng_state.uinteger}
Expand All @@ -195,7 +195,7 @@ cdef class EFIIX64(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))
state = value["state"]
Expand Down
4 changes: 2 additions & 2 deletions randomgen/hc128.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ cdef class HC128(BitGenerator):
buf_arr = <uint32_t *>np.PyArray_DATA(buffer)
for i in range(16):
buf_arr[i] = self.rng_state.buffer[i]
return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"state": {"p": p,
"q": q,
"hc_idx": self.rng_state.hc_idx,
Expand All @@ -220,7 +220,7 @@ cdef class HC128(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))
state = value["state"]
Expand Down
4 changes: 2 additions & 2 deletions randomgen/jsf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ cdef class JSF(BitGenerator):
b = self.rng_state.b.u32
c = self.rng_state.c.u32
d = self.rng_state.d.u32
return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"state": {"a": a, "b": b, "c": c, "d": d,
"p": self.rng_state.p,
"q": self.rng_state.q,
Expand All @@ -316,7 +316,7 @@ cdef class JSF(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))
self.size = value["size"]
Expand Down
4 changes: 2 additions & 2 deletions randomgen/lxm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ cdef class LXM(BitGenerator):
x = np.empty(4, dtype=np.uint64)
for i in range(4):
x[i] = self.rng_state.x[i]
return {"bit_generator": type(self).__name__,
return {"bit_generator": fully_qualified_name(self),
"state": {"x": x,
"lcg_state": self.rng_state.lcg_state,
"b": self.rng_state.b,
Expand All @@ -283,7 +283,7 @@ cdef class LXM(BitGenerator):
if not isinstance(value, dict):
raise TypeError("state must be a dict")
bitgen = value.get("bit_generator", "")
if bitgen != type(self).__name__:
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
raise ValueError("state must be for a {0} "
"PRNG".format(type(self).__name__))
state = check_state_array(value["state"]["x"], 4, 64, "x")
Expand Down
Loading

0 comments on commit 8846a2c

Please sign in to comment.