"""
Mode port boundary conditions for waveguide simulations.
This module implements mode ports that can inject and extract waveguide modes
at simulation boundaries, enabling accurate S-parameter calculations.
"""
from dataclasses import dataclass
from typing import Literal, Optional
import numpy as np
from prismo.core.fields import ElectromagneticFields
from prismo.core.grid import YeeGrid
from prismo.modes.solver import WaveguideMode
from prismo.utils.mode_matching import (
compute_mode_overlap,
interpolate_mode_to_grid,
)
[docs]
@dataclass
class ModePortConfig:
"""
Configuration for a mode port.
Attributes
----------
center : Tuple[float, float, float]
Port center position.
size : Tuple[float, float, float]
Port size (transverse dimensions).
direction : str
Port normal direction and orientation ('+x', '-x', '+y', '-y', '+z', '-z').
modes : List[WaveguideMode]
Modes supported by this port.
inject : bool
Whether to inject modes (source) or only extract (monitor).
"""
center: tuple[float, float, float]
size: tuple[float, float, float]
direction: Literal["+x", "-x", "+y", "-y", "+z", "-z"]
modes: list[WaveguideMode]
inject: bool = False
[docs]
class ModePort:
"""
Mode port for injecting and extracting waveguide modes.
A mode port acts as both a source (injecting modes) and a monitor
(extracting mode amplitudes) at a plane in the simulation domain.
Parameters
----------
config : ModePortConfig
Port configuration.
name : str, optional
Port name for identification.
enabled : bool, optional
Enable/disable port, default=True.
Examples
--------
>>> from prismo.modes.solver import ModeSolver
>>> # Solve for waveguide modes
>>> mode_solver = ModeSolver(wavelength=1.55e-6, x=x, y=y, epsilon=eps)
>>> modes = mode_solver.solve(num_modes=2, mode_type='TE')
>>>
>>> # Create mode port
>>> config = ModePortConfig(
... center=(0.0, 0.0, 0.0),
... size=(2e-6, 2e-6, 0.0),
... direction='+z',
... modes=modes,
... inject=True,
... )
>>> port = ModePort(config, name='input_port')
"""
[docs]
def __init__(
self,
config: ModePortConfig,
name: Optional[str] = None,
enabled: bool = True,
):
self.config = config
self.name = name or f"ModePort_{id(self)}"
self.enabled = enabled
# Grid information (set during initialization)
self._grid: Optional[YeeGrid] = None
self._port_region: dict[str, np.ndarray] = {}
# Interpolated mode profiles on simulation grid
self._interpolated_modes: list[WaveguideMode] = []
# Mode coefficients storage
self._mode_coefficients: dict[int, list[complex]] = {
i: [] for i in range(len(config.modes))
}
self._time_points: list[float] = []
# Parse direction
self.sign = +1 if config.direction[0] == "+" else -1
self.axis = config.direction[-1].lower()
[docs]
def initialize(self, grid: YeeGrid) -> None:
"""
Initialize the mode port on the simulation grid.
Parameters
----------
grid : YeeGrid
Simulation grid.
"""
self._grid = grid
self._setup_port_region()
self._interpolate_modes()
def _setup_port_region(self) -> None:
"""Set up port region in grid coordinates."""
if self._grid is None:
raise RuntimeError("Port must be initialized with a grid first")
# Determine port plane indices based on direction
center = self.config.center
size = self.config.size
# Get grid index for port position
if self.axis == "x":
port_idx = int((center[0] - self._grid.origin[0]) / self._grid.dx)
y_min = int(
(center[1] - size[1] / 2 - self._grid.origin[1]) / self._grid.dy
)
y_max = int(
(center[1] + size[1] / 2 - self._grid.origin[1]) / self._grid.dy
)
z_min = (
int((center[2] - size[2] / 2 - self._grid.origin[2]) / self._grid.dz)
if self._grid.is_3d
else 0
)
z_max = (
int((center[2] + size[2] / 2 - self._grid.origin[2]) / self._grid.dz)
if self._grid.is_3d
else 1
)
self._port_region = {
"x_idx": port_idx,
"y_slice": slice(y_min, y_max),
"z_slice": slice(z_min, z_max),
"normal_axis": 0,
}
elif self.axis == "y":
x_min = int(
(center[0] - size[0] / 2 - self._grid.origin[0]) / self._grid.dx
)
x_max = int(
(center[0] + size[0] / 2 - self._grid.origin[0]) / self._grid.dx
)
port_idx = int((center[1] - self._grid.origin[1]) / self._grid.dy)
z_min = (
int((center[2] - size[2] / 2 - self._grid.origin[2]) / self._grid.dz)
if self._grid.is_3d
else 0
)
z_max = (
int((center[2] + size[2] / 2 - self._grid.origin[2]) / self._grid.dz)
if self._grid.is_3d
else 1
)
self._port_region = {
"x_slice": slice(x_min, x_max),
"y_idx": port_idx,
"z_slice": slice(z_min, z_max),
"normal_axis": 1,
}
else: # z
x_min = int(
(center[0] - size[0] / 2 - self._grid.origin[0]) / self._grid.dx
)
x_max = int(
(center[0] + size[0] / 2 - self._grid.origin[0]) / self._grid.dx
)
y_min = int(
(center[1] - size[1] / 2 - self._grid.origin[1]) / self._grid.dy
)
y_max = int(
(center[1] + size[1] / 2 - self._grid.origin[1]) / self._grid.dy
)
port_idx = (
int((center[2] - self._grid.origin[2]) / self._grid.dz)
if self._grid.is_3d
else 0
)
self._port_region = {
"x_slice": slice(x_min, x_max),
"y_slice": slice(y_min, y_max),
"z_idx": port_idx,
"normal_axis": 2,
}
def _interpolate_modes(self) -> None:
"""Interpolate mode profiles to simulation grid."""
if self._grid is None:
raise RuntimeError("Port must be initialized first")
# Get transverse grid coordinates from port region
if self.axis == "x":
y_slice = self._port_region["y_slice"]
z_slice = self._port_region["z_slice"]
y_coords = (
np.arange(y_slice.start, y_slice.stop) * self._grid.dy
+ self._grid.origin[1]
)
z_coords = (
np.arange(z_slice.start, z_slice.stop) * self._grid.dz
+ self._grid.origin[2]
if self._grid.is_3d
else np.array([0.0])
)
grid_coords = (y_coords, z_coords)
elif self.axis == "y":
x_slice = self._port_region["x_slice"]
z_slice = self._port_region["z_slice"]
x_coords = (
np.arange(x_slice.start, x_slice.stop) * self._grid.dx
+ self._grid.origin[0]
)
z_coords = (
np.arange(z_slice.start, z_slice.stop) * self._grid.dz
+ self._grid.origin[2]
if self._grid.is_3d
else np.array([0.0])
)
grid_coords = (x_coords, z_coords)
else: # z
x_slice = self._port_region["x_slice"]
y_slice = self._port_region["y_slice"]
x_coords = (
np.arange(x_slice.start, x_slice.stop) * self._grid.dx
+ self._grid.origin[0]
)
y_coords = (
np.arange(y_slice.start, y_slice.stop) * self._grid.dy
+ self._grid.origin[1]
)
grid_coords = (x_coords, y_coords)
# Interpolate each mode
for mode in self.config.modes:
# Interpolate mode to grid
interp_mode = interpolate_mode_to_grid(mode, grid_coords[0], grid_coords[1])
self._interpolated_modes.append(interp_mode)
[docs]
def inject_fields(
self,
fields: ElectromagneticFields,
time: float,
dt: float,
mode_amplitudes: Optional[list[complex]] = None,
) -> None:
"""
Inject mode fields into the simulation.
This method adds mode field patterns to the simulation fields at the
port location, with proper Yee grid staggering.
Parameters
----------
fields : ElectromagneticFields
Field arrays to update.
time : float
Current simulation time.
dt : float
Time step.
mode_amplitudes : List[complex], optional
Complex amplitudes for each mode. If None, uses unit amplitude.
"""
if not self.enabled or not self.config.inject:
return
if mode_amplitudes is None:
mode_amplitudes = [1.0] * len(self._interpolated_modes)
# For each mode, add its field pattern
for _mode_idx, (mode, amplitude) in enumerate(
zip(self._interpolated_modes, mode_amplitudes)
):
# Get mode phase (propagating wave)
2 * np.pi * mode.neff.real / mode.wavelength
omega = 2 * np.pi * mode.frequency
# Time-dependent amplitude
phase = omega * time
amp_t = amplitude * np.exp(1j * phase)
# Extract real part for time-domain injection
amp_real = amp_t.real
# Add mode fields to simulation fields at port region
self._add_mode_to_fields(fields, mode, amp_real)
def _add_mode_to_fields(
self,
fields: ElectromagneticFields,
mode: WaveguideMode,
amplitude: float,
) -> None:
"""
Add mode field pattern to simulation fields.
Properly handles Yee grid staggering for different field components.
Parameters
----------
fields : ElectromagneticFields
Field arrays to update.
mode : WaveguideMode
Mode with field patterns.
amplitude : float
Real-valued amplitude factor.
"""
# Get port region slicing
region = self._port_region
# Add fields based on axis orientation
if self.axis == "z":
# Port in xy-plane
x_slice = region["x_slice"]
y_slice = region["y_slice"]
z_idx = region.get("z_idx", 0)
# Add tangential E fields (Ex, Ey) and normal H field (Hz)
if hasattr(fields, "Ex"):
fields.Ex[x_slice, y_slice, z_idx] += amplitude * mode.Ex.real
if hasattr(fields, "Ey"):
fields.Ey[x_slice, y_slice, z_idx] += amplitude * mode.Ey.real
if hasattr(fields, "Hz"):
fields.Hz[x_slice, y_slice, z_idx] += amplitude * mode.Hz.real
elif self.axis == "x":
# Port in yz-plane
x_idx = region["x_idx"]
y_slice = region["y_slice"]
z_slice = region["z_slice"]
if hasattr(fields, "Ey"):
fields.Ey[x_idx, y_slice, z_slice] += amplitude * mode.Ey.real
if hasattr(fields, "Ez"):
fields.Ez[x_idx, y_slice, z_slice] += amplitude * mode.Ez.real
if hasattr(fields, "Hx"):
fields.Hx[x_idx, y_slice, z_slice] += amplitude * mode.Hx.real
else: # y axis
# Port in xz-plane
x_slice = region["x_slice"]
y_idx = region["y_idx"]
z_slice = region["z_slice"]
if hasattr(fields, "Ex"):
fields.Ex[x_slice, y_idx, z_slice] += amplitude * mode.Ex.real
if hasattr(fields, "Ez"):
fields.Ez[x_slice, y_idx, z_slice] += amplitude * mode.Ez.real
if hasattr(fields, "Hy"):
fields.Hy[x_slice, y_idx, z_slice] += amplitude * mode.Hy.real
def _extract_field_slice(
self,
fields: ElectromagneticFields,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Extract field components at port plane.
Returns
-------
Tuple of arrays
(Ex, Ey, Ez, Hx, Hy, Hz) at port location.
"""
region = self._port_region
# Extract based on port orientation
if self.axis == "z":
x_slice = region["x_slice"]
y_slice = region["y_slice"]
z_idx = region.get("z_idx", 0)
Ex = (
fields.Ex[x_slice, y_slice, z_idx]
if hasattr(fields, "Ex")
else np.zeros((1, 1))
)
Ey = (
fields.Ey[x_slice, y_slice, z_idx]
if hasattr(fields, "Ey")
else np.zeros((1, 1))
)
Ez = (
fields.Ez[x_slice, y_slice, z_idx]
if hasattr(fields, "Ez")
else np.zeros((1, 1))
)
Hx = (
fields.Hx[x_slice, y_slice, z_idx]
if hasattr(fields, "Hx")
else np.zeros((1, 1))
)
Hy = (
fields.Hy[x_slice, y_slice, z_idx]
if hasattr(fields, "Hy")
else np.zeros((1, 1))
)
Hz = (
fields.Hz[x_slice, y_slice, z_idx]
if hasattr(fields, "Hz")
else np.zeros((1, 1))
)
elif self.axis == "x":
x_idx = region["x_idx"]
y_slice = region["y_slice"]
z_slice = region["z_slice"]
Ex = (
fields.Ex[x_idx, y_slice, z_slice]
if hasattr(fields, "Ex")
else np.zeros((1, 1))
)
Ey = (
fields.Ey[x_idx, y_slice, z_slice]
if hasattr(fields, "Ey")
else np.zeros((1, 1))
)
Ez = (
fields.Ez[x_idx, y_slice, z_slice]
if hasattr(fields, "Ez")
else np.zeros((1, 1))
)
Hx = (
fields.Hx[x_idx, y_slice, z_slice]
if hasattr(fields, "Hx")
else np.zeros((1, 1))
)
Hy = (
fields.Hy[x_idx, y_slice, z_slice]
if hasattr(fields, "Hy")
else np.zeros((1, 1))
)
Hz = (
fields.Hz[x_idx, y_slice, z_slice]
if hasattr(fields, "Hz")
else np.zeros((1, 1))
)
else: # y
x_slice = region["x_slice"]
y_idx = region["y_idx"]
z_slice = region["z_slice"]
Ex = (
fields.Ex[x_slice, y_idx, z_slice]
if hasattr(fields, "Ex")
else np.zeros((1, 1))
)
Ey = (
fields.Ey[x_slice, y_idx, z_slice]
if hasattr(fields, "Ey")
else np.zeros((1, 1))
)
Ez = (
fields.Ez[x_slice, y_idx, z_slice]
if hasattr(fields, "Ez")
else np.zeros((1, 1))
)
Hx = (
fields.Hx[x_slice, y_idx, z_slice]
if hasattr(fields, "Hx")
else np.zeros((1, 1))
)
Hy = (
fields.Hy[x_slice, y_idx, z_slice]
if hasattr(fields, "Hy")
else np.zeros((1, 1))
)
Hz = (
fields.Hz[x_slice, y_idx, z_slice]
if hasattr(fields, "Hz")
else np.zeros((1, 1))
)
return Ex, Ey, Ez, Hx, Hy, Hz
[docs]
def get_mode_coefficient(self, mode_index: int) -> list[complex]:
"""
Get time series of mode coefficient.
Parameters
----------
mode_index : int
Mode index.
Returns
-------
List[complex]
Mode coefficient vs time.
"""
return self._mode_coefficients[mode_index]
[docs]
def get_time_points(self) -> list[float]:
"""
Get recorded time points.
Returns
-------
List[float]
Time points where coefficients were recorded.
"""
return self._time_points