-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update qchem demos to be JAX or JAX/JIT compatible #1211
base: master
Are you sure you want to change the base?
Conversation
👋 Hey, looks like you've updated some demos! 🐘 Don't forget to update the Please hide this comment once the field(s) are updated. Thanks! |
Thank you for opening this pull request. You can find the built site at this link. Deployment Info:
Note: It may take several minutes for updates to this pull request to be reflected on the deployed site. |
[sc-73935] * Remove deprecated code from demos. * Update various `from pennylane import numpy as np` to `import numpy as np`.
**Summary:** Fixes a merge conflict between `master` and `dev` introduced by #1232 (due to divergent `dateOfLastModification`). To reproduce this PR: 1. Run `git checkout dev`. 2. Run `git checkout -b merge-master-into-dev`. 3. Run `git merge master`. 4. Accept all incoming changes for merge conflicts on `dateOfLastModification`. **Relevant GHA Workflow Runs:** * https://github.com/PennyLaneAI/qml/actions/runs/11241666164 * https://github.com/PennyLaneAI/qml/actions/runs/11253088983 --------- Co-authored-by: David Wierichs <[email protected]> Co-authored-by: Korbinian Kottmann <[email protected]> Co-authored-by: Ivana Kurečić <[email protected]>
- Branched off dev - Merged Master in via `--allow-unrelated-histories` Co-authored-by: Jack Brown <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @austingmhuang, left my final set of comments. They are mainly about being consistent in using qml.grad
and jax.config.update
.
import optax | ||
import jax | ||
|
||
jax.config.update("jax_enable_x64", True) # use double-precision numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure that there is no warning about this in any of the demos.
|
||
import pennylane as qml | ||
import numpy as np | ||
|
||
import jax | ||
|
||
jax.config.update("jax_platform_name", "cpu") | ||
jax.config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need two jax config updates in this demo, while only used one in the others?
@@ -169,7 +169,7 @@ | |||
|
|||
H_tapered = qml.taper(H, generators, paulixops, paulix_sector) | |||
H_tapered_coeffs, H_tapered_ops = H_tapered.terms() | |||
H_tapered = qml.Hamiltonian(np.real(H_tapered_coeffs), H_tapered_ops) | |||
H_tapered = qml.Hamiltonian(jnp.real(jnp.array(H_tapered_coeffs)), H_tapered_ops) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the initial jax array gets changed to something else during the workflow? I was expecting no type change.
|
||
def update_step(i, params, opt_state): | ||
"""Perform a single gradient update step""" | ||
grads = catalyst.grad(tapered_circuit)(params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for not using qml.grad here?
@@ -134,7 +134,7 @@ | |||
|
|||
import jax | |||
|
|||
jax.config.update("jax_platform_name", "cpu") | |||
jax.config.update('jax_platform_name', 'cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this? Shouldn't by default this run on the CPU
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's in the original demo... I could delete it or just leave it be.
@@ -44,7 +44,7 @@ | |||
# | |||
|
|||
import pennylane as qml | |||
from pennylane import numpy as np | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need jax
numpy here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this part of the demo is not JIT compiled so we don't necessarily need JAX.
@@ -159,7 +157,7 @@ | |||
# Let's take this opportunity to create the Hartree-Fock initial state, to compare the | |||
# other states against it later on. | |||
|
|||
from pennylane import numpy as np | |||
import numpy as np | |||
|
|||
hf_primer = ([[3, 0, 0]], np.array([1.0])) | |||
wf_hf = import_state(hf_primer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does import_state
work with jnp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I did my checks it did work with jnp as well. But we don't use jnp here since vanilla np works just as fine (might be even faster with vanilla np, not sure though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just thought if we could use jnp
instead we wouldn't need to import numpy
🤔 makes things more uniform (only if it is possible)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's preferred to only import jnp when it's necessary e.g. the VQE portion where we start doing jax.grad
For the beginning of the demo, i think we should stick to vanilla numpy since initial state preparation itself doesn't require jnp
; I personally find it a little weird to use jnp there. I can delete this additional import numpy as np though since numpy is imported at the start of the demo 😅
Steps used: * Checkout new branch from `dev` * Merge `master` using `git merge master` * Publish branch and open PR --------- Co-authored-by: Jack Brown <[email protected]> Co-authored-by: Ivana Kurečić <[email protected]> Co-authored-by: bellekaplan <[email protected]>
Title:
Update qchem demos to be JAX or JAX/JIT compatible
Summary:
As we phase out autograd, we need to ensure that our current qchem demos can work with either vanilla numpy or jax. Demos that are JIT compatible were made to use JIT as well. For demos that don't have any differentiation, we opt to use vanilla numpy.
Notes: the qubit tapering demo is not JIT compatible due to 2 issues: 1) Wire ordering is not ascending (Not a bug, but might need attention) 2) qml.taper_operation() uses the Exp operator, which has a conditional statement (PennyLaneAI/pennylane#5993)
Relevant references:
Possible Drawbacks:
Related GitHub Issues:
[sc-69776]