Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create autodiff doc for mcsolve #65

Merged
merged 7 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion doc/source/autodiff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,67 @@ should work:
result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
return result.e_data[0][1].real

jax.grad(f)(0.5, solver)
jax.grad(f)(0.5, solver)


Auto differentiation in ``mcsolve``
===================================


Here is an example to use jax auto differentiation with `mcsolve`.
The automatic differentiation (`jax.grad`) in `mcsolve` does not support parallel map operations.
To ensure accurate gradient computations, please use the default serial execution instead of
parallel mapping within `mcsolve`.


.. code-block::

import qutip_jax
import qutip
import jax
import jax.numpy as jnp
from functools import partial
from qutip import mcsolve

# Use JAX backend for QuTiP
qutip_jax.set_as_default()

# Define time-dependent functions
@partial(jax.jit, static_argnames=("omega",))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)

# Define operators and states
size = 10
a = qutip.tensor(qutip.destroy(size), qutip.qeye(2)).to('jaxdia') # Annihilation operator
sm = qutip.qeye(size).to('jaxdia') & qutip.sigmax().to('jaxdia') # Example spin operator

# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a

# Initialize the Hamiltonian with time-dependent coefficients
H = [H_0, [H_1_op, qutip.coefficient(H_1_coeff, args={"omega": 1.0})]]

# Define initial states, mixed states are not supported
state = qutip.basis(size, size - 1).to('jax') & qutip.basis(2, 1).to('jax')

# Define collapse operators and observables
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]

# Time list
tlist = jnp.linspace(0.0, 10.0, 101)

# Define the function for which we want to compute the gradient
def f(omega):
result = mcsolve(
H, state, tlist, c_ops, e_ops, ntraj=10,
args={"omega": omega},
options={"method": "diffrax"}
)
# Return the expectation value of the number operator at the final time
return result.expect[0][-1].real

# Compute the gradient
gradient = jax.grad(f)(1.0)
12 changes: 6 additions & 6 deletions doc/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To enable JAX as the backend for QuTiP, you need to set the backend to `jax` usi
import qutip_jax

# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()

Using `jax.jit` with QuTiP
--------------------------
Expand All @@ -35,7 +35,7 @@ Using `jax.jit` with QuTiP
import qutip_jax

# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()

# Define states
psi = basis(2, 0).to("jax")
Expand All @@ -57,7 +57,7 @@ Using `jax.jit` with QuTiP
import qutip_jax

# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()

# Define a density matrix
rho = ket2dm(psi).to("jax")
Expand Down Expand Up @@ -87,7 +87,7 @@ To compute the gradient, you need a function that returns a scalar. Note that `j
import qutip_jax

# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()

# Define bra and ket states
bra_state = basis(2, 0).dag()
Expand Down Expand Up @@ -119,7 +119,7 @@ The `trace_dist` function supports `oper` states for gradient computation.
import qutip_jax

# Use JAX as the backend
qutip_jax.use_jax_backend()
qutip_jax.set_as_default()

# Define an operator state
oper_state = rand_dm(2)
Expand Down Expand Up @@ -147,5 +147,5 @@ If you want to switch back to the default backend (NumPy), use the following:

.. code-block:: python

qutip.settings.core["numpy_backend"] = np
qutip_jax.set_as_default(revert = True)

Loading