"""VMEC input routines for the JAX implementation of ``booz_xform``.
This module contains functions for loading data from VMEC output files
and for initialising a :class:`~booz_xform_jax.core.Booz_xform` instance
with that data. The goal is to mimic the behaviour of the
``booz_xform`` C++ code while providing a Pythonic interface.
The functions defined here operate on instances of
:class:`~booz_xform_jax.core.Booz_xform`. They are not intended to be
called standalone; instead, call the corresponding methods on a
Booz_xform object (``init_from_vmec`` and ``read_wout``), which
delegate to these functions.
"""
from __future__ import annotations
import numpy as _np
from typing import Optional
try:
import jax.numpy as jnp
except ImportError as e: # pragma: no cover
raise ImportError(
"The booz_xform_jax package requires JAX. Please install jax and "
"jaxlib before using this module."
) from e
try:
import netCDF4 # type: ignore
except ImportError:
netCDF4 = None # pragma: no cover
try:
from scipy.io import netcdf_file # type: ignore
except ImportError:
netcdf_file = None # pragma: no cover
_RADIUS_DIM_NAMES = {"radius", "ns", "radial"}
_NONNYQ_MODE_DIM_NAMES = {"mn_mode", "mnmax", "mode", "modes"}
_NYQ_MODE_DIM_NAMES = {"mn_mode_nyq", "mnmax_nyq", "mode_nyq", "modes_nyq"}
def _dim_name(dim) -> str:
"""Normalize NetCDF dimension labels for layout inference."""
if isinstance(dim, bytes):
dim = dim.decode()
return str(dim).lower()
def _variable_dimensions(var) -> tuple[str, ...]:
"""Return NetCDF variable dimensions when the reader exposes them."""
return tuple(_dim_name(dim) for dim in getattr(var, "dimensions", ()) or ())
def _layout_from_dimensions(
dimensions: tuple[str, ...],
*,
radius_names: set[str],
mode_names: set[str],
) -> Optional[str]:
"""Infer whether a VMEC coefficient array is radius-mode or mode-radius.
VMEC NetCDF files usually label coefficient dimensions, e.g.
``("radius", "mn_mode")``. Those labels are the only reliable way
to orient square arrays when ``ns == mnmax``.
"""
if len(dimensions) != 2:
return None
first, second = dimensions
first_is_radius = first in radius_names
second_is_radius = second in radius_names
first_is_mode = first in mode_names
second_is_mode = second in mode_names
if first_is_radius and second_is_mode:
return "radius_mode"
if first_is_mode and second_is_radius:
return "mode_radius"
return None
def _set_layout_hint_from_variable(self, var, *, nyquist: bool = False) -> None:
dimensions = _variable_dimensions(var)
layout = _layout_from_dimensions(
dimensions,
radius_names=_RADIUS_DIM_NAMES,
mode_names=_NYQ_MODE_DIM_NAMES if nyquist else _NONNYQ_MODE_DIM_NAMES,
)
if layout is not None:
attr = "_vmec_nyq_layout" if nyquist else "_vmec_nonnyq_layout"
setattr(self, attr, layout)
def _infer_layout_from_shape(
arr: _np.ndarray,
*,
ns_full: int,
mode_count: Optional[int],
name: str,
layout_hint: Optional[str] = None,
) -> tuple[int, str]:
"""Infer coefficient layout, using explicit metadata before shape.
Shape-only inference is ambiguous when ``ns_full == mode_count``.
In that case callers that loaded a NetCDF file should pass a
``layout_hint`` derived from the variable dimension names.
"""
if arr.ndim != 2:
raise ValueError(f"{name} must be 2D, got shape {arr.shape}")
if layout_hint == "radius_mode":
if arr.shape[0] != ns_full:
raise ValueError(
f"{name} has layout hint radius_mode but shape {arr.shape}; "
f"expected first dimension ns={ns_full}"
)
return arr.shape[1], layout_hint
if layout_hint == "mode_radius":
if arr.shape[1] != ns_full:
raise ValueError(
f"{name} has layout hint mode_radius but shape {arr.shape}; "
f"expected second dimension ns={ns_full}"
)
return arr.shape[0], layout_hint
if arr.shape[0] == ns_full and arr.shape[1] != ns_full:
return arr.shape[1], "radius_mode"
if arr.shape[1] == ns_full and arr.shape[0] != ns_full:
return arr.shape[0], "mode_radius"
if (
mode_count is not None
and mode_count > 0
and arr.shape == (ns_full, mode_count)
and arr.shape != (mode_count, ns_full)
):
return mode_count, "radius_mode"
if (
mode_count is not None
and mode_count > 0
and arr.shape == (mode_count, ns_full)
and arr.shape != (ns_full, mode_count)
):
return mode_count, "mode_radius"
raise ValueError(
f"{name} has ambiguous or unexpected shape {arr.shape}; "
f"one dimension must equal ns={ns_full}. If ns equals the number "
"of modes, use read_wout() so NetCDF dimension names can disambiguate "
"the array orientation."
)
[docs]
def init_from_vmec(self, *args, s_in: Optional[_np.ndarray] = None) -> None:
"""Initialise a :class:`~booz_xform_jax.core.Booz_xform` instance with VMEC data.
This function accepts two calling conventions for compatibility
with the original C++ ``init_from_vmec`` function:
1. ``init_from_vmec(ns, iotas, rmnc, rmns, zmnc, zmns, lmnc, lmns,
bmnc, bmns, bsubumnc, bsubumns, bsubvmnc, bsubvmns[, phip, chi, pres, phi])``
where ``ns`` is an integer giving the number of radial
surfaces on the full VMEC grid (including the axis) and the
remaining arguments are 1D or 2D arrays with shapes matching
the VMEC documentation. The zeroth radial entry (the axis) is
discarded. Optional arrays ``phip``, ``chi``, ``pres`` and ``phi``
may appear at the end of the argument list; if present, all
four must be provided.
2. ``init_from_vmec(iotas, rmnc, rmns, zmnc, zmns, lmnc, lmns,
bmnc, bmns, bsubumnc, bsubumns, bsubvmnc, bsubvmns[, phip, chi, pres, phi])``
omitting the ``ns`` argument. In this case the length of the
``iotas`` array determines the number of surfaces on the full
grid. Optional flux arrays may appear at the end as in the
first calling convention.
Parameters
----------
self : Booz_xform
The instance to initialise.
*args : sequence
Positional arguments following one of the two calling
conventions described above.
s_in : ndarray, optional
Optional array of length ``ns`` giving the values of
normalized toroidal flux on the full radial grid. If
provided, the first entry should correspond to the axis and
will be discarded. When not provided, a uniform grid
between 0 and 1 is used.
Notes
-----
The zeroth radial entry (corresponding to the magnetic axis) is
ignored in all arrays. The remaining ``ns - 1`` surfaces
constitute the half‑grid on which the Boozer transform is
performed. All input arrays are defensively copied and
converted to ``jax.numpy.DeviceArray`` objects of dtype
``float64`` for computation; however, the radial coordinate array
``s_in`` is stored as a NumPy array so that it can be indexed
using Python lists in :meth:`~booz_xform_jax.core.Booz_xform.run` and
:func:`~booz_xform_jax.io_utils.read_boozmn`.
"""
if len(args) == 0:
raise TypeError("init_from_vmec requires at least one positional argument")
# Determine if first argument is ns (an integer) or iotas (array)
first = args[0]
if isinstance(first, (int, _np.integer)):
# Signature with ns provided
if len(args) < 2:
raise TypeError("init_from_vmec(ns, ...) missing iotas array")
ns_full = int(first)
iotas = _np.asarray(args[1])
arrays = args[2:]
else:
# Signature without ns; infer ns from iotas
iotas = _np.asarray(first)
ns_full = iotas.shape[0]
arrays = args[1:]
# Validate iotas
if iotas.ndim != 1 or iotas.shape[0] != ns_full:
raise ValueError("iotas must be a 1D array of length ns")
if ns_full < 2:
raise ValueError("ns must be at least 2 surfaces (including the axis)")
# Validate number of arrays: expect either 12 (no flux) or 16 (with flux)
if len(arrays) not in (12, 16):
expected = 14 if isinstance(first, (int, _np.integer)) else 13
expected_with_flux = expected + 4
raise TypeError(
f"init_from_vmec expects {expected} or {expected_with_flux} positional arguments, "
f"but {len(args)} were given"
)
# Determine number of half-grid surfaces and drop the axis
ns_in = ns_full - 1
self.ns_in = ns_in
# Store iota values as JAX array (drop axis)
self.iota = jnp.asarray(iotas[1:], dtype=jnp.float64)
# Unpack mandatory arrays: order rmnc, rmns, zmnc, zmns, lmnc, lmns,
# bmnc, bmns, bsubumnc, bsubumns, bsubvmnc, bsubvmns.
(
rmnc0,
rmns0,
zmnc0,
zmns0,
lmnc0,
lmns0,
bmnc0,
bmns0,
bsubumnc0,
bsubumns0,
bsubvmnc0,
bsubvmns0,
) = arrays[:12]
rmnc0 = _np.asarray(rmnc0)
rmns0 = _np.asarray(rmns0)
zmnc0 = _np.asarray(zmnc0)
zmns0 = _np.asarray(zmns0)
lmnc0 = _np.asarray(lmnc0)
lmns0 = _np.asarray(lmns0)
bmnc0 = _np.asarray(bmnc0)
bmns0 = _np.asarray(bmns0)
bsubumnc0 = _np.asarray(bsubumnc0)
bsubumns0 = _np.asarray(bsubumns0)
bsubvmnc0 = _np.asarray(bsubvmnc0)
bsubvmns0 = _np.asarray(bsubvmns0)
# Determine mnmax from rmnc0, allowing both (ns_full, mnmax)
# and (mnmax, ns_full) layouts. read_wout() stores a layout hint
# from NetCDF dimension names so square arrays are not ambiguous.
nonnyq_layout = getattr(self, "_vmec_nonnyq_layout", None)
mode_count_hint = int(getattr(self, "mnmax", 0) or 0)
if mode_count_hint <= 0 and getattr(self, "xm", None) is not None:
mode_count_hint = int(_np.asarray(self.xm).shape[0])
mnmax, nonnyq_layout = _infer_layout_from_shape(
rmnc0,
ns_full=ns_full,
mode_count=mode_count_hint if mode_count_hint > 0 else None,
name="rmnc0",
layout_hint=nonnyq_layout,
)
self.mnmax = mnmax
mnmax = self.mnmax
asym = self.asym
xm = _np.asarray(self.xm, dtype=int) # needs self.xm set from read_wout
# ------------------------------------------------------------------
# Canonicalize full-grid arrays to shape (ns_full, mnmax)
# SIMSOPT typically gives (mnmax, ns_full); some readers use (ns_full, mnmax).
# We unify to (ns_full, mnmax) for the interpolation logic below.
# ------------------------------------------------------------------
if nonnyq_layout == "radius_mode":
if rmnc0.shape != (ns_full, mnmax):
raise ValueError(
f"rmnc0 has unexpected shape {rmnc0.shape}; "
f"expected (ns={ns_full}, mn={mnmax})"
)
rmnc_full = rmnc0
zmns_full = zmns0
rmns_full = rmns0
zmnc_full = zmnc0
elif nonnyq_layout == "mode_radius":
if rmnc0.shape != (mnmax, ns_full):
raise ValueError(
f"rmnc0 has unexpected shape {rmnc0.shape}; "
f"expected (mn={mnmax}, ns={ns_full})"
)
rmnc_full = rmnc0.T
zmns_full = zmns0.T
rmns_full = rmns0.T
zmnc_full = zmnc0.T
else:
raise ValueError(
f"rmnc0 has unexpected shape {rmnc0.shape}; "
f"expected (ns={ns_full}, mn={mnmax}) or (mn={mnmax}, ns={ns_full})"
)
# --- Build full and half s grids like C++ ---
if s_in is not None:
# Treat s_in as the full-grid toroidal-flux coordinate, like VMEC
s_full = _np.asarray(s_in, dtype=float)
if s_full.shape[0] != ns_full:
raise ValueError("s_in must have length ns (full grid including axis)")
else:
hs = 1.0 / (ns_full - 1.0)
s_full = hs * _np.arange(ns_full)
sqrt_s_full = _np.sqrt(s_full)
sqrt_s_full[0] = 1.0 # avoid div-by-zero; rmnc(s=0)=0 for m>1 anyway
# Half grid: midpoints between full-grid points
s_half = 0.5 * (s_full[:-1] + s_full[1:])
sqrt_s_half = _np.sqrt(s_half)
# Store half-grid s_in (this is what C++ uses internally)
self.s_in = s_half.astype(float)
self.ns_in = ns_in
# --- Radial interpolation for R and Z on half grid: rmnc, zmns (+ asym parts) ---
# rmnc0, zmns0, rmns0, zmnc0 currently have shape (ns_full, mnmax)
# We build rmnc_half and zmns_half by interpolating between full-grid points, as in the C++ code.
rmnc_half = _np.empty((mnmax, ns_in), dtype=float)
zmns_half = _np.empty((mnmax, ns_in), dtype=float)
# For lambda harmonics (lmns and lmnc), the VMEC output already stores values on the
# half grid, so we do NOT perform radial interpolation. Instead, we drop the axis
# entry and reshape to (mnmax, ns_in), mimicking the original C++ code.
lmns_half = _np.empty((mnmax, ns_in), dtype=float)
if asym:
rmns_half = _np.empty((mnmax, ns_in), dtype=float)
zmnc_half = _np.empty((mnmax, ns_in), dtype=float)
lmnc_half = _np.empty((mnmax, ns_in), dtype=float)
else:
rmns_half = zmnc_half = lmnc_half = None
def copy_half_mesh(arr: _np.ndarray, name: str) -> _np.ndarray:
"""Convert VMEC half-mesh array to (mnmax, ns_in) by dropping the axis.
We support both (ns_full, mnmax) and (mnmax, ns_full) layouts,
mirroring the C++ Booz_xform::init_from_vmec behavior where
lmns0 is (mnmax, ns) and we keep columns j=1..ns-1.
"""
arr = _np.asarray(arr)
if arr.ndim != 2:
raise ValueError(f"{name} must be 2D, got shape {arr.shape}")
# Case 1: (radius, mode) = (ns_full, mnmax)
if nonnyq_layout == "radius_mode" and arr.shape == (ns_full, mnmax):
# Drop axis row, then transpose → (ns_in, mnmax) → (mnmax, ns_in)
return arr[1:, :].T
# Case 2: (mode, radius) = (mnmax, ns_full)
if nonnyq_layout == "mode_radius" and arr.shape == (mnmax, ns_full):
# Drop axis column → (mnmax, ns_in)
return arr[:, 1:]
raise ValueError(
f"{name} has unexpected shape {arr.shape}; "
f"expected (ns={ns_full}, mn={mnmax}) or (mn={mnmax}, ns={ns_full})"
)
# Drop axis and reshape lmns (half mesh)
lmns_half[:, :] = copy_half_mesh(lmns0, "lmns0")
# Asymmetric lambda (if present) is also on the half mesh.
if asym and lmnc0 is not None and lmnc0.size > 0:
lmnc_half[:, :] = copy_half_mesh(lmnc0, "lmnc0")
# Interpolate RMNC and ZMNS from full grid (ns_full) to half grid (ns_in).
# For even m: average adjacent full‑grid points
# For odd m: interpolate f/√s on the full grid and multiply by √s on the half grid.
# -------- fully vectorised interpolation over m --------
even_mask = (xm % 2 == 0)
odd_mask = ~even_mask
even_idx = _np.nonzero(even_mask)[0]
odd_idx = _np.nonzero(odd_mask)[0]
# Even m: simple average of adjacent full-grid points\
if even_idx.size > 0:
rmnc_half[even_idx, :] = 0.5 * (
rmnc_full[:-1, even_idx] + rmnc_full[1:, even_idx]
).T
zmns_half[even_idx, :] = 0.5 * (
zmns_full[:-1, even_idx] + zmns_full[1:, even_idx]
).T
if asym:
rmns_half[even_idx, :] = 0.5 * (
rmns_full[:-1, even_idx] + rmns_full[1:, even_idx]
).T
zmnc_half[even_idx, :] = 0.5 * (
zmnc_full[:-1, even_idx] + zmnc_full[1:, even_idx]
).T
# Odd m: interpolate f/√s on the full grid and multiply by √s on the half grid.
if odd_idx.size > 0:
# shapes: (ns_in, n_odd)
rmnc_odd = 0.5 * (
(rmnc_full[:-1, odd_idx] / sqrt_s_full[:-1, None]) +
(rmnc_full[1:, odd_idx] / sqrt_s_full[1:, None])
) * sqrt_s_half[:, None]
zmns_odd = 0.5 * (
(zmns_full[:-1, odd_idx] / sqrt_s_full[:-1, None]) +
(zmns_full[1:, odd_idx] / sqrt_s_full[1:, None])
) * sqrt_s_half[:, None]
rmnc_half[odd_idx, :] = rmnc_odd.T
zmns_half[odd_idx, :] = zmns_odd.T
if asym:
rmns_odd = 0.5 * (
(rmns_full[:-1, odd_idx] / sqrt_s_full[:-1, None]) +
(rmns_full[1:, odd_idx] / sqrt_s_full[1:, None])
) * sqrt_s_half[:, None]
zmnc_odd = 0.5 * (
(zmnc_full[:-1, odd_idx] / sqrt_s_full[:-1, None]) +
(zmnc_full[1:, odd_idx] / sqrt_s_full[1:, None])
) * sqrt_s_half[:, None]
rmns_half[odd_idx, :] = rmns_odd.T
zmnc_half[odd_idx, :] = zmnc_odd.T
# m = 1 special axis extrapolation (for all mn with m==1)
axis_idx = _np.nonzero(xm == 1)[0]
if axis_idx.size > 0:
rmnc_axis = (
1.5 * rmnc_full[1, axis_idx] / sqrt_s_full[1]
- 0.5 * rmnc_full[2, axis_idx] / sqrt_s_full[2]
) * sqrt_s_half[0]
zmns_axis = (
1.5 * zmns_full[1, axis_idx] / sqrt_s_full[1]
- 0.5 * zmns_full[2, axis_idx] / sqrt_s_full[2]
) * sqrt_s_half[0]
rmnc_half[axis_idx, 0] = rmnc_axis
zmns_half[axis_idx, 0] = zmns_axis
if asym:
rmns_axis = (
1.5 * rmns_full[1, axis_idx] / sqrt_s_full[1]
- 0.5 * rmns_full[2, axis_idx] / sqrt_s_full[2]
) * sqrt_s_half[0]
zmnc_axis = (
1.5 * zmnc_full[1, axis_idx] / sqrt_s_full[1]
- 0.5 * zmnc_full[2, axis_idx] / sqrt_s_full[2]
) * sqrt_s_half[0]
rmns_half[axis_idx, 0] = rmns_axis
zmnc_half[axis_idx, 0] = zmnc_axis
# Now store these in the same orientation as the C++ internal rmnc(jmn, js)
# i.e. (mnmax, ns_in) as JAX arrays:
self.rmnc = jnp.asarray(rmnc_half, dtype=jnp.float64)
self.zmns = jnp.asarray(zmns_half, dtype=jnp.float64)
self.lmns = jnp.asarray(lmns_half, dtype=jnp.float64)
if asym:
self.rmns = jnp.asarray(rmns_half, dtype=jnp.float64)
self.zmnc = jnp.asarray(zmnc_half, dtype=jnp.float64)
self.lmnc = jnp.asarray(lmnc_half, dtype=jnp.float64)
else:
self.rmns = None
self.zmnc = None
self.lmnc = None
# ------------------------------------------------------------------
# Nyquist arrays: canonicalize to (ns_full, mnmax_nyq)
# SIMSOPT/C++ style is typically (mnmax_nyq, ns_full).
# ------------------------------------------------------------------
nyq_layout = getattr(self, "_vmec_nyq_layout", None)
nyq_mode_count_hint = int(getattr(self, "mnmax_nyq", 0) or 0)
if nyq_mode_count_hint <= 0 and getattr(self, "xm_nyq", None) is not None:
nyq_mode_count_hint = int(_np.asarray(self.xm_nyq).shape[0])
mnmax_nyq, nyq_layout = _infer_layout_from_shape(
bmnc0,
ns_full=ns_full,
mode_count=nyq_mode_count_hint if nyq_mode_count_hint > 0 else None,
name="bmnc0",
layout_hint=nyq_layout,
)
if nyq_layout == "radius_mode":
# (ns_full, mnmax_nyq)
bmnc_full = bmnc0
bsubumnc_full = bsubumnc0
bsubvmnc_full = bsubvmnc0
bmns_full = bmns0
bsubumns_full = bsubumns0
bsubvmns_full = bsubvmns0
elif nyq_layout == "mode_radius":
# (mnmax_nyq, ns_full) → transpose
bmnc_full = bmnc0.T
bsubumnc_full = bsubumnc0.T
bsubvmnc_full = bsubvmnc0.T
bmns_full = bmns0.T
bsubumns_full = bsubumns0.T
bsubvmns_full = bsubvmns0.T
else:
raise ValueError(
f"bmnc0 has unexpected shape {bmnc0.shape}; "
f"one dimension must equal ns={ns_full}"
)
self.mnmax_nyq = mnmax_nyq
def strip_axis_nyq(arr_full: _np.ndarray, name: str) -> jnp.ndarray:
if arr_full.ndim != 2 or arr_full.shape[0] != ns_full:
raise ValueError(
f"strip_axis_nyq: {name} expected shape (ns={ns_full}, *), got {arr_full.shape}"
)
# arr_full: (ns_full, mnmax_nyq) → drop s=0 row → (ns_in, mnmax_nyq)
# → transpose to (mnmax_nyq, ns_in).
return jnp.asarray(arr_full[1:, :].T, dtype=jnp.float64)
self.bmnc = strip_axis_nyq(bmnc_full, "bmnc0")
self.bsubumnc = strip_axis_nyq(bsubumnc_full, "bsubumnc0")
self.bsubvmnc = strip_axis_nyq(bsubvmnc_full, "bsubvmnc0")
if self.asym:
self.bmns = strip_axis_nyq(bmns_full, "bmns0")
self.bsubumns = strip_axis_nyq(bsubumns_full, "bsubumns0")
self.bsubvmns = strip_axis_nyq(bsubvmns_full, "bsubvmns0")
else:
self.bmns = None
self.bsubumns = None
self.bsubvmns = None
# Store Boozer I and G profiles for all half-grid surfaces (as numpy).
# With our layout (mnmax_nyq, ns_in), the (m=0,n=0) mode is row 0.
self.Boozer_I_all = _np.asarray(self.bsubumnc[0, :])
self.Boozer_G_all = _np.asarray(self.bsubvmnc[0, :])
# Check for flux arrays
if len(arrays) == 16:
phip0, chi0, pres0, phi0 = ( _np.asarray(a) for a in arrays[12:16] )
ns_full = phip0.shape[0]
ns_in = ns_full - 1
two_pi = 2.0 * _np.pi
# Match C++ sizes and scaling
phip = _np.empty(ns_in + 1, dtype=float)
chi = _np.empty(ns_in + 1, dtype=float)
pres = _np.empty(ns_in + 1, dtype=float)
phi = _np.empty(ns_in + 1, dtype=float)
for j in range(ns_in + 1):
phip[j] = -phip0[j] / two_pi
chi[j] = chi0[j]
pres[j] = pres0[j]
phi[j] = phi0[j]
# Store flux profiles on the half grid (drop the axis), to be consistent
# with iota, rmnc, etc., which all live on ns_in surfaces.
self.phip = jnp.asarray(phip[1:], dtype=jnp.float64) # length ns_in
self.chi = jnp.asarray(chi[1:], dtype=jnp.float64)
self.pres = jnp.asarray(pres[1:], dtype=jnp.float64)
self.phi = jnp.asarray(phi[1:], dtype=jnp.float64)
# Toroidal flux: keep the full-grid last value (outer surface)
self.toroidal_flux = float(phi[ns_full - 1])
else:
self.phip = self.chi = self.pres = self.phi = None
self.toroidal_flux = 0.0
# Set default compute_surfs if not already set
if self.compute_surfs is None:
self.compute_surfs = list(range(ns_in))
else:
# Validate existing indices
cs = list(self.compute_surfs)
for idx in cs:
if idx < 0 or idx >= ns_in:
raise ValueError(
f"compute_surfs has an entry {idx} outside the range [0, {ns_in - 1}]"
)
self.compute_surfs = cs
return None
[docs]
def read_wout(self, filename: str, flux: bool = False) -> None:
"""Read a VMEC ``wout`` file and populate the internal arrays.
This routine loads the equilibrium data from a VMEC NetCDF file (the
file whose name begins with ``wout_``). The Fourier mode definitions,
the non‑Nyquist and Nyquist Fourier coefficients, and optional
flux profiles are read. Once the data are assembled, this
function calls :func:`init_from_vmec` on the instance to prepare
the arrays for the Boozer transformation.
Parameters
----------
self : Booz_xform
The instance to populate.
filename : str
Path to a VMEC wout NetCDF file.
flux : bool, optional
If ``True``, the flux profile arrays ``phipf`` (or ``phips``),
``chi``, ``pres`` and ``phi`` are read and passed to
:func:`init_from_vmec`. When ``False``, these arrays are
ignored.
"""
# Open file via netCDF4 or SciPy
if netCDF4 is not None:
ds = netCDF4.Dataset(filename, 'r') # type: ignore
use_scipy = False
elif netcdf_file is not None:
ds = netcdf_file(filename, 'r', mmap=False) # type: ignore
use_scipy = True
else:
raise RuntimeError("No NetCDF reader available. Install netCDF4 or SciPy.")
if self.verbose > 0:
print(f"[booz_xform_jax] Reading wout file: {filename}")
print(f"[booz_xform_jax] Using NetCDF reader: {'netCDF4' if not use_scipy else 'scipy.io.netcdf'}")
# Read symmetry flag
# In netCDF4 dimensions are names with double underscores
lasym_name = 'lasym__logical__'
if lasym_name in ds.variables:
lasym = bool(ds.variables[lasym_name][...].item())
else:
# Fallback for SciPy file
lasym = bool(getattr(ds, lasym_name, False))
self.asym = lasym
# Read field periodicity
self.nfp = int(ds.variables['nfp'][...].item())
# Non‑Nyquist dimension sizes
self.mpol = int(ds.variables['mpol'][...].item())
self.ntor = int(ds.variables['ntor'][...].item())
self.mnmax = int(ds.variables['mnmax'][...].item())
self.mnmax_nyq = int(ds.variables['mnmax_nyq'][...].item())
# Read mode number arrays
self.xm = _np.asarray(ds.variables['xm'][:], dtype=int)
self.xn = _np.asarray(ds.variables['xn'][:], dtype=int)
self.xm_nyq = _np.asarray(ds.variables['xm_nyq'][:], dtype=int)
self.xn_nyq = _np.asarray(ds.variables['xn_nyq'][:], dtype=int)
self.mpol_nyq = int(self.xm_nyq[-1])
self.ntor_nyq = int(self.xn_nyq[-1] // self.nfp)
self.ns_vmec = int(ds.variables['ns'][...].item())
if self.verbose > 0:
print(f"[booz_xform_jax] mpol={self.mpol}, ntor={self.ntor}, mnmax={self.mnmax}")
print(f"[booz_xform_jax] mpol_nyq={self.mpol_nyq}, ntor_nyq={self.ntor_nyq}, mnmax_nyq={self.mnmax_nyq}")
# Preserve NetCDF dimension metadata for init_from_vmec(). This is
# required when ns == mnmax because shape-only orientation inference
# cannot distinguish (radius, mode) from (mode, radius).
_set_layout_hint_from_variable(self, ds.variables["rmnc"], nyquist=False)
_set_layout_hint_from_variable(self, ds.variables["bmnc"], nyquist=True)
# Read non-Nyquist Fourier coefficients (shape (mnmax, ns))
rmnc0 = _np.asarray(ds.variables['rmnc'][:])
rmns0 = _np.asarray(ds.variables['rmns'][:]) if self.asym else _np.zeros_like(rmnc0)
zmnc0 = _np.asarray(ds.variables['zmnc'][:]) if self.asym else _np.zeros_like(rmnc0)
zmns0 = _np.asarray(ds.variables['zmns'][:])
lmnc0 = _np.asarray(ds.variables['lmnc'][:]) if self.asym else _np.zeros_like(rmnc0)
lmns0 = _np.asarray(ds.variables['lmns'][:])
# Read Nyquist Fourier coefficients (shape (mnmax_nyq, ns))
bmnc0 = _np.asarray(ds.variables['bmnc'][:])
bmns0 = _np.asarray(ds.variables['bmns'][:]) if self.asym else _np.zeros_like(bmnc0)
bsubumnc0 = _np.asarray(ds.variables['bsubumnc'][:])
bsubumns0 = _np.asarray(ds.variables['bsubumns'][:]) if self.asym else _np.zeros_like(bmnc0)
bsubvmnc0 = _np.asarray(ds.variables['bsubvmnc'][:])
bsubvmns0 = _np.asarray(ds.variables['bsubvmns'][:]) if self.asym else _np.zeros_like(bmnc0)
# Determine number of radial surfaces
ns = rmnc0.shape[0]
# Initialize variables for flux profiles
phip0 = chi0 = pres0 = phi0 = None
if flux:
# VMEC stores phipf as derivative of toroidal flux, sometimes named phipf or phips
if 'phipf' in ds.variables:
phip0 = _np.asarray(ds.variables['phipf'][:])
elif 'phips' in ds.variables:
phip0 = _np.asarray(ds.variables['phips'][:])
if 'chi' in ds.variables:
chi0 = _np.asarray(ds.variables['chi'][:])
if 'pres' in ds.variables:
pres0 = _np.asarray(ds.variables['pres'][:])
if 'phi' in ds.variables:
phi0 = _np.asarray(ds.variables['phi'][:])
# Extract iotas before closing the dataset; shape (ns,)
iotas = _np.asarray(ds.variables['iotas'][:])
# Record additional scalar quantities before closing. The aspect
# ratio and toroidal flux are stored on the full grid. We copy
# them now so that we can safely close the dataset.
try:
aspect0 = float(ds.variables['aspect'][...].item())
except Exception:
aspect0 = 0.0
# The toroidal flux is stored in phi0; if phi0 was read above we
# will extract the last value later; otherwise it remains zero.
# Extract iotas before closing the dataset; shape (ns,)
iotas = _np.asarray(ds.variables['iotas'][:])
# Close dataset
if use_scipy:
ds.close()
else:
ds.close()
# Build argument list for init_from_vmec
args = [ns, iotas]
# Non-Nyquist arrays
args.extend([
rmnc0,
rmns0,
zmnc0,
zmns0,
lmnc0,
lmns0,
bmnc0,
bmns0,
bsubumnc0,
bsubumns0,
bsubvmnc0,
bsubvmns0,
])
# Append flux arrays if requested and available
if flux and phip0 is not None and chi0 is not None and pres0 is not None and phi0 is not None:
args.extend([phip0, chi0, pres0, phi0])
# Call init_from_vmec on self
init_from_vmec(self, *args)
# Set aspect ratio from stored value
self.aspect = aspect0
return None
[docs]
def read_wout_data(self, wout, flux: bool = False) -> None:
"""Populate the instance from a VMEC wout-like object (e.g. vmec_jax.WoutData).
This mirrors :func:`read_wout` but accepts an in-memory object with the
standard VMEC fields as attributes instead of reading a NetCDF file.
Parameters
----------
self : Booz_xform
The instance to populate.
wout : object
A wout-like object with VMEC attributes (e.g. ``vmec_jax.WoutData``).
flux : bool, optional
If ``True``, attempt to populate flux profile arrays. When the
required fields are unavailable, the profiles are silently skipped
to match ``read_wout`` behavior.
"""
if self.verbose > 0:
print("[booz_xform_jax] Reading wout data from object")
lasym = bool(getattr(wout, "lasym", getattr(wout, "asym", False)))
self.asym = lasym
self.nfp = int(getattr(wout, "nfp"))
self.mpol = int(getattr(wout, "mpol"))
self.ntor = int(getattr(wout, "ntor"))
self.xm = _np.asarray(getattr(wout, "xm"), dtype=int)
self.xn = _np.asarray(getattr(wout, "xn"), dtype=int)
self.xm_nyq = _np.asarray(getattr(wout, "xm_nyq"), dtype=int)
self.xn_nyq = _np.asarray(getattr(wout, "xn_nyq"), dtype=int)
self.mnmax = int(self.xm.shape[0])
self.mnmax_nyq = int(self.xm_nyq.shape[0])
self.mpol_nyq = int(self.xm_nyq[-1]) if self.xm_nyq.size else 0
self.ntor_nyq = int(self.xn_nyq[-1] // self.nfp) if self.xn_nyq.size else 0
self.ns_vmec = int(getattr(wout, "ns", getattr(wout, "ns_vmec", 0)))
if self.verbose > 0:
print(f"[booz_xform_jax] mpol={self.mpol}, ntor={self.ntor}, mnmax={self.mnmax}")
print(f"[booz_xform_jax] mpol_nyq={self.mpol_nyq}, ntor_nyq={self.ntor_nyq}, mnmax_nyq={self.mnmax_nyq}")
rmnc0 = _np.asarray(getattr(wout, "rmnc"))
rmns0 = _np.asarray(getattr(wout, "rmns")) if self.asym else _np.zeros_like(rmnc0)
zmnc0 = _np.asarray(getattr(wout, "zmnc")) if self.asym else _np.zeros_like(rmnc0)
zmns0 = _np.asarray(getattr(wout, "zmns"))
lmnc0 = _np.asarray(getattr(wout, "lmnc")) if self.asym else _np.zeros_like(rmnc0)
lmns0 = _np.asarray(getattr(wout, "lmns"))
bmnc0 = _np.asarray(getattr(wout, "bmnc"))
bmns0 = _np.asarray(getattr(wout, "bmns")) if self.asym else _np.zeros_like(bmnc0)
bsubumnc0 = _np.asarray(getattr(wout, "bsubumnc"))
bsubumns0 = _np.asarray(getattr(wout, "bsubumns")) if self.asym else _np.zeros_like(bmnc0)
bsubvmnc0 = _np.asarray(getattr(wout, "bsubvmnc"))
bsubvmns0 = _np.asarray(getattr(wout, "bsubvmns")) if self.asym else _np.zeros_like(bmnc0)
ns = rmnc0.shape[0]
iotas = _np.asarray(getattr(wout, "iotas"))
aspect0 = float(getattr(wout, "aspect", 0.0))
def integrate_uniform(values: _np.ndarray) -> _np.ndarray:
values = _np.asarray(values, dtype=float)
if values.ndim != 1:
raise ValueError("flux arrays must be 1D")
if values.size < 2:
return values.copy()
ds = 1.0 / (values.size - 1.0)
out = _np.zeros_like(values, dtype=float)
out[1:] = _np.cumsum(0.5 * ds * (values[:-1] + values[1:]))
return out
phip0 = chi0 = pres0 = phi0 = None
if flux:
if hasattr(wout, "phipf"):
phip0 = _np.asarray(getattr(wout, "phipf"))
elif hasattr(wout, "phips"):
phip0 = _np.asarray(getattr(wout, "phips"))
if hasattr(wout, "chi"):
chi0 = _np.asarray(getattr(wout, "chi"))
elif hasattr(wout, "chipf"):
chi0 = integrate_uniform(_np.asarray(getattr(wout, "chipf")))
if hasattr(wout, "presf"):
pres0 = _np.asarray(getattr(wout, "presf"))
elif hasattr(wout, "pres"):
pres0 = _np.asarray(getattr(wout, "pres"))
if hasattr(wout, "phi"):
phi0 = _np.asarray(getattr(wout, "phi"))
elif phip0 is not None:
phi0 = integrate_uniform(phip0)
args = [ns, iotas]
args.extend([
rmnc0,
rmns0,
zmnc0,
zmns0,
lmnc0,
lmns0,
bmnc0,
bmns0,
bsubumnc0,
bsubumns0,
bsubvmnc0,
bsubvmns0,
])
if flux and phip0 is not None and chi0 is not None and pres0 is not None and phi0 is not None:
args.extend([phip0, chi0, pres0, phi0])
init_from_vmec(self, *args)
self.aspect = aspect0
return None