"""Module for Variable class."""
import torch
from typing_extensions import Self
from rlaopt.expression.expression import Expression
from rlaopt.expression.tree import ExprTree
from rlaopt.settings import VAR_PREFIX
from rlaopt.utils.counter import get_id
[docs]
class Variable(Expression):
"""Leaf optimization variable wrapping a torch.nn.Parameter.
Variable represents a trainable parameter in an optimization problem.
It extends Expression to provide automatic differentiation, device
management, and integration with PyTorch's optimization ecosystem.
Variables register their parameters with meaningful names (not just 'value')
to improve debugging and state_dict readability.
Args:
*size_or_tensor: Either a size tuple (e.g., (5,) or (5, 10)) or an
existing torch.Tensor to wrap.
requires_grad: Whether the parameter should track gradients.
var_id: Optional custom ID for the variable.
name: Optional custom name for the variable.
dtype: Data type for the variable tensor.
device: Device to place the variable on.
Examples:
>>> # Create from size
>>> x = Variable((5,), name='weights')
>>> x.value.shape
torch.Size([5])
>>> # Create matrix variable
>>> A = Variable((3, 4,), name='matrix')
>>> A.value.shape
torch.Size([3, 4])
>>> # Create from existing tensor
>>> data = torch.randn(10)
>>> y = Variable(data, name='initialized')
"""
[docs]
def __init__(
self,
size_or_tensor: int | tuple[int, ...] | torch.Tensor,
requires_grad: bool = True,
var_id: int | None = None,
name: str | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
"""Constructor method for Variable."""
super().__init__()
if isinstance(size_or_tensor, torch.Tensor):
data = size_or_tensor
elif isinstance(size_or_tensor, (tuple, int)): # Accept int or tuple
# Normalize int to tuple
size = (
(size_or_tensor,) if isinstance(size_or_tensor, int) else size_or_tensor
)
data = torch.zeros(size, dtype=dtype, device=device)
else:
raise TypeError(
f"size must be int, tuple[int, ...] or torch.Tensor, "
f"got {type(size_or_tensor)}"
)
self._set_id_and_name(var_id, name)
# Register parameter with variable's name for better state_dict readability
self.register_parameter(self._name, torch.nn.Parameter(data, requires_grad))
def _set_id_and_name(self, var_id=None, name=None):
"""Set ID and name attributes.
Args:
var_id: Optional custom ID (generates unique ID if None).
name: Optional custom name (generates default if None).
Raises:
TypeError: If name is not a string.
"""
# Set ID
if var_id is None:
self._id = get_id()
else:
self._id = var_id
# Set name
if name is None:
self._name = f"{VAR_PREFIX}{self._id}"
elif isinstance(name, str):
self._name = name
else:
raise TypeError(f"Expected name to be a string, got {type(name)} instead.")
@classmethod
def like(cls, expr: Expression) -> Self:
"""Create a new Variable with same shape, dtype, and device as an expression.
This is useful for creating auxiliary variables in decomposition methods
that match the properties of an existing expression.
Args:
expr: Expression to match properties from.
Returns:
Self: New variable with same shape, dtype, and device as expr's value.
"""
value = expr.forward()
return cls(value.shape, dtype=value.dtype, device=value.device)
@property
def value(self) -> torch.nn.Parameter:
"""Get the parameter value.
Returns the underlying torch.nn.Parameter using the variable's name.
This allows accessing parameters as x.value while storing them with
meaningful names in the state_dict.
Returns:
torch.nn.Parameter: The parameter tensor.
Examples:
>>> x = Variable((5,), name='alpha')
>>> x.value.data = torch.ones(5)
>>> x.value
Parameter containing:
tensor([1., 1., 1., 1., 1.], requires_grad=True)
"""
return getattr(self, self._name)
@value.setter
def value(self, val: torch.Tensor):
"""Set the parameter value.
Replaces the entire parameter with a new torch.nn.Parameter wrapping
the provided tensor. Preserves the requires_grad setting from the
existing parameter if it exists.
Args:
val: New tensor value to wrap in a Parameter.
Examples:
>>> x = Variable((5,), name='beta')
>>> x.value = torch.randn(5)
>>> x.value.shape
torch.Size([5])
"""
requires_grad = getattr(self.value, "requires_grad", True)
self.register_parameter(self._name, torch.nn.Parameter(val, requires_grad))
@property
def id(self) -> int:
"""Get the unique identifier of the variable.
Returns:
int: The variable's unique ID.
"""
return self._id
@property
def name(self) -> str:
"""Get the name of the variable.
Returns:
str: The variable's name.
"""
return self._name
[docs]
def __repr__(self):
"""Full representation of the Variable.
Returns:
str: Detailed string representation including all attributes.
"""
info_components = [
f"Variable(name='{self.name}'",
f"id='{self.id}'",
f"shape={tuple(self.value.shape)}",
f"dtype={self.value.dtype}",
f"device='{self.value.device}'",
f"requires_grad={self.value.requires_grad}",
]
info = ", ".join(info_components)
return info + ")"
[docs]
def __str__(self):
"""Shortened representation of the Variable.
Returns:
str: Brief string representation.
"""
return f"Variable '{self.name}' with shape {self.value.shape}"
[docs]
def is_smooth(self):
"""Variables are smooth (identity function is differentiable).
Returns:
bool: Always True.
"""
return True
[docs]
def forward(self) -> torch.Tensor:
"""Evaluate the variable (returns its current value).
Returns:
torch.Tensor: The parameter tensor.
"""
return self.value
[docs]
def tree(self) -> ExprTree:
"""Return tree representation for Variable (leaf node).
Returns:
ExprTree: Leaf node with type 'Variable'.
"""
return ExprTree(f"Variable({self._name})")
[docs]
def is_affine(self) -> bool:
"""Variables are affine expressions.
Returns:
bool: Always True.
"""
return True