Source code for booz_xform_jax.jax_api

"""Pure JAX API for end-to-end Boozer transforms.

This module provides a JIT-friendly, functional interface that avoids
Python loops over surfaces and keeps all arrays in JAX. It is intended
for end-to-end differentiation with vmec_jax -> booz_xform_jax -> neo_jax.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence, Tuple
import os

import jax
import jax.numpy as jnp

from .trig import _init_trig


[docs] @dataclass(frozen=True) class BoozXformConstants: """Static constants for the JAX Boozer transform.""" nfp: int mboz: int nboz: int asym: bool ntheta: int nzeta: int nu2_b: int mmax_non: int nmax_non: int mmax_nyq: int nmax_nyq: int
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class BoozXformGrids: """Grid arrays for the JAX Boozer transform.""" theta_grid: jnp.ndarray zeta_grid: jnp.ndarray xm_b: jnp.ndarray xn_b: jnp.ndarray def tree_flatten(self): children = (self.theta_grid, self.zeta_grid, self.xm_b, self.xn_b) return children, None @classmethod def tree_unflatten(cls, aux, children): theta_grid, zeta_grid, xm_b, xn_b = children return cls(theta_grid=theta_grid, zeta_grid=zeta_grid, xm_b=xm_b, xn_b=xn_b)
def _prepare_mode_lists(mboz: int, nboz: int, nfp: int) -> Tuple[jnp.ndarray, jnp.ndarray]: """Prepare Boozer mode indices following the C++/Fortran convention.""" m_list: list[int] = [] n_list: list[int] = [] for m in range(mboz): if m == 0: for n in range(0, nboz + 1): m_list.append(m) n_list.append(n * nfp) else: for n in range(-nboz, nboz + 1): m_list.append(m) n_list.append(n * nfp) return jnp.asarray(m_list, dtype=jnp.int32), jnp.asarray(n_list, dtype=jnp.int32) def _prepare_grids(mboz: int, nboz: int, nfp: int, asym: bool) -> Tuple[int, int, int, jnp.ndarray, jnp.ndarray]: """Prepare flattened (theta, zeta) grids following BOOZ_XFORM conventions.""" ntheta_full = 2 * (2 * mboz + 1) nzeta_full = 2 * (2 * nboz + 1) if nboz > 0 else 1 nu2_b = ntheta_full // 2 + 1 nu3_b = ntheta_full if asym else nu2_b d_theta = (2.0 * jnp.pi) / ntheta_full d_zeta = (2.0 * jnp.pi) / (nfp * nzeta_full) theta_vals = jnp.arange(nu3_b) * d_theta zeta_vals = jnp.arange(nzeta_full) * d_zeta theta_grid = jnp.repeat(theta_vals, nzeta_full) zeta_grid = jnp.tile(zeta_vals, nu3_b) return int(ntheta_full), int(nzeta_full), int(nu2_b), theta_grid, zeta_grid
[docs] def prepare_booz_xform_constants( *, nfp: int, mboz: int, nboz: int, asym: bool, xm: Sequence[int], xn: Sequence[int], xm_nyq: Sequence[int], xn_nyq: Sequence[int], ) -> tuple[BoozXformConstants, BoozXformGrids]: """Compute static constants for the JAX Boozer transform. This helper runs on the host and can be used before JIT compilation. """ xm_arr = jnp.asarray(xm, dtype=jnp.int32) xn_arr = jnp.asarray(xn, dtype=jnp.int32) xm_nyq_arr = jnp.asarray(xm_nyq, dtype=jnp.int32) xn_nyq_arr = jnp.asarray(xn_nyq, dtype=jnp.int32) mmax_non = int(jnp.max(jnp.abs(xm_arr))) nmax_non = int(jnp.max(jnp.abs(xn_arr // nfp))) mmax_nyq = int(jnp.max(jnp.abs(xm_nyq_arr))) nmax_nyq = int(jnp.max(jnp.abs(xn_nyq_arr // nfp))) ntheta, nzeta, nu2_b, theta_grid, zeta_grid = _prepare_grids(mboz, nboz, nfp, asym) xm_b, xn_b = _prepare_mode_lists(mboz, nboz, nfp) constants = BoozXformConstants( nfp=nfp, mboz=mboz, nboz=nboz, asym=asym, ntheta=ntheta, nzeta=nzeta, nu2_b=nu2_b, mmax_non=mmax_non, nmax_non=nmax_non, mmax_nyq=mmax_nyq, nmax_nyq=nmax_nyq, ) grids = BoozXformGrids( theta_grid=theta_grid, zeta_grid=zeta_grid, xm_b=jnp.asarray(xm_b, dtype=jnp.int32), xn_b=jnp.asarray(xn_b, dtype=jnp.int32), ) return constants, grids
[docs] def prepare_booz_xform_constants_from_inputs( *, inputs, mboz: int, nboz: int, asym: bool, ) -> tuple[BoozXformConstants, BoozXformGrids]: """Convenience wrapper using a VMEC -> Boozer input bundle.""" return prepare_booz_xform_constants( nfp=int(jnp.asarray(inputs.nfp)), mboz=int(mboz), nboz=int(nboz), asym=bool(asym), xm=jnp.asarray(inputs.xm), xn=jnp.asarray(inputs.xn), xm_nyq=jnp.asarray(inputs.xm_nyq), xn_nyq=jnp.asarray(inputs.xn_nyq), )
def _surface_transform( rmnc: jnp.ndarray, rmns: jnp.ndarray, zmnc: jnp.ndarray, zmns: jnp.ndarray, lmnc: jnp.ndarray, lmns: jnp.ndarray, bmnc: jnp.ndarray, bsubumnc: jnp.ndarray, bsubvmnc: jnp.ndarray, iota: jnp.ndarray, *, constants: BoozXformConstants, grids: BoozXformGrids, tcos_non: jnp.ndarray, tsin_non: jnp.ndarray, tcos_nyq: jnp.ndarray, tsin_nyq: jnp.ndarray, m_non_f: jnp.ndarray, n_non_f: jnp.ndarray, m_nyq_f: jnp.ndarray, n_nyq_f: jnp.ndarray, idx_theta0: jnp.ndarray, idx_thetapi: jnp.ndarray, m_b: jnp.ndarray, abs_n_b: jnp.ndarray, sign_b: jnp.ndarray, bmns: Optional[jnp.ndarray] = None, bsubumns: Optional[jnp.ndarray] = None, bsubvmns: Optional[jnp.ndarray] = None, fourier_mode: str = "vectorized", trig_f32: bool = False, ) -> Tuple[jnp.ndarray, ...]: """Compute Boozer spectra for a single surface.""" nfp = constants.nfp theta_grid = grids.theta_grid zeta_grid = grids.zeta_grid # Boozer I/G from m=n=0 Nyquist mode idx00 = jnp.where((m_nyq_f == 0) & (n_nyq_f == 0), size=1)[0][0] Boozer_I = bsubumnc[idx00] Boozer_G = bsubvmnc[idx00] # w spectrum from B_theta and B_zeta. Safe denominators avoid inf/NaN # tangents at m=n=0 when this kernel is differentiated. m_nonzero = m_nyq_f != 0.0 n_nonzero_only = jnp.logical_and(~m_nonzero, n_nyq_f != 0.0) m_nyq_safe = jnp.where(m_nonzero, m_nyq_f, 1.0) n_nyq_safe = jnp.where(n_nonzero_only, n_nyq_f, 1.0) wmns = jnp.where( m_nonzero, bsubumnc / m_nyq_safe, jnp.where(n_nonzero_only, -bsubvmnc / n_nyq_safe, 0.0), ) if constants.asym and bsubumns is not None and bsubvmns is not None: wmnc = jnp.where( m_nonzero, -bsubumns / m_nyq_safe, jnp.where(n_nonzero_only, bsubvmns / n_nyq_safe, 0.0), ) else: wmnc = None # Non-Nyquist R, Z, lambda and derivatives r = jnp.einsum("ij,j->i", tcos_non, rmnc) z = jnp.einsum("ij,j->i", tsin_non, zmns) lam = jnp.einsum("ij,j->i", tsin_non, lmns) dlam_dth = jnp.einsum("ij,j->i", tcos_non, lmns * m_non_f) dlam_dze = -jnp.einsum("ij,j->i", tcos_non, lmns * n_non_f) if constants.asym: r = r + jnp.einsum("ij,j->i", tsin_non, rmns) z = z + jnp.einsum("ij,j->i", tcos_non, zmnc) lam = lam + jnp.einsum("ij,j->i", tcos_non, lmnc) dlam_dth = dlam_dth - jnp.einsum("ij,j->i", tsin_non, lmnc * m_non_f) dlam_dze = dlam_dze + jnp.einsum("ij,j->i", tsin_non, lmnc * n_non_f) # Nyquist w, derivatives, and |B| w = jnp.einsum("ij,j->i", tsin_nyq, wmns) dw_dth = jnp.einsum("ij,j->i", tcos_nyq, wmns * m_nyq_f) dw_dze = -jnp.einsum("ij,j->i", tcos_nyq, wmns * n_nyq_f) bmod = jnp.einsum("ij,j->i", tcos_nyq, bmnc) if constants.asym and wmnc is not None and bmns is not None: w = w + jnp.einsum("ij,j->i", tcos_nyq, wmnc) dw_dth = dw_dth - jnp.einsum("ij,j->i", tsin_nyq, wmnc * m_nyq_f) dw_dze = dw_dze + jnp.einsum("ij,j->i", tsin_nyq, wmnc * n_nyq_f) bmod = bmod + jnp.einsum("ij,j->i", tsin_nyq, bmns) # Boozer angles and derivatives GI = Boozer_G + iota * Boozer_I one_over_GI = 1.0 / GI nu = one_over_GI * (w - Boozer_I * lam) theta_B = theta_grid + lam + iota * nu zeta_B = zeta_grid + nu dnu_dze = one_over_GI * (dw_dze - Boozer_I * dlam_dze) dnu_dth = one_over_GI * (dw_dth - Boozer_I * dlam_dth) dB_dvmec = (1.0 + dlam_dth) * (1.0 + dnu_dze) + (iota - dlam_dze) * dnu_dth # Boozer trig tables on (theta_B, zeta_B) cosm_b, sinm_b, cosn_b, sinn_b = _init_trig( theta_B, zeta_B, constants.mboz, constants.nboz, nfp ) if trig_f32: cosm_b = cosm_b.astype(jnp.float32) sinm_b = sinm_b.astype(jnp.float32) cosn_b = cosn_b.astype(jnp.float32) sinn_b = sinn_b.astype(jnp.float32) if not constants.asym: cosm_b = cosm_b.at[idx_theta0, :].set(cosm_b[idx_theta0, :] * 0.5) cosm_b = cosm_b.at[idx_thetapi, :].set(cosm_b[idx_thetapi, :] * 0.5) sinm_b = sinm_b.at[idx_theta0, :].set(sinm_b[idx_theta0, :] * 0.5) sinm_b = sinm_b.at[idx_thetapi, :].set(sinm_b[idx_thetapi, :] * 0.5) boozer_jac = GI / (bmod * bmod) if constants.asym: fourier_factor0 = 2.0 / (constants.ntheta * constants.nzeta) else: fourier_factor0 = 2.0 / ((constants.nu2_b - 1) * constants.nzeta) fourier_factor = jnp.ones((m_b.shape[0],), dtype=jnp.float64) * fourier_factor0 fourier_factor = fourier_factor.at[0].set(fourier_factor0 * 0.5) if fourier_mode == "streamed": base_b = bmod * dB_dvmec base_r = r * dB_dvmec base_z = z * dB_dvmec base_nu = nu * dB_dvmec base_g = boozer_jac * dB_dvmec m_b_f = m_b abs_n_b_f = abs_n_b sign_b_f = jnp.reshape(sign_b, (-1,)) def init_out(): zeros = jnp.zeros((m_b_f.shape[0],), dtype=base_b.dtype) return zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros def body(k, state): ( bmnc_b, bmns_b, rmnc_b, rmns_b, zmnc_b, zmns_b, numnc_b, numns_b, gmnc_b, gmns_b, ) = state m_idx = m_b_f[k] n_idx = abs_n_b_f[k] sign = sign_b_f[k] cosm = jax.lax.dynamic_index_in_dim(cosm_b, m_idx, axis=1, keepdims=False) sinm = jax.lax.dynamic_index_in_dim(sinm_b, m_idx, axis=1, keepdims=False) cosn = jax.lax.dynamic_index_in_dim(cosn_b, n_idx, axis=1, keepdims=False) sinn = jax.lax.dynamic_index_in_dim(sinn_b, n_idx, axis=1, keepdims=False) tcos = cosm * cosn + sinm * sinn * sign tsin = sinm * cosn - cosm * sinn * sign ff = fourier_factor[k] bmnc_b = bmnc_b.at[k].set(ff * jnp.sum(tcos * base_b)) rmnc_b = rmnc_b.at[k].set(ff * jnp.sum(tcos * base_r)) zmns_b = zmns_b.at[k].set(ff * jnp.sum(tsin * base_z)) numns_b = numns_b.at[k].set(ff * jnp.sum(tsin * base_nu)) gmnc_b = gmnc_b.at[k].set(ff * jnp.sum(tcos * base_g)) if constants.asym: bmns_b = bmns_b.at[k].set(ff * jnp.sum(tsin * base_b)) rmns_b = rmns_b.at[k].set(ff * jnp.sum(tsin * base_r)) zmnc_b = zmnc_b.at[k].set(ff * jnp.sum(tcos * base_z)) numnc_b = numnc_b.at[k].set(ff * jnp.sum(tcos * base_nu)) gmns_b = gmns_b.at[k].set(ff * jnp.sum(tsin * base_g)) return ( bmnc_b, bmns_b, rmnc_b, rmns_b, zmnc_b, zmns_b, numnc_b, numns_b, gmnc_b, gmns_b, ) ( bmnc_b, bmns_b, rmnc_b, rmns_b, zmnc_b, zmns_b, numnc_b, numns_b, gmnc_b, gmns_b, ) = jax.lax.fori_loop( 0, m_b_f.shape[0], body, init_out() ) else: cosm_b_m = jnp.take(cosm_b, m_b, axis=1) sinm_b_m = jnp.take(sinm_b, m_b, axis=1) cosn_b_n = jnp.take(cosn_b, abs_n_b, axis=1) sinn_b_n = jnp.take(sinn_b, abs_n_b, axis=1) tcos_modes = cosm_b_m * cosn_b_n + sinm_b_m * sinn_b_n * sign_b tsin_modes = sinm_b_m * cosn_b_n - cosm_b_m * sinn_b_n * sign_b base_b = bmod * dB_dvmec base_r = r * dB_dvmec base_z = z * dB_dvmec base_nu = nu * dB_dvmec base_g = boozer_jac * dB_dvmec def project_cos(field: jnp.ndarray) -> jnp.ndarray: return fourier_factor * jnp.einsum("ij,i->j", tcos_modes, field) def project_sin(field: jnp.ndarray) -> jnp.ndarray: return fourier_factor * jnp.einsum("ij,i->j", tsin_modes, field) bmnc_b = project_cos(base_b) rmnc_b = project_cos(base_r) zmns_b = project_sin(base_z) numns_b = project_sin(base_nu) gmnc_b = project_cos(base_g) zeros = jnp.zeros_like(bmnc_b) if constants.asym: bmns_b = project_sin(base_b) rmns_b = project_sin(base_r) zmnc_b = project_cos(base_z) numnc_b = project_cos(base_nu) gmns_b = project_sin(base_g) else: bmns_b = zeros rmns_b = zeros zmnc_b = zeros numnc_b = zeros gmns_b = zeros return ( bmnc_b, bmns_b, rmnc_b, rmns_b, zmnc_b, zmns_b, numnc_b, numns_b, gmnc_b, gmns_b, Boozer_I, Boozer_G, )
[docs] def booz_xform_jax_impl( rmnc: jnp.ndarray, zmns: jnp.ndarray, lmns: jnp.ndarray, bmnc: jnp.ndarray, bsubumnc: jnp.ndarray, bsubvmnc: jnp.ndarray, iota: jnp.ndarray, *, xm: jnp.ndarray, xn: jnp.ndarray, xm_nyq: jnp.ndarray, xn_nyq: jnp.ndarray, constants: BoozXformConstants, grids: BoozXformGrids, rmns: Optional[jnp.ndarray] = None, zmnc: Optional[jnp.ndarray] = None, lmnc: Optional[jnp.ndarray] = None, bmns: Optional[jnp.ndarray] = None, bsubumns: Optional[jnp.ndarray] = None, bsubvmns: Optional[jnp.ndarray] = None, surface_indices: Optional[jnp.ndarray] = None, ) -> dict: """JAX-native Boozer transform over all (or selected) surfaces. All inputs must be JAX arrays with surface dimension first, i.e. shape (ns, mn_non) for non-Nyquist arrays and (ns, mn_nyq) for Nyquist arrays. """ ns_b_full = int(rmnc.shape[0]) if surface_indices is not None: rmnc = jnp.take(rmnc, surface_indices, axis=0) zmns = jnp.take(zmns, surface_indices, axis=0) lmns = jnp.take(lmns, surface_indices, axis=0) bmnc = jnp.take(bmnc, surface_indices, axis=0) bsubumnc = jnp.take(bsubumnc, surface_indices, axis=0) bsubvmnc = jnp.take(bsubvmnc, surface_indices, axis=0) iota = jnp.take(iota, surface_indices, axis=0) if rmns is not None: rmns = jnp.take(rmns, surface_indices, axis=0) if zmnc is not None: zmnc = jnp.take(zmnc, surface_indices, axis=0) if lmnc is not None: lmnc = jnp.take(lmnc, surface_indices, axis=0) if bmns is not None: bmns = jnp.take(bmns, surface_indices, axis=0) if bsubumns is not None: bsubumns = jnp.take(bsubumns, surface_indices, axis=0) if bsubvmns is not None: bsubvmns = jnp.take(bsubvmns, surface_indices, axis=0) xm_non_j = jnp.asarray(xm, dtype=jnp.int32) xn_non_j = jnp.asarray(xn, dtype=jnp.int32) xm_nyq_j = jnp.asarray(xm_nyq, dtype=jnp.int32) xn_nyq_j = jnp.asarray(xn_nyq, dtype=jnp.int32) fourier_mode = os.getenv("BOOZ_XFORM_JAX_FOURIER_MODE", "vectorized").strip().lower() if fourier_mode not in {"vectorized", "streamed"}: raise ValueError(f"Unsupported BOOZ_XFORM_JAX_FOURIER_MODE '{fourier_mode}'") trig_f32 = os.getenv("BOOZ_XFORM_JAX_TRIG_F32", "0").strip().lower() in {"1", "true", "yes", "on"} # Precompute trig tables and mode combinations once for all surfaces. cosm, sinm, cosn, sinn = _init_trig( grids.theta_grid, grids.zeta_grid, constants.mmax_non, constants.nmax_non, constants.nfp ) cosm_nyq, sinm_nyq, cosn_nyq, sinn_nyq = _init_trig( grids.theta_grid, grids.zeta_grid, constants.mmax_nyq, constants.nmax_nyq, constants.nfp ) if trig_f32: cosm = cosm.astype(jnp.float32) sinm = sinm.astype(jnp.float32) cosn = cosn.astype(jnp.float32) sinn = sinn.astype(jnp.float32) cosm_nyq = cosm_nyq.astype(jnp.float32) sinm_nyq = sinm_nyq.astype(jnp.float32) cosn_nyq = cosn_nyq.astype(jnp.float32) sinn_nyq = sinn_nyq.astype(jnp.float32) cosm_m_non = jnp.take(cosm, xm_non_j, axis=1) sinm_m_non = jnp.take(sinm, xm_non_j, axis=1) abs_n_non = jnp.abs(xn_non_j // constants.nfp) cosn_n_non = jnp.take(cosn, abs_n_non, axis=1) sinn_n_non = jnp.take(sinn, abs_n_non, axis=1) sign_non = jnp.where(xn_non_j < 0, -1.0, 1.0)[None, :] tcos_non = cosm_m_non * cosn_n_non + sinm_m_non * sinn_n_non * sign_non tsin_non = sinm_m_non * cosn_n_non - cosm_m_non * sinn_n_non * sign_non m_non_f = xm_non_j.astype(jnp.float64) n_non_f = xn_non_j.astype(jnp.float64) cosm_m_nyq = jnp.take(cosm_nyq, xm_nyq_j, axis=1) sinm_m_nyq = jnp.take(sinm_nyq, xm_nyq_j, axis=1) abs_n_nyq = jnp.abs(xn_nyq_j // constants.nfp) cosn_n_nyq = jnp.take(cosn_nyq, abs_n_nyq, axis=1) sinn_n_nyq = jnp.take(sinn_nyq, abs_n_nyq, axis=1) sign_nyq = jnp.where(xn_nyq_j < 0, -1.0, 1.0)[None, :] tcos_nyq = cosm_m_nyq * cosn_n_nyq + sinm_m_nyq * sinn_n_nyq * sign_nyq tsin_nyq = sinm_m_nyq * cosn_n_nyq - cosm_m_nyq * sinn_n_nyq * sign_nyq m_nyq_f = xm_nyq_j.astype(jnp.float64) n_nyq_f = xn_nyq_j.astype(jnp.float64) idx_theta0 = jnp.arange(0, constants.nzeta) idx_thetapi = jnp.arange( (constants.nu2_b - 1) * constants.nzeta, constants.nu2_b * constants.nzeta ) m_b = grids.xm_b abs_n_b = jnp.abs(grids.xn_b // constants.nfp) sign_b = jnp.where(grids.xn_b < 0, -1.0, 1.0)[None, :] def _surf( _rmnc, _rmns, _zmnc, _zmns, _lmnc, _lmns, _bmnc, _bsubumnc, _bsubvmnc, _iota, _bmns, _bsubumns, _bsubvmns, ): return _surface_transform( _rmnc, _rmns, _zmnc, _zmns, _lmnc, _lmns, _bmnc, _bsubumnc, _bsubvmnc, _iota, constants=constants, grids=grids, tcos_non=tcos_non, tsin_non=tsin_non, tcos_nyq=tcos_nyq, tsin_nyq=tsin_nyq, m_non_f=m_non_f, n_non_f=n_non_f, m_nyq_f=m_nyq_f, n_nyq_f=n_nyq_f, idx_theta0=idx_theta0, idx_thetapi=idx_thetapi, m_b=m_b, abs_n_b=abs_n_b, sign_b=sign_b, bmns=_bmns, bsubumns=_bsubumns, bsubvmns=_bsubvmns, fourier_mode=fourier_mode, trig_f32=trig_f32, ) vmap_fn = jax.vmap(_surf) rmns_in = rmns if rmns is not None else jnp.zeros_like(rmnc) zmnc_in = zmnc if zmnc is not None else jnp.zeros_like(zmns) lmnc_in = lmnc if lmnc is not None else jnp.zeros_like(lmns) bmns_in = bmns if bmns is not None else jnp.zeros_like(bmnc) bsubumns_in = bsubumns if bsubumns is not None else jnp.zeros_like(bsubumnc) bsubvmns_in = bsubvmns if bsubvmns is not None else jnp.zeros_like(bsubvmnc) ( bmnc_b, bmns_b, rmnc_b, rmns_b, zmnc_b, zmns_b, numnc_b, numns_b, gmnc_b, gmns_b, Boozer_I, Boozer_G, ) = vmap_fn( rmnc, rmns_in, zmnc_in, zmns, lmnc_in, lmns, bmnc, bsubumnc, bsubvmnc, iota, bmns_in, bsubumns_in, bsubvmns_in, ) ns_b = bmnc_b.shape[0] if surface_indices is None: jlist = jnp.arange(2, ns_b + 2) else: jlist = surface_indices + 2 return { "nfp_b": jnp.asarray(constants.nfp), "ns_b": jnp.asarray(ns_b_full), "ixm_b": jnp.asarray(grids.xm_b), "ixn_b": jnp.asarray(grids.xn_b), "iota_b": iota, "buco_b": Boozer_I, "bvco_b": Boozer_G, "rmnc_b": rmnc_b, "rmns_b": rmns_b, "zmnc_b": zmnc_b, "zmns_b": zmns_b, "numnc_b": numnc_b, "numns_b": numns_b, "pmnc_b": -numnc_b, "pmns_b": -numns_b, "bmnc_b": bmnc_b, "bmns_b": bmns_b, "gmnc_b": gmnc_b, "gmns_b": gmns_b, # BOOZ_XFORM/netCDF-compatible spelling for the Jacobian harmonics. "gmn_b": gmnc_b, "jlist": jlist, }
[docs] def booz_xform_from_inputs( *, inputs, constants: BoozXformConstants, grids: BoozXformGrids, surface_indices: Optional[jnp.ndarray] = None, jit: bool = True, ) -> dict: """Run the JAX Boozer transform using a VMEC -> Boozer input bundle.""" booz_fn = booz_xform_jax_impl if jit: booz_fn = jax.jit(booz_xform_jax_impl, static_argnames=("constants",)) return booz_fn( rmnc=inputs.rmnc, zmns=inputs.zmns, lmns=inputs.lmns, bmnc=inputs.bmnc, bsubumnc=inputs.bsubumnc, bsubvmnc=inputs.bsubvmnc, iota=inputs.iota, xm=inputs.xm, xn=inputs.xn, xm_nyq=inputs.xm_nyq, xn_nyq=inputs.xn_nyq, constants=constants, grids=grids, rmns=getattr(inputs, "rmns", None), zmnc=getattr(inputs, "zmnc", None), lmnc=getattr(inputs, "lmnc", None), bmns=getattr(inputs, "bmns", None), bsubumns=getattr(inputs, "bsubumns", None), bsubvmns=getattr(inputs, "bsubvmns", None), surface_indices=surface_indices, )
[docs] def booz_xform_jax( *, rmnc: jnp.ndarray, zmns: jnp.ndarray, lmns: jnp.ndarray, bmnc: jnp.ndarray, bsubumnc: jnp.ndarray, bsubvmnc: jnp.ndarray, iota: jnp.ndarray, xm: Sequence[int], xn: Sequence[int], xm_nyq: Sequence[int], xn_nyq: Sequence[int], nfp: int, mboz: int, nboz: int, asym: bool = False, rmns: Optional[jnp.ndarray] = None, zmnc: Optional[jnp.ndarray] = None, lmnc: Optional[jnp.ndarray] = None, bmns: Optional[jnp.ndarray] = None, bsubumns: Optional[jnp.ndarray] = None, bsubvmns: Optional[jnp.ndarray] = None, surface_indices: Optional[Sequence[int]] = None, ) -> dict: """Host-side convenience wrapper for :func:`booz_xform_jax_impl`. This wrapper computes static constants on the host (NumPy) and returns a JAX output dictionary. For full JIT, call :func:`booz_xform_jax_impl` directly with precomputed constants. """ constants, grids = prepare_booz_xform_constants( nfp=nfp, mboz=mboz, nboz=nboz, asym=asym, xm=xm, xn=xn, xm_nyq=xm_nyq, xn_nyq=xn_nyq, ) surf_idx = None if surface_indices is not None: surf_idx = jnp.asarray(surface_indices, dtype=jnp.int32) return booz_xform_jax_impl( rmnc=jnp.asarray(rmnc), zmns=jnp.asarray(zmns), lmns=jnp.asarray(lmns), bmnc=jnp.asarray(bmnc), bsubumnc=jnp.asarray(bsubumnc), bsubvmnc=jnp.asarray(bsubvmnc), iota=jnp.asarray(iota), xm=jnp.asarray(xm, dtype=jnp.int32), xn=jnp.asarray(xn, dtype=jnp.int32), xm_nyq=jnp.asarray(xm_nyq, dtype=jnp.int32), xn_nyq=jnp.asarray(xn_nyq, dtype=jnp.int32), constants=constants, grids=grids, rmns=jnp.asarray(rmns) if rmns is not None else None, zmnc=jnp.asarray(zmnc) if zmnc is not None else None, lmnc=jnp.asarray(lmnc) if lmnc is not None else None, bmns=jnp.asarray(bmns) if bmns is not None else None, bsubumns=jnp.asarray(bsubumns) if bsubumns is not None else None, bsubvmns=jnp.asarray(bsubvmns) if bsubvmns is not None else None, surface_indices=surf_idx, )