-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for np.random.Generator #6566
base: main
Are you sure you want to change the base?
Changes from all commits
834ee27
9ce26d6
30a73e7
f455207
c6cfb37
0fc07df
d5ade4f
1e9c7ab
aff9556
b06f456
770e8fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,66 @@ | ||||||
# Copyright 2024 The Cirq Developers | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# https://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
from typing import Union | ||||||
|
||||||
import numbers | ||||||
import numpy as np | ||||||
|
||||||
from cirq._doc import document | ||||||
from cirq.value.random_state import RANDOM_STATE_OR_SEED_LIKE | ||||||
|
||||||
PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator] | ||||||
|
||||||
document( | ||||||
PRNG_OR_SEED_LIKE, | ||||||
"""A pseudorandom number generator or object that can be converted to one. | ||||||
|
||||||
If is an integer or None, turns into a `np.random.Generator` seeded with that value. | ||||||
If is an instance of `np.random.Generator` or a subclass of it, return as is. | ||||||
If is an instance of `np.random.RandomState` or has a `randint` method, returns | ||||||
`np.random.default_rng(rs.randint(2**31))` | ||||||
""", | ||||||
) | ||||||
|
||||||
|
||||||
def parse_prng( | ||||||
prng_or_seed: Union[PRNG_OR_SEED_LIKE, RANDOM_STATE_OR_SEED_LIKE] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just have only the
|
||||||
) -> np.random.Generator: | ||||||
"""Interpret an object as a pseudorandom number generator. | ||||||
|
||||||
If `prng_or_seed` is an `np.random.Generator`, return it unmodified. | ||||||
If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`. | ||||||
If `prng_or_seed` is an instance of `np.random.RandomState` or has a `randint` method, | ||||||
returns `np.random.default_rng(prng_or_seed.randint(2**31))`. | ||||||
|
||||||
Args: | ||||||
prng_or_seed: The object to be used as or converted to a pseudorandom | ||||||
number generator. | ||||||
|
||||||
Returns: | ||||||
The pseudorandom number generator object. | ||||||
|
||||||
Raises: | ||||||
TypeError: If `prng_or_seed` is can't be converted to an np.random.Generator. | ||||||
""" | ||||||
if isinstance(prng_or_seed, np.random.Generator): | ||||||
return prng_or_seed | ||||||
if prng_or_seed is None or isinstance(prng_or_seed, numbers.Integral): | ||||||
return np.random.default_rng(prng_or_seed if prng_or_seed is None else int(prng_or_seed)) | ||||||
Comment on lines
+59
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we return a singleton Generator object for |
||||||
if isinstance(prng_or_seed, np.random.RandomState): | ||||||
return np.random.default_rng(prng_or_seed.randint(2**31)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can reuse the bit generator for a more genuine conversion.
Suggested change
|
||||||
randint = getattr(prng_or_seed, "randint", None) | ||||||
if randint is not None: | ||||||
return np.random.default_rng(randint(2**31)) | ||||||
raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit - maybe state the actual class here ?
Suggested change
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,48 @@ | ||||||||||||||||||||||||||||
# Copyright 2024 The Cirq Developers | ||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||
# https://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
from typing import List, Union | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
import cirq | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _sample(prng): | ||||||||||||||||||||||||||||
return tuple(prng.random(10)) | ||||||||||||||||||||||||||||
Comment on lines
+23
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this. One output from |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def test_parse_rng() -> None: | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||
eq = cirq.testing.EqualsTester() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# An `np.random.Generator` or a seed. | ||||||||||||||||||||||||||||
group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)] | ||||||||||||||||||||||||||||
group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs] | ||||||||||||||||||||||||||||
eq.add_equality_group(*[_sample(g) for g in group]) | ||||||||||||||||||||||||||||
Comment on lines
+30
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let us not check cross-group inequality. Following the
Suggested change
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# A None seed. | ||||||||||||||||||||||||||||
prng = cirq.value.parse_prng(None) | ||||||||||||||||||||||||||||
eq.add_equality_group(_sample(prng)) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a noop check for a single value. Perhaps replace with
if you are OK with the previous suggestion to have a singleton generator for None. |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# RandomState PRNG. | ||||||||||||||||||||||||||||
prng = cirq.value.parse_prng(np.random.RandomState(42)) | ||||||||||||||||||||||||||||
eq.add_equality_group(_sample(prng)) | ||||||||||||||||||||||||||||
Comment on lines
+39
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can check reproducibility here -
Suggested change
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# np.random module | ||||||||||||||||||||||||||||
prng = cirq.value.parse_prng(np.random) | ||||||||||||||||||||||||||||
eq.add_equality_group(_sample(prng)) | ||||||||||||||||||||||||||||
Comment on lines
+43
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not support creation of generator from a module, not a good practice. I don't quite see a need for it, users can pass None for a default generator.
Suggested change
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
with pytest.raises(TypeError): | ||||||||||||||||||||||||||||
_ = cirq.value.parse_prng(1.0) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -42,3 +42,5 @@ def rand(prng): | |||||||||||||||||||||
vals = [prng.rand() for prng in prngs] | ||||||||||||||||||||||
eq = cirq.testing.EqualsTester() | ||||||||||||||||||||||
eq.add_equality_group(*vals) | ||||||||||||||||||||||
|
||||||||||||||||||||||
eq.add_equality_group(cirq.value.parse_random_state(np.random.default_rng(0)).rand()) | ||||||||||||||||||||||
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let us follow the above style here. Creating a new EqualsTester will only check equality within the one group, we don't need to check inequality from other groups.
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.