Source code for rlaopt.linalg.preconditioners.nystrom

"""Nyström preconditioner and configuration."""

from typing import Literal
from warnings import warn

import torch
from linops import LinearOperator
from pydantic import Field, model_validator
from typing_extensions import Self

from rlaopt.linalg.preconditioners.preconditioner import (
    Preconditioner,
    PreconditionerConfig,
)
from rlaopt.linalg.spectral_estimators import randomized_powering


[docs] class NystromConfig(PreconditionerConfig): """Configuration for the Nyström preconditioner.""" # TODO(pratik): add option for sketching method rank_init: int = Field( gt=0, description="Initial rank of the Nyström approximation." ) rank_max: int | None = Field(default=None, description="Maximum allowable rank.") num_power_iters: int = Field( default=10, gt=0, description="Number of power iterations for error estimation" " in rank adaptation.", ) error_tolerance: float = Field( default=1e-2, gt=0.0, description="Error tolerance for rank adaptation." ) base_damping: float = Field(ge=0.0, description="Base damping parameter.") damping_mode: Literal["adaptive", "non_adaptive"] = Field( default="adaptive", description="Damping mode: 'adaptive' adjusts based on smallest eigenvalue," " 'non_adaptive' uses base_damping only.", ) @model_validator(mode="after") def validate_rank_max(self) -> Self: """Validate and set rank_max based on rank_init. If rank_max is None, it is set to rank_init. If rank_max is provided, it must be >= rank_init. """ if self.rank_max is None: self.rank_max = self.rank_init elif self.rank_max < self.rank_init: raise ValueError( f"rank_max ({self.rank_max}) must be >= rank_init ({self.rank_init})" ) return self
class Nystrom(Preconditioner): """Nyström preconditioner implementation with adaptive rank.""" def __init__(self, config: NystromConfig): """Initialize the Nyström preconditioner with the given configuration. Args: config (NystromConfig): Configuration for the Nyström preconditioner. """ super().__init__(config) self.U = None self.S = None self.L = None self.current_damping = None self.norm = None self.using_low_precision = False def _update(self, A: torch.Tensor | LinearOperator, dtype: torch.dtype): """Update the Nyström preconditioner based on the matrix A.""" # Unpack config num_cols_to_add = self._config.rank_init error_tolerance = self._config.error_tolerance num_power_iters = self._config.num_power_iters rank_max = self._config.rank_max damping_mode = self._config.damping_mode base_damping = self._config.base_damping if dtype != torch.float64: self.using_low_precision = True device = A.device n = A.shape[0] # Initialize sketching matrix and sketch Omega = torch.empty((n, 0), dtype=dtype, device=device) Y = Omega.clone() # Initialize empty tensors for estimated eigenvectors and eigenvalues U = torch.Tensor([]) S = torch.Tensor([]) # Start error at infinity to enter the loop error = torch.inf break_early = False # Compute Nyström approximation with adaptive rank while error > error_tolerance: # Update sketching matrix and sketch Omega_new = _generate_ortho_embedding( dimension=n, sketch_size=num_cols_to_add, dtype=dtype, device=device, ) Y_new = A @ Omega_new Omega = torch.hstack([Omega, Omega_new]) Y = torch.hstack([Y, Y_new]) # Compute core Core = Omega.T @ Y # Shift for stability shift = torch.finfo(Y.dtype).eps * torch.trace(Core) Core.diagonal().add_(shift) L = torch.linalg.cholesky(Core, upper=False) # Get eigendecomposition B = torch.linalg.solve_triangular(L, Y.T, upper=False) U, sigma, _ = torch.linalg.svd(B.T, full_matrices=False) S = torch.nn.functional.relu(sigma**2 - shift) # Make sure to estimate error before possibly breaking early error = _randomized_power_err_est(A, U, S, num_power_iters, dtype) if break_early: break # Increase sketch size for next iteration num_cols_to_add = Omega.shape[1] if 2 * num_cols_to_add > rank_max: num_cols_to_add = rank_max - Omega.shape[1] break_early = True if error > error_tolerance: warn( f"Reached maximum rank {rank_max} before achieving " f"desired error tolerance {error_tolerance}." ) self.U = U self.S = S # Recalculate damping if damping_mode == "adaptive": self.current_damping = base_damping + self.S[-1] else: self.current_damping = base_damping # Get norm of preconditioner self.norm = S[0] + self.current_damping # Reset L for inverse computations self.L = None def _matmul_impl(self, x: torch.Tensor) -> torch.Tensor: """Apply the Nyström preconditioner to the input tensor x.""" S_safe = self.S if x.ndim == 1 else self.S.unsqueeze(-1) return self.U @ (S_safe * (self.U.T @ x)) + self.current_damping * x def _inverse_matmul_impl(self, x: torch.Tensor) -> torch.Tensor: """Apply the inverse of the Nyström preconditioner to the input tensor x.""" x_in = x.unsqueeze(-1) if x.ndim == 1 else x damping = self.current_damping UTx = self.U.T @ x_in # If we are not in double precision, we try to take a more numerically # stable approach that requires an additional Cholesky factorization. if self.using_low_precision: if self.L is None: self.L = torch.linalg.cholesky( damping * torch.diag(self.S**-1) + self.U.T @ self.U, ) L_inv_UTx = torch.linalg.solve_triangular(self.L, UTx, upper=False) LT_inv_L_inv_UTx = torch.linalg.solve_triangular( self.L.T, L_inv_UTx, upper=True ) x_in = 1 / damping * (x_in - self.U @ LT_inv_L_inv_UTx) else: x_in = 1 / damping * (x_in - self.U @ UTx) + self.U @ torch.divide( UTx, (self.S + damping).unsqueeze(-1) ) return x_in.squeeze(-1) if x.ndim == 1 else x_in def _generate_ortho_embedding( dimension: int, sketch_size: int, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: """Generate an orthogonal random embedding matrix.""" # Generate a random Gaussian matrix Omega = torch.linalg.qr( torch.randn(dimension, sketch_size, dtype=dtype, device=device), mode="reduced", )[0] return Omega def _randomized_power_err_est( A: torch.Tensor | LinearOperator, U: torch.Tensor, S: torch.Tensor, num_iters: int, dtype: torch.dtype, ) -> float: """Estimate approximation error of the Nyström method.""" def E_linop(v: torch.Tensor) -> torch.Tensor: return A @ v - U @ (S * (U.T @ v)) return randomized_powering( E_linop, shape=A.shape, max_iters=num_iters, device=A.device )