Source code for prismo.backends.base

"""
Abstract base class for computational backends.

This module defines the interface that all backends must implement,
providing array operations for FDTD computations.
"""

from abc import ABC, abstractmethod
from typing import Any, Optional, Union

import numpy as np


[docs] class Backend(ABC): """ Abstract base class for computational backends. All backends must implement this interface to provide array operations for electromagnetic field computations. """ @property @abstractmethod def name(self) -> str: """Name of the backend (e.g., 'numpy', 'cupy').""" pass @property @abstractmethod def is_gpu(self) -> bool: """Whether this backend uses GPU acceleration.""" pass
[docs] @abstractmethod def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Any: """Create an array filled with zeros.""" pass
[docs] @abstractmethod def ones(self, shape: tuple[int, ...], dtype: Any = None) -> Any: """Create an array filled with ones.""" pass
[docs] @abstractmethod def empty(self, shape: tuple[int, ...], dtype: Any = None) -> Any: """Create an uninitialized array.""" pass
[docs] @abstractmethod def array(self, data: Any, dtype: Any = None) -> Any: """Convert data to backend array.""" pass
[docs] @abstractmethod def asarray(self, data: Any, dtype: Any = None) -> Any: """Convert data to backend array (no-copy if possible).""" pass
[docs] @abstractmethod def to_numpy(self, array: Any) -> np.ndarray: """Convert backend array to NumPy array (for CPU).""" pass
[docs] @abstractmethod def copy(self, array: Any) -> Any: """Create a copy of an array.""" pass
# Mathematical operations
[docs] @abstractmethod def sqrt(self, array: Any) -> Any: """Element-wise square root.""" pass
[docs] @abstractmethod def exp(self, array: Any) -> Any: """Element-wise exponential.""" pass
[docs] @abstractmethod def sin(self, array: Any) -> Any: """Element-wise sine.""" pass
[docs] @abstractmethod def cos(self, array: Any) -> Any: """Element-wise cosine.""" pass
[docs] @abstractmethod def abs(self, array: Any) -> Any: """Element-wise absolute value.""" pass
[docs] @abstractmethod def sum( self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None ) -> Any: """Sum array elements.""" pass
[docs] @abstractmethod def max( self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None ) -> Any: """Maximum of array elements.""" pass
[docs] @abstractmethod def min( self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None ) -> Any: """Minimum of array elements.""" pass
[docs] @abstractmethod def mean( self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None ) -> Any: """Mean of array elements.""" pass
# FFT operations
[docs] @abstractmethod def fft(self, array: Any, axis: int = -1) -> Any: """1D Fast Fourier Transform.""" pass
[docs] @abstractmethod def ifft(self, array: Any, axis: int = -1) -> Any: """1D Inverse Fast Fourier Transform.""" pass
[docs] @abstractmethod def fft2(self, array: Any) -> Any: """2D Fast Fourier Transform.""" pass
[docs] @abstractmethod def ifft2(self, array: Any) -> Any: """2D Inverse Fast Fourier Transform.""" pass
# Linear algebra
[docs] @abstractmethod def dot(self, a: Any, b: Any) -> Any: """Dot product of two arrays.""" pass
[docs] @abstractmethod def matmul(self, a: Any, b: Any) -> Any: """Matrix multiplication.""" pass
# Indexing and slicing helpers
[docs] @abstractmethod def where(self, condition: Any, x: Any, y: Any) -> Any: """Return elements chosen from x or y depending on condition.""" pass
# Memory management
[docs] @abstractmethod def synchronize(self) -> None: """Synchronize device (for GPU backends).""" pass
[docs] @abstractmethod def get_memory_info(self) -> dict: """Get memory usage information.""" pass
# Type information @property @abstractmethod def float32(self) -> Any: """32-bit float dtype.""" pass @property @abstractmethod def float64(self) -> Any: """64-bit float dtype.""" pass @property @abstractmethod def complex64(self) -> Any: """64-bit complex dtype.""" pass @property @abstractmethod def complex128(self) -> Any: """128-bit complex dtype.""" pass @property @abstractmethod def int32(self) -> Any: """32-bit integer dtype.""" pass @property @abstractmethod def int64(self) -> Any: """64-bit integer dtype.""" pass # Constants @property @abstractmethod def pi(self) -> float: """Value of pi.""" pass
[docs] def __repr__(self) -> str: """String representation.""" device = "GPU" if self.is_gpu else "CPU" return f"{self.__class__.__name__}(name='{self.name}', device={device})"