Source code for booz_xform_jax.jax_api
"""Pure JAX API for end-to-end Boozer transforms.
This module provides a JIT-friendly, functional interface that avoids
Python loops over surfaces and keeps all arrays in JAX. It is intended
for end-to-end differentiation with vmec_jax -> booz_xform_jax -> neo_jax.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple
import os
import jax
import jax.numpy as jnp
from .trig import _init_trig
[docs]
@dataclass(frozen=True)
class BoozXformConstants:
"""Static constants for the JAX Boozer transform."""
nfp: int
mboz: int
nboz: int
asym: bool
ntheta: int
nzeta: int
nu2_b: int
mmax_non: int
nmax_non: int
mmax_nyq: int
nmax_nyq: int
[docs]
@jax.tree_util.register_pytree_node_class
@dataclass(frozen=True)
class BoozXformGrids:
"""Grid arrays for the JAX Boozer transform."""
theta_grid: jnp.ndarray
zeta_grid: jnp.ndarray
xm_b: jnp.ndarray
xn_b: jnp.ndarray
def tree_flatten(self):
children = (self.theta_grid, self.zeta_grid, self.xm_b, self.xn_b)
return children, None
@classmethod
def tree_unflatten(cls, aux, children):
theta_grid, zeta_grid, xm_b, xn_b = children
return cls(theta_grid=theta_grid, zeta_grid=zeta_grid, xm_b=xm_b, xn_b=xn_b)
def _prepare_mode_lists(mboz: int, nboz: int, nfp: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prepare Boozer mode indices following the C++/Fortran convention."""
m_list: list[int] = []
n_list: list[int] = []
for m in range(mboz):
if m == 0:
for n in range(0, nboz + 1):
m_list.append(m)
n_list.append(n * nfp)
else:
for n in range(-nboz, nboz + 1):
m_list.append(m)
n_list.append(n * nfp)
return jnp.asarray(m_list, dtype=jnp.int32), jnp.asarray(n_list, dtype=jnp.int32)
def _prepare_grids(mboz: int, nboz: int, nfp: int, asym: bool) -> Tuple[int, int, int, jnp.ndarray, jnp.ndarray]:
"""Prepare flattened (theta, zeta) grids following BOOZ_XFORM conventions."""
ntheta_full = 2 * (2 * mboz + 1)
nzeta_full = 2 * (2 * nboz + 1) if nboz > 0 else 1
nu2_b = ntheta_full // 2 + 1
nu3_b = ntheta_full if asym else nu2_b
d_theta = (2.0 * jnp.pi) / ntheta_full
d_zeta = (2.0 * jnp.pi) / (nfp * nzeta_full)
theta_vals = jnp.arange(nu3_b) * d_theta
zeta_vals = jnp.arange(nzeta_full) * d_zeta
theta_grid = jnp.repeat(theta_vals, nzeta_full)
zeta_grid = jnp.tile(zeta_vals, nu3_b)
return int(ntheta_full), int(nzeta_full), int(nu2_b), theta_grid, zeta_grid
[docs]
def prepare_booz_xform_constants(
*,
nfp: int,
mboz: int,
nboz: int,
asym: bool,
xm: Sequence[int],
xn: Sequence[int],
xm_nyq: Sequence[int],
xn_nyq: Sequence[int],
) -> tuple[BoozXformConstants, BoozXformGrids]:
"""Compute static constants for the JAX Boozer transform.
This helper runs on the host and can be used before JIT compilation.
"""
xm_arr = jnp.asarray(xm, dtype=jnp.int32)
xn_arr = jnp.asarray(xn, dtype=jnp.int32)
xm_nyq_arr = jnp.asarray(xm_nyq, dtype=jnp.int32)
xn_nyq_arr = jnp.asarray(xn_nyq, dtype=jnp.int32)
mmax_non = int(jnp.max(jnp.abs(xm_arr)))
nmax_non = int(jnp.max(jnp.abs(xn_arr // nfp)))
mmax_nyq = int(jnp.max(jnp.abs(xm_nyq_arr)))
nmax_nyq = int(jnp.max(jnp.abs(xn_nyq_arr // nfp)))
ntheta, nzeta, nu2_b, theta_grid, zeta_grid = _prepare_grids(mboz, nboz, nfp, asym)
xm_b, xn_b = _prepare_mode_lists(mboz, nboz, nfp)
constants = BoozXformConstants(
nfp=nfp,
mboz=mboz,
nboz=nboz,
asym=asym,
ntheta=ntheta,
nzeta=nzeta,
nu2_b=nu2_b,
mmax_non=mmax_non,
nmax_non=nmax_non,
mmax_nyq=mmax_nyq,
nmax_nyq=nmax_nyq,
)
grids = BoozXformGrids(
theta_grid=theta_grid,
zeta_grid=zeta_grid,
xm_b=jnp.asarray(xm_b, dtype=jnp.int32),
xn_b=jnp.asarray(xn_b, dtype=jnp.int32),
)
return constants, grids
[docs]
def prepare_booz_xform_constants_from_inputs(
*,
inputs,
mboz: int,
nboz: int,
asym: bool,
) -> tuple[BoozXformConstants, BoozXformGrids]:
"""Convenience wrapper using a VMEC -> Boozer input bundle."""
return prepare_booz_xform_constants(
nfp=int(jnp.asarray(inputs.nfp)),
mboz=int(mboz),
nboz=int(nboz),
asym=bool(asym),
xm=jnp.asarray(inputs.xm),
xn=jnp.asarray(inputs.xn),
xm_nyq=jnp.asarray(inputs.xm_nyq),
xn_nyq=jnp.asarray(inputs.xn_nyq),
)
def _surface_transform(
rmnc: jnp.ndarray,
rmns: jnp.ndarray,
zmnc: jnp.ndarray,
zmns: jnp.ndarray,
lmnc: jnp.ndarray,
lmns: jnp.ndarray,
bmnc: jnp.ndarray,
bsubumnc: jnp.ndarray,
bsubvmnc: jnp.ndarray,
iota: jnp.ndarray,
*,
constants: BoozXformConstants,
grids: BoozXformGrids,
tcos_non: jnp.ndarray,
tsin_non: jnp.ndarray,
tcos_nyq: jnp.ndarray,
tsin_nyq: jnp.ndarray,
m_non_f: jnp.ndarray,
n_non_f: jnp.ndarray,
m_nyq_f: jnp.ndarray,
n_nyq_f: jnp.ndarray,
idx_theta0: jnp.ndarray,
idx_thetapi: jnp.ndarray,
m_b: jnp.ndarray,
abs_n_b: jnp.ndarray,
sign_b: jnp.ndarray,
bmns: Optional[jnp.ndarray] = None,
bsubumns: Optional[jnp.ndarray] = None,
bsubvmns: Optional[jnp.ndarray] = None,
fourier_mode: str = "vectorized",
trig_f32: bool = False,
) -> Tuple[jnp.ndarray, ...]:
"""Compute Boozer spectra for a single surface."""
nfp = constants.nfp
theta_grid = grids.theta_grid
zeta_grid = grids.zeta_grid
# Boozer I/G from m=n=0 Nyquist mode
idx00 = jnp.where((m_nyq_f == 0) & (n_nyq_f == 0), size=1)[0][0]
Boozer_I = bsubumnc[idx00]
Boozer_G = bsubvmnc[idx00]
# w spectrum from B_theta and B_zeta. Safe denominators avoid inf/NaN
# tangents at m=n=0 when this kernel is differentiated.
m_nonzero = m_nyq_f != 0.0
n_nonzero_only = jnp.logical_and(~m_nonzero, n_nyq_f != 0.0)
m_nyq_safe = jnp.where(m_nonzero, m_nyq_f, 1.0)
n_nyq_safe = jnp.where(n_nonzero_only, n_nyq_f, 1.0)
wmns = jnp.where(
m_nonzero,
bsubumnc / m_nyq_safe,
jnp.where(n_nonzero_only, -bsubvmnc / n_nyq_safe, 0.0),
)
if constants.asym and bsubumns is not None and bsubvmns is not None:
wmnc = jnp.where(
m_nonzero,
-bsubumns / m_nyq_safe,
jnp.where(n_nonzero_only, bsubvmns / n_nyq_safe, 0.0),
)
else:
wmnc = None
# Non-Nyquist R, Z, lambda and derivatives
r = jnp.einsum("ij,j->i", tcos_non, rmnc)
z = jnp.einsum("ij,j->i", tsin_non, zmns)
lam = jnp.einsum("ij,j->i", tsin_non, lmns)
dlam_dth = jnp.einsum("ij,j->i", tcos_non, lmns * m_non_f)
dlam_dze = -jnp.einsum("ij,j->i", tcos_non, lmns * n_non_f)
if constants.asym:
r = r + jnp.einsum("ij,j->i", tsin_non, rmns)
z = z + jnp.einsum("ij,j->i", tcos_non, zmnc)
lam = lam + jnp.einsum("ij,j->i", tcos_non, lmnc)
dlam_dth = dlam_dth - jnp.einsum("ij,j->i", tsin_non, lmnc * m_non_f)
dlam_dze = dlam_dze + jnp.einsum("ij,j->i", tsin_non, lmnc * n_non_f)
# Nyquist w, derivatives, and |B|
w = jnp.einsum("ij,j->i", tsin_nyq, wmns)
dw_dth = jnp.einsum("ij,j->i", tcos_nyq, wmns * m_nyq_f)
dw_dze = -jnp.einsum("ij,j->i", tcos_nyq, wmns * n_nyq_f)
bmod = jnp.einsum("ij,j->i", tcos_nyq, bmnc)
if constants.asym and wmnc is not None and bmns is not None:
w = w + jnp.einsum("ij,j->i", tcos_nyq, wmnc)
dw_dth = dw_dth - jnp.einsum("ij,j->i", tsin_nyq, wmnc * m_nyq_f)
dw_dze = dw_dze + jnp.einsum("ij,j->i", tsin_nyq, wmnc * n_nyq_f)
bmod = bmod + jnp.einsum("ij,j->i", tsin_nyq, bmns)
# Boozer angles and derivatives
GI = Boozer_G + iota * Boozer_I
one_over_GI = 1.0 / GI
nu = one_over_GI * (w - Boozer_I * lam)
theta_B = theta_grid + lam + iota * nu
zeta_B = zeta_grid + nu
dnu_dze = one_over_GI * (dw_dze - Boozer_I * dlam_dze)
dnu_dth = one_over_GI * (dw_dth - Boozer_I * dlam_dth)
dB_dvmec = (1.0 + dlam_dth) * (1.0 + dnu_dze) + (iota - dlam_dze) * dnu_dth
# Boozer trig tables on (theta_B, zeta_B)
cosm_b, sinm_b, cosn_b, sinn_b = _init_trig(
theta_B, zeta_B, constants.mboz, constants.nboz, nfp
)
if trig_f32:
cosm_b = cosm_b.astype(jnp.float32)
sinm_b = sinm_b.astype(jnp.float32)
cosn_b = cosn_b.astype(jnp.float32)
sinn_b = sinn_b.astype(jnp.float32)
if not constants.asym:
cosm_b = cosm_b.at[idx_theta0, :].set(cosm_b[idx_theta0, :] * 0.5)
cosm_b = cosm_b.at[idx_thetapi, :].set(cosm_b[idx_thetapi, :] * 0.5)
sinm_b = sinm_b.at[idx_theta0, :].set(sinm_b[idx_theta0, :] * 0.5)
sinm_b = sinm_b.at[idx_thetapi, :].set(sinm_b[idx_thetapi, :] * 0.5)
boozer_jac = GI / (bmod * bmod)
if constants.asym:
fourier_factor0 = 2.0 / (constants.ntheta * constants.nzeta)
else:
fourier_factor0 = 2.0 / ((constants.nu2_b - 1) * constants.nzeta)
fourier_factor = jnp.ones((m_b.shape[0],), dtype=jnp.float64) * fourier_factor0
fourier_factor = fourier_factor.at[0].set(fourier_factor0 * 0.5)
if fourier_mode == "streamed":
base_b = bmod * dB_dvmec
base_r = r * dB_dvmec
base_z = z * dB_dvmec
base_nu = nu * dB_dvmec
base_g = boozer_jac * dB_dvmec
m_b_f = m_b
abs_n_b_f = abs_n_b
sign_b_f = jnp.reshape(sign_b, (-1,))
def init_out():
zeros = jnp.zeros((m_b_f.shape[0],), dtype=base_b.dtype)
return zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros
def body(k, state):
(
bmnc_b,
bmns_b,
rmnc_b,
rmns_b,
zmnc_b,
zmns_b,
numnc_b,
numns_b,
gmnc_b,
gmns_b,
) = state
m_idx = m_b_f[k]
n_idx = abs_n_b_f[k]
sign = sign_b_f[k]
cosm = jax.lax.dynamic_index_in_dim(cosm_b, m_idx, axis=1, keepdims=False)
sinm = jax.lax.dynamic_index_in_dim(sinm_b, m_idx, axis=1, keepdims=False)
cosn = jax.lax.dynamic_index_in_dim(cosn_b, n_idx, axis=1, keepdims=False)
sinn = jax.lax.dynamic_index_in_dim(sinn_b, n_idx, axis=1, keepdims=False)
tcos = cosm * cosn + sinm * sinn * sign
tsin = sinm * cosn - cosm * sinn * sign
ff = fourier_factor[k]
bmnc_b = bmnc_b.at[k].set(ff * jnp.sum(tcos * base_b))
rmnc_b = rmnc_b.at[k].set(ff * jnp.sum(tcos * base_r))
zmns_b = zmns_b.at[k].set(ff * jnp.sum(tsin * base_z))
numns_b = numns_b.at[k].set(ff * jnp.sum(tsin * base_nu))
gmnc_b = gmnc_b.at[k].set(ff * jnp.sum(tcos * base_g))
if constants.asym:
bmns_b = bmns_b.at[k].set(ff * jnp.sum(tsin * base_b))
rmns_b = rmns_b.at[k].set(ff * jnp.sum(tsin * base_r))
zmnc_b = zmnc_b.at[k].set(ff * jnp.sum(tcos * base_z))
numnc_b = numnc_b.at[k].set(ff * jnp.sum(tcos * base_nu))
gmns_b = gmns_b.at[k].set(ff * jnp.sum(tsin * base_g))
return (
bmnc_b,
bmns_b,
rmnc_b,
rmns_b,
zmnc_b,
zmns_b,
numnc_b,
numns_b,
gmnc_b,
gmns_b,
)
(
bmnc_b,
bmns_b,
rmnc_b,
rmns_b,
zmnc_b,
zmns_b,
numnc_b,
numns_b,
gmnc_b,
gmns_b,
) = jax.lax.fori_loop(
0, m_b_f.shape[0], body, init_out()
)
else:
cosm_b_m = jnp.take(cosm_b, m_b, axis=1)
sinm_b_m = jnp.take(sinm_b, m_b, axis=1)
cosn_b_n = jnp.take(cosn_b, abs_n_b, axis=1)
sinn_b_n = jnp.take(sinn_b, abs_n_b, axis=1)
tcos_modes = cosm_b_m * cosn_b_n + sinm_b_m * sinn_b_n * sign_b
tsin_modes = sinm_b_m * cosn_b_n - cosm_b_m * sinn_b_n * sign_b
base_b = bmod * dB_dvmec
base_r = r * dB_dvmec
base_z = z * dB_dvmec
base_nu = nu * dB_dvmec
base_g = boozer_jac * dB_dvmec
def project_cos(field: jnp.ndarray) -> jnp.ndarray:
return fourier_factor * jnp.einsum("ij,i->j", tcos_modes, field)
def project_sin(field: jnp.ndarray) -> jnp.ndarray:
return fourier_factor * jnp.einsum("ij,i->j", tsin_modes, field)
bmnc_b = project_cos(base_b)
rmnc_b = project_cos(base_r)
zmns_b = project_sin(base_z)
numns_b = project_sin(base_nu)
gmnc_b = project_cos(base_g)
zeros = jnp.zeros_like(bmnc_b)
if constants.asym:
bmns_b = project_sin(base_b)
rmns_b = project_sin(base_r)
zmnc_b = project_cos(base_z)
numnc_b = project_cos(base_nu)
gmns_b = project_sin(base_g)
else:
bmns_b = zeros
rmns_b = zeros
zmnc_b = zeros
numnc_b = zeros
gmns_b = zeros
return (
bmnc_b,
bmns_b,
rmnc_b,
rmns_b,
zmnc_b,
zmns_b,
numnc_b,
numns_b,
gmnc_b,
gmns_b,
Boozer_I,
Boozer_G,
)
[docs]
def booz_xform_jax_impl(
rmnc: jnp.ndarray,
zmns: jnp.ndarray,
lmns: jnp.ndarray,
bmnc: jnp.ndarray,
bsubumnc: jnp.ndarray,
bsubvmnc: jnp.ndarray,
iota: jnp.ndarray,
*,
xm: jnp.ndarray,
xn: jnp.ndarray,
xm_nyq: jnp.ndarray,
xn_nyq: jnp.ndarray,
constants: BoozXformConstants,
grids: BoozXformGrids,
rmns: Optional[jnp.ndarray] = None,
zmnc: Optional[jnp.ndarray] = None,
lmnc: Optional[jnp.ndarray] = None,
bmns: Optional[jnp.ndarray] = None,
bsubumns: Optional[jnp.ndarray] = None,
bsubvmns: Optional[jnp.ndarray] = None,
surface_indices: Optional[jnp.ndarray] = None,
) -> dict:
"""JAX-native Boozer transform over all (or selected) surfaces.
All inputs must be JAX arrays with surface dimension first, i.e. shape
(ns, mn_non) for non-Nyquist arrays and (ns, mn_nyq) for Nyquist arrays.
"""
ns_b_full = int(rmnc.shape[0])
if surface_indices is not None:
rmnc = jnp.take(rmnc, surface_indices, axis=0)
zmns = jnp.take(zmns, surface_indices, axis=0)
lmns = jnp.take(lmns, surface_indices, axis=0)
bmnc = jnp.take(bmnc, surface_indices, axis=0)
bsubumnc = jnp.take(bsubumnc, surface_indices, axis=0)
bsubvmnc = jnp.take(bsubvmnc, surface_indices, axis=0)
iota = jnp.take(iota, surface_indices, axis=0)
if rmns is not None:
rmns = jnp.take(rmns, surface_indices, axis=0)
if zmnc is not None:
zmnc = jnp.take(zmnc, surface_indices, axis=0)
if lmnc is not None:
lmnc = jnp.take(lmnc, surface_indices, axis=0)
if bmns is not None:
bmns = jnp.take(bmns, surface_indices, axis=0)
if bsubumns is not None:
bsubumns = jnp.take(bsubumns, surface_indices, axis=0)
if bsubvmns is not None:
bsubvmns = jnp.take(bsubvmns, surface_indices, axis=0)
xm_non_j = jnp.asarray(xm, dtype=jnp.int32)
xn_non_j = jnp.asarray(xn, dtype=jnp.int32)
xm_nyq_j = jnp.asarray(xm_nyq, dtype=jnp.int32)
xn_nyq_j = jnp.asarray(xn_nyq, dtype=jnp.int32)
fourier_mode = os.getenv("BOOZ_XFORM_JAX_FOURIER_MODE", "vectorized").strip().lower()
if fourier_mode not in {"vectorized", "streamed"}:
raise ValueError(f"Unsupported BOOZ_XFORM_JAX_FOURIER_MODE '{fourier_mode}'")
trig_f32 = os.getenv("BOOZ_XFORM_JAX_TRIG_F32", "0").strip().lower() in {"1", "true", "yes", "on"}
# Precompute trig tables and mode combinations once for all surfaces.
cosm, sinm, cosn, sinn = _init_trig(
grids.theta_grid, grids.zeta_grid, constants.mmax_non, constants.nmax_non, constants.nfp
)
cosm_nyq, sinm_nyq, cosn_nyq, sinn_nyq = _init_trig(
grids.theta_grid, grids.zeta_grid, constants.mmax_nyq, constants.nmax_nyq, constants.nfp
)
if trig_f32:
cosm = cosm.astype(jnp.float32)
sinm = sinm.astype(jnp.float32)
cosn = cosn.astype(jnp.float32)
sinn = sinn.astype(jnp.float32)
cosm_nyq = cosm_nyq.astype(jnp.float32)
sinm_nyq = sinm_nyq.astype(jnp.float32)
cosn_nyq = cosn_nyq.astype(jnp.float32)
sinn_nyq = sinn_nyq.astype(jnp.float32)
cosm_m_non = jnp.take(cosm, xm_non_j, axis=1)
sinm_m_non = jnp.take(sinm, xm_non_j, axis=1)
abs_n_non = jnp.abs(xn_non_j // constants.nfp)
cosn_n_non = jnp.take(cosn, abs_n_non, axis=1)
sinn_n_non = jnp.take(sinn, abs_n_non, axis=1)
sign_non = jnp.where(xn_non_j < 0, -1.0, 1.0)[None, :]
tcos_non = cosm_m_non * cosn_n_non + sinm_m_non * sinn_n_non * sign_non
tsin_non = sinm_m_non * cosn_n_non - cosm_m_non * sinn_n_non * sign_non
m_non_f = xm_non_j.astype(jnp.float64)
n_non_f = xn_non_j.astype(jnp.float64)
cosm_m_nyq = jnp.take(cosm_nyq, xm_nyq_j, axis=1)
sinm_m_nyq = jnp.take(sinm_nyq, xm_nyq_j, axis=1)
abs_n_nyq = jnp.abs(xn_nyq_j // constants.nfp)
cosn_n_nyq = jnp.take(cosn_nyq, abs_n_nyq, axis=1)
sinn_n_nyq = jnp.take(sinn_nyq, abs_n_nyq, axis=1)
sign_nyq = jnp.where(xn_nyq_j < 0, -1.0, 1.0)[None, :]
tcos_nyq = cosm_m_nyq * cosn_n_nyq + sinm_m_nyq * sinn_n_nyq * sign_nyq
tsin_nyq = sinm_m_nyq * cosn_n_nyq - cosm_m_nyq * sinn_n_nyq * sign_nyq
m_nyq_f = xm_nyq_j.astype(jnp.float64)
n_nyq_f = xn_nyq_j.astype(jnp.float64)
idx_theta0 = jnp.arange(0, constants.nzeta)
idx_thetapi = jnp.arange(
(constants.nu2_b - 1) * constants.nzeta, constants.nu2_b * constants.nzeta
)
m_b = grids.xm_b
abs_n_b = jnp.abs(grids.xn_b // constants.nfp)
sign_b = jnp.where(grids.xn_b < 0, -1.0, 1.0)[None, :]
def _surf(
_rmnc,
_rmns,
_zmnc,
_zmns,
_lmnc,
_lmns,
_bmnc,
_bsubumnc,
_bsubvmnc,
_iota,
_bmns,
_bsubumns,
_bsubvmns,
):
return _surface_transform(
_rmnc,
_rmns,
_zmnc,
_zmns,
_lmnc,
_lmns,
_bmnc,
_bsubumnc,
_bsubvmnc,
_iota,
constants=constants,
grids=grids,
tcos_non=tcos_non,
tsin_non=tsin_non,
tcos_nyq=tcos_nyq,
tsin_nyq=tsin_nyq,
m_non_f=m_non_f,
n_non_f=n_non_f,
m_nyq_f=m_nyq_f,
n_nyq_f=n_nyq_f,
idx_theta0=idx_theta0,
idx_thetapi=idx_thetapi,
m_b=m_b,
abs_n_b=abs_n_b,
sign_b=sign_b,
bmns=_bmns,
bsubumns=_bsubumns,
bsubvmns=_bsubvmns,
fourier_mode=fourier_mode,
trig_f32=trig_f32,
)
vmap_fn = jax.vmap(_surf)
rmns_in = rmns if rmns is not None else jnp.zeros_like(rmnc)
zmnc_in = zmnc if zmnc is not None else jnp.zeros_like(zmns)
lmnc_in = lmnc if lmnc is not None else jnp.zeros_like(lmns)
bmns_in = bmns if bmns is not None else jnp.zeros_like(bmnc)
bsubumns_in = bsubumns if bsubumns is not None else jnp.zeros_like(bsubumnc)
bsubvmns_in = bsubvmns if bsubvmns is not None else jnp.zeros_like(bsubvmnc)
(
bmnc_b,
bmns_b,
rmnc_b,
rmns_b,
zmnc_b,
zmns_b,
numnc_b,
numns_b,
gmnc_b,
gmns_b,
Boozer_I,
Boozer_G,
) = vmap_fn(
rmnc,
rmns_in,
zmnc_in,
zmns,
lmnc_in,
lmns,
bmnc,
bsubumnc,
bsubvmnc,
iota,
bmns_in,
bsubumns_in,
bsubvmns_in,
)
ns_b = bmnc_b.shape[0]
if surface_indices is None:
jlist = jnp.arange(2, ns_b + 2)
else:
jlist = surface_indices + 2
return {
"nfp_b": jnp.asarray(constants.nfp),
"ns_b": jnp.asarray(ns_b_full),
"ixm_b": jnp.asarray(grids.xm_b),
"ixn_b": jnp.asarray(grids.xn_b),
"iota_b": iota,
"buco_b": Boozer_I,
"bvco_b": Boozer_G,
"rmnc_b": rmnc_b,
"rmns_b": rmns_b,
"zmnc_b": zmnc_b,
"zmns_b": zmns_b,
"numnc_b": numnc_b,
"numns_b": numns_b,
"pmnc_b": -numnc_b,
"pmns_b": -numns_b,
"bmnc_b": bmnc_b,
"bmns_b": bmns_b,
"gmnc_b": gmnc_b,
"gmns_b": gmns_b,
# BOOZ_XFORM/netCDF-compatible spelling for the Jacobian harmonics.
"gmn_b": gmnc_b,
"jlist": jlist,
}
[docs]
def booz_xform_from_inputs(
*,
inputs,
constants: BoozXformConstants,
grids: BoozXformGrids,
surface_indices: Optional[jnp.ndarray] = None,
jit: bool = True,
) -> dict:
"""Run the JAX Boozer transform using a VMEC -> Boozer input bundle."""
booz_fn = booz_xform_jax_impl
if jit:
booz_fn = jax.jit(booz_xform_jax_impl, static_argnames=("constants",))
return booz_fn(
rmnc=inputs.rmnc,
zmns=inputs.zmns,
lmns=inputs.lmns,
bmnc=inputs.bmnc,
bsubumnc=inputs.bsubumnc,
bsubvmnc=inputs.bsubvmnc,
iota=inputs.iota,
xm=inputs.xm,
xn=inputs.xn,
xm_nyq=inputs.xm_nyq,
xn_nyq=inputs.xn_nyq,
constants=constants,
grids=grids,
rmns=getattr(inputs, "rmns", None),
zmnc=getattr(inputs, "zmnc", None),
lmnc=getattr(inputs, "lmnc", None),
bmns=getattr(inputs, "bmns", None),
bsubumns=getattr(inputs, "bsubumns", None),
bsubvmns=getattr(inputs, "bsubvmns", None),
surface_indices=surface_indices,
)
[docs]
def booz_xform_jax(
*,
rmnc: jnp.ndarray,
zmns: jnp.ndarray,
lmns: jnp.ndarray,
bmnc: jnp.ndarray,
bsubumnc: jnp.ndarray,
bsubvmnc: jnp.ndarray,
iota: jnp.ndarray,
xm: Sequence[int],
xn: Sequence[int],
xm_nyq: Sequence[int],
xn_nyq: Sequence[int],
nfp: int,
mboz: int,
nboz: int,
asym: bool = False,
rmns: Optional[jnp.ndarray] = None,
zmnc: Optional[jnp.ndarray] = None,
lmnc: Optional[jnp.ndarray] = None,
bmns: Optional[jnp.ndarray] = None,
bsubumns: Optional[jnp.ndarray] = None,
bsubvmns: Optional[jnp.ndarray] = None,
surface_indices: Optional[Sequence[int]] = None,
) -> dict:
"""Host-side convenience wrapper for :func:`booz_xform_jax_impl`.
This wrapper computes static constants on the host (NumPy) and
returns a JAX output dictionary. For full JIT, call
:func:`booz_xform_jax_impl` directly with precomputed constants.
"""
constants, grids = prepare_booz_xform_constants(
nfp=nfp,
mboz=mboz,
nboz=nboz,
asym=asym,
xm=xm,
xn=xn,
xm_nyq=xm_nyq,
xn_nyq=xn_nyq,
)
surf_idx = None
if surface_indices is not None:
surf_idx = jnp.asarray(surface_indices, dtype=jnp.int32)
return booz_xform_jax_impl(
rmnc=jnp.asarray(rmnc),
zmns=jnp.asarray(zmns),
lmns=jnp.asarray(lmns),
bmnc=jnp.asarray(bmnc),
bsubumnc=jnp.asarray(bsubumnc),
bsubvmnc=jnp.asarray(bsubvmnc),
iota=jnp.asarray(iota),
xm=jnp.asarray(xm, dtype=jnp.int32),
xn=jnp.asarray(xn, dtype=jnp.int32),
xm_nyq=jnp.asarray(xm_nyq, dtype=jnp.int32),
xn_nyq=jnp.asarray(xn_nyq, dtype=jnp.int32),
constants=constants,
grids=grids,
rmns=jnp.asarray(rmns) if rmns is not None else None,
zmnc=jnp.asarray(zmnc) if zmnc is not None else None,
lmnc=jnp.asarray(lmnc) if lmnc is not None else None,
bmns=jnp.asarray(bmns) if bmns is not None else None,
bsubumns=jnp.asarray(bsubumns) if bsubumns is not None else None,
bsubvmns=jnp.asarray(bsubvmns) if bsubvmns is not None else None,
surface_indices=surf_idx,
)