Source code for prismo.backends.metal_backend
"""
Metal backend for GPU computations on macOS.
This module implements the Backend interface using Metal for
GPU-accelerated array operations on Apple Silicon and Intel Macs.
"""
import platform
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
from .base import Backend
if TYPE_CHECKING:
from Metal import MTLBuffer
try:
import Metal
from Metal import (
MTLBuffer,
MTLResourceStorageModePrivate,
MTLResourceStorageModeShared,
)
METAL_AVAILABLE = True
except ImportError:
METAL_AVAILABLE = False
Metal = None
MTLDevice = None
MTLBuffer = None # type: ignore
[docs]
class MetalBackend(Backend):
"""
Metal-based backend for GPU computations on macOS.
This backend uses Metal for GPU-accelerated array operations.
Requires macOS and Metal framework to be available.
Parameters
----------
device_id : int, optional
Metal device ID to use. Default is 0.
use_unified_memory : bool, optional
Whether to use unified memory (shared storage mode). Default is True.
"""
[docs]
def __init__(self, device_id: int = 0, use_unified_memory: bool = True):
if not METAL_AVAILABLE:
raise RuntimeError(
"Metal is not available. This backend requires macOS with Metal framework."
)
if platform.system() != "Darwin":
raise RuntimeError("Metal backend requires macOS")
self.device_id = device_id
self.use_unified_memory = use_unified_memory
# Get Metal device
devices = Metal.MTLCopyAllDevices()
if device_id >= len(devices):
raise ValueError(
f"Device ID {device_id} not available. Found {len(devices)} devices."
)
self.device = devices[device_id]
self.device_name = self.device.name()
# Create command queue
self.command_queue = self.device.newCommandQueue()
# Storage mode for buffers
self.storage_mode = (
MTLResourceStorageModeShared
if use_unified_memory
else MTLResourceStorageModePrivate
)
# Cache for compiled compute pipelines
self._pipeline_cache = {}
# Memory pool for buffer reuse
self._buffer_pool = {}
@property
def name(self) -> str:
return "metal"
@property
def is_gpu(self) -> bool:
return True
def _get_buffer(self, size: int, dtype: Any = None) -> MTLBuffer:
"""Get or create a Metal buffer."""
# For now, create new buffers. In production, implement pooling
if dtype is None:
dtype = np.float64
# Convert numpy dtype to Metal data type
if dtype == np.float32:
metal_dtype = Metal.MTLDataTypeFloat
elif dtype == np.float64:
metal_dtype = Metal.MTLDataTypeDouble
elif dtype == np.int32:
metal_dtype = Metal.MTLDataTypeInt
elif dtype == np.int64:
metal_dtype = Metal.MTLDataTypeLong
else:
metal_dtype = Metal.MTLDataTypeFloat # Default to float
buffer = self.device.newBufferWithLength_options_(size, self.storage_mode)
return buffer
def _numpy_to_metal_buffer(self, array: np.ndarray) -> MTLBuffer:
"""Convert NumPy array to Metal buffer."""
buffer = self._get_buffer(array.nbytes, array.dtype)
if self.use_unified_memory:
# Copy data directly to shared memory
buffer.contents().as_buffer(array.nbytes)[:] = array.tobytes()
else:
# For private memory, would need explicit copy
# This is simplified for now
buffer.contents().as_buffer(array.nbytes)[:] = array.tobytes()
return buffer
def _metal_buffer_to_numpy(
self, buffer: MTLBuffer, shape: tuple, dtype: Any
) -> np.ndarray:
"""Convert Metal buffer to NumPy array."""
if self.use_unified_memory:
# Direct access to shared memory
data = np.frombuffer(
buffer.contents().as_buffer(buffer.length()), dtype=dtype
).reshape(shape)
else:
# Would need explicit copy for private memory
data = np.frombuffer(
buffer.contents().as_buffer(buffer.length()), dtype=dtype
).reshape(shape)
return data.copy() # Return a copy to avoid memory issues
[docs]
def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Any:
"""Create an array filled with zeros."""
if dtype is None:
dtype = np.float64
array = np.zeros(shape, dtype=dtype)
return self._numpy_to_metal_buffer(array)
[docs]
def ones(self, shape: tuple[int, ...], dtype: Any = None) -> Any:
"""Create an array filled with ones."""
if dtype is None:
dtype = np.float64
array = np.ones(shape, dtype=dtype)
return self._numpy_to_metal_buffer(array)
[docs]
def empty(self, shape: tuple[int, ...], dtype: Any = None) -> Any:
"""Create an uninitialized array."""
if dtype is None:
dtype = np.float64
array = np.empty(shape, dtype=dtype)
return self._numpy_to_metal_buffer(array)
[docs]
def array(self, data: Any, dtype: Any = None) -> Any:
"""Convert data to backend array."""
if isinstance(data, np.ndarray):
if dtype is not None and data.dtype != dtype:
data = data.astype(dtype)
return self._numpy_to_metal_buffer(data)
else:
array = np.array(data, dtype=dtype)
return self._numpy_to_metal_buffer(array)
[docs]
def asarray(self, data: Any, dtype: Any = None) -> Any:
"""Convert data to backend array (no-copy if possible)."""
return self.array(data, dtype)
[docs]
def to_numpy(self, array: Any) -> np.ndarray:
"""Convert Metal buffer to NumPy array."""
if isinstance(array, MTLBuffer):
# This is a simplified implementation
# In practice, we'd need to track buffer metadata
raise NotImplementedError(
"Buffer to NumPy conversion requires metadata tracking"
)
elif isinstance(array, np.ndarray):
return array
else:
raise TypeError(f"Cannot convert {type(array)} to NumPy array")
[docs]
def copy(self, array: Any) -> Any:
"""Create a copy of an array."""
if isinstance(array, MTLBuffer):
# Create new buffer and copy data
new_buffer = self._get_buffer(array.length())
new_buffer.contents().as_buffer(array.length())[
:
] = array.contents().as_buffer(array.length())
return new_buffer
elif isinstance(array, np.ndarray):
return self._numpy_to_metal_buffer(array.copy())
else:
raise TypeError(f"Cannot copy {type(array)}")
# Mathematical operations
[docs]
def sqrt(self, array: Any) -> Any:
"""Element-wise square root."""
# For now, fall back to NumPy operations
# In production, implement Metal compute shaders
if isinstance(array, MTLBuffer):
# Convert to NumPy, compute, convert back
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.sqrt(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.sqrt(array)
[docs]
def exp(self, array: Any) -> Any:
"""Element-wise exponential."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.exp(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.exp(array)
[docs]
def sin(self, array: Any) -> Any:
"""Element-wise sine."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.sin(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.sin(array)
[docs]
def cos(self, array: Any) -> Any:
"""Element-wise cosine."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.cos(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.cos(array)
[docs]
def abs(self, array: Any) -> Any:
"""Element-wise absolute value."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.abs(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.abs(array)
[docs]
def sum(
self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None
) -> Any:
"""Sum array elements."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.sum(np_array, axis=axis)
return self._numpy_to_metal_buffer(result)
else:
return np.sum(array, axis=axis)
[docs]
def max(
self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None
) -> Any:
"""Maximum of array elements."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.max(np_array, axis=axis)
return self._numpy_to_metal_buffer(result)
else:
return np.max(array, axis=axis)
[docs]
def min(
self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None
) -> Any:
"""Minimum of array elements."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.min(np_array, axis=axis)
return self._numpy_to_metal_buffer(result)
else:
return np.min(array, axis=axis)
[docs]
def mean(
self, array: Any, axis: Optional[Union[int, tuple[int, ...]]] = None
) -> Any:
"""Mean of array elements."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.mean(np_array, axis=axis)
return self._numpy_to_metal_buffer(result)
else:
return np.mean(array, axis=axis)
# FFT operations
[docs]
def fft(self, array: Any, axis: int = -1) -> Any:
"""1D Fast Fourier Transform."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.fft.fft(np_array, axis=axis)
return self._numpy_to_metal_buffer(result)
else:
return np.fft.fft(array, axis=axis)
[docs]
def ifft(self, array: Any, axis: int = -1) -> Any:
"""1D Inverse Fast Fourier Transform."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.fft.ifft(np_array, axis=axis)
return self._numpy_to_metal_buffer(result)
else:
return np.fft.ifft(array, axis=axis)
[docs]
def fft2(self, array: Any) -> Any:
"""2D Fast Fourier Transform."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.fft.fft2(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.fft.fft2(array)
[docs]
def ifft2(self, array: Any) -> Any:
"""2D Inverse Fast Fourier Transform."""
if isinstance(array, MTLBuffer):
np_array = self._metal_buffer_to_numpy(
array, (1,), np.float64
) # Simplified
result = np.fft.ifft2(np_array)
return self._numpy_to_metal_buffer(result)
else:
return np.fft.ifft2(array)
# Linear algebra
[docs]
def dot(self, a: Any, b: Any) -> Any:
"""Dot product of two arrays."""
if isinstance(a, MTLBuffer) or isinstance(b, MTLBuffer):
# Convert to NumPy for now
if isinstance(a, MTLBuffer):
a = self._metal_buffer_to_numpy(a, (1,), np.float64) # Simplified
if isinstance(b, MTLBuffer):
b = self._metal_buffer_to_numpy(b, (1,), np.float64) # Simplified
result = np.dot(a, b)
return self._numpy_to_metal_buffer(result)
else:
return np.dot(a, b)
[docs]
def matmul(self, a: Any, b: Any) -> Any:
"""Matrix multiplication."""
if isinstance(a, MTLBuffer) or isinstance(b, MTLBuffer):
# Convert to NumPy for now
if isinstance(a, MTLBuffer):
a = self._metal_buffer_to_numpy(a, (1,), np.float64) # Simplified
if isinstance(b, MTLBuffer):
b = self._metal_buffer_to_numpy(b, (1,), np.float64) # Simplified
result = np.matmul(a, b)
return self._numpy_to_metal_buffer(result)
else:
return np.matmul(a, b)
# Indexing and slicing helpers
[docs]
def where(self, condition: Any, x: Any, y: Any) -> Any:
"""Return elements chosen from x or y depending on condition."""
# Convert all to NumPy for now
if isinstance(condition, MTLBuffer):
condition = self._metal_buffer_to_numpy(
condition, (1,), np.bool_
) # Simplified
if isinstance(x, MTLBuffer):
x = self._metal_buffer_to_numpy(x, (1,), np.float64) # Simplified
if isinstance(y, MTLBuffer):
y = self._metal_buffer_to_numpy(y, (1,), np.float64) # Simplified
result = np.where(condition, x, y)
return self._numpy_to_metal_buffer(result)
# Memory management
[docs]
def synchronize(self) -> None:
"""Synchronize Metal device."""
# For now, this is a no-op since we're using shared memory
# In production with private memory, would need to wait for command buffer
pass
[docs]
def get_memory_info(self) -> dict:
"""Get Metal device memory usage information."""
# Metal doesn't provide detailed memory info like CUDA
# This is a simplified implementation
return {
"backend": "metal",
"device": f"Metal:{self.device_id}",
"device_name": self.device_name,
"unified_memory": self.use_unified_memory,
"storage_mode": "shared" if self.use_unified_memory else "private",
"max_buffer_size": self.device.maxBufferLength(),
"recommended_max_working_set_size": self.device.recommendedMaxWorkingSetSize(),
}
# Type information
@property
def float32(self) -> Any:
return np.float32
@property
def float64(self) -> Any:
return np.float64
@property
def complex64(self) -> Any:
return np.complex64
@property
def complex128(self) -> Any:
return np.complex128
@property
def int32(self) -> Any:
return np.int32
@property
def int64(self) -> Any:
return np.int64
@property
def pi(self) -> float:
return np.pi
[docs]
def __repr__(self) -> str:
"""String representation."""
return f"MetalBackend(device={self.device_id}, name='{self.device_name}', unified_memory={self.use_unified_memory})"