Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update qchem demos to be JAX or JAX/JIT compatible #1211

Open
wants to merge 60 commits into
base: master
Choose a base branch
from

Conversation

austingmhuang
Copy link
Contributor

Title:
Update qchem demos to be JAX or JAX/JIT compatible

Summary:
As we phase out autograd, we need to ensure that our current qchem demos can work with either vanilla numpy or jax. Demos that are JIT compatible were made to use JIT as well. For demos that don't have any differentiation, we opt to use vanilla numpy.

Notes: the qubit tapering demo is not JIT compatible due to 2 issues: 1) Wire ordering is not ascending (Not a bug, but might need attention) 2) qml.taper_operation() uses the Exp operator, which has a conditional statement (PennyLaneAI/pennylane#5993)

Relevant references:

Possible Drawbacks:

Related GitHub Issues:
[sc-69776]

Copy link

github-actions bot commented Sep 5, 2024

👋 Hey, looks like you've updated some demos!

🐘 Don't forget to update the dateOfLastModification in the associated metadata files so your changes are reflected in Glass Onion (search and recommendations).

Please hide this comment once the field(s) are updated. Thanks!

Copy link

github-actions bot commented Oct 1, 2024

Thank you for opening this pull request.

You can find the built site at this link.

Deployment Info:

  • Pull Request ID: 1211
  • Deployment SHA: 8f96c77e6e9445e15b29cec51b9198cb5cf39e69
    (The Deployment SHA refers to the latest commit hash the docs were built from)

Note: It may take several minutes for updates to this pull request to be reflected on the deployed site.

mudit2812 and others added 2 commits October 3, 2024 21:06
[sc-73935]

* Remove deprecated code from demos.
* Update various `from pennylane import numpy as np` to `import numpy as
np`.
actions-user and others added 5 commits October 7, 2024 10:28
**Summary:**
Fixes a merge conflict between `master` and `dev` introduced by #1232
(due to divergent `dateOfLastModification`).

To reproduce this PR:
1. Run `git checkout dev`.
2. Run `git checkout -b merge-master-into-dev`.
3. Run `git merge master`.
4. Accept all incoming changes for merge conflicts on
`dateOfLastModification`.

**Relevant GHA Workflow Runs:**
* https://github.com/PennyLaneAI/qml/actions/runs/11241666164
* https://github.com/PennyLaneAI/qml/actions/runs/11253088983

---------

Co-authored-by: David Wierichs <[email protected]>
Co-authored-by: Korbinian Kottmann <[email protected]>
Co-authored-by: Ivana Kurečić <[email protected]>
- Branched off dev
- Merged Master in via `--allow-unrelated-histories`

Co-authored-by: Jack Brown <[email protected]>
Copy link
Contributor

@soranjh soranjh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @austingmhuang, left my final set of comments. They are mainly about being consistent in using qml.grad and jax.config.update.

demonstrations/tutorial_adaptive_circuits.py Outdated Show resolved Hide resolved
demonstrations/tutorial_chemical_reactions.py Show resolved Hide resolved
demonstrations/tutorial_chemical_reactions.py Show resolved Hide resolved
import optax
import jax

jax.config.update("jax_enable_x64", True) # use double-precision numbers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure that there is no warning about this in any of the demos.


import pennylane as qml
import numpy as np

import jax

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need two jax config updates in this demo, while only used one in the others?

@@ -169,7 +169,7 @@

H_tapered = qml.taper(H, generators, paulixops, paulix_sector)
H_tapered_coeffs, H_tapered_ops = H_tapered.terms()
H_tapered = qml.Hamiltonian(np.real(H_tapered_coeffs), H_tapered_ops)
H_tapered = qml.Hamiltonian(jnp.real(jnp.array(H_tapered_coeffs)), H_tapered_ops)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the initial jax array gets changed to something else during the workflow? I was expecting no type change.


def update_step(i, params, opt_state):
"""Perform a single gradient update step"""
grads = catalyst.grad(tapered_circuit)(params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for not using qml.grad here?

demonstrations/tutorial_adaptive_circuits.py Show resolved Hide resolved
demonstrations/tutorial_chemical_reactions.py Show resolved Hide resolved
demonstrations/tutorial_classically_boosted_vqe.py Outdated Show resolved Hide resolved
@@ -134,7 +134,7 @@

import jax

jax.config.update("jax_platform_name", "cpu")
jax.config.update('jax_platform_name', 'cpu')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this? Shouldn't by default this run on the CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in the original demo... I could delete it or just leave it be.

demonstrations/tutorial_eqnn_force_field.py Show resolved Hide resolved
demonstrations/tutorial_givens_rotations.py Outdated Show resolved Hide resolved
demonstrations/tutorial_givens_rotations.py Outdated Show resolved Hide resolved
@@ -44,7 +44,7 @@
#

import pennylane as qml
from pennylane import numpy as np
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need jax numpy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this part of the demo is not JIT compiled so we don't necessarily need JAX.

@@ -159,7 +157,7 @@
# Let's take this opportunity to create the Hartree-Fock initial state, to compare the
# other states against it later on.

from pennylane import numpy as np
import numpy as np

hf_primer = ([[3, 0, 0]], np.array([1.0]))
wf_hf = import_state(hf_primer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does import_state work with jnp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I did my checks it did work with jnp as well. But we don't use jnp here since vanilla np works just as fine (might be even faster with vanilla np, not sure though)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought if we could use jnp instead we wouldn't need to import numpy 🤔 makes things more uniform (only if it is possible)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's preferred to only import jnp when it's necessary e.g. the VQE portion where we start doing jax.grad

For the beginning of the demo, i think we should stick to vanilla numpy since initial state preparation itself doesn't require jnp; I personally find it a little weird to use jnp there. I can delete this additional import numpy as np though since numpy is imported at the start of the demo 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants