Source code for prismo.core.solver

"""
Maxwell equation updates for FDTD simulations.

This module implements the core Maxwell equation updates using finite differences
on the Yee grid. The updates follow Faraday's law and Ampère's law:

∂H/∂t = -(1/μ₀) ∇ × E
∂E/∂t = (1/ε₀) ∇ × H

The curl operations are discretized using finite differences on the staggered
Yee grid to maintain second-order accuracy.
"""

from typing import Optional, Union

import numpy as np

from prismo.backends import Backend, get_backend
from prismo.solvers.base import TimeDomainSolver

from .fields import ElectromagneticFields
from .grid import YeeGrid


class MaxwellUpdater:
    """
    Core Maxwell equation updater for FDTD simulations.

    This class implements the discrete curl operations and time updates
    for electromagnetic fields on the Yee grid.

    Parameters
    ----------
    grid : YeeGrid
        The computational grid.
    dt : float
        Time step in seconds.
    material_arrays : dict, optional
        Material property arrays (ε, μ, σ). If None, vacuum properties are used.
    """

    def __init__(
        self,
        grid: YeeGrid,
        dt: float,
        material_arrays: Optional[dict] = None,
        backend: Optional[Union[str, Backend]] = None,
    ):
        self.grid = grid
        self.dt = dt

        # Initialize backend
        if isinstance(backend, str):
            self.backend = get_backend(backend)
        elif isinstance(backend, Backend):
            self.backend = backend
        elif backend is None:
            self.backend = get_backend()  # Auto-select
        else:
            raise TypeError("backend must be a Backend instance or string name")

        # Physical constants
        self.eps0 = 8.854187817e-12  # Vacuum permittivity (F/m)
        self.mu0 = 4 * self.backend.pi * 1e-7  # Vacuum permeability (H/m)
        self.c = 299792458.0  # Speed of light (m/s)

        # Validate Courant condition
        courant = self.grid.get_courant_number(dt)
        if courant >= 1.0:
            raise ValueError(
                f"Time step dt={dt:.2e} violates Courant condition (S={courant:.3f} >= 1)"
            )

        # Initialize material arrays (vacuum by default)
        self._initialize_materials(material_arrays)

        # Precompute update coefficients
        self._compute_update_coefficients()

    def _initialize_materials(self, material_arrays: Optional[dict]) -> None:
        """Initialize material property arrays."""
        # Get grid dimensions
        nx, ny, nz = self.grid.dimensions

        if material_arrays is None:
            # Vacuum everywhere
            self.eps_rel = self.backend.ones(
                (nx, ny, nz), dtype=self.backend.float64
            )  # Relative permittivity
            self.mu_rel = self.backend.ones(
                (nx, ny, nz), dtype=self.backend.float64
            )  # Relative permeability
            self.sigma_e = self.backend.zeros(
                (nx, ny, nz), dtype=self.backend.float64
            )  # Electric conductivity
            self.sigma_m = self.backend.zeros(
                (nx, ny, nz), dtype=self.backend.float64
            )  # Magnetic conductivity (usually 0)
        else:
            # Use provided material arrays (convert to backend arrays)
            self.eps_rel = self.backend.asarray(
                material_arrays.get("eps_rel", np.ones((nx, ny, nz)))
            )
            self.mu_rel = self.backend.asarray(
                material_arrays.get("mu_rel", np.ones((nx, ny, nz)))
            )
            self.sigma_e = self.backend.asarray(
                material_arrays.get("sigma_e", np.zeros((nx, ny, nz)))
            )
            self.sigma_m = self.backend.asarray(
                material_arrays.get("sigma_m", np.zeros((nx, ny, nz)))
            )

    def _compute_update_coefficients(self) -> None:
        """Precompute update coefficients for E and H field updates."""
        # For E-field update: E^(n+1) = Ca * E^n + Cb * curl(H^(n+1/2))
        # Ca and Cb account for material properties and conductivity

        # E-field coefficients (with conductivity)
        eps = self.eps0 * self.eps_rel
        sigma_dt_2eps = self.sigma_e * self.dt / (2 * eps)

        self.Ca = (1 - sigma_dt_2eps) / (1 + sigma_dt_2eps)
        self.Cb = (self.dt / eps) / (1 + sigma_dt_2eps)

        # H-field coefficients (usually no magnetic conductivity)
        mu = self.mu0 * self.mu_rel
        sigma_m_dt_2mu = self.sigma_m * self.dt / (2 * mu)

        self.Da = (1 - sigma_m_dt_2mu) / (1 + sigma_m_dt_2mu)
        self.Db = (self.dt / mu) / (1 + sigma_m_dt_2mu)

        # Grid spacing for curl operations
        self.dx, self.dy, self.dz = self.grid.spacing

    def update_magnetic_fields(self, fields: ElectromagneticFields) -> None:
        """
        Update magnetic field components using Faraday's law.

        ∂H/∂t = -(1/μ) ∇ × E

        Parameters
        ----------
        fields : ElectromagneticFields
            Electromagnetic fields to update.
        """
        if self.grid.is_2d:
            self._update_h_fields_2d(fields)
        else:
            self._update_h_fields_3d(fields)

    def update_electric_fields(self, fields: ElectromagneticFields) -> None:
        """
        Update electric field components using Ampère's law.

        ∂E/∂t = (1/ε) ∇ × H

        Parameters
        ----------
        fields : ElectromagneticFields
            Electromagnetic fields to update.
        """
        if self.grid.is_2d:
            self._update_e_fields_2d(fields)
        else:
            self._update_e_fields_3d(fields)

    def _update_h_fields_3d(self, fields: ElectromagneticFields) -> None:
        """Update H-field components for 3D case."""
        Ex, Ey, Ez = fields.Ex, fields.Ey, fields.Ez
        Hx, Hy, Hz = fields.Hx, fields.Hy, fields.Hz

        # Hx update: ∂Hx/∂t = -(1/μ) * (∂Ez/∂y - ∂Ey/∂z)
        # Hx is at (i+1/2, j, k), shape (nx-1, ny, nz)
        # Ez is at (i+1/2, j+1/2, k), shape (nx-1, ny-1, nz)
        # Ey is at (i+1/2, j, k+1/2), shape (nx-1, ny, nz-1)

        # For ∂Ez/∂y: compute (Ez[j+1/2] - Ez[j-1/2])/dy
        # This gives curl_ez_y at (i+1/2, j, k) with shape (nx-1, ny-1, nz)
        curl_ez_y = (Ez[:, 1:, :] - Ez[:, :-1, :]) / self.dy  # Shape: (nx-1, ny-1, nz)
        
        # For ∂Ey/∂z: compute (Ey[k+1/2] - Ey[k-1/2])/dz
        # This gives curl_ey_z at (i+1/2, j, k) with shape (nx-1, ny, nz-1)
        curl_ey_z = (Ey[:, :, 1:] - Ey[:, :, :-1]) / self.dz  # Shape: (nx-1, ny, nz-1)

        # Update Hx only in the interior region where both curl components are defined
        # Common region: (nx-1, ny-1, nz-1)
        # Boundary points (j=ny-1, k=nz-1) are handled by boundary conditions
        ny_common = min(Hx.shape[1], curl_ez_y.shape[1])
        nz_common = min(Hx.shape[2], curl_ey_z.shape[2])

        da_hx = self._interpolate_to_hx_points(self.Da)[
            : Hx.shape[0], :ny_common, :nz_common
        ]
        db_hx = self._interpolate_to_hx_points(self.Db)[
            : Hx.shape[0], :ny_common, :nz_common
        ]

        # Update Hx in the common region where both curls are defined
        # curl_ez_y: (nx-1, ny-1, nz) -> use [:ny_common, :nz_common]
        # curl_ey_z: (nx-1, ny, nz-1) -> use [:ny_common, :nz_common]
        Hx[:, :ny_common, :nz_common] = (
            da_hx * Hx[:, :ny_common, :nz_common]
            - db_hx * (
                curl_ez_y[:, :ny_common, :nz_common]
                - curl_ey_z[:, :ny_common, :nz_common]
            )
        )

        # Hy update: ∂Hy/∂t = -(1/μ) * (∂Ex/∂z - ∂Ez/∂x)
        # Hy is at (i, j+1/2, k), shape (nx, ny-1, nz)
        # Ex is at (i, j+1/2, k+1/2), shape (nx, ny-1, nz-1)
        # Ez is at (i+1/2, j+1/2, k), shape (nx-1, ny-1, nz)

        curl_ex_z = (Ex[:, :, 1:] - Ex[:, :, :-1]) / self.dz  # Shape: (nx, ny-1, nz-1)
        curl_ez_x = (Ez[1:, :, :] - Ez[:-1, :, :]) / self.dx  # Shape: (nx-1, ny-1, nz)

        # Update Hy only in the interior region where both curl components are defined
        nx_common = min(Hy.shape[0], curl_ex_z.shape[0], curl_ez_x.shape[0])
        nz_common = min(Hy.shape[2], curl_ex_z.shape[2], curl_ez_x.shape[2])

        da_hy = self._interpolate_to_hy_points(self.Da)[
            :nx_common, : Hy.shape[1], :nz_common
        ]
        db_hy = self._interpolate_to_hy_points(self.Db)[
            :nx_common, : Hy.shape[1], :nz_common
        ]

        Hy[:nx_common, :, :nz_common] = (
            da_hy * Hy[:nx_common, :, :nz_common]
            - db_hy * (
                curl_ex_z[:nx_common, :, :nz_common]
                - curl_ez_x[:nx_common, :, :nz_common]
            )
        )

        # Hz update: ∂Hz/∂t = -(1/μ) * (∂Ey/∂x - ∂Ex/∂y)
        # Hz is at (i, j, k+1/2), shape (nx, ny, nz-1)
        # Ey is at (i+1/2, j, k+1/2), shape (nx-1, ny, nz-1)
        # Ex is at (i, j+1/2, k+1/2), shape (nx, ny-1, nz-1)

        curl_ey_x = (Ey[1:, :, :] - Ey[:-1, :, :]) / self.dx  # Shape: (nx-1, ny, nz-1)
        curl_ex_y = (Ex[:, 1:, :] - Ex[:, :-1, :]) / self.dy  # Shape: (nx, ny-1, nz-1)

        # Update Hz only in the interior region where both curl components are defined
        nx_common = min(Hz.shape[0], curl_ey_x.shape[0], curl_ex_y.shape[0])
        ny_common = min(Hz.shape[1], curl_ey_x.shape[1], curl_ex_y.shape[1])

        da_hz = self._interpolate_to_hz_points(self.Da)[
            :nx_common, :ny_common, : Hz.shape[2]
        ]
        db_hz = self._interpolate_to_hz_points(self.Db)[
            :nx_common, :ny_common, : Hz.shape[2]
        ]

        Hz[:nx_common, :ny_common, :] = (
            da_hz * Hz[:nx_common, :ny_common, :]
            - db_hz * (
                curl_ey_x[:nx_common, :ny_common, :]
                - curl_ex_y[:nx_common, :ny_common, :]
            )
        )

    def _update_e_fields_3d(self, fields: ElectromagneticFields) -> None:
        """Update E-field components for 3D case."""
        Ex, Ey, Ez = fields.Ex, fields.Ey, fields.Ez
        Hx, Hy, Hz = fields.Hx, fields.Hy, fields.Hz

        # Ex update: ∂Ex/∂t = (1/ε) * (∂Hz/∂y - ∂Hy/∂z)
        # Ex is at (i, j+1/2, k+1/2), shape (nx, ny-1, nz-1)
        # Hz is at (i, j, k+1/2), shape (nx, ny, nz-1)
        # Hy is at (i, j+1/2, k), shape (nx, ny-1, nz)

        curl_hz_y = (Hz[:, 1:, :] - Hz[:, :-1, :]) / self.dy  # Shape: (nx, ny-1, nz-1)
        curl_hy_z = (Hy[:, :, 1:] - Hy[:, :, :-1]) / self.dz  # Shape: (nx, ny-1, nz-1)

        ca_ex = self._interpolate_to_ex_points(self.Ca)[
            : Ex.shape[0], : Ex.shape[1], : Ex.shape[2]
        ]
        cb_ex = self._interpolate_to_ex_points(self.Cb)[
            : Ex.shape[0], : Ex.shape[1], : Ex.shape[2]
        ]

        Ex[:, :, :] = ca_ex * Ex + cb_ex * (curl_hz_y - curl_hy_z)

        # Ey update: ∂Ey/∂t = (1/ε) * (∂Hx/∂z - ∂Hz/∂x)
        # Ey is at (i+1/2, j, k+1/2), shape (nx-1, ny, nz-1)
        # Hx is at (i+1/2, j, k), shape (nx-1, ny, nz)
        # Hz is at (i, j, k+1/2), shape (nx, ny, nz-1)

        curl_hx_z = (Hx[:, :, 1:] - Hx[:, :, :-1]) / self.dz  # Shape: (nx-1, ny, nz-1)
        curl_hz_x = (Hz[1:, :, :] - Hz[:-1, :, :]) / self.dx  # Shape: (nx-1, ny, nz-1)

        ca_ey = self._interpolate_to_ey_points(self.Ca)[
            : Ey.shape[0], : Ey.shape[1], : Ey.shape[2]
        ]
        cb_ey = self._interpolate_to_ey_points(self.Cb)[
            : Ey.shape[0], : Ey.shape[1], : Ey.shape[2]
        ]

        Ey[:, :, :] = ca_ey * Ey + cb_ey * (curl_hx_z - curl_hz_x)

        # Ez update: ∂Ez/∂t = (1/ε) * (∂Hy/∂x - ∂Hx/∂y)
        # Ez is at (i+1/2, j+1/2, k), shape (nx-1, ny-1, nz)
        # Hy is at (i, j+1/2, k), shape (nx, ny-1, nz)
        # Hx is at (i+1/2, j, k), shape (nx-1, ny, nz)

        curl_hy_x = (Hy[1:, :, :] - Hy[:-1, :, :]) / self.dx  # Shape: (nx-1, ny-1, nz)
        curl_hx_y = (Hx[:, 1:, :] - Hx[:, :-1, :]) / self.dy  # Shape: (nx-1, ny-1, nz)

        ca_ez = self._interpolate_to_ez_points(self.Ca)[
            : Ez.shape[0], : Ez.shape[1], : Ez.shape[2]
        ]
        cb_ez = self._interpolate_to_ez_points(self.Cb)[
            : Ez.shape[0], : Ez.shape[1], : Ez.shape[2]
        ]

        Ez[:, :, :] = ca_ez * Ez + cb_ez * (curl_hy_x - curl_hx_y)

    def _update_h_fields_2d(self, fields: ElectromagneticFields) -> None:
        """Update H-field components for 2D case."""
        # In 2D we support two modes:
        # TE mode: Ez, Hx, Hy (Ez is dominant, H fields in-plane)
        # TM mode: Hz, Ex, Ey (Hz is dominant, E fields in-plane)

        Ez = fields.Ez
        Hx, Hy = fields.Hx, fields.Hy

        # TM mode (Ez dominant): Update Hx and Hy from Ez
        if np.max(np.abs(Ez)) > 0:
            # Hx update: ∂Hx/∂t = -(1/μ) * (-∂Ez/∂y) = (1/μ) * ∂Ez/∂y
            # Hx is at (i+1/2, j), Ez is at (i+1/2, j+1/2)
            # ∂Ez/∂y at Hx position: (Ez[i,j+1/2] - Ez[i,j-1/2])/dy
            nx_hx, ny_hx = Hx.shape
            nx_ez, ny_ez = Ez.shape

            if nx_ez >= nx_hx and ny_ez > 0:
                curl_ez_y = (
                    Ez[:nx_hx, 1:] - Ez[:nx_hx, :-1]
                ) / self.dy  # Shape: (nx_hx, ny_ez-1)

                da_hx = self._interpolate_to_hx_points_2d(self.Da[:, :, 0])
                db_hx = self._interpolate_to_hx_points_2d(self.Db[:, :, 0])

                # Apply update only where curl is defined
                ny_curl = min(ny_hx, curl_ez_y.shape[1])
                Hx[:, :ny_curl] = (
                    da_hx[:nx_hx, :ny_curl] * Hx[:, :ny_curl]
                    + db_hx[:nx_hx, :ny_curl] * curl_ez_y[:, :ny_curl]
                )

            # Hy update: ∂Hy/∂t = -(1/μ) * (∂Ez/∂x)
            # Hy is at (i, j+1/2), Ez is at (i+1/2, j+1/2)
            # ∂Ez/∂x at Hy position: (Ez[i+1/2,j] - Ez[i-1/2,j])/dx
            nx_hy, ny_hy = Hy.shape

            if nx_ez > 0 and ny_ez >= ny_hy:
                curl_ez_x = (
                    Ez[1:, :ny_hy] - Ez[:-1, :ny_hy]
                ) / self.dx  # Shape: (nx_ez-1, ny_hy)

                da_hy = self._interpolate_to_hy_points_2d(self.Da[:, :, 0])
                db_hy = self._interpolate_to_hy_points_2d(self.Db[:, :, 0])

                # Apply update only where curl is defined
                nx_curl = min(nx_hy, curl_ez_x.shape[0])
                Hy[:nx_curl, :] = (
                    da_hy[:nx_curl, :ny_hy] * Hy[:nx_curl, :]
                    - db_hy[:nx_curl, :ny_hy] * curl_ez_x[:nx_curl, :]
                )

        # TE mode (Hz dominant): Update Hz from Ex, Ey
        Hz = fields.Hz
        Ex, Ey = fields.Ex, fields.Ey

        if np.max(np.abs(Ex)) > 0 or np.max(np.abs(Ey)) > 0:
            # Hz update: ∂Hz/∂t = -(1/μ) * (∂Ey/∂x - ∂Ex/∂y)
            # Hz is at (i, j), Ex is at (i, j+1/2), Ey is at (i+1/2, j)

            # Hz update: ∂Hz/∂t = -(1/μ) * (∂Ey/∂x - ∂Ex/∂y)
            # Hz is at (i, j), Ex is at (i, j+1/2), Ey is at (i+1/2, j)
            
            # Compute curl components with proper shape handling
            curl_ey_x = np.zeros_like(Hz)
            if Ey.shape[0] > 1 and Ey.shape[1] > 0 and self.dx > 0:
                # ∂Ey/∂x: (Ey[i+1,j] - Ey[i,j])/dx
                # Ey has shape (nx-1, ny), Hz has shape (nx, ny)
                # Compute derivative for overlapping region
                nx_curl_ey = min(Ey.shape[0] - 1, Hz.shape[0] - 1)
                ny_curl_ey = min(Ey.shape[1], Hz.shape[1])
                if nx_curl_ey > 0 and ny_curl_ey > 0:
                    # Clamp Ey values to prevent overflow
                    ey_slice1 = np.clip(Ey[1 : nx_curl_ey + 1, :ny_curl_ey], -1e10, 1e10)
                    ey_slice2 = np.clip(Ey[:nx_curl_ey, :ny_curl_ey], -1e10, 1e10)
                    curl_ey_x[:nx_curl_ey, :ny_curl_ey] = (ey_slice1 - ey_slice2) / self.dx
                    # Clamp curl result to prevent overflow
                    curl_ey_x[:nx_curl_ey, :ny_curl_ey] = np.clip(
                        curl_ey_x[:nx_curl_ey, :ny_curl_ey], -1e15, 1e15
                    )

            curl_ex_y = np.zeros_like(Hz)
            if Ex.shape[0] > 0 and Ex.shape[1] > 1 and self.dy > 0:
                # ∂Ex/∂y: (Ex[i,j+1] - Ex[i,j])/dy
                # Ex has shape (nx, ny-1), Hz has shape (nx, ny)
                # Compute derivative for overlapping region
                nx_curl_ex = min(Ex.shape[0], Hz.shape[0])
                ny_curl_ex = min(Ex.shape[1] - 1, Hz.shape[1] - 1)
                if nx_curl_ex > 0 and ny_curl_ex > 0:
                    # Clamp Ex values to prevent overflow
                    ex_slice1 = np.clip(Ex[:nx_curl_ex, 1 : ny_curl_ex + 1], -1e10, 1e10)
                    ex_slice2 = np.clip(Ex[:nx_curl_ex, :ny_curl_ex], -1e10, 1e10)
                    curl_ex_y[:nx_curl_ex, :ny_curl_ex] = (ex_slice1 - ex_slice2) / self.dy
                    # Clamp curl result to prevent overflow
                    curl_ex_y[:nx_curl_ex, :ny_curl_ex] = np.clip(
                        curl_ex_y[:nx_curl_ex, :ny_curl_ex], -1e15, 1e15
                    )

            da_hz = self._interpolate_to_hz_points_2d(self.Da[:, :, 0])
            db_hz = self._interpolate_to_hz_points_2d(self.Db[:, :, 0])

            nx_hz, ny_hz = Hz.shape
            # Update only in the common region where both curls are defined
            nx_common = min(nx_hz, curl_ey_x.shape[0], curl_ex_y.shape[0])
            ny_common = min(ny_hz, curl_ey_x.shape[1], curl_ex_y.shape[1])
            
            if nx_common > 0 and ny_common > 0:
                # Compute curl difference with overflow protection
                curl_diff = curl_ey_x[:nx_common, :ny_common] - curl_ex_y[:nx_common, :ny_common]
                curl_diff = np.clip(curl_diff, -1e15, 1e15)
                
                # Clamp Hz before update to prevent overflow
                hz_old = np.clip(Hz[:nx_common, :ny_common], -1e10, 1e10)
                
                Hz[:nx_common, :ny_common] = (
                    da_hz[:nx_common, :ny_common] * hz_old
                    - db_hz[:nx_common, :ny_common] * curl_diff
                )
                
                # Clamp result to prevent overflow
                Hz[:nx_common, :ny_common] = np.clip(
                    Hz[:nx_common, :ny_common], -1e10, 1e10
                )

    def _update_e_fields_2d(self, fields: ElectromagneticFields) -> None:
        """Update E-field components for 2D case."""
        # TM mode (Ez dominant): Update Ez from Hx, Hy
        Ez = fields.Ez
        Hx, Hy = fields.Hx, fields.Hy

        if Ez.shape[0] > 0 and Ez.shape[1] > 0:
            # Ez update: ∂Ez/∂t = (1/ε) * (∂Hy/∂x - ∂Hx/∂y)
            # Ez is at (i+1/2, j+1/2), Hx is at (i+1/2, j), Hy is at (i, j+1/2)

            curl_hy_x = (Hy[1:, :] - Hy[:-1, :]) / self.dx  # Shape: (nx-1, ny)
            curl_hx_y = (Hx[:, 1:] - Hx[:, :-1]) / self.dy  # Shape: (nx, ny-1)

            # Both curls need to match Ez shape (nx-1, ny-1)
            # Take appropriate slices
            ca_ez = self._interpolate_to_ez_points_2d(self.Ca[:, :, 0])
            cb_ez = self._interpolate_to_ez_points_2d(self.Cb[:, :, 0])

            nx_ez, ny_ez = Ez.shape
            ca_ez = ca_ez[:nx_ez, :ny_ez]
            cb_ez = cb_ez[:nx_ez, :ny_ez]

            Ez[:, :] = ca_ez * Ez + cb_ez * (
                curl_hy_x[:nx_ez, :ny_ez] - curl_hx_y[:nx_ez, :ny_ez]
            )

        # TE mode (Hz dominant): Update Ex, Ey from Hz
        Hz = fields.Hz
        Ex, Ey = fields.Ex, fields.Ey

        if Hz.shape[0] > 0 and Hz.shape[1] > 0:
            # Ex update: ∂Ex/∂t = (1/ε) * ∂Hz/∂y
            # Ex is at (i, j+1/2), Hz is at (i, j)
            if Ex.shape[0] > 0 and Ex.shape[1] > 0:
                curl_hz_y = (Hz[:, 1:] - Hz[:, :-1]) / self.dy  # Shape matches Ex

                ca_ex = self._interpolate_to_ex_points_2d(self.Ca[:, :, 0])
                cb_ex = self._interpolate_to_ex_points_2d(self.Cb[:, :, 0])

                nx_ex, ny_ex = Ex.shape
                ca_ex = ca_ex[:nx_ex, :ny_ex]
                cb_ex = cb_ex[:nx_ex, :ny_ex]

                Ex[:, :] = ca_ex * Ex + cb_ex * curl_hz_y

            # Ey update: ∂Ey/∂t = -(1/ε) * ∂Hz/∂x
            # Ey is at (i+1/2, j), Hz is at (i, j)
            if Ey.shape[0] > 0 and Ey.shape[1] > 0:
                curl_hz_x = (Hz[1:, :] - Hz[:-1, :]) / self.dx  # Shape matches Ey

                ca_ey = self._interpolate_to_ey_points_2d(self.Ca[:, :, 0])
                cb_ey = self._interpolate_to_ey_points_2d(self.Cb[:, :, 0])

                nx_ey, ny_ey = Ey.shape
                ca_ey = ca_ey[:nx_ey, :ny_ey]
                cb_ey = cb_ey[:nx_ey, :ny_ey]

                Ey[:, :] = ca_ey * Ey - cb_ey * curl_hz_x

    def _interpolate_to_ex_points(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Ex field points."""
        # Ex is at (i, j+1/2, k+1/2) - average over y and z
        return 0.25 * (
            array[:, :-1, :-1]
            + array[:, 1:, :-1]
            + array[:, :-1, 1:]
            + array[:, 1:, 1:]
        )

    def _interpolate_to_ey_points(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Ey field points."""
        # Ey is at (i+1/2, j, k+1/2) - average over x and z
        return 0.25 * (
            array[:-1, :, :-1]
            + array[1:, :, :-1]
            + array[:-1, :, 1:]
            + array[1:, :, 1:]
        )

    def _interpolate_to_ez_points(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Ez field points."""
        # Ez is at (i+1/2, j+1/2, k) - average over x and y
        return 0.25 * (
            array[:-1, :-1, :]
            + array[1:, :-1, :]
            + array[:-1, 1:, :]
            + array[1:, 1:, :]
        )

    def _interpolate_to_hx_points(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Hx field points."""
        # Hx is at (i+1/2, j, k) - average over x
        return 0.5 * (array[:-1, :, :] + array[1:, :, :])

    def _interpolate_to_hy_points(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Hy field points."""
        # Hy is at (i, j+1/2, k) - average over y
        return 0.5 * (array[:, :-1, :] + array[:, 1:, :])

    def _interpolate_to_hz_points(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Hz field points."""
        # Hz is at (i, j, k+1/2) - average over z
        return 0.5 * (array[:, :, :-1] + array[:, :, 1:])

    def _interpolate_to_ex_points_2d(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Ex field points (2D)."""
        # Ex is at (i, j+1/2) - average over y
        return 0.5 * (array[:, :-1] + array[:, 1:])

    def _interpolate_to_ey_points_2d(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Ey field points (2D)."""
        # Ey is at (i+1/2, j) - average over x
        return 0.5 * (array[:-1, :] + array[1:, :])

    def _interpolate_to_ez_points_2d(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Ez field points (2D)."""
        # Ez is at (i+1/2, j+1/2) - average over both x and y
        return 0.25 * (
            array[:-1, :-1] + array[1:, :-1] + array[:-1, 1:] + array[1:, 1:]
        )

    def _interpolate_to_hz_points_2d(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Hz field points (2D)."""
        # Hz is at (i, j) - no averaging needed for cell centers
        return array

    def _interpolate_to_hx_points_2d(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Hx field points (2D)."""
        # Hx is at (i+1/2, j) - average over x
        return 0.5 * (array[:-1, :] + array[1:, :])

    def _interpolate_to_hy_points_2d(self, array: np.ndarray) -> np.ndarray:
        """Interpolate material properties to Hy field points (2D)."""
        # Hy is at (i, j+1/2) - average over y
        return 0.5 * (array[:, :-1] + array[:, 1:])

    def step(self, fields: ElectromagneticFields) -> None:
        """
        Perform one complete FDTD time step.

        The leap-frog time stepping follows:
        1. Update H-fields using E-fields at time n
        2. Update E-fields using H-fields at time n+1/2

        Parameters
        ----------
        fields : ElectromagneticFields
            Electromagnetic fields to advance in time.
        """
        # Update H-fields (time n -> n+1/2)
        self.update_magnetic_fields(fields)

        # Update E-fields (time n -> n+1)
        self.update_electric_fields(fields)

    def get_time_step(self) -> float:
        """Get the time step."""
        return self.dt

    def get_courant_number(self) -> float:
        """Get the Courant number for this updater."""
        return self.grid.get_courant_number(self.dt)

    def __repr__(self) -> str:
        """String representation."""
        courant = self.get_courant_number()
        return (
            f"MaxwellUpdater(dt={self.dt:.2e}s, "
            f"Courant={courant:.3f}, grid={self.grid.dimensions})"
        )


[docs] class FDTDSolver(TimeDomainSolver): """ High-level FDTD solver combining grid, fields, and updater. This class provides a convenient interface for running FDTD simulations with automatic time stepping and field management. Parameters ---------- grid : YeeGrid Computational grid. dt : float, optional Time step. If None, automatically computed for stability. material_arrays : dict, optional Material property arrays. """
[docs] def __init__( self, grid: YeeGrid, dt: Optional[float] = None, material_arrays: Optional[dict] = None, backend: Optional[Union[str, Backend]] = None, ): super().__init__(grid) self.grid = grid # Initialize backend if isinstance(backend, str): self.backend = get_backend(backend) elif isinstance(backend, Backend): self.backend = backend elif backend is None: self.backend = get_backend() # Auto-select else: raise TypeError("backend must be a Backend instance or string name") # Auto-determine time step if not provided if dt is None: dt = grid.suggest_time_step(safety_factor=0.95) self.fields = ElectromagneticFields(grid, backend=self.backend) self.updater = MaxwellUpdater(grid, dt, material_arrays, backend=self.backend)
[docs] def run(self, total_time: float, callback: Optional[callable] = None) -> None: """ Run FDTD simulation for specified time. Parameters ---------- total_time : float Total simulation time in seconds. callback : callable, optional Function called after each time step: callback(solver, step_num). """ dt = self.updater.get_time_step() num_steps = int(np.ceil(total_time / dt)) for step in range(num_steps): # Perform one FDTD step self.updater.step(self.fields) # Update time tracking self.time += dt self.step_count += 1 # Call user callback if provided if callback is not None: callback(self, step)
[docs] def run_steps(self, num_steps: int, callback: Optional[callable] = None) -> None: """ Run FDTD simulation for specified number of steps. Parameters ---------- num_steps : int Number of time steps to run. callback : callable, optional Function called after each time step. """ dt = self.updater.get_time_step() for step in range(num_steps): self.updater.step(self.fields) self.time += dt self.step_count += 1 if callback is not None: callback(self, step)
[docs] def step(self, fields: Optional[ElectromagneticFields] = None) -> None: """ Perform a single FDTD time step. Parameters ---------- fields : ElectromagneticFields, optional Fields to update. If None, uses internal fields. """ if fields is None: fields = self.fields self.updater.step(fields) self.time += self.updater.get_time_step() self.step_count += 1
[docs] def get_fields(self) -> ElectromagneticFields: """Get the current electromagnetic fields.""" return self.fields
[docs] def get_time_step(self) -> float: """ Get the time step used by the solver. Returns ------- float Time step in seconds. """ return self.updater.get_time_step()
[docs] def reset(self) -> None: """Reset simulation to initial state.""" self.fields.zero_fields() self.time = 0.0 self.step_count = 0
[docs] def get_simulation_info(self) -> dict: """Get information about the current simulation state.""" return { "time": self.time, "step_count": self.step_count, "dt": self.updater.get_time_step(), "courant_number": self.updater.get_courant_number(), "grid_dimensions": self.grid.dimensions, "is_2d": self.grid.is_2d, "field_energy": self.fields.get_field_energy(), }
[docs] def __repr__(self) -> str: """String representation.""" info = self.get_simulation_info() return ( f"FDTDSolver(t={info['time']:.2e}s, steps={info['step_count']}, " f"E_total={info['field_energy']:.2e}J)" )