From 8846a2cdeb3e6637a6b4d396baae98a4cd3d6373 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Mon, 18 Jul 2022 23:08:54 +0100 Subject: [PATCH] ENH: Allow pickle with NumPy Generator --- README.md | 8 +-- ci/azure/azure_template_posix.yml | 6 +- ci/azure/azure_template_windows.yml | 4 ++ doc/source/change-log.rst | 10 ++++ randomgen/__init__.py | 2 + randomgen/_pickle.py | 55 ++---------------- randomgen/_register.py | 62 +++++++++++++++++++++ randomgen/aes.pyx | 4 +- randomgen/chacha.pyx | 4 +- randomgen/common.pxd | 3 + randomgen/common.pyx | 18 ++++++ randomgen/dsfmt.pyx | 4 +- randomgen/efiix64.pyx | 4 +- randomgen/hc128.pyx | 4 +- randomgen/jsf.pyx | 4 +- randomgen/lxm.pyx | 4 +- randomgen/mt19937.pyx | 4 +- randomgen/mt64.pyx | 4 +- randomgen/pcg32.pyx | 4 +- randomgen/pcg64.pyx | 12 ++-- randomgen/philox.pyx | 4 +- randomgen/rdrand.pyx | 4 +- randomgen/romu.pyx | 4 +- randomgen/sfc.pyx | 4 +- randomgen/sfmt.pyx | 4 +- randomgen/speck128.pyx | 4 +- randomgen/tests/test_direct.py | 4 +- randomgen/tests/test_extended_generator.py | 2 +- randomgen/tests/test_lcg128mix_pcg64dxsm.py | 2 - randomgen/tests/test_smoke.py | 1 - randomgen/threefry.pyx | 4 +- randomgen/xoroshiro128.pyx | 4 +- randomgen/xorshift1024.pyx | 4 +- randomgen/xoshiro256.pyx | 4 +- randomgen/xoshiro512.pyx | 4 +- setup.cfg | 3 +- setup.py | 4 +- 37 files changed, 168 insertions(+), 112 deletions(-) create mode 100644 randomgen/_register.py diff --git a/README.md b/README.md index ae9774f47..52197485d 100644 --- a/README.md +++ b/README.md @@ -98,10 +98,10 @@ The RNGs include: ## Status * Builds and passes all tests on: - * Linux 32/64 bit, Python 2.7, 3.5, 3.6, 3.7 - * Linux (ARM/ARM64), Python 3.7 - * OSX 64-bit, Python 2.7, 3.5, 3.6, 3.7 - * Windows 32/64 bit, Python 2.7, 3.5, 3.6, 3.7 + * Linux 32/64 bit, Python 3.7, 3.8, 3.9, 3.10 + * Linux (ARM/ARM64), Python 3.8 + * OSX 64-bit, Python 3.9 + * Windows 32/64 bit, Python 3.7, 3.8, 3.9, 3.10 * FreeBSD 64-bit ## Version diff --git a/ci/azure/azure_template_posix.yml b/ci/azure/azure_template_posix.yml index 2f5b2d925..ebc5434a3 100644 --- a/ci/azure/azure_template_posix.yml +++ b/ci/azure/azure_template_posix.yml @@ -26,12 +26,14 @@ jobs: python.version: '3.8' coverage: true NUMPY: 1.17.0 + python37_latest: + python.version: '3.7' + python38_latest: + python.version: '3.8' python39_latest: python.version: '3.9' python310_latest: python.version: '3.10' - python36_latest: - python.version: '3.6' python38_mid_conda: python.version: '3.8' use.conda: true diff --git a/ci/azure/azure_template_windows.yml b/ci/azure/azure_template_windows.yml index 4ef7f53db..8d3066934 100644 --- a/ci/azure/azure_template_windows.yml +++ b/ci/azure/azure_template_windows.yml @@ -16,10 +16,14 @@ jobs: vmImage: ${{ parameters.vmImage }} strategy: matrix: + python37_win_latest: + python.version: '3.7' python38_win_latest: python.version: '3.8' python39_win_latest: python.version: '3.9' + python310_win_latest: + python.version: '3.10' maxParallel: 10 steps: diff --git a/doc/source/change-log.rst b/doc/source/change-log.rst index 24c815df9..1bd4fa274 100644 --- a/doc/source/change-log.rst +++ b/doc/source/change-log.rst @@ -13,6 +13,16 @@ Change Log You should be using :class:`numpy.random.Generator` or :class:`numpy.random.RandomState` which are maintained. +v1.23.1 +======= +- Registered the bit generators included in ``randomgen`` with NumPy + so that NumPy :class:`~numpy.random.Generator` instances can be pickled + and unpickled when using a ``randomstate`` bit generator. +- Changed the canonical name of the bit generators to be their fully qualified + name. For example, :class:`~randomgen.pcg64.PCG64` is not named ``"randomgen.pcg64.PCG64"`` + instead of ``"PCG64"``. This was done to avoid ambiguity with NumPy's supplied + bit generators with the same name. + v1.23.0 ======= - Removed ``Generator`` and ``RandomState``. diff --git a/randomgen/__init__.py b/randomgen/__init__.py index cdb02e71e..a04c9ff2a 100644 --- a/randomgen/__init__.py +++ b/randomgen/__init__.py @@ -2,6 +2,7 @@ import sys from typing import List, Union +from randomgen._register import BitGenerators from randomgen.aes import AESCounter from randomgen.chacha import ChaCha from randomgen.dsfmt import DSFMT @@ -37,6 +38,7 @@ __all__ = [ "AESCounter", + "BitGenerators", "ChaCha", "DSFMT", "EFIIX64", diff --git a/randomgen/_pickle.py b/randomgen/_pickle.py index 46afb30d8..dea88d0ef 100644 --- a/randomgen/_pickle.py +++ b/randomgen/_pickle.py @@ -6,13 +6,12 @@ from randomgen.common import BitGenerator from randomgen.dsfmt import DSFMT from randomgen.efiix64 import EFIIX64 -from randomgen.generator import ExtendedGenerator, Generator +from randomgen.generator import ExtendedGenerator from randomgen.hc128 import HC128 from randomgen.jsf import JSF from randomgen.lxm import LXM from randomgen.mt64 import MT64 from randomgen.mt19937 import MT19937 -from randomgen.mtrand import RandomState from randomgen.pcg32 import PCG32 from randomgen.pcg64 import PCG64, PCG64DXSM, LCG128Mix from randomgen.philox import Philox @@ -54,6 +53,10 @@ "RDRAND": RDRAND, } +# Assign the fully qualified name for future proofness +for value in list(BitGenerators.values()): + BitGenerators[f"{value.__module__}.{value.__name__}"] = value + def _get_bitgenerator(bit_generator_name: str) -> Type[BitGenerator]: """ @@ -75,29 +78,6 @@ def _decode(name: Union[str, bytes]) -> str: return name.decode("ascii") -def __generator_ctor(bit_generator_name: Union[bytes, str] = "MT19937") -> Generator: - """ - Pickling helper function that returns a Generator object - - Parameters - ---------- - bit_generator_name: str - String containing the core BitGenerator - - Returns - ------- - rg: Generator - Generator using the named core BitGenerator - """ - bit_generator_name = _decode(bit_generator_name) - assert isinstance(bit_generator_name, str) - bit_generator = _get_bitgenerator(bit_generator_name) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - bit_gen = bit_generator() - return Generator(bit_gen) - - def __extended_generator_ctor( bit_generator_name: Union[str, bytes] = "MT19937" ) -> ExtendedGenerator: @@ -146,28 +126,3 @@ def __bit_generator_ctor( warnings.filterwarnings("ignore", category=FutureWarning) bit_gen = bit_generator() return bit_gen - - -def __randomstate_ctor( - bit_generator_name: Union[str, bytes] = "MT19937" -) -> RandomState: - """ - Pickling helper function that returns a legacy RandomState-like object - - Parameters - ---------- - bit_generator_name: str - String containing the core BitGenerator - - Returns - ------- - rs: RandomState - Legacy RandomState using the named core BitGenerator - """ - bit_generator_name = _decode(bit_generator_name) - assert isinstance(bit_generator_name, str) - bit_generator = _get_bitgenerator(bit_generator_name) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - bit_gen = bit_generator() - return RandomState(bit_gen) diff --git a/randomgen/_register.py b/randomgen/_register.py new file mode 100644 index 000000000..8caeb8c80 --- /dev/null +++ b/randomgen/_register.py @@ -0,0 +1,62 @@ +from numpy.random._pickle import BitGenerators + +from randomgen.aes import AESCounter +from randomgen.chacha import ChaCha +from randomgen.dsfmt import DSFMT +from randomgen.efiix64 import EFIIX64 +from randomgen.hc128 import HC128 +from randomgen.jsf import JSF +from randomgen.lxm import LXM +from randomgen.mt64 import MT64 +from randomgen.mt19937 import MT19937 +from randomgen.pcg32 import PCG32 +from randomgen.pcg64 import PCG64, PCG64DXSM, LCG128Mix +from randomgen.philox import Philox +from randomgen.rdrand import RDRAND +from randomgen.romu import Romu +from randomgen.sfc import SFC64 +from randomgen.sfmt import SFMT +from randomgen.speck128 import SPECK128 +from randomgen.threefry import ThreeFry +from randomgen.wrapper import UserBitGenerator +from randomgen.xoroshiro128 import Xoroshiro128 +from randomgen.xorshift1024 import Xorshift1024 +from randomgen.xoshiro256 import Xoshiro256 +from randomgen.xoshiro512 import Xoshiro512 + +bit_generators = [ + AESCounter, + ChaCha, + DSFMT, + EFIIX64, + HC128, + JSF, + LXM, + MT19937, + MT64, + PCG32, + PCG64, + PCG64DXSM, + LCG128Mix, + Philox, + RDRAND, + Romu, + SFC64, + SFMT, + SPECK128, + ThreeFry, + UserBitGenerator, + Xoroshiro128, + Xorshift1024, + Xoshiro256, + Xoshiro512, +] + +for bitgen in bit_generators: + key = f"{bitgen.__name__}" + if key not in BitGenerators: + BitGenerators[key] = bitgen + full_key = f"{bitgen.__module__}.{bitgen.__name__}" + BitGenerators[full_key] = bitgen + +__all__ = ["BitGenerators"] diff --git a/randomgen/aes.pyx b/randomgen/aes.pyx index b1126837e..15ae256b7 100644 --- a/randomgen/aes.pyx +++ b/randomgen/aes.pyx @@ -277,7 +277,7 @@ cdef class AESCounter(BitGenerator): for i in range(16 * 4): state[i] = self.rng_state.state[i] offset = self.rng_state.offset - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "s": {"state": state, "seed": seed, "counter": counter, "offset": offset}, "has_uint32": self.rng_state.has_uint32, @@ -290,7 +290,7 @@ cdef class AESCounter(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state =value["s"]["state"] diff --git a/randomgen/chacha.pyx b/randomgen/chacha.pyx index 969f44a7e..b7f1fb1a1 100644 --- a/randomgen/chacha.pyx +++ b/randomgen/chacha.pyx @@ -279,7 +279,7 @@ cdef class ChaCha(BitGenerator): for i in range(2): ctr[i] = self.rng_state.ctr[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"block": block, "keysetup": keysetup, "ctr": ctr, "rounds": self.rng_state.rounds}} @@ -288,7 +288,7 @@ cdef class ChaCha(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) diff --git a/randomgen/common.pxd b/randomgen/common.pxd index bcb72f813..7b3451744 100644 --- a/randomgen/common.pxd +++ b/randomgen/common.pxd @@ -148,3 +148,6 @@ cdef inline void compute_complex(double *rv_r, double *rv_i, double loc_r, rv_i[0] = loc_i + scale_i * (rho * rv_r[0] + scale_c * rv_i[0]) rv_r[0] = loc_r + scale_r * rv_r[0] + + +cdef object fully_qualified_name(instance) \ No newline at end of file diff --git a/randomgen/common.pyx b/randomgen/common.pyx index 2c2dd94d3..b546d58bf 100644 --- a/randomgen/common.pyx +++ b/randomgen/common.pyx @@ -1332,3 +1332,21 @@ cdef object cont_f(void *func, bitgen_t *state, object size, object lock, return randoms else: return out + + +cdef object fully_qualified_name(instance): + """ + Return the module and class name + + Parameters + ---------- + instance + A bit generator instance + + Returns + ------- + str + The fully qualified name + """ + typ = type(instance) + return f"{typ.__module__}.{typ.__name__}" diff --git a/randomgen/dsfmt.pyx b/randomgen/dsfmt.pyx index 5973ba56c..8ae9e5e0e 100644 --- a/randomgen/dsfmt.pyx +++ b/randomgen/dsfmt.pyx @@ -290,7 +290,7 @@ cdef class DSFMT(BitGenerator): buffered_uniforms = np.empty(DSFMT_N64, dtype=np.double) for i in range(DSFMT_N64): buffered_uniforms[i] = self.rng_state.buffered_uniforms[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"state": np.asarray(state), "idx": self.rng_state.state.idx}, "buffer_loc": self.rng_state.buffer_loc, @@ -302,7 +302,7 @@ cdef class DSFMT(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["state"]["state"], 2*DSFMT_N_PLUS_1, diff --git a/randomgen/efiix64.pyx b/randomgen/efiix64.pyx index 970b917f0..89e8457b5 100644 --- a/randomgen/efiix64.pyx +++ b/randomgen/efiix64.pyx @@ -182,7 +182,7 @@ cdef class EFIIX64(BitGenerator): "a": self.rng_state.a, "b": self.rng_state.b, "c": self.rng_state.c } - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": state, "has_uint32": self.rng_state.has_uint32, "uinteger": self.rng_state.uinteger} @@ -195,7 +195,7 @@ cdef class EFIIX64(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = value["state"] diff --git a/randomgen/hc128.pyx b/randomgen/hc128.pyx index 6f50a1a43..b39b3a4fb 100644 --- a/randomgen/hc128.pyx +++ b/randomgen/hc128.pyx @@ -202,7 +202,7 @@ cdef class HC128(BitGenerator): buf_arr = np.PyArray_DATA(buffer) for i in range(16): buf_arr[i] = self.rng_state.buffer[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"p": p, "q": q, "hc_idx": self.rng_state.hc_idx, @@ -220,7 +220,7 @@ cdef class HC128(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = value["state"] diff --git a/randomgen/jsf.pyx b/randomgen/jsf.pyx index 519febf76..172fc277f 100644 --- a/randomgen/jsf.pyx +++ b/randomgen/jsf.pyx @@ -301,7 +301,7 @@ cdef class JSF(BitGenerator): b = self.rng_state.b.u32 c = self.rng_state.c.u32 d = self.rng_state.d.u32 - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"a": a, "b": b, "c": c, "d": d, "p": self.rng_state.p, "q": self.rng_state.q, @@ -316,7 +316,7 @@ cdef class JSF(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) self.size = value["size"] diff --git a/randomgen/lxm.pyx b/randomgen/lxm.pyx index 7abc371e1..51562fdff 100644 --- a/randomgen/lxm.pyx +++ b/randomgen/lxm.pyx @@ -270,7 +270,7 @@ cdef class LXM(BitGenerator): x = np.empty(4, dtype=np.uint64) for i in range(4): x[i] = self.rng_state.x[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"x": x, "lcg_state": self.rng_state.lcg_state, "b": self.rng_state.b, @@ -283,7 +283,7 @@ cdef class LXM(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["state"]["x"], 4, 64, "x") diff --git a/randomgen/mt19937.pyx b/randomgen/mt19937.pyx index 3ca8d230a..232908d3a 100644 --- a/randomgen/mt19937.pyx +++ b/randomgen/mt19937.pyx @@ -316,7 +316,7 @@ cdef class MT19937(BitGenerator): for i in range(RK_STATE_LEN): key[i] = self.rng_state.key[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"key": key, "pos": self.rng_state.pos}} @state.setter @@ -330,7 +330,7 @@ cdef class MT19937(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) key = check_state_array(value["state"]["key"], RK_STATE_LEN, 32, "key") diff --git a/randomgen/mt64.pyx b/randomgen/mt64.pyx index 5b60e3acc..1b817655b 100644 --- a/randomgen/mt64.pyx +++ b/randomgen/mt64.pyx @@ -190,7 +190,7 @@ cdef class MT64(BitGenerator): for i in range(312): key[i] = self.rng_state.mt[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"key": key, "pos": self.rng_state.mti}, "has_uint32": self.rng_state.has_uint32, "uinteger": self.rng_state.uinteger} @@ -200,7 +200,7 @@ cdef class MT64(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) key = check_state_array(value["state"]["key"], 312, 64, "key") diff --git a/randomgen/pcg32.pyx b/randomgen/pcg32.pyx index c1329ef6d..62acdb4dc 100644 --- a/randomgen/pcg32.pyx +++ b/randomgen/pcg32.pyx @@ -185,7 +185,7 @@ cdef class PCG32(BitGenerator): Dictionary containing the information required to describe the state of the PRNG """ - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"state": self.rng_state.pcg_state.state, "inc": self.rng_state.pcg_state.inc}} @@ -194,7 +194,7 @@ cdef class PCG32(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) self.rng_state.pcg_state.state = value["state"]["state"] diff --git a/randomgen/pcg64.pyx b/randomgen/pcg64.pyx index 8200bddbb..e7f3b1c9c 100644 --- a/randomgen/pcg64.pyx +++ b/randomgen/pcg64.pyx @@ -322,7 +322,7 @@ cdef class PCG64(BitGenerator): &use_dxsm, &has_uint32, &uinteger) state = int(state_vec[0]) * 2**64 + int(state_vec[1]) inc = int(state_vec[2]) * 2**64 + int(state_vec[3]) - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"state": state, "inc": inc}, "variant": self.variant, "has_uint32": has_uint32, @@ -336,7 +336,7 @@ cdef class PCG64(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "RNG".format(type(self).__name__)) state_vec = np.empty(4, dtype=np.uint64) @@ -794,7 +794,7 @@ cdef class LCG128Mix(BitGenerator): state = int(state_vec[0]) * 2**64 + int(state_vec[1]) inc = int(inc_vec[0]) * 2**64 + int(inc_vec[1]) mult = int(mult_vec[0]) * 2**64 + int(mult_vec[1]) - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"state": state, "inc": inc, "multiplier": mult, @@ -811,7 +811,7 @@ cdef class LCG128Mix(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "RNG".format(type(self).__name__)) state_vec = np.empty(2, dtype=np.uint64) @@ -1075,7 +1075,7 @@ cdef class PCG64DXSM(PCG64): &use_dxsm, &has_uint32, &uinteger) state = int(state_vec[0]) * 2**64 + int(state_vec[1]) inc = int(state_vec[2]) * 2**64 + int(state_vec[3]) - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"state": state, "inc": inc}, "has_uint32": has_uint32, "uinteger": uinteger} @@ -1088,7 +1088,7 @@ cdef class PCG64DXSM(PCG64): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "RNG".format(type(self).__name__)) state_vec = np.empty(4, dtype=np.uint64) diff --git a/randomgen/philox.pyx b/randomgen/philox.pyx index 927a0bf47..42eb49e4c 100644 --- a/randomgen/philox.pyx +++ b/randomgen/philox.pyx @@ -367,7 +367,7 @@ cdef class Philox(BitGenerator): else: # self.n == 4 and self.w == 64 key[i] = self.rng_state.state.state4x64.key.v[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"counter": ctr, "key": key}, "buffer": buffer, "buffer_pos": self.rng_state.buffer_pos, @@ -381,7 +381,7 @@ cdef class Philox(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) # Default for previous version diff --git a/randomgen/rdrand.pyx b/randomgen/rdrand.pyx index ec822f5d2..5fbbd24be 100644 --- a/randomgen/rdrand.pyx +++ b/randomgen/rdrand.pyx @@ -441,7 +441,7 @@ cdef class RDRAND(BitGenerator): for i in range(BUFFER_SIZE): buffer[i] = self.rng_state.buffer[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "status": self.rng_state.status, "retries": self.rng_state.retries, "buffer_loc": self.rng_state.buffer_loc, @@ -453,7 +453,7 @@ cdef class RDRAND(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) self.rng_state.retries = value["retries"] diff --git a/randomgen/romu.pyx b/randomgen/romu.pyx index 183d37b3d..ac78ffedd 100644 --- a/randomgen/romu.pyx +++ b/randomgen/romu.pyx @@ -176,7 +176,7 @@ cdef class Romu(BitGenerator): Dictionary containing the information required to describe the state of the PRNG """ - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"w": self.rng_state.w, "x":self.rng_state.x, "y":self.rng_state.y, @@ -191,7 +191,7 @@ cdef class Romu(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) self.rng_state.w = value["state"]["w"] diff --git a/randomgen/sfc.pyx b/randomgen/sfc.pyx index 32977cd87..308567999 100644 --- a/randomgen/sfc.pyx +++ b/randomgen/sfc.pyx @@ -334,7 +334,7 @@ cdef class SFC64(BitGenerator): Dictionary containing the information required to describe the state of the PRNG """ - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"a": self.rng_state.a, "b":self.rng_state.b, "c":self.rng_state.c, @@ -349,7 +349,7 @@ cdef class SFC64(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) self.rng_state.a = value["state"]["a"] diff --git a/randomgen/sfmt.pyx b/randomgen/sfmt.pyx index 0ef0be9c5..6877869f9 100644 --- a/randomgen/sfmt.pyx +++ b/randomgen/sfmt.pyx @@ -296,7 +296,7 @@ cdef class SFMT(BitGenerator): buffered_uint64 = np.empty(SFMT_N64, dtype=np.uint64) for i in range(SFMT_N64): buffered_uint64[i] = self.rng_state.buffered_uint64[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"state": state_arr, "idx": self.rng_state.state.idx}, "buffer_loc": self.rng_state.buffer_loc, @@ -310,7 +310,7 @@ cdef class SFMT(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["state"]["state"], 4 * SFMT_N, 32, diff --git a/randomgen/speck128.pyx b/randomgen/speck128.pyx index 9d3abecd4..9c155b938 100644 --- a/randomgen/speck128.pyx +++ b/randomgen/speck128.pyx @@ -290,7 +290,7 @@ cdef class SPECK128(BitGenerator): arr[2*i] = self.rng_state.round_key[i].u64[0] arr[2*i+1] = self.rng_state.round_key[i].u64[1] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"ctr": ctr, "buffer": buffer.view(np.uint64), "round_key": round_key, @@ -308,7 +308,7 @@ cdef class SPECK128(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) diff --git a/randomgen/tests/test_direct.py b/randomgen/tests/test_direct.py index 886e579be..78838428b 100644 --- a/randomgen/tests/test_direct.py +++ b/randomgen/tests/test_direct.py @@ -1154,7 +1154,9 @@ def test_state_tuple(self): bit_generator = rs.bit_generator state = bit_generator.state desired = rs.integers(2**16) - tup = (state["bit_generator"], state["state"]["key"], state["state"]["pos"]) + # Due to changes + bg_name = state["bit_generator"].split(".")[-1] + tup = (bg_name, state["state"]["key"], state["state"]["pos"]) bit_generator.state = tup actual = rs.integers(2**16) assert_equal(actual, desired) diff --git a/randomgen/tests/test_extended_generator.py b/randomgen/tests/test_extended_generator.py index 08e57067d..0b52f8fce 100644 --- a/randomgen/tests/test_extended_generator.py +++ b/randomgen/tests/test_extended_generator.py @@ -392,7 +392,7 @@ def test_wishart_broadcast(df, scale_dim): pcg = PCG64(0, mode="sequence") eg = ExtendedGenerator(pcg) scale = np.eye(dim) - for i in range(scale_dim): + for _ in range(scale_dim): scale = np.array([scale, scale]) w = eg.wishart(df, scale) assert w.shape[-2:] == (dim, dim) diff --git a/randomgen/tests/test_lcg128mix_pcg64dxsm.py b/randomgen/tests/test_lcg128mix_pcg64dxsm.py index f8b3dd47d..f5f115f96 100644 --- a/randomgen/tests/test_lcg128mix_pcg64dxsm.py +++ b/randomgen/tests/test_lcg128mix_pcg64dxsm.py @@ -125,10 +125,8 @@ def test_ctypes(): so_loc = os.path.join(base, "libctypes_testing.so") try: cmd = ["gcc", "-c", "-Wall", "-Werror", "-fpic", c_loc, "-o", o_loc] - print(" ".join(cmd)) subprocess.call(cmd) cmd = ["gcc", "-shared", "-o", so_loc, o_loc] - print(" ".join(cmd)) subprocess.call(cmd) if not os.path.exists(so_loc): raise FileNotFoundError(f"{so_loc} does not exist") diff --git a/randomgen/tests/test_smoke.py b/randomgen/tests/test_smoke.py index e1e61708b..be0a3da06 100644 --- a/randomgen/tests/test_smoke.py +++ b/randomgen/tests/test_smoke.py @@ -542,7 +542,6 @@ def test_dirichlet(self): s = self.rg.dirichlet((10, 5, 3), 20) assert_(s.shape == (20, 3)) - @pytest.mark.skip(reason="Doesn't work since can't register bit generators") def test_pickle(self): pick = pickle.dumps(self.rg) unpick = pickle.loads(pick) diff --git a/randomgen/threefry.pyx b/randomgen/threefry.pyx index 3d39006bc..46b913932 100644 --- a/randomgen/threefry.pyx +++ b/randomgen/threefry.pyx @@ -348,7 +348,7 @@ cdef class ThreeFry(BitGenerator): else: buffer[i] = self.rng_state.buffer[i].u32 - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"counter": ctr, "key": key}, "buffer": buffer, "buffer_pos": self.rng_state.buffer_pos, @@ -362,7 +362,7 @@ cdef class ThreeFry(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) # Default for previous version diff --git a/randomgen/xoroshiro128.pyx b/randomgen/xoroshiro128.pyx index 8931d5a87..37f6de931 100644 --- a/randomgen/xoroshiro128.pyx +++ b/randomgen/xoroshiro128.pyx @@ -287,7 +287,7 @@ cdef class Xoroshiro128(BitGenerator): state = np.empty(2, dtype=np.uint64) state[0] = self.rng_state.s[0] state[1] = self.rng_state.s[1] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "s": state, "plusplus": self._plusplus, "has_uint32": self.rng_state.has_uint32, @@ -298,7 +298,7 @@ cdef class Xoroshiro128(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["s"], 2, 64, "s") diff --git a/randomgen/xorshift1024.pyx b/randomgen/xorshift1024.pyx index 32a484890..ece4df680 100644 --- a/randomgen/xorshift1024.pyx +++ b/randomgen/xorshift1024.pyx @@ -265,7 +265,7 @@ cdef class Xorshift1024(BitGenerator): s = np.empty(16, dtype=np.uint64) for i in range(16): s[i] = self.rng_state.s[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "state": {"s": s, "p": self.rng_state.p}, "has_uint32": self.rng_state.has_uint32, "uinteger": self.rng_state.uinteger} @@ -275,7 +275,7 @@ cdef class Xorshift1024(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["state"]["s"], 16, 64, "s") diff --git a/randomgen/xoshiro256.pyx b/randomgen/xoshiro256.pyx index fc39cc7ca..fc9596d26 100644 --- a/randomgen/xoshiro256.pyx +++ b/randomgen/xoshiro256.pyx @@ -268,7 +268,7 @@ cdef class Xoshiro256(BitGenerator): state[1] = self.rng_state.s[1] state[2] = self.rng_state.s[2] state[3] = self.rng_state.s[3] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "s": state, "has_uint32": self.rng_state.has_uint32, "uinteger": self.rng_state.uinteger} @@ -278,7 +278,7 @@ cdef class Xoshiro256(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["s"], 4, 64, "s") diff --git a/randomgen/xoshiro512.pyx b/randomgen/xoshiro512.pyx index c0f08c595..2d912482d 100644 --- a/randomgen/xoshiro512.pyx +++ b/randomgen/xoshiro512.pyx @@ -259,7 +259,7 @@ cdef class Xoshiro512(BitGenerator): state = np.empty(8, dtype=np.uint64) for i in range(8): state[i] = self.rng_state.s[i] - return {"bit_generator": type(self).__name__, + return {"bit_generator": fully_qualified_name(self), "s": state, "has_uint32": self.rng_state.has_uint32, "uinteger": self.rng_state.uinteger} @@ -269,7 +269,7 @@ cdef class Xoshiro512(BitGenerator): if not isinstance(value, dict): raise TypeError("state must be a dict") bitgen = value.get("bit_generator", "") - if bitgen != type(self).__name__: + if bitgen not in (type(self).__name__, fully_qualified_name(self)): raise ValueError("state must be for a {0} " "PRNG".format(type(self).__name__)) state = check_state_array(value["s"], 8, 64, "s") diff --git a/setup.cfg b/setup.cfg index 32268189d..cb7e23c2e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,7 @@ [metadata] description_file = README.md -license_file = LICENSE.md +license_files = + LICENSE.md [flake8] max-line-length = 99 diff --git a/setup.py b/setup.py index ccd550d21..37a2480c4 100644 --- a/setup.py +++ b/setup.py @@ -321,10 +321,10 @@ def bit_generator( "Operating System :: Unix", "Programming Language :: C", "Programming Language :: Cython", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Topic :: Adaptive Technologies", "Topic :: Artistic Software", "Topic :: Office/Business :: Financial", @@ -384,5 +384,5 @@ def is_pure(self): ], zip_safe=False, install_requires=install_required, - python_requires=">=3.6", + python_requires=">=3.7", )