JAX Geometry And Jacobian Plan
Purpose
This page records the implementation plan for making booz_xform_jax a
stable differentiable geometry provider for downstream codes such as FAX, while
remaining useful for any code that needs Boozer harmonics, Boozer-coordinate
geometry, or Jacobian sensitivities.
The immediate downstream contract is a pure-JAX transform that returns
surface-major arrays with shape (ns_selected, mnboz) for Boozer spectra and
keeps BOOZ_XFORM-compatible variable names. The key Jacobian output is:
gmnc_b: cosine harmonics of the Boozer Jacobian-related quantity.gmns_b: sine harmonics of the same quantity whenlasymis true.gmn_b: compatibility alias matching the BOOZ_XFORM/netCDF variable name.
Physics Contract
The Boozer-coordinate covariant representation uses flux functions \(I(s)\) and \(G(s)\). Following the BOOZ_XFORM derivation used by the HiddenSymmetries documentation, the transform forms
The JAX kernel reconstructs \(|B|\), the VMEC-to-Boozer angle shift, and the coordinate-transformation factor on an angular grid, then projects the weighted fields onto Boozer Fourier modes. For the Jacobian harmonics, the implemented projection is the same one already used by the reference Python path:
where \(C_{mn}\) is the BOOZ_XFORM quadrature normalization and \(J_{\mathrm{VMEC}\rightarrow B}\) is the angular coordinate factor used for all Boozer-space projections in this package.
Milestones
Public Jacobian harmonics Expose the already-computed
gmnc_barray from the JAX API and providegmn_bas a BOOZ_XFORM-compatible alias. This is the first hard contract needed by FAX continuum and mode-structure operators.Stable differentiable API Keep
booz_xform_jax_implas the low-level primitive for composed JAX programs. Geometry constants and mode lists remain static inputs, while VMEC spectra, current profiles, and rotational transform arrays remain differentiable numerical inputs.Jacobian access patterns Support direct scalar objectives with
jax.gradandjax.value_and_grad. For large geometry-to-physics maps, prefer matrix-free products throughjax.jvp,jax.vjp, andjax.linearizerather than materializing dense Jacobians. Usejax.jacfwdorjax.jacrevonly when the output/input aspect ratio makes dense Jacobians reasonable.Performance modes Preserve the default vectorized Fourier projection for speed and the
BOOZ_XFORM_JAX_FOURIER_MODE=streamedpath for lower memory. Both paths must produce the same Jacobian harmonics within regression tolerances.Validation gates Every geometry output used by downstream codes should have:
parity against
Booz_xform.runon bundled VMEC cases,vectorized-versus-streamed parity,
JIT-versus-non-JIT parity where practical,
finite-gradient tests through representative scalar objectives,
NetCDF name compatibility checks for
gmn_b.
Full asymmetric-output expansion The JAX path accepts asymmetric VMEC inputs
rmns,zmnc,lmnc,bmns,bsubumns, andbsubvmns. Whenlasymis true it exposesbmns_b,rmns_b,zmnc_b,numnc_b,pmnc_b, andgmns_bin addition to the symmetric spectra. The validation suite compares these arrays against the bundled asymmetric BOOZ_XFORM reference file and checks vectorized, streamed, JIT, and autodiff paths.
Downstream FAX Contract
FAX and other spectral MHD tools should consume the JAX output dictionary without relying on object attributes. The minimum stable keys are:
ixm_bandixn_bfor Boozer mode numbers,iota_b,buco_b, andbvco_bfor flux functions,bmnc_bfor \(|B|\),rmnc_bandzmns_bfor stellarator-symmetric geometry,bmns_b,rmns_b, andzmnc_bwhen asymmetric spectra are present,pmns_bandpmnc_bfor the legacy stored toroidal-angle shift,gmnc_borgmn_bplusgmns_bfor Jacobian harmonics,jlistfor selected 1-based full-grid surface indices, using the BOOZ_XFORM full-grid conventioncompute_surfs + 2.
The shape convention is intentionally surface-major in JAX outputs. Writers and legacy BOOZ_XFORM files may transpose this layout for file compatibility.
Implementation Notes
Keep numerical work in
jax.numpyandjax.laxprimitives sojit,grad,jvp,vjp, and batching remain valid.Mark only small configuration objects as static in
jax.jit. JAX recompiles for new static values, so static arguments should be mode lists, grid sizes, and constants rather than frequently changing spectra.Treat dense Jacobian formation as a diagnostic path, not the default for production optimization. FAX objectives should generally use scalar losses with reverse-mode gradients or matrix-free JVP/VJP products.
Keep tests physics based: compare spectral coefficients and transformation identities, not only array shapes.
Guard differentiable divisions in the auxiliary
wspectrum with safe denominators atm=n=0so covariance-field gradients remain finite.
References
STELLOPT BOOZ_XFORM documentation records the BOOZ_XFORM output variables, including
gmn_b.HiddenSymmetries BOOZ_XFORM theory notes describe the Boozer-coordinate equations and Jacobian relation used here.
JAX automatic differentiation API documents
grad,jacfwd,jacrev,jvp,vjp, andlinearize.JAX forward- and reverse-mode autodiff guide explains the computational tradeoffs between JVPs, VJPs, and dense Jacobians.
JAX JIT compilation guide documents when static arguments are appropriate and why they affect recompilation.