Source code for prismo.boundaries.mode_port

"""
Mode port boundary conditions for waveguide simulations.

This module implements mode ports that can inject and extract waveguide modes
at simulation boundaries, enabling accurate S-parameter calculations.
"""

from dataclasses import dataclass
from typing import Literal, Optional

import numpy as np

from prismo.core.fields import ElectromagneticFields
from prismo.core.grid import YeeGrid
from prismo.modes.solver import WaveguideMode
from prismo.utils.mode_matching import (
    compute_mode_overlap,
    interpolate_mode_to_grid,
)


[docs] @dataclass class ModePortConfig: """ Configuration for a mode port. Attributes ---------- center : Tuple[float, float, float] Port center position. size : Tuple[float, float, float] Port size (transverse dimensions). direction : str Port normal direction and orientation ('+x', '-x', '+y', '-y', '+z', '-z'). modes : List[WaveguideMode] Modes supported by this port. inject : bool Whether to inject modes (source) or only extract (monitor). """ center: tuple[float, float, float] size: tuple[float, float, float] direction: Literal["+x", "-x", "+y", "-y", "+z", "-z"] modes: list[WaveguideMode] inject: bool = False
[docs] class ModePort: """ Mode port for injecting and extracting waveguide modes. A mode port acts as both a source (injecting modes) and a monitor (extracting mode amplitudes) at a plane in the simulation domain. Parameters ---------- config : ModePortConfig Port configuration. name : str, optional Port name for identification. enabled : bool, optional Enable/disable port, default=True. Examples -------- >>> from prismo.modes.solver import ModeSolver >>> # Solve for waveguide modes >>> mode_solver = ModeSolver(wavelength=1.55e-6, x=x, y=y, epsilon=eps) >>> modes = mode_solver.solve(num_modes=2, mode_type='TE') >>> >>> # Create mode port >>> config = ModePortConfig( ... center=(0.0, 0.0, 0.0), ... size=(2e-6, 2e-6, 0.0), ... direction='+z', ... modes=modes, ... inject=True, ... ) >>> port = ModePort(config, name='input_port') """
[docs] def __init__( self, config: ModePortConfig, name: Optional[str] = None, enabled: bool = True, ): self.config = config self.name = name or f"ModePort_{id(self)}" self.enabled = enabled # Grid information (set during initialization) self._grid: Optional[YeeGrid] = None self._port_region: dict[str, np.ndarray] = {} # Interpolated mode profiles on simulation grid self._interpolated_modes: list[WaveguideMode] = [] # Mode coefficients storage self._mode_coefficients: dict[int, list[complex]] = { i: [] for i in range(len(config.modes)) } self._time_points: list[float] = [] # Parse direction self.sign = +1 if config.direction[0] == "+" else -1 self.axis = config.direction[-1].lower()
[docs] def initialize(self, grid: YeeGrid) -> None: """ Initialize the mode port on the simulation grid. Parameters ---------- grid : YeeGrid Simulation grid. """ self._grid = grid self._setup_port_region() self._interpolate_modes()
def _setup_port_region(self) -> None: """Set up port region in grid coordinates.""" if self._grid is None: raise RuntimeError("Port must be initialized with a grid first") # Determine port plane indices based on direction center = self.config.center size = self.config.size # Get grid index for port position if self.axis == "x": port_idx = int((center[0] - self._grid.origin[0]) / self._grid.dx) y_min = int( (center[1] - size[1] / 2 - self._grid.origin[1]) / self._grid.dy ) y_max = int( (center[1] + size[1] / 2 - self._grid.origin[1]) / self._grid.dy ) z_min = ( int((center[2] - size[2] / 2 - self._grid.origin[2]) / self._grid.dz) if self._grid.is_3d else 0 ) z_max = ( int((center[2] + size[2] / 2 - self._grid.origin[2]) / self._grid.dz) if self._grid.is_3d else 1 ) self._port_region = { "x_idx": port_idx, "y_slice": slice(y_min, y_max), "z_slice": slice(z_min, z_max), "normal_axis": 0, } elif self.axis == "y": x_min = int( (center[0] - size[0] / 2 - self._grid.origin[0]) / self._grid.dx ) x_max = int( (center[0] + size[0] / 2 - self._grid.origin[0]) / self._grid.dx ) port_idx = int((center[1] - self._grid.origin[1]) / self._grid.dy) z_min = ( int((center[2] - size[2] / 2 - self._grid.origin[2]) / self._grid.dz) if self._grid.is_3d else 0 ) z_max = ( int((center[2] + size[2] / 2 - self._grid.origin[2]) / self._grid.dz) if self._grid.is_3d else 1 ) self._port_region = { "x_slice": slice(x_min, x_max), "y_idx": port_idx, "z_slice": slice(z_min, z_max), "normal_axis": 1, } else: # z x_min = int( (center[0] - size[0] / 2 - self._grid.origin[0]) / self._grid.dx ) x_max = int( (center[0] + size[0] / 2 - self._grid.origin[0]) / self._grid.dx ) y_min = int( (center[1] - size[1] / 2 - self._grid.origin[1]) / self._grid.dy ) y_max = int( (center[1] + size[1] / 2 - self._grid.origin[1]) / self._grid.dy ) port_idx = ( int((center[2] - self._grid.origin[2]) / self._grid.dz) if self._grid.is_3d else 0 ) self._port_region = { "x_slice": slice(x_min, x_max), "y_slice": slice(y_min, y_max), "z_idx": port_idx, "normal_axis": 2, } def _interpolate_modes(self) -> None: """Interpolate mode profiles to simulation grid.""" if self._grid is None: raise RuntimeError("Port must be initialized first") # Get transverse grid coordinates from port region if self.axis == "x": y_slice = self._port_region["y_slice"] z_slice = self._port_region["z_slice"] y_coords = ( np.arange(y_slice.start, y_slice.stop) * self._grid.dy + self._grid.origin[1] ) z_coords = ( np.arange(z_slice.start, z_slice.stop) * self._grid.dz + self._grid.origin[2] if self._grid.is_3d else np.array([0.0]) ) grid_coords = (y_coords, z_coords) elif self.axis == "y": x_slice = self._port_region["x_slice"] z_slice = self._port_region["z_slice"] x_coords = ( np.arange(x_slice.start, x_slice.stop) * self._grid.dx + self._grid.origin[0] ) z_coords = ( np.arange(z_slice.start, z_slice.stop) * self._grid.dz + self._grid.origin[2] if self._grid.is_3d else np.array([0.0]) ) grid_coords = (x_coords, z_coords) else: # z x_slice = self._port_region["x_slice"] y_slice = self._port_region["y_slice"] x_coords = ( np.arange(x_slice.start, x_slice.stop) * self._grid.dx + self._grid.origin[0] ) y_coords = ( np.arange(y_slice.start, y_slice.stop) * self._grid.dy + self._grid.origin[1] ) grid_coords = (x_coords, y_coords) # Interpolate each mode for mode in self.config.modes: # Interpolate mode to grid interp_mode = interpolate_mode_to_grid(mode, grid_coords[0], grid_coords[1]) self._interpolated_modes.append(interp_mode)
[docs] def inject_fields( self, fields: ElectromagneticFields, time: float, dt: float, mode_amplitudes: Optional[list[complex]] = None, ) -> None: """ Inject mode fields into the simulation. This method adds mode field patterns to the simulation fields at the port location, with proper Yee grid staggering. Parameters ---------- fields : ElectromagneticFields Field arrays to update. time : float Current simulation time. dt : float Time step. mode_amplitudes : List[complex], optional Complex amplitudes for each mode. If None, uses unit amplitude. """ if not self.enabled or not self.config.inject: return if mode_amplitudes is None: mode_amplitudes = [1.0] * len(self._interpolated_modes) # For each mode, add its field pattern for _mode_idx, (mode, amplitude) in enumerate( zip(self._interpolated_modes, mode_amplitudes) ): # Get mode phase (propagating wave) 2 * np.pi * mode.neff.real / mode.wavelength omega = 2 * np.pi * mode.frequency # Time-dependent amplitude phase = omega * time amp_t = amplitude * np.exp(1j * phase) # Extract real part for time-domain injection amp_real = amp_t.real # Add mode fields to simulation fields at port region self._add_mode_to_fields(fields, mode, amp_real)
def _add_mode_to_fields( self, fields: ElectromagneticFields, mode: WaveguideMode, amplitude: float, ) -> None: """ Add mode field pattern to simulation fields. Properly handles Yee grid staggering for different field components. Parameters ---------- fields : ElectromagneticFields Field arrays to update. mode : WaveguideMode Mode with field patterns. amplitude : float Real-valued amplitude factor. """ # Get port region slicing region = self._port_region # Add fields based on axis orientation if self.axis == "z": # Port in xy-plane x_slice = region["x_slice"] y_slice = region["y_slice"] z_idx = region.get("z_idx", 0) # Add tangential E fields (Ex, Ey) and normal H field (Hz) if hasattr(fields, "Ex"): fields.Ex[x_slice, y_slice, z_idx] += amplitude * mode.Ex.real if hasattr(fields, "Ey"): fields.Ey[x_slice, y_slice, z_idx] += amplitude * mode.Ey.real if hasattr(fields, "Hz"): fields.Hz[x_slice, y_slice, z_idx] += amplitude * mode.Hz.real elif self.axis == "x": # Port in yz-plane x_idx = region["x_idx"] y_slice = region["y_slice"] z_slice = region["z_slice"] if hasattr(fields, "Ey"): fields.Ey[x_idx, y_slice, z_slice] += amplitude * mode.Ey.real if hasattr(fields, "Ez"): fields.Ez[x_idx, y_slice, z_slice] += amplitude * mode.Ez.real if hasattr(fields, "Hx"): fields.Hx[x_idx, y_slice, z_slice] += amplitude * mode.Hx.real else: # y axis # Port in xz-plane x_slice = region["x_slice"] y_idx = region["y_idx"] z_slice = region["z_slice"] if hasattr(fields, "Ex"): fields.Ex[x_slice, y_idx, z_slice] += amplitude * mode.Ex.real if hasattr(fields, "Ez"): fields.Ez[x_slice, y_idx, z_slice] += amplitude * mode.Ez.real if hasattr(fields, "Hy"): fields.Hy[x_slice, y_idx, z_slice] += amplitude * mode.Hy.real
[docs] def extract_mode_coefficients( self, fields: ElectromagneticFields, time: float, ) -> list[complex]: """ Extract mode coefficients from simulation fields. Uses overlap integrals to decompose fields into mode amplitudes. Parameters ---------- fields : ElectromagneticFields Current simulation fields. time : float Current time. Returns ------- List[complex] Mode coefficients for each port mode. """ if not self.enabled: return [0.0] * len(self._interpolated_modes) # Extract field slices at port Ex, Ey, Ez, Hx, Hy, Hz = self._extract_field_slice(fields) # Compute overlap with each mode coefficients = [] for mode in self._interpolated_modes: coeff = compute_mode_overlap( Ex, Ey, Ez, Hx, Hy, Hz, mode, direction=self.axis, dx=self._grid.dx if self.axis != "x" else self._grid.dy, dy=( self._grid.dy if self.axis != "y" else (self._grid.dz if self._grid.is_3d else self._grid.dx) ), ) coefficients.append(coeff) # Store coefficients for mode_idx, coeff in enumerate(coefficients): self._mode_coefficients[mode_idx].append(coeff) self._time_points.append(time) return coefficients
def _extract_field_slice( self, fields: ElectromagneticFields, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Extract field components at port plane. Returns ------- Tuple of arrays (Ex, Ey, Ez, Hx, Hy, Hz) at port location. """ region = self._port_region # Extract based on port orientation if self.axis == "z": x_slice = region["x_slice"] y_slice = region["y_slice"] z_idx = region.get("z_idx", 0) Ex = ( fields.Ex[x_slice, y_slice, z_idx] if hasattr(fields, "Ex") else np.zeros((1, 1)) ) Ey = ( fields.Ey[x_slice, y_slice, z_idx] if hasattr(fields, "Ey") else np.zeros((1, 1)) ) Ez = ( fields.Ez[x_slice, y_slice, z_idx] if hasattr(fields, "Ez") else np.zeros((1, 1)) ) Hx = ( fields.Hx[x_slice, y_slice, z_idx] if hasattr(fields, "Hx") else np.zeros((1, 1)) ) Hy = ( fields.Hy[x_slice, y_slice, z_idx] if hasattr(fields, "Hy") else np.zeros((1, 1)) ) Hz = ( fields.Hz[x_slice, y_slice, z_idx] if hasattr(fields, "Hz") else np.zeros((1, 1)) ) elif self.axis == "x": x_idx = region["x_idx"] y_slice = region["y_slice"] z_slice = region["z_slice"] Ex = ( fields.Ex[x_idx, y_slice, z_slice] if hasattr(fields, "Ex") else np.zeros((1, 1)) ) Ey = ( fields.Ey[x_idx, y_slice, z_slice] if hasattr(fields, "Ey") else np.zeros((1, 1)) ) Ez = ( fields.Ez[x_idx, y_slice, z_slice] if hasattr(fields, "Ez") else np.zeros((1, 1)) ) Hx = ( fields.Hx[x_idx, y_slice, z_slice] if hasattr(fields, "Hx") else np.zeros((1, 1)) ) Hy = ( fields.Hy[x_idx, y_slice, z_slice] if hasattr(fields, "Hy") else np.zeros((1, 1)) ) Hz = ( fields.Hz[x_idx, y_slice, z_slice] if hasattr(fields, "Hz") else np.zeros((1, 1)) ) else: # y x_slice = region["x_slice"] y_idx = region["y_idx"] z_slice = region["z_slice"] Ex = ( fields.Ex[x_slice, y_idx, z_slice] if hasattr(fields, "Ex") else np.zeros((1, 1)) ) Ey = ( fields.Ey[x_slice, y_idx, z_slice] if hasattr(fields, "Ey") else np.zeros((1, 1)) ) Ez = ( fields.Ez[x_slice, y_idx, z_slice] if hasattr(fields, "Ez") else np.zeros((1, 1)) ) Hx = ( fields.Hx[x_slice, y_idx, z_slice] if hasattr(fields, "Hx") else np.zeros((1, 1)) ) Hy = ( fields.Hy[x_slice, y_idx, z_slice] if hasattr(fields, "Hy") else np.zeros((1, 1)) ) Hz = ( fields.Hz[x_slice, y_idx, z_slice] if hasattr(fields, "Hz") else np.zeros((1, 1)) ) return Ex, Ey, Ez, Hx, Hy, Hz
[docs] def get_mode_coefficient(self, mode_index: int) -> list[complex]: """ Get time series of mode coefficient. Parameters ---------- mode_index : int Mode index. Returns ------- List[complex] Mode coefficient vs time. """ return self._mode_coefficients[mode_index]
[docs] def get_time_points(self) -> list[float]: """ Get recorded time points. Returns ------- List[float] Time points where coefficients were recorded. """ return self._time_points