Source code for booz_xform_jax.io_utils

"""Input/output utilities for the JAX implementation of ``booz_xform``.

This module provides functions to read from and write to the NetCDF
formats used by the original ``booz_xform`` code.  In particular it
handles the ``boozmn`` file format, which stores Boozer‐coordinate
Fourier spectra and associated radial profiles.  These functions
operate on :class:`~booz_xform_jax.core.Booz_xform` instances and are
used by the methods :meth:`booz_xform_jax.core.Booz_xform.write_boozmn`
and :meth:`booz_xform_jax.core.Booz_xform.read_boozmn`.
"""

from __future__ import annotations

import numpy as _np
from typing import Optional

try:
    import jax.numpy as jnp
except ImportError as e:  # pragma: no cover
    raise ImportError(
        "The booz_xform_jax package requires JAX. Please install jax and "
        "jaxlib before using this module."
    ) from e

[docs] def write_boozmn(self, filename: str) -> None: """Write the computed Boozer Fourier spectra to a ``boozmn`` NetCDF file. The ``boozmn`` format is used by the original ``booz_xform`` package to store Boozer coordinates and related quantities. This function writes the essential information required to reconstruct the Boozer harmonics: the mode definitions, radial profiles, and spectral coefficients. If the instance was initialised from a VMEC equilibrium and :meth:`booz_xform_jax.core.Booz_xform.run` has been called, this routine will create a NetCDF file that can be read by the original ``booz_xform`` or by :func:`read_boozmn` below. Parameters ---------- self : Booz_xform The instance containing the results of the Boozer transform. filename : str Path of the output NetCDF file. """ # Ensure run() has been called if self.bmnc_b is None: raise RuntimeError("run() must be called before write_boozmn()") if self.verbose > 0: print(f"[booz_xform_jax] Writing boozmn file: {filename}") # Prepare sizes ns_in_plus_1 = int(self.ns_in) + 1 mnboz = int(self.mnboz) ns_b = int(self.ns_b) # Construct jlist: convert compute_surfs to 1-based full-grid indices jlist = _np.array([idx + 2 for idx in self.compute_surfs], dtype='i4') # Prepare radial profiles with zero prepended iota_b = _np.zeros(ns_in_plus_1) buco_b = _np.zeros(ns_in_plus_1) bvco_b = _np.zeros(ns_in_plus_1) iota_b[1:] = _np.asarray(self.iota) buco_b[1:] = _np.asarray(self.Boozer_I_all) bvco_b[1:] = _np.asarray(self.Boozer_G_all) # Helper to build profiles only if present profiles: dict[str, _np.ndarray] = {} def add_profile(name: str, arr: Optional[jnp.ndarray]) -> None: if arr is not None: prof = _np.zeros(ns_in_plus_1) prof[1:] = _np.asarray(arr) profiles[name] = prof add_profile('phip_b', self.phip) add_profile('chi_b', self.chi) add_profile('pres_b', self.pres) add_profile('phi_b', self.phi) # Spectral arrays need to be transposed to shape (pack_rad, mn_mode) bmnc_b = _np.asarray(self.bmnc_b).T rmnc_b = _np.asarray(self.rmnc_b).T zmns_b = _np.asarray(self.zmns_b).T # Parallel current harmonics: pmns_b is minus numns_b, as in the original code numns_b = _np.asarray(self.numns_b).T pmns_b = -numns_b gmn_b = _np.asarray(self.gmnc_b).T if self.asym: bmns_b = _np.asarray(self.bmns_b).T rmns_b = _np.asarray(self.rmns_b).T zmnc_b = _np.asarray(self.zmnc_b).T numnc_b = _np.asarray(self.numnc_b).T pmnc_b = -numnc_b gmns_b = _np.asarray(self.gmns_b).T # Attempt to use netCDF4 for writing; fall back to SciPy if necessary try: import netCDF4 as nc # type: ignore ds = nc.Dataset(filename, 'w') using_netcdf4 = True except Exception: using_netcdf4 = False if not using_netcdf4: # SciPy netcdf_file writes NetCDF3 format if 'scipy.io' not in globals(): from scipy.io import netcdf_file # type: ignore ds = netcdf_file(filename, 'w') # type: ignore # Define dimensions ds.createDimension('radius', ns_in_plus_1) ds.createDimension('mn_mode', mnboz) ds.createDimension('mn_modes', mnboz) ds.createDimension('comput_surfs', ns_b) ds.createDimension('pack_rad', ns_b) # Write version and symmetry flag if using_netcdf4: vvar = ds.createVariable('version', str) vvar[...] = _np.array(['JAX booz_xform'], dtype='object') asym_var = ds.createVariable('lasym__logical__', 'i4') asym_var[...] = 1 if self.asym else 0 else: ds.version = 'JAX booz_xform' ds.lasym__logical__ = 1 if self.asym else 0 # Helper to write scalars def put_scalar(name: str, value): if using_netcdf4: var = ds.createVariable(name, 'f8' if isinstance(value, float) else 'i4') var.assignValue(value) else: setattr(ds, name, value) # ns_b in the original boozmn files is the number of VMEC radial grid points # (the full "radius" grid, length ns_in + 1), *not* the number of packed # Boozer surfaces (pack_rad). Use ns_in_plus_1 here to match the C++ booz_xform. put_scalar('ns_b', int(ns_in_plus_1)) put_scalar('nfp_b', int(self.nfp)) put_scalar('mboz_b', int(self.mboz)) put_scalar('nboz_b', int(self.nboz)) put_scalar('mnboz_b', int(self.mnboz)) put_scalar('aspect_b', float(self.aspect)) # Store toroidal flux for round-trip consistency put_scalar('toroidal_flux_b', float(self.toroidal_flux)) # Write 1D arrays if using_netcdf4: ds.createVariable('jlist', 'i4', ('comput_surfs',))[:] = jlist ds.createVariable('ixm_b', 'i4', ('mn_modes',))[:] = _np.asarray(self.xm_b, dtype='i4') ds.createVariable('ixn_b', 'i4', ('mn_modes',))[:] = _np.asarray(self.xn_b, dtype='i4') ds.createVariable('iota_b', 'f8', ('radius',))[:] = iota_b ds.createVariable('buco_b', 'f8', ('radius',))[:] = buco_b ds.createVariable('bvco_b', 'f8', ('radius',))[:] = bvco_b for name, data in profiles.items(): ds.createVariable(name, 'f8', ('radius',))[:] = data else: ds.createVariable('jlist', 'i4', ('comput_surfs',))[:] = jlist ds.createVariable('ixm_b', 'i4', ('mn_modes',))[:] = _np.asarray(self.xm_b, dtype='i4') ds.createVariable('ixn_b', 'i4', ('mn_modes',))[:] = _np.asarray(self.xn_b, dtype='i4') ds.createVariable('iota_b', 'f8', ('radius',))[:] = iota_b ds.createVariable('buco_b', 'f8', ('radius',))[:] = buco_b ds.createVariable('bvco_b', 'f8', ('radius',))[:] = bvco_b for name, data in profiles.items(): ds.createVariable(name, 'f8', ('radius',))[:] = data # Write 2D arrays: dims (pack_rad, mn_mode) dims = ('pack_rad', 'mn_mode') ds.createVariable('bmnc_b', 'f8', dims)[:, :] = bmnc_b ds.createVariable('rmnc_b', 'f8', dims)[:, :] = rmnc_b ds.createVariable('zmns_b', 'f8', dims)[:, :] = zmns_b # Parallel current and its minus sign variant (pmns_b = -numns_b) ds.createVariable('numns_b', 'f8', dims)[:, :] = numns_b ds.createVariable('pmns_b', 'f8', dims)[:, :] = pmns_b ds.createVariable('gmn_b', 'f8', dims)[:, :] = gmn_b if self.asym: ds.createVariable('bmns_b', 'f8', dims)[:, :] = bmns_b ds.createVariable('rmns_b', 'f8', dims)[:, :] = rmns_b ds.createVariable('zmnc_b', 'f8', dims)[:, :] = zmnc_b ds.createVariable('numnc_b', 'f8', dims)[:, :] = numnc_b ds.createVariable('pmnc_b', 'f8', dims)[:, :] = pmnc_b ds.createVariable('gmns_b', 'f8', dims)[:, :] = gmns_b # Close file ds.close() if self.verbose > 0: print(f"[booz_xform_jax] Finished writing {filename}") return None
[docs] def read_boozmn(self, filename: str) -> None: """Read Boozer Fourier data from a ``boozmn`` NetCDF file. This routine populates a :class:`~booz_xform_jax.core.Booz_xform` instance with data from a file produced by the original ``booz_xform`` program or by :func:`write_boozmn`. It reads the mode definitions, radial profiles and spectral arrays, reorienting the latter into the internal ``(mnboz, ns_b)`` layout. Existing data on the instance will be overwritten. Parameters ---------- self : Booz_xform The instance to populate. filename : str Path to a ``boozmn`` NetCDF file. """ try: import netCDF4 as nc # type: ignore except ImportError as e: raise ImportError( "The netCDF4 package is required to read boozmn files. " "Install it via 'pip install netCDF4'" ) from e if self.verbose > 0: print(f"[booz_xform_jax] Reading boozmn file: {filename}") with nc.Dataset(filename, 'r') as ds: # Symmetry flag self.asym = bool(ds.variables['lasym__logical__'][...].item()) # Dimensions ns_in_plus_1 = ds.dimensions['radius'].size self.ns_in = ns_in_plus_1 - 1 self.ns_b = ds.dimensions['pack_rad'].size self.mnboz = ds.dimensions['mn_mode'].size # Scalars self.nfp = int(ds.variables['nfp_b'][...].item()) self.mboz = int(ds.variables['mboz_b'][...].item()) self.nboz = int(ds.variables['nboz_b'][...].item()) # Toroidal flux (if present – older files may not have it) if 'toroidal_flux_b' in ds.variables: self.toroidal_flux = float(ds.variables['toroidal_flux_b'][...].item()) else: self.toroidal_flux = 0.0 # Indices of selected surfaces (convert from 1-based jlist) self.compute_surfs = [int(j) - 2 for j in ds.variables['jlist'][:]] # Mode lists self.xm_b = _np.asarray(ds.variables['ixm_b'][:], dtype=int) self.xn_b = _np.asarray(ds.variables['ixn_b'][:], dtype=int) # Radial profiles on full grid iota_b = _np.asarray(ds.variables['iota_b'][:]) self.iota = jnp.asarray(iota_b[1:], dtype=jnp.float64) buco_b = _np.asarray(ds.variables['buco_b'][:]) bvco_b = _np.asarray(ds.variables['bvco_b'][:]) self.Boozer_I_all = buco_b[1:] self.Boozer_G_all = bvco_b[1:] # Optional profiles for name, attr in [('phip_b','phip'), ('chi_b','chi'), ('pres_b','pres'), ('phi_b','phi')]: if name in ds.variables: arr = _np.asarray(ds.variables[name][:]) setattr(self, attr, jnp.asarray(arr[1:], dtype=jnp.float64)) # Spectra: stored as (pack_rad, mn_mode) self.bmnc_b = _np.asarray(ds.variables['bmnc_b'][:, :]).T self.rmnc_b = _np.asarray(ds.variables['rmnc_b'][:, :]).T self.zmns_b = _np.asarray(ds.variables['zmns_b'][:, :]).T self.numns_b = -_np.asarray(ds.variables['pmns_b'][:, :]).T self.gmnc_b = _np.asarray(ds.variables['gmn_b'][:, :]).T if self.asym: self.bmns_b = _np.asarray(ds.variables['bmns_b'][:, :]).T self.rmns_b = _np.asarray(ds.variables['rmns_b'][:, :]).T self.zmnc_b = _np.asarray(ds.variables['zmnc_b'][:, :]).T self.numnc_b = -_np.asarray(ds.variables['pmnc_b'][:, :]).T self.gmns_b = _np.asarray(ds.variables['gmns_b'][:, :]).T else: self.bmns_b = None self.rmns_b = None self.zmnc_b = None self.numnc_b = None self.gmns_b = None # Derive Boozer I and G on selected surfaces self.Boozer_I = _np.asarray(self.Boozer_I_all)[self.compute_surfs] self.Boozer_G = _np.asarray(self.Boozer_G_all)[self.compute_surfs] # If s_in is not already defined (i.e., we haven't run init_from_vmec), # reconstruct the same half-grid used in init_from_vmec: if self.s_in is None: full_grid = _np.linspace(0.0, 1.0, ns_in_plus_1) # 0..1, length ns_in+1 half_grid = 0.5 * (full_grid[:-1] + full_grid[1:]) # midpoints self.s_in = half_grid # Set s_b for selected surfaces self.s_b = _np.asarray(self.s_in)[self.compute_surfs] if self.verbose > 0: print("[booz_xform_jax] Finished reading boozmn file") return None