Source code for booz_xform_jax.core

"""
Core classes for the JAX implementation of ``booz_xform``.

This module defines the :class:`Booz_xform` class, which is the primary
interface for converting Fourier data from a VMEC equilibrium
(spectral representation in VMEC coordinates) to a spectral
representation in Boozer coordinates.

Pedagogical overview
====================

**What problem are we solving?**

Given a VMEC MHD equilibrium we know (on a set of half-grid radial
surfaces) its Fourier representation in angles (θ, ζ):

  * Geometry: R(θ, ζ, s), Z(θ, ζ, s) and the poloidal angle shift
    λ(θ, ζ, s),
  * Magnetic‐field strength and covariant components:
    |B|(θ, ζ, s), B_θ(θ, ζ, s), B_ζ(θ, ζ, s).

The goal of BOOZ_XFORM is to construct a spectral representation of
the same equilibrium but in **Boozer angles** (θ_B, ζ_B), where the
magnetic field lines are straight and the contravariant components of
B take a particularly simple form. The result is stored as
Fourier coefficients B_{m,n}(s), R_{m,n}(s), Z_{m,n}(s), ν_{m,n}(s),
and Jacobian harmonics on a chosen subset of radial surfaces.

**High-level algorithm (per radial surface)**

For each selected radial surface, the core algorithm follows the
original C++ / Fortran implementation closely:

 1. Build a tensor-product grid in VMEC angles (θ, ζ) and flatten it
    to a vector of length N = N_θ × N_ζ.

 2. Using the *non-Nyquist* VMEC spectrum, synthesise:
        - R(θ, ζ), Z(θ, ζ), λ(θ, ζ),
        - ∂λ/∂θ, ∂λ/∂ζ.

 3. Using the *Nyquist* VMEC spectrum, construct:
        - an auxiliary function w(θ, ζ),
        - its derivatives ∂w/∂θ, ∂w/∂ζ,
        - |B|(θ, ζ).

 4. From the Nyquist spectra of B_θ and B_ζ, recover the Boozer
    profiles I(s) and G(s) and the auxiliary Nyquist spectrum of w.

 5. Compute the “field-line label” ν(θ, ζ) from equation (10) of the
    BOOZ_XFORM theory, and then the Boozer angles:
        θ_B = θ + λ + ι ν,
        ζ_B = ζ + ν,
    where ι is the rotational transform on this surface.

 6. From the derivatives of w and λ, construct ∂ν/∂θ and ∂ν/∂ζ and
    hence the factor dB/d(vmec) appearing in the Fourier integrals.

 7. On the (θ_B, ζ_B) grid, precompute trigonometric tables and
    perform the 2D Fourier integrals that define the Boozer
    coefficients B_{m,n}, R_{m,n}, Z_{m,n}, ν_{m,n} and the
    Boozer-Jacobian harmonics.

This module provides a **vectorised, JAX-based implementation** of the
above steps. The main performance principles are:

  * Precompute trigonometric tables on the (θ, ζ) grid once per run.
  * Hoist all per-mode cos/sin combinations that do not depend on the
    surface index out of the radial loop.
  * Replace explicit Python loops over Fourier modes by
    `jax.numpy.einsum` and broadcasting.
  * Keep the outer loop over radial surfaces in Python (a typical
    equilibrium has tens of surfaces, whereas the number of grid
    points N can be in the thousands, so most work is still inside
    JAX kernels).

Public API
==========

The external API mirrors the original BOOZ_XFORM library:

  * Create an instance of :class:`Booz_xform`.
  * Call :meth:`read_wout` or :meth:`init_from_vmec` to populate
    VMEC data.
  * Optionally call :meth:`register_surfaces` to select a subset of
    radial surfaces (by index or by normalised toroidal flux s).
  * Call :meth:`run()` to perform the Boozer transform. The resulting
    Boozer spectra and profiles are stored on the instance
    (``bmnc_b``, ``rmnc_b``, ``zmns_b``, ``numns_b``, ``gmnc_b``,
    etc., plus Boozer I/G and the chosen radial grid ``s_b``).
  * Use :meth:`write_boozmn` / :meth:`read_boozmn` and plotting helpers
    (defined in other modules) as in the original code.

This file is deliberately **pedagogical**: in addition to the
performance-oriented vectorisation, it includes detailed comments
explaining each mathematical step and its relationship to the
published BOOZ_XFORM theory and to the original implementation.
"""

from __future__ import annotations

import numpy as _np
from dataclasses import dataclass, field
from typing import Iterable, List, Optional

try:
    import jax
    import jax.numpy as jnp

    # The original BOOZ_XFORM (and VMEC) use double precision
    # throughout. We enable 64-bit mode globally so that JAX matches
    # the reference implementation and regression tests can compare
    # against double-precision reference outputs.
    from jax import config as _jax_config
    _jax_config.update("jax_enable_x64", True)
except ImportError as e:  # pragma: no cover
    raise ImportError(
        "The booz_xform_jax package requires JAX. Please install jax and "
        "jaxlib before using this module."
    ) from e

from .vmec import init_from_vmec, read_wout, read_wout_data
from .io_utils import write_boozmn, read_boozmn
from .jax_api import booz_xform_jax_impl, prepare_booz_xform_constants
from .trig import _init_trig, _init_trig_np


# -----------------------------------------------------------------------------
# Trigonometric table helper
# -----------------------------------------------------------------------------


# -----------------------------------------------------------------------------
# Main Booz_xform class
# -----------------------------------------------------------------------------


[docs] @dataclass class Booz_xform: """ Class implementing the Boozer coordinate transformation using JAX. Instances of :class:`Booz_xform` encapsulate all data required to convert the spectral representation of a VMEC equilibrium (in VMEC angles) to a spectral representation in Boozer coordinates. Typical usage ------------- >>> bx = Booz_xform() >>> bx.read_wout("wout_mycase.nc", flux=True) # or init_from_vmec(...) >>> bx.register_surfaces([0.2, 0.5, 0.8]) # select surfaces in s-space >>> bx.run() >>> bx.write_boozmn("boozmn_mycase.nc") After :meth:`run` completes, the Boozer spectra are stored in attributes like ``bmnc_b``, ``rmnc_b``, ``zmns_b``, etc., and the Boozer I/G profiles and radial grid on ``Boozer_I``, ``Boozer_G``, and ``s_b``. Attributes ---------- nfp : int Field periodicity (number of field periods) of the equilibrium. asym : bool Whether the VMEC equilibrium is non-stellarator-symmetric. If ``False``, only the symmetric Fourier coefficients are used (cosine/sine combinations that respect stellarator symmetry). If ``True``, additional “ns” arrays are populated and used. verbose : int or bool Controls diagnostic printing during :meth:`run`. Historically this was an integer (0, 1, 2, …). In this implementation any truthy value enables basic per-surface diagnostics; setting ``verbose > 1`` prints additional information. mpol, ntor : int Maximum poloidal and toroidal mode numbers in the non-Nyquist VMEC spectrum, read from the wout file. mnmax : int Total number of *non-Nyquist* VMEC Fourier modes. For symmetric equilibria this is typically ``mpol * (2*ntor + 1)``. xm, xn : ndarray of int, shape (mnmax,) Mode list for the non-Nyquist VMEC spectrum: poloidal and toroidal mode numbers (with xn stored as :math:`n n_{fp}` to match VMEC conventions). xm_nyq, xn_nyq : ndarray of int Mode list for the Nyquist spectrum used to reconstruct w and |B|. Sizes and ranges mirror those in the original BOOZ_XFORM. mpol_nyq, ntor_nyq, mnmax_nyq : int Nyquist resolutions and total number of Nyquist modes in the VMEC input, read from the wout file. s_in : ndarray, shape (ns_in,) Radial coordinate values on the VMEC half grid (excluding the magnetic axis). This is stored as a NumPy array (host side) so that we can use standard Python indexing and ``numpy.argmin`` when mapping floating-point s values to nearest indices. iota : jax.numpy.ndarray, shape (ns_in,) Rotational transform on the VMEC half grid. rmnc, rmns, zmnc, zmns, lmnc, lmns : jax.numpy.ndarray Non-Nyquist VMEC Fourier coefficients on the half grid, with dimensions ``(mnmax, ns_in)``. Asymmetric quantities are set to ``None`` when ``asym`` is ``False``. bmnc, bmns, bsubumnc, bsubumns, bsubvmnc, bsubvmns : jax.numpy.ndarray Nyquist VMEC Fourier coefficients on the half grid, with dimensions ``(mnmax_nyq, ns_in)``. Asymmetric quantities are set to ``None`` when ``asym`` is ``False``. These are used to reconstruct |B| and the covariant components B_θ, B_ζ. Boozer_I_all, Boozer_G_all : ndarray, shape (ns_in,) Boozer I(s) and G(s) profiles on the full half grid. These correspond to the m=0, n=0 components of ``bsubumnc`` and ``bsubvmnc`` and are stored as NumPy arrays. phip, chi, pres, phi : jax.numpy.ndarray, shape (ns_in,), optional Optional radial profiles read from the VMEC file when the ``flux`` flag is passed to :meth:`read_wout`. They are not used directly in the Boozer transform but are convenient to have available for post-processing. aspect : float Aspect ratio of the equilibrium (copied from VMEC). toroidal_flux : float Total toroidal flux of the equilibrium (copied from VMEC). compute_surfs : list[int] or None Indices of the half-grid surfaces on which to compute the Boozer transform. Indices run from 0 to ``ns_in-1``. ``None`` (default) means “all surfaces”. s_b : ndarray, shape (ns_b,) Radial coordinate values on the subset of surfaces selected by ``compute_surfs``. Populated by :meth:`run` and :meth:`read_boozmn`. ns_in : int Number of half-grid surfaces (excluding the axis) in the VMEC input. ns_b : int Number of surfaces selected for the Boozer transform (i.e. ``len(compute_surfs)``). Boozer_I, Boozer_G : ndarray, shape (ns_b,) Boozer I and G profiles restricted to the selected surfaces. mboz, nboz : int Maximum poloidal and toroidal mode numbers in the *Boozer* spectrum. If not explicitly set by the user, these default to ``mpol`` and ``ntor`` respectively (mirroring the original BOOZ_XFORM behaviour). mnboz : int Total number of Boozer harmonics retained. The enumeration follows the original code: * m runs from 0, 1, …, mboz-1 * for m = 0, n runs 0, 1, …, nboz * for m > 0, n runs -nboz, …, -1, 0, 1, …, nboz The toroidal index is stored as ``xn_b = n * nfp``. xm_b, xn_b : ndarray of int, shape (mnboz,) Boozer mode list as described above. bmnc_b, bmns_b, rmnc_b, rmns_b, zmnc_b, zmns_b, numnc_b, numns_b, gmnc_b, gmns_b : ndarray Boozer Fourier coefficients on the selected surfaces. Each has shape ``(mnboz, ns_b)``. Asymmetric arrays are ``None`` when ``asym`` is ``False``. The “c” suffix denotes cosine-like coefficients and the “s” suffix sine-like coefficients, following the usual VMEC/BOOZ_XFORM conventions. _prepared : bool Internal flag indicating whether the angular grids and related bookkeeping (θ, ζ, grid sizes) have been initialised. """ # VMEC parameters read from the wout file nfp: int = 1 asym: bool = False # Verbosity as described in the docstring verbose: int | bool = 1 mpol: int = 0 ntor: int = 0 mnmax: int = 0 xm: Optional[_np.ndarray] = None xn: Optional[_np.ndarray] = None xm_nyq: Optional[_np.ndarray] = None xn_nyq: Optional[_np.ndarray] = None mpol_nyq: Optional[int] = None ntor_nyq: Optional[int] = None mnmax_nyq: Optional[int] = None # Input arrays on the VMEC half grid (radial index runs over ns_in) s_in: Optional[_np.ndarray] = None iota: Optional[jnp.ndarray] = None rmnc: Optional[jnp.ndarray] = None rmns: Optional[jnp.ndarray] = None zmnc: Optional[jnp.ndarray] = None zmns: Optional[jnp.ndarray] = None lmnc: Optional[jnp.ndarray] = None lmns: Optional[jnp.ndarray] = None bmnc: Optional[jnp.ndarray] = None bmns: Optional[jnp.ndarray] = None bsubumnc: Optional[jnp.ndarray] = None bsubumns: Optional[jnp.ndarray] = None bsubvmnc: Optional[jnp.ndarray] = None bsubvmns: Optional[jnp.ndarray] = None Boozer_I_all: Optional[_np.ndarray] = None Boozer_G_all: Optional[_np.ndarray] = None phip: Optional[jnp.ndarray] = None chi: Optional[jnp.ndarray] = None pres: Optional[jnp.ndarray] = None phi: Optional[jnp.ndarray] = None aspect: float = 0.0 toroidal_flux: float = 0.0 # Derived quantities set by init_from_vmec or read_boozmn compute_surfs: Optional[List[int]] = field(default=None) s_b: Optional[_np.ndarray] = None ns_in: Optional[int] = None ns_b: Optional[int] = None Boozer_I: Optional[_np.ndarray] = None Boozer_G: Optional[_np.ndarray] = None mboz: Optional[int] = None nboz: Optional[int] = None mnboz: Optional[int] = None xm_b: Optional[_np.ndarray] = None xn_b: Optional[_np.ndarray] = None bmnc_b: Optional[_np.ndarray] = None bmns_b: Optional[_np.ndarray] = None rmnc_b: Optional[_np.ndarray] = None rmns_b: Optional[_np.ndarray] = None zmnc_b: Optional[_np.ndarray] = None zmns_b: Optional[_np.ndarray] = None numnc_b: Optional[_np.ndarray] = None numns_b: Optional[_np.ndarray] = None gmnc_b: Optional[_np.ndarray] = None gmns_b: Optional[_np.ndarray] = None # Bookkeeping _prepared: bool = False # whether mode lists and grids have been prepared # ------------------------------------------------------------------ # Delegated methods from external modules # ------------------------------------------------------------------
[docs] def init_from_vmec(self, *args, s_in: Optional[_np.ndarray] = None) -> None: """ Load Fourier data from VMEC into this instance. This method simply delegates to :func:`booz_xform_jax.vmec.init_from_vmec`. See that function for the full list of arguments and options. Parameters ---------- *args : Passed directly to :func:`init_from_vmec`. s_in : Optional replacement radial grid of normalised toroidal flux. If provided, its first element should correspond to the axis; this element will be discarded so that ``s_in[0]`` on the instance is the first half-grid surface away from the axis. """ init_from_vmec(self, *args, s_in=s_in)
[docs] def read_wout(self, filename: str, flux: bool = False) -> None: """ Read a VMEC ``wout`` file and populate the internal arrays. This is a thin wrapper around :func:`booz_xform_jax.vmec.read_wout`. In addition to the core Fourier coefficients needed for the Boozer transform, optional flux profile arrays can be loaded when ``flux=True``. Parameters ---------- filename : Path to the VMEC wout file. flux : If ``True``, also read radial profile arrays (φ', χ, p, …). """ read_wout(self, filename, flux)
[docs] def read_wout_data(self, wout, flux: bool = False) -> None: """ Populate the instance from an in-memory VMEC wout object. This is a thin wrapper around :func:`booz_xform_jax.vmec.read_wout_data`. Parameters ---------- wout : A VMEC wout-like object (e.g. ``vmec_jax.WoutData``). flux : If ``True``, also read radial profile arrays (φ', χ, p, …) when available. """ read_wout_data(self, wout, flux)
[docs] def write_boozmn(self, filename: str) -> None: """ Write the computed Boozer spectra to a NetCDF file. This delegates to :func:`booz_xform_jax.io_utils.write_boozmn`. The file format (NetCDF3 vs NetCDF4) depends on the availability of the ``netCDF4`` package and mirrors the behaviour of the original BOOZ_XFORM code. """ write_boozmn(self, filename)
[docs] def read_boozmn(self, filename: str) -> None: """ Read Boozer spectra from an existing ``boozmn`` file. This delegates to :func:`booz_xform_jax.io_utils.read_boozmn` and populates the current instance with the data from that file, including mode definitions, radial profiles, and Boozer spectra. """ read_boozmn(self, filename)
# ------------------------------------------------------------------ # Internal helper routines for preparing mode lists and grids # ------------------------------------------------------------------ def _prepare_mode_lists(self) -> None: """ Construct lists of Boozer mode indices based on ``mboz`` and ``nboz``. The enumeration mirrors the original C++ implementation: * m runs from 0, 1, ..., ``mboz - 1``. * For m == 0, n runs 0, 1, ..., nboz (only non-negative toroidal indices). * For m > 0, n runs -nboz, ..., -1, 0, 1, ..., nboz. The toroidal indices are stored as ``xn_b = n * nfp`` to match VMEC conventions (i.e. actual Fourier angle is ``xn_b * ζ``). The resulting arrays are stored on ``self.xm_b`` and ``self.xn_b``, and the total number of modes on ``self.mnboz``. """ if self.mboz is None or self.nboz is None: raise RuntimeError("mboz and nboz must be set before preparing mode lists") m_list: List[int] = [] n_list: List[int] = [] for m in range(self.mboz): if m == 0: # m = 0 → keep only non-negative n for n in range(0, self.nboz + 1): m_list.append(m) n_list.append(n * self.nfp) else: # m > 0 → keep full range of n for n in range(-self.nboz, self.nboz + 1): m_list.append(m) n_list.append(n * self.nfp) self.xm_b = _np.asarray(m_list, dtype=int) self.xn_b = _np.asarray(n_list, dtype=int) self.mnboz = len(self.xm_b) def _setup_grids(self) -> None: """ Set up the (theta, zeta) grid and basic bookkeeping. This routine constructs a tensor-product grid in VMEC angles, following the grid-sizing logic from the original BOOZ_XFORM code. The grid is slightly larger than the nominal Boozer resolution to comfortably resolve products of harmonics. For symmetric equilibria (``asym == False``) we exploit stellarator symmetry to restrict θ to [0, π] plus the end points. In that case: * ``ntheta_full = 2 * (2*mboz + 1)`` * we use only the first ``nu2_b = ntheta_full//2 + 1`` rows in θ, i.e. 0 ≤ θ ≤ π, and apply special 1/2 weights to the θ=0 and θ=π rows in the Fourier integrals. For asymmetric equilibria (``asym == True``) we use the full range θ ∈ [0, 2π); then ``nu3_b = ntheta_full``. The flattened grids are stored on ``self._theta_grid`` and ``self._zeta_grid``, and grid sizes on: * ``self._ntheta`` – total θ points in the full grid. * ``self._nzeta`` – total ζ points. * ``self._n_theta_zeta`` – product grid size. * ``self._nu2_b`` – number of θ rows used in the symmetric case. """ if self._prepared: return if self.mboz is None or self.nboz is None: raise RuntimeError("mboz and nboz must be set before setting up grids") # Nominal angular resolutions (full θ range) ntheta_full = 2 * (2 * self.mboz + 1) nzeta_full = 2 * (2 * self.nboz + 1) if self.nboz > 0 else 1 nu2_b = ntheta_full // 2 + 1 # number of θ rows in [0, π] if self.asym: # Asymmetric case: keep all θ rows in [0, 2π) nu3_b = ntheta_full else: # Symmetric case: exploit θ → 2π - θ symmetry, keep [0, π] nu3_b = nu2_b d_theta = (2.0 * jnp.pi) / ntheta_full d_zeta = (2.0 * jnp.pi) / (self.nfp * nzeta_full) theta_vals = jnp.arange(nu3_b) * d_theta zeta_vals = jnp.arange(nzeta_full) * d_zeta # Build flattened tensor-product grid: # # θ_j = θ_i for i fixed, repeated over all ζ # ζ_j = ζ_k tiled over θ rows # self._theta_grid = jnp.repeat(theta_vals, nzeta_full) self._zeta_grid = jnp.tile(zeta_vals, nu3_b) self._ntheta = int(ntheta_full) self._nzeta = int(nzeta_full) self._n_theta_zeta = int(nu3_b * nzeta_full) self._nu2_b = nu2_b self._prepared = True # ------------------------------------------------------------------ # Main transform # ------------------------------------------------------------------
[docs] def run(self, jit: bool = False) -> None: """ Perform the Boozer coordinate transformation on selected surfaces. Parameters ---------- jit : bool, optional Placeholder flag (currently unused). The transform is implemented entirely in terms of JAX array operations (``jax.numpy`` and ``einsum``). To avoid large compile times on CPU, we do **not** wrap the entire :meth:`run` in a single :func:`jax.jit` by default. Small helpers such as :func:`_init_trig` *are* jitted. Advanced users who want full JIT compilation can wrap :meth:`run` externally, but should be aware that this may lead to long compilation times for large Boozer resolutions. Notes ----- The implementation follows the algorithm outlined in the module docstring and in the BOOZ_XFORM documentation. The main difference from a direct translation of the Fortran/C++ code is that all loops over Fourier modes are vectorised. Only the loop over radial surfaces remains as a Python loop. """ _verbose = bool(self.verbose) if _verbose: pass # Header printed after grid setup # Basic sanity checks: VMEC data must be initialised. if self.rmnc is None or self.bmnc is None: raise RuntimeError("VMEC data must be initialised before running the transform") if self.ns_in is None: raise RuntimeError("ns_in must be set; did init_from_vmec run correctly?") ns_in = int(self.ns_in) if ns_in <= 0: raise RuntimeError("ns_in must be positive; did init_from_vmec run correctly?") # ------------------------------------------------------------------ # Surface selection # ------------------------------------------------------------------ # Default: compute on all surfaces. if self.compute_surfs is None: self.compute_surfs = list(range(ns_in)) else: for idx in self.compute_surfs: if idx < 0 or idx >= ns_in: raise ValueError( f"compute_surfs has an entry {idx} outside [0, {ns_in - 1}]" ) # ------------------------------------------------------------------ # Boozer mode lists and grids # ------------------------------------------------------------------ # Default Boozer resolution: match VMEC angular resolution. if self.mboz is None: if self.mpol is None: raise RuntimeError("mboz is not set and mpol is not available") self.mboz = int(self.mpol) if self.nboz is None: if self.ntor is None: raise RuntimeError("nboz is not set and ntor is not available") self.nboz = int(self.ntor) if self.mnboz is None or self.xm_b is None or self.xn_b is None: self._prepare_mode_lists() self._setup_grids() if _verbose: print( f" 0 <= mboz <= {int(self.mboz) - 1:4d}" f" {-int(self.nboz):4d} <= nboz <= {int(self.nboz):4d}" ) print(f" nu_boz = {self._ntheta:5d} nv_boz = {self._nzeta:5d}") print() print( " OUTBOARD (u=0)" " JS INBOARD (u=pi)" ) print("-" * 77) print( " v |B|vmec |B|booz Error" " |B|vmec |B|booz Error" ) print() theta_grid = self._theta_grid zeta_grid = self._zeta_grid # ------------------------------------------------------------------ # Precompute trig tables for VMEC spectra (non-Nyquist and Nyquist) # and hoist all per-mode trig combinations out of the surface loop. # ------------------------------------------------------------------ xm_non_np = _np.asarray(self.xm, dtype=int) xn_non_np = _np.asarray(self.xn, dtype=int) xm_nyq_np = _np.asarray(self.xm_nyq, dtype=int) xn_nyq_np = _np.asarray(self.xn_nyq, dtype=int) # Non-Nyquist (geometry, λ): mmax_non = int(_np.max(_np.abs(xm_non_np))) nmax_non = int(_np.max(_np.abs(xn_non_np // self.nfp))) cosm, sinm, cosn, sinn = _init_trig( theta_grid, zeta_grid, mmax_non, nmax_non, self.nfp ) # Nyquist (w, |B|): mmax_nyq = int(_np.max(_np.abs(xm_nyq_np))) nmax_nyq = int(_np.max(_np.abs(xn_nyq_np // self.nfp))) cosm_nyq, sinm_nyq, cosn_nyq, sinn_nyq = _init_trig( theta_grid, zeta_grid, mmax_nyq, nmax_nyq, self.nfp ) # Convert mode index lists to JAX arrays once (reused per surface). xm_non = jnp.asarray(xm_non_np, dtype=jnp.int32) xn_non = jnp.asarray(xn_non_np, dtype=jnp.int32) xm_nyq = jnp.asarray(xm_nyq_np, dtype=jnp.int32) xn_nyq = jnp.asarray(xn_nyq_np, dtype=jnp.int32) xm_b_j = jnp.asarray(self.xm_b, dtype=jnp.int32) xn_b_j = jnp.asarray(self.xn_b, dtype=jnp.int32) # Index of (m=0, n=0) Nyquist mode → Boozer I, G. idx00_candidates = _np.where((xm_nyq_np == 0) & (xn_nyq_np == 0))[0] if len(idx00_candidates) == 0: raise RuntimeError("Could not find (m=0,n=0) Nyquist mode in xm_nyq/xn_nyq") idx00 = int(idx00_candidates[0]) # ------------------------- # Hoisted non-Nyquist trig combinations # ------------------------- # Shapes: # cosm_m_non, sinm_m_non : (N, mnmax_non) # cosn_n_non, sinn_n_non : (N, mnmax_non) cosm_m_non = cosm[:, xm_non_np] sinm_m_non = sinm[:, xm_non_np] abs_n_non = jnp.abs(xn_non // self.nfp) abs_n_non_idx = _np.asarray(abs_n_non, dtype=int) cosn_n_non = cosn[:, abs_n_non_idx] sinn_n_non = sinn[:, abs_n_non_idx] sign_non = jnp.where(xn_non < 0, -1.0, 1.0)[None, :] # tcos_non / tsin_non: trigonometric factors multiplying # Fourier coefficients for rmnc, zmns, lmns, etc. tcos_non = cosm_m_non * cosn_n_non + sinm_m_non * sinn_n_non * sign_non tsin_non = sinm_m_non * cosn_n_non - cosm_m_non * sinn_n_non * sign_non m_non_f = xm_non.astype(jnp.float64) n_non_f = xn_non.astype(jnp.float64) # ------------------------- # Hoisted Nyquist trig combinations # ------------------------- cosm_m_nyq = cosm_nyq[:, xm_nyq_np] sinm_m_nyq = sinm_nyq[:, xm_nyq_np] abs_n_nyq = jnp.abs(xn_nyq // self.nfp) abs_n_nyq_idx = _np.asarray(abs_n_nyq, dtype=int) cosn_n_nyq = cosn_nyq[:, abs_n_nyq_idx] sinn_n_nyq = sinn_nyq[:, abs_n_nyq_idx] sign_nyq = jnp.where(xn_nyq < 0, -1.0, 1.0)[None, :] tcos_nyq = cosm_m_nyq * cosn_n_nyq + sinm_m_nyq * sinn_n_nyq * sign_nyq tsin_nyq = sinm_m_nyq * cosn_n_nyq - cosm_m_nyq * sinn_n_nyq * sign_nyq m_nyq_f = xm_nyq.astype(jnp.float64) n_nyq_f = xn_nyq.astype(jnp.float64) # ------------------------------------------------------------------ # Convert all hoisted JAX arrays to NumPy once. # This eliminates every JAX dispatch and device→host sync from the # per-surface loop — replacing jnp.einsum with numpy matmul (@). # ------------------------------------------------------------------ tcos_non_np = _np.asarray(tcos_non) # (N, mnmax_non) tsin_non_np = _np.asarray(tsin_non) tcos_nyq_np = _np.asarray(tcos_nyq) # (N, mnmax_nyq) tsin_nyq_np = _np.asarray(tsin_nyq) m_non_f_np = _np.asarray(m_non_f) # (mnmax_non,) n_non_f_np = _np.asarray(n_non_f) m_nyq_f_np = _np.asarray(m_nyq_f) # (mnmax_nyq,) n_nyq_f_np = _np.asarray(n_nyq_f) theta_grid_np = _np.asarray(theta_grid) # (N,) zeta_grid_np = _np.asarray(zeta_grid) # VMEC coefficient arrays (per-surface slices become plain numpy views) rmnc_arr = _np.asarray(self.rmnc) # (mnmax_non, ns_in) zmns_arr = _np.asarray(self.zmns) lmns_arr = _np.asarray(self.lmns) bmnc_arr = _np.asarray(self.bmnc) # (mnmax_nyq, ns_in) bsubumnc_arr = _np.asarray(self.bsubumnc) bsubvmnc_arr = _np.asarray(self.bsubvmnc) iota_arr = _np.asarray(self.iota) # (ns_in,) if self.asym: rmns_arr = _np.asarray(self.rmns) if self.rmns is not None else None zmnc_arr = _np.asarray(self.zmnc) if self.zmnc is not None else None lmnc_arr = _np.asarray(self.lmnc) if self.lmnc is not None else None bmns_arr = _np.asarray(self.bmns) if self.bmns is not None else None bsubumns_arr = _np.asarray(self.bsubumns) if self.bsubumns is not None else None bsubvmns_arr = _np.asarray(self.bsubvmns) if self.bsubvmns is not None else None else: rmns_arr = zmnc_arr = lmnc_arr = bmns_arr = None bsubumns_arr = bsubvmns_arr = None # Hoist wmns boolean masks and safe-divisors out of the surface loop. m_nonzero_np = m_nyq_f_np != 0.0 n_nonzero_only_np = ~m_nonzero_np & (n_nyq_f_np != 0.0) m_nyq_f_safe = _np.where(m_nonzero_np, m_nyq_f_np, 1.0) n_nyq_f_safe = _np.where(n_nonzero_only_np, n_nyq_f_np, 1.0) # ------------------------------------------------------------------ # Output arrays (NumPy, host side) # ------------------------------------------------------------------ ns_b = len(self.compute_surfs) self.ns_b = ns_b mnboz = int(self.mnboz) bmnc_b = _np.zeros((mnboz, ns_b), dtype=float) rmnc_b = _np.zeros((mnboz, ns_b), dtype=float) zmns_b = _np.zeros((mnboz, ns_b), dtype=float) numns_b = _np.zeros((mnboz, ns_b), dtype=float) gmnc_b = _np.zeros((mnboz, ns_b), dtype=float) if self.asym: bmns_b = _np.zeros((mnboz, ns_b), dtype=float) rmns_b = _np.zeros((mnboz, ns_b), dtype=float) zmnc_b = _np.zeros((mnboz, ns_b), dtype=float) numnc_b = _np.zeros((mnboz, ns_b), dtype=float) gmns_b = _np.zeros((mnboz, ns_b), dtype=float) else: bmns_b = rmns_b = zmnc_b = numnc_b = gmns_b = None # Batch-extract Boozer I and G for all selected surfaces at once # (avoids one device→host transfer per surface inside the loop). _surfs_np = _np.asarray(self.compute_surfs, dtype=int) Boozer_I = _np.asarray(self.bsubumnc[idx00, _surfs_np], dtype=float) Boozer_G = _np.asarray(self.bsubvmnc[idx00, _surfs_np], dtype=float) # ------------------------------------------------------------------ # Batch VMEC synthesis — compute all surface fields in one DGEMM. # tcos_non_np is (N, mnmax_non); rmnc_arr[:, surfs] is (mnmax_non, ns_b) # Result: (N, ns_b) per field. Keeps the trig matrix in L3/L4 cache # rather than re-reading it for each of the ns_b surfaces. # ------------------------------------------------------------------ _lmns_s = lmns_arr[:, _surfs_np] # (mnmax_non, ns_b) _lmns_m_s = _lmns_s * m_non_f_np[:, None] # pre-scaled _lmns_n_s = _lmns_s * n_non_f_np[:, None] _r_all = tcos_non_np @ rmnc_arr[:, _surfs_np] # (N, ns_b) _z_all = tsin_non_np @ zmns_arr[:, _surfs_np] _lam_all = tsin_non_np @ _lmns_s _dlam_dth_all = tcos_non_np @ _lmns_m_s _dlam_dze_all = -(tcos_non_np @ _lmns_n_s) # wmns for all surfaces at once: (mnmax_nyq, ns_b) _bsubumnc_s = bsubumnc_arr[:, _surfs_np] # (mnmax_nyq, ns_b) _bsubvmnc_s = bsubvmnc_arr[:, _surfs_np] _wmns_all = _np.where(m_nonzero_np[:, None], _bsubumnc_s / m_nyq_f_safe[:, None], _np.where(n_nonzero_only_np[:, None], -_bsubvmnc_s / n_nyq_f_safe[:, None], 0.0)) _wmns_m_s = _wmns_all * m_nyq_f_np[:, None] _wmns_n_s = _wmns_all * n_nyq_f_np[:, None] _w_all = tsin_nyq_np @ _wmns_all # (N, ns_b) _dw_dth_all = tcos_nyq_np @ _wmns_m_s _dw_dze_all = -(tcos_nyq_np @ _wmns_n_s) _bmod_all = tcos_nyq_np @ bmnc_arr[:, _surfs_np] # (N, ns_b) if self.asym: if lmnc_arr is not None: _lmnc_s = lmnc_arr[:, _surfs_np] _lmnc_m_s = _lmnc_s * m_non_f_np[:, None] _lmnc_n_s = _lmnc_s * n_non_f_np[:, None] _r_all = _r_all + tsin_non_np @ rmns_arr[:, _surfs_np] _z_all = _z_all + tcos_non_np @ zmnc_arr[:, _surfs_np] _lam_all = _lam_all + tcos_non_np @ _lmnc_s _dlam_dth_all = _dlam_dth_all - tsin_non_np @ _lmnc_m_s _dlam_dze_all = _dlam_dze_all + tsin_non_np @ _lmnc_n_s if bsubumns_arr is not None: _bsubumns_s = bsubumns_arr[:, _surfs_np] _bsubvmns_s = bsubvmns_arr[:, _surfs_np] _wmnc_all = _np.where(m_nonzero_np[:, None], -_bsubumns_s / m_nyq_f_safe[:, None], _np.where(n_nonzero_only_np[:, None], _bsubvmns_s / n_nyq_f_safe[:, None], 0.0)) _wmnc_m_s = _wmnc_all * m_nyq_f_np[:, None] _wmnc_n_s = _wmnc_all * n_nyq_f_np[:, None] _w_all = _w_all + tcos_nyq_np @ _wmnc_all _dw_dth_all = _dw_dth_all - tsin_nyq_np @ _wmnc_m_s _dw_dze_all = _dw_dze_all + tsin_nyq_np @ _wmnc_n_s _bmod_all = _bmod_all + tsin_nyq_np @ bmns_arr[:, _surfs_np] # Pre-allocate reusable scratch buffers for the double-spectral step. # This eliminates ~10 small allocations per surface. _N = int(theta_grid_np.shape[0]) _nb = int(self.nboz) + 1 _mb = int(self.mboz) + 1 _fcn_buf = _np.empty((_N, _nb), dtype=float) _fsn_buf = _np.empty((_N, _nb), dtype=float) _Xc_buf = _np.empty((_mb, _nb), dtype=float) _Xs_buf = _np.empty((_mb, _nb), dtype=float) _Ysc_buf = _np.empty((_mb, _nb), dtype=float) _Ycs_buf = _np.empty((_mb, _nb), dtype=float) # ------------------------------------------------------------------ # Hoist Boozer-mode index arrays out of the surface loop. # These depend only on xm_b / xn_b which are constant across surfaces. # Computing them inside the loop triggers repeated device→host syncs. # ------------------------------------------------------------------ _m_b_np_idx = _np.asarray(xm_b_j, dtype=int) # (mnboz,) _abs_n_b_np = _np.asarray(jnp.abs(xn_b_j // self.nfp), dtype=int) # (mnboz,) _sign_b_hoisted = jnp.where(xn_b_j < 0, -1.0, 1.0)[None, :] # (1, mnboz) # Fourier normalisation factor (constant: depends only on grid sizes) _fourier_factor0 = ( 2.0 / (self._ntheta * self._nzeta) if self.asym else 2.0 / ((self._nu2_b - 1) * self._nzeta) ) _fourier_factor = jnp.ones((mnboz,), dtype=jnp.float64) * _fourier_factor0 _fourier_factor = _fourier_factor.at[0].set(_fourier_factor0 * 0.5) # ------------------------------------------------------------------ # Chunk size for memory-bounded Fourier integrals. # Each chunk allocates 2×(N×L)×8 bytes for tcos_c / tsin_c. # Default cap: 200 MB; override via BOOZ_XFORM_JAX_CHUNK_BYTES env var. # ------------------------------------------------------------------ # NumPy copies of constant mode index/weight arrays _m_b_chunk = _np.asarray(_m_b_np_idx) # (mnboz,) int _n_b_chunk = _np.asarray(_abs_n_b_np) # (mnboz,) int _ff_chunk = _np.asarray(_fourier_factor) # (mnboz,) float _sgn_chunk = _np.asarray(_sign_b_hoisted[0]) # (mnboz,) float # NumPy copies for the verbose modbooz reconstruction if _verbose: _xm_b_np_f = _np.asarray(xm_b_j, dtype=float) _xn_b_np_f = _np.asarray(xn_b_j, dtype=float) # Convenience indices for symmetric θ integration (θ=0 and θ=π rows). # NumPy integer arrays — used for in-place *= 0.5 on numpy trig tables. idx_theta0 = _np.arange(0, self._nzeta) idx_thetapi = _np.arange( (self._nu2_b - 1) * self._nzeta, self._nu2_b * self._nzeta ) # Fixed-point indices for Fortran-style accuracy check # (u=0,v=0), (u=pi,v=0), (u=0,v=pi), (u=pi,v=pi) nv2_b_idx = self._nzeta // 2 # 0-based index for v=pi idx_00 = 0 idx_pi0 = (self._nu2_b - 1) * self._nzeta idx_0pi = nv2_b_idx idx_pipi = (self._nu2_b - 1) * self._nzeta + nv2_b_idx # ------------------------------------------------------------------ # Loop over surfaces js_b (Python loop; heavy math is vectorised) # ------------------------------------------------------------------ for js_b, js in enumerate(self.compute_surfs): if isinstance(self.verbose, int) and self.verbose > 1: print(f"[booz_xform_jax] Solving surface js_b={js_b}, js={js}") # ------------------------------------------------------------------ # 2) Boozer I and G (already batch-extracted before the loop) # ------------------------------------------------------------------ Boozer_I_js = Boozer_I[js_b] Boozer_G_js = Boozer_G[js_b] # ------------------------------------------------------------------ # 3) R, Z, λ and derivatives — sliced from pre-batched arrays # ------------------------------------------------------------------ r = _r_all[:, js_b] z = _z_all[:, js_b] lam = _lam_all[:, js_b] dlam_dth = _dlam_dth_all[:, js_b] dlam_dze = _dlam_dze_all[:, js_b] # ------------------------------------------------------------------ # 4) w, ∂w/∂θ, ∂w/∂ζ and |B| — sliced from pre-batched arrays # ------------------------------------------------------------------ w = _w_all[:, js_b] dw_dth = _dw_dth_all[:, js_b] dw_dze = _dw_dze_all[:, js_b] bmod = _bmod_all[:, js_b] # ------------------------------------------------------------------ # 5) ν, Boozer angles, their derivatives, J_B, and dB/d(vmec) # ------------------------------------------------------------------ this_iota = float(iota_arr[js]) GI = Boozer_G_js + this_iota * Boozer_I_js one_over_GI = 1.0 / GI # ν from eq (10): ν = (w - I λ) / (G + ι I) nu = one_over_GI * (w - Boozer_I_js * lam) # Boozer angles from eq (3): # θ_B = θ + λ + ι ν # ζ_B = ζ + ν theta_B = theta_grid_np + lam + this_iota * nu zeta_B = zeta_grid_np + nu # Derivatives of ν: dnu_dze = one_over_GI * (dw_dze - Boozer_I_js * dlam_dze) dnu_dth = one_over_GI * (dw_dth - Boozer_I_js * dlam_dth) # Eq (12): dB/d(vmec) factor dB_dvmec = (1.0 + dlam_dth) * (1.0 + dnu_dze) + \ (this_iota - dlam_dze) * dnu_dth # Store VMEC-space |B| at 4 fixed points for accuracy check later if _verbose: bmodv = ( float(bmod[idx_00]), # (u=0, v=0) float(bmod[idx_pi0]), # (u=pi, v=0) float(bmod[idx_0pi]), # (u=0, v=pi) float(bmod[idx_pipi]), # (u=pi, v=pi) ) u_b = ( float(theta_B[idx_00]), float(theta_B[idx_pi0]), float(theta_B[idx_0pi]), float(theta_B[idx_pipi]), ) v_b = ( float(zeta_B[idx_00]), float(zeta_B[idx_pi0]), float(zeta_B[idx_0pi]), float(zeta_B[idx_pipi]), ) # ------------------------------------------------------------------ # 6) Boozer trig tables on (theta_B, zeta_B) — pure NumPy # ------------------------------------------------------------------ cosm_b, sinm_b, cosn_b, sinn_b = _init_trig_np( theta_B, zeta_B, int(self.mboz), int(self.nboz), self.nfp ) # Boozer Jacobian: J_B = (G + ι I) / |B|² = GI / |B|² boozer_jac = GI / (bmod * bmod) # ------------------------------------------------------------------ # 7) Final Fourier integrals — double-spectral decomposition # ------------------------------------------------------------------ # Separability of the Boozer trig factor: # tcos[i, j] = cosm[i,m_j]*cosn[i,n_j] + sinm[i,m_j]*sinn[i,n_j]*sgn_j # tsin[i, j] = sinm[i,m_j]*cosn[i,n_j] - cosm[i,m_j]*sinn[i,n_j]*sgn_j # # The Fourier integral factors as two tiny DGEMM calls per field: # X_c[m, n] = cosm.T @ (field * cosn) shape (mboz+1, nboz+1) # X_s[m, n] = sinm.T @ (field * sinn) # Y_sc[m,n] = sinm.T @ (field * cosn) # Y_cs[m,n] = cosm.T @ (field * sinn) # # Then scatter: cos_out[j] = ff_j*(X_c[m_j,n_j] + sgn_j*X_s[m_j,n_j]) # sin_out[j] = ff_j*(Y_sc[m_j,n_j] - sgn_j*Y_cs[m_j,n_j]) # # Peak memory: O(N*(mboz+1)) not O(N*mnboz); no chunk loop needed. # Apply symmetric half-weight to a dB copy (not to trig tables) _dB = dB_dvmec.copy() if not self.asym else dB_dvmec if not self.asym: _dB[idx_theta0] *= 0.5 _dB[idx_thetapi] *= 0.5 # Reusable transpose views (no copy — BLAS handles Fortran order) _cmT = cosm_b.T # (mboz+1, N) _smT = sinm_b.T # (mboz+1, N) _cos_out = _np.empty((5 if self.asym else 3, mnboz), dtype=float) _sin_out = _np.empty((5 if self.asym else 2, mnboz), dtype=float) if self.asym: # All 5 fields contribute to both cosine and sine output _field_list = [bmod, r, z, nu, boozer_jac] for k, fk in enumerate(_field_list): _fkdB = fk * _dB # (N,) weighted field _np.multiply(_fkdB[:, None], cosn_b, out=_fcn_buf) # (N, nboz+1) _np.multiply(_fkdB[:, None], sinn_b, out=_fsn_buf) # (N, nboz+1) _np.dot(_cmT, _fcn_buf, out=_Xc_buf) # (mboz+1, nboz+1) _np.dot(_smT, _fsn_buf, out=_Xs_buf) _np.dot(_smT, _fcn_buf, out=_Ysc_buf) _np.dot(_cmT, _fsn_buf, out=_Ycs_buf) _cos_out[k] = _ff_chunk * ( _Xc_buf[_m_b_chunk, _n_b_chunk] + _sgn_chunk * _Xs_buf[_m_b_chunk, _n_b_chunk] ) _sin_out[k] = _ff_chunk * ( _Ysc_buf[_m_b_chunk, _n_b_chunk] - _sgn_chunk * _Ycs_buf[_m_b_chunk, _n_b_chunk] ) else: # Cosine-output fields: bmod, r, jac for k, fk in enumerate([bmod, r, boozer_jac]): _fkdB = fk * _dB _np.multiply(_fkdB[:, None], cosn_b, out=_fcn_buf) _np.multiply(_fkdB[:, None], sinn_b, out=_fsn_buf) _np.dot(_cmT, _fcn_buf, out=_Xc_buf) _np.dot(_smT, _fsn_buf, out=_Xs_buf) _cos_out[k] = _ff_chunk * ( _Xc_buf[_m_b_chunk, _n_b_chunk] + _sgn_chunk * _Xs_buf[_m_b_chunk, _n_b_chunk] ) # Sine-output fields: z, nu for k, fk in enumerate([z, nu]): _fkdB = fk * _dB _np.multiply(_fkdB[:, None], cosn_b, out=_fcn_buf) _np.multiply(_fkdB[:, None], sinn_b, out=_fsn_buf) _np.dot(_smT, _fcn_buf, out=_Ysc_buf) _np.dot(_cmT, _fsn_buf, out=_Ycs_buf) _sin_out[k] = _ff_chunk * ( _Ysc_buf[_m_b_chunk, _n_b_chunk] - _sgn_chunk * _Ycs_buf[_m_b_chunk, _n_b_chunk] ) # Write to NumPy output buffers (no .asarray needed — already host) bmnc_b[:, js_b] = _cos_out[0] rmnc_b[:, js_b] = _cos_out[1] if self.asym: zmnc_b[:, js_b] = _cos_out[2] numnc_b[:, js_b] = _cos_out[3] gmnc_b[:, js_b] = _cos_out[4] bmns_b[:, js_b] = _sin_out[0] rmns_b[:, js_b] = _sin_out[1] zmns_b[:, js_b] = _sin_out[2] numns_b[:, js_b] = _sin_out[3] gmns_b[:, js_b] = _sin_out[4] else: gmnc_b[:, js_b] = _cos_out[2] zmns_b[:, js_b] = _sin_out[0] numns_b[:, js_b] = _sin_out[1] # Fortran-style accuracy check: reconstruct |B| at 4 fixed # Boozer-angle points and compare with VMEC real-space |B|. if _verbose: # jrad is the 1-based full-grid index (Fortran convention) jrad = js + 2 # Vectorised modbooz: u_b/v_b are (4,), _xm_b_np_f/_xn_b_np_f (mnboz,) # -> angles (4, mnboz), then sum over modes u_b_arr = _np.array(u_b) # (4,) v_b_arr = _np.array(v_b) # (4,) sgn_arr = _np.where(_xn_b_np_f >= 0, 1.0, -1.0) # (mnboz,) n_abs_arr = _np.abs(_xn_b_np_f) / self.nfp # (mnboz,) cosm_4 = _np.cos(_xm_b_np_f[None, :] * u_b_arr[:, None]) # (4, mnboz) sinm_4 = _np.sin(_xm_b_np_f[None, :] * u_b_arr[:, None]) cosn_4 = _np.cos(n_abs_arr[None, :] * v_b_arr[:, None] * self.nfp) sinn_4 = _np.sin(n_abs_arr[None, :] * v_b_arr[:, None] * self.nfp) cost_4 = cosm_4 * cosn_4 + sinm_4 * sinn_4 * sgn_arr[None, :] # (4, mnboz) bmodb_arr = cost_4 @ bmnc_b[:, js_b] # already numpy if self.asym: sint_4 = sinm_4 * cosn_4 - cosm_4 * sinn_4 * sgn_arr[None, :] bmodb_arr = bmodb_arr + sint_4 @ bmns_b[:, js_b] bmodv_arr = _np.array(bmodv) err_arr = _np.abs(bmodb_arr - bmodv_arr) / _np.maximum( _np.abs(bmodb_arr), _np.maximum(_np.abs(bmodv_arr), 1e-30) ) bmodb = bmodb_arr.tolist() err = err_arr.tolist() print( f" 0 {bmodv[0]:11.3E}{bmodb[0]:11.3E}{err[0]:11.3E}" f"{jrad:5d} {bmodv[1]:11.3E}{bmodb[1]:11.3E}{err[1]:11.3E}" ) print( f" pi {bmodv[2]:11.3E}{bmodb[2]:11.3E}{err[2]:11.3E}" f" {bmodv[3]:11.3E}{bmodb[3]:11.3E}{err[3]:11.3E}" ) # ------------------------------------------------------------------ # Store results on the instance # ------------------------------------------------------------------ self.bmnc_b = bmnc_b self.rmnc_b = rmnc_b self.zmns_b = zmns_b self.numns_b = numns_b self.gmnc_b = gmnc_b if self.asym: self.bmns_b = bmns_b self.rmns_b = rmns_b self.zmnc_b = zmnc_b self.numnc_b = numnc_b self.gmns_b = gmns_b self.Boozer_I = Boozer_I self.Boozer_G = Boozer_G self.s_b = _np.asarray(self.s_in)[self.compute_surfs]
[docs] def run_jax(self, *, jit: bool = True) -> dict: """Run a JAX-native Boozer transform (no Python surface loop). This method returns a mapping compatible with boozmn field names. The mapping includes ``gmnc_b`` and its BOOZ_XFORM-compatible ``gmn_b`` alias for Boozer Jacobian harmonics, plus asymmetric spectra when ``asym`` is true. It is intended for end-to-end JIT/differentiable workflows and does not populate the instance attributes (unlike `run`). """ if self.rmnc is None or self.bmnc is None: raise RuntimeError("VMEC data must be initialised before running the transform") if self.ns_in is None: raise RuntimeError("ns_in must be set; did init_from_vmec run correctly?") # Default surfaces: all half-grid surfaces. if self.compute_surfs is None: compute_surfs = list(range(int(self.ns_in))) else: compute_surfs = list(self.compute_surfs) # Default Boozer resolution: match VMEC angular resolution. if self.mboz is None: if self.mpol is None: raise RuntimeError("mboz is not set and mpol is not available") self.mboz = int(self.mpol) if self.nboz is None: if self.ntor is None: raise RuntimeError("nboz is not set and ntor is not available") self.nboz = int(self.ntor) if self.mnboz is None or self.xm_b is None or self.xn_b is None: self._prepare_mode_lists() constants, grids = prepare_booz_xform_constants( nfp=int(self.nfp), mboz=int(self.mboz), nboz=int(self.nboz), asym=bool(self.asym), xm=self.xm, xn=self.xn, xm_nyq=self.xm_nyq, xn_nyq=self.xn_nyq, ) # Ensure surface dimension is first (ns, mn) rmnc = jnp.asarray(_np.asarray(self.rmnc)).T zmns = jnp.asarray(_np.asarray(self.zmns)).T lmns = jnp.asarray(_np.asarray(self.lmns)).T bmnc = jnp.asarray(_np.asarray(self.bmnc)).T bsubumnc = jnp.asarray(_np.asarray(self.bsubumnc)).T bsubvmnc = jnp.asarray(_np.asarray(self.bsubvmnc)).T iota = jnp.asarray(_np.asarray(self.iota)) rmns = jnp.asarray(_np.asarray(self.rmns)).T if self.asym and self.rmns is not None else None zmnc = jnp.asarray(_np.asarray(self.zmnc)).T if self.asym and self.zmnc is not None else None lmnc = jnp.asarray(_np.asarray(self.lmnc)).T if self.asym and self.lmnc is not None else None bmns = jnp.asarray(_np.asarray(self.bmns)).T if self.asym and self.bmns is not None else None bsubumns = ( jnp.asarray(_np.asarray(self.bsubumns)).T if self.asym and self.bsubumns is not None else None ) bsubvmns = ( jnp.asarray(_np.asarray(self.bsubvmns)).T if self.asym and self.bsubvmns is not None else None ) surface_indices = jnp.asarray(compute_surfs, dtype=jnp.int32) booz_fn = booz_xform_jax_impl if jit: booz_fn = jax.jit(booz_xform_jax_impl, static_argnames=("constants",)) return booz_fn( rmnc=rmnc, zmns=zmns, lmns=lmns, bmnc=bmnc, bsubumnc=bsubumnc, bsubvmnc=bsubvmnc, iota=iota, xm=jnp.asarray(self.xm, dtype=jnp.int32), xn=jnp.asarray(self.xn, dtype=jnp.int32), xm_nyq=jnp.asarray(self.xm_nyq, dtype=jnp.int32), xn_nyq=jnp.asarray(self.xn_nyq, dtype=jnp.int32), constants=constants, grids=grids, rmns=rmns, zmnc=zmnc, lmnc=lmnc, bmns=bmns, bsubumns=bsubumns, bsubvmns=bsubvmns, surface_indices=surface_indices, )
# ------------------------------------------------------------------ # Surface registration (unchanged API) # ------------------------------------------------------------------
[docs] def register_surfaces(self, s: Iterable[int | float] | int | float) -> None: """ Register one or more surfaces on which to compute the transform. This method mirrors the original C++ ``register`` routine. It accepts either integer half-grid indices or floating-point radial coordinate values in normalised toroidal flux space. Parameters ---------- s : int, float, or iterable of these Surfaces to register: * If an integer, it is interpreted as an index on the VMEC half grid (0 ≤ index < ns_in). * If a float, it should lie in [0, 1] and is interpreted as a normalised toroidal flux value. We then choose the nearest index based on ``self.s_in``. Notes ----- * Any new surfaces are **appended** to the existing :attr:`compute_surfs` list (duplicates are removed). * Surfaces outside the valid index range produce a :class:`ValueError`. * The method does not perform the transform; you must call :meth:`run` afterwards. """ # Normalise input to a list if isinstance(s, (int, float)): ss = [s] else: ss = list(s) if self.compute_surfs is None: current = set() else: current = set(self.compute_surfs) for val in ss: if isinstance(val, int): # Integer: treated as direct index idx = val else: # Float: map to nearest index based on s_in sval = float(val) if sval < 0.0 or sval > 1.0: raise ValueError("Normalized toroidal flux values must lie in [0,1]") idx = int(_np.argmin(_np.abs(self.s_in - sval))) # type: ignore[arg-type] if idx < 0 or idx >= int(self.ns_in): raise ValueError( f"Surface index {idx} is outside the range [0, {int(self.ns_in) - 1}]" ) current.add(idx) self.compute_surfs = sorted(current) # Respect the verbose flag: only print when truthy if bool(self.verbose): print(f"[booz_xform_jax] Registered surfaces: {self.compute_surfs}") return None