Source code for rlaopt.linalg.lin_sys
"""LinSys module for positive-definite linear systems."""
from warnings import warn
import torch
from linops import LinearOperator, aslinearoperator
[docs]
class LinSys(torch.nn.Module):
"""Module for regularized linear systems (A + reg * I)w = B."""
[docs]
def __init__(
self,
A: torch.Tensor | LinearOperator,
B: torch.Tensor,
reg: float | torch.Tensor = 0.0,
w: torch.Tensor | None = None,
):
"""Initialize LinSys module.
Args:
A (torch.Tensor | LinearOperator): Square matrix defining
the linear system.
B (torch.Tensor): Right-hand side of the linear system. Must be 1D or 2D.
If 1D with shape (N,), it is automatically resized
to 2D with shape (N, 1).
reg (float | torch.Tensor): Non-negative regularization parameter.
Defaults to 0.0.
Using a tensor allows for differentiation with respect to reg.
w (torch.Tensor | None): Initial guess for the solution. Defaults to None.
"""
super().__init__()
LinSys._check_inputs(A, B, reg, w)
if w is None:
w = torch.zeros_like(B)
# Resize B to 2D for consistent processing
# When B is resized, we must also resize w accordingly
if B.ndim == 1:
B = B.unsqueeze(-1)
w = w.unsqueeze(-1)
# Make reg a tensor if it is a float
if isinstance(reg, float):
reg = torch.tensor(reg, device=B.device, dtype=B.dtype)
self.A = aslinearoperator(A) # Make A a LinearOperator
self.register_buffer("B", B)
self.register_buffer("reg", reg)
self.w = torch.nn.Parameter(w)
[docs]
def forward(self, v: torch.Tensor) -> torch.Tensor:
"""Apply the linear operator (A + reg * I) to tensor v.
Args:
v (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Result of applying the linear operator to v.
"""
if v.ndim == 1:
warn("Input tensor v is 1D. The input tensor will be unsqueezed to 2D.")
v = v.unsqueeze(-1)
return self.A @ v + self.reg * v
[docs]
def compute_residual(self, v: torch.Tensor) -> torch.Tensor:
"""Compute the residual of the linear system for a given tensor v.
Args:
v (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Residual of the linear system.
"""
return self.B - self.forward(v)
[docs]
def compute_residual_norm(
self, v: torch.Tensor, relative: bool = False
) -> torch.Tensor:
"""Compute the residual norm of the linear system for a given tensor v.
Args:
v (torch.Tensor): Input tensor.
relative (bool): If True, return the relative residual norm.
Defaults to False.
Returns:
torch.Tensor: Residual norm of the linear system.
"""
residual = self.compute_residual(v)
res_norm = torch.linalg.norm(residual, dim=0, ord=2)
if relative:
res_norm /= self.rhs_norm
return res_norm
@property
def device(self) -> torch.device:
"""Get the device of the LinSys module.
Returns:
torch.device: Device where the module's tensors are located.
"""
return self.B.device
@property
def dtype(self) -> torch.dtype:
"""Get the data type of the LinSys module.
Returns:
torch.dtype: Data type of the module's tensors.
"""
return self.B.dtype
@property
def rhs_norm(self) -> torch.Tensor:
"""Get the norm of the right-hand side B.
Returns:
torch.Tensor: Norm of B.
"""
return torch.linalg.norm(self.B, dim=0, ord=2)
@staticmethod
def _check_inputs(
A: torch.Tensor | LinearOperator,
B: torch.Tensor,
reg: float | torch.Tensor,
w: torch.Tensor | None,
):
if not (torch.is_tensor(A) or isinstance(A, LinearOperator)):
raise TypeError(
"A must be a torch.Tensor or LinearOperator, "
f"but received {type(A).__name__}."
)
if not torch.is_tensor(B):
raise TypeError(
f"B must be a torch.Tensor, but received {type(B).__name__}."
)
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError("A must be a square matrix or square linear operator.")
if B.ndim not in [1, 2] or B.shape[0] != A.shape[0]:
raise ValueError(
"B must be a tensor whose first dimension matches A's size."
)
if not isinstance(reg, (float, torch.Tensor)):
raise TypeError(
f"reg must be a float or torch.Tensor, "
f"but received {type(reg).__name__}."
)
if isinstance(reg, float) and reg < 0:
raise ValueError(f"reg must be a non-negative float, but received {reg}.")
if isinstance(reg, torch.Tensor):
if reg.ndim != 0:
raise ValueError("reg tensor must be a scalar (0-dimensional).")
if (reg < 0).item():
raise ValueError(
f"reg tensor must be non-negative, but received {reg.item()}."
)
if w is not None:
if not torch.is_tensor(w):
raise TypeError(
f"w must be a torch.Tensor, but received {type(w).__name__}."
)
if w.shape != B.shape:
raise ValueError("w must have the same shape as B.")