Source code for rlaopt.linalg.preconditioners.preconditioner

"""Abstract base classes for preconditioners and configurations."""

from __future__ import annotations

from abc import ABC, abstractmethod

import torch
from linops import LinearOperator
from pydantic import BaseModel, ConfigDict


[docs] class PreconditionerConfig(BaseModel): """Base configuration class for preconditioners.""" model_config = ConfigDict(extra="forbid") @property def is_identity(self) -> bool: """Whether this config corresponds to the identity preconditioner. Solvers branch on this property to skip preconditioning work. """ return False
class Preconditioner(ABC): """Abstract base class for preconditioners.""" def __init__(self, config: PreconditionerConfig): """Initialize the preconditioner with the given configuration. Args: config (PreconditionerConfig): Configuration for the preconditioner. """ self._config = config @abstractmethod def _update(self, A: torch.Tensor | LinearOperator, dtype: torch.dtype): """Update the preconditioner based on the matrix A. Args: A (torch.Tensor | LinearOperator): The matrix for which to compute the preconditioner. dtype (torch.dtype): The data type for computations. """ pass @abstractmethod def _matmul_impl(self, x: torch.Tensor) -> torch.Tensor: """Apply the preconditioner to the input tensor x. Args: x (torch.Tensor): Input tensor to which the preconditioner is applied. Returns: torch.Tensor: Result of applying the preconditioner to x. """ pass @abstractmethod def _inverse_matmul_impl(self, x: torch.Tensor) -> torch.Tensor: """Apply the inverse of the preconditioner to the input tensor x. Args: x (torch.Tensor): Input tensor to which the inverse preconditioner is applied. Returns: torch.Tensor: Result of applying the inverse preconditioner to x. """ pass def __matmul__(self, x: torch.Tensor) -> torch.Tensor: """Overload the matrix multiplication operator to apply preconditioner. Args: x (torch.Tensor): Input tensor to which the preconditioner is applied. Returns: torch.Tensor: Result of applying the preconditioner to x. """ _is_torch_tensor_1d_2d(x) return self._matmul_impl(x) def _inverse_matmul(self, x: torch.Tensor) -> torch.Tensor: """Apply the inverse of the preconditioner to the input tensor x. Args: x (torch.Tensor): Input tensor to which the inverse preconditioner is applied. Returns: torch.Tensor: Result of applying the inverse preconditioner to x. """ _is_torch_tensor_1d_2d(x) return self._inverse_matmul_impl(x) @property def inv(self) -> InvPreconditioner: """Get the inverse preconditioner. Returns: InvPreconditioner: The inverse preconditioner. """ return InvPreconditioner(self) class InvPreconditioner: """Helper class to access the inverse of a preconditioner. This class wraps a Preconditioner instance and provides access to its inverse matrix multiplication method. Attributes: preconditioner (Preconditioner): The preconditioner instance. """ def __init__(self, preconditioner: Preconditioner): """Initialize the InvPreconditioner with the given preconditioner. Args: preconditioner (Preconditioner): The preconditioner instance. """ self.preconditioner = preconditioner def __matmul__(self, x: torch.Tensor) -> torch.Tensor: """Overload the matrix multiplication operator to apply inverse preconditioner. Args: x (torch.Tensor): Input tensor to which the inverse preconditioner is applied. Returns: torch.Tensor: Result of applying the inverse preconditioner to x. """ return self.preconditioner._inverse_matmul(x) def _is_torch_tensor_1d_2d(tensor: torch.Tensor): """Check if the input is a 1D or 2D torch tensor. Args: tensor: Input tensor to check. Raises: TypeError: If the input is not a torch tensor. ValueError: If the tensor is not 1D or 2D. """ if not isinstance(tensor, torch.Tensor): raise TypeError("Input must be a torch tensor.") if tensor.ndim not in (1, 2): raise ValueError("Input tensor must be 1D or 2D.")