Expression Module#

Expressions for modeling optimization problems.

Core Classes#

class rlaopt.expression.Expression[source]#

Bases: Module, ABC

Abstract base class for all expressions.

Expression extends torch.nn.Module to provide automatic parameter tracking, gradient computation, and device management. All concrete expression types (Variable, SumExpression, etc.) inherit from this base class.

Expressions support operator overloading for natural mathematical syntax:
  • Arithmetic: +, -, , /, @, *

  • Comparisons: Used in constraints (future)

  • Composition: Expressions can be nested arbitrarily

None#
Type:

abstract class

__init__()[source]#

Initialize the Expression base class.

abstractmethod is_smooth()[source]#

Check if the expression is smooth (differentiable everywhere).

Smoothness is important for choosing optimization algorithms. Smooth expressions can use gradient-based methods, while non-smooth expressions require specialized algorithms like proximal methods or subgradient methods.

Returns:

True if the expression is smooth, False otherwise.

Return type:

bool

abstractmethod forward()[source]#

Evaluate the expression using current parameter values.

This method evaluates the expression with the current values of all registered parameters.

Returns:

The evaluated result.

Return type:

torch.Tensor

abstractmethod tree()[source]#

Return a tree representation of the expression structure.

Returns an ExprTree object showing how expressions are composed. Useful for testing that expression optimizations (constant folding, flattening, etc.) are working correctly.

Each subclass implements this method to expose its own structure without leaking implementation details to the base class.

Returns:

Tree structure with node type and children.

Return type:

ExprTree

is_affine()[source]#

Check if the expression is affine.

An expression is affine if it can be expressed as a linear function of variables plus a constant.

Returns:

True if the expression is affine, False otherwise.

Return type:

bool

evaluate(variable_values)[source]#

Evaluate the expression at specified variable values.

Unlike forward(), this method evaluates the expression at variable values different from those currently stored. Useful for line searches, parameter exploration, etc.

Note that the user can specify a partial set of variables; any variables not included will be evaluated at their current stored values. Any variables in the input that are not part of this expression will be ignored.

Parameters:

variable_values (TensorDict) – Dictionary mapping variable names to their values.

Returns:

The evaluated result.

Return type:

torch.Tensor

Examples

>>> x = Variable((5,), name='x')
>>> new_params = TensorDict({'x': torch.ones(5)})
>>> result = x.evaluate(new_params)
>>> torch.equal(result, new_params['x'])
True
property variable_values: TensorDict#

Get variables as a dictionary.

Returns:

Dictionary of variable names to variable tensors.

Return type:

TensorDict

get_variable_names()[source]#

Returns the list of variable names in order.

Return type:

list[str]

get_variable_shapes()[source]#

Returns a dictionary mapping variable names to their shapes.

Return type:

dict[str, Size]

update_variables(variable_values)[source]#

Update variables from a TensorDict.

Note that this method allows partial updates; only the variables specified in the input TensorDict will be updated, while others will remain unchanged. Any variables in the input that are not part of this expression will be ignored.

Parameters:

variable_values (TensorDict) – TensorDict with new variable values.

Examples

>>> x = Variable((5,), name='x')
>>> x.update_variables(TensorDict({'x': torch.ones(5)}))
>>> torch.equal(x.value, torch.ones(5))
True
select_relevant_variables(variable_values)[source]#

Select variables relevant to this expression from a TensorDict.

Parameters:

variable_values (TensorDict) – TensorDict containing variable values.

Returns:

TensorDict with only variables relevant to this expression.

Return type:

TensorDict

transpose()[source]#

Create a transpose operation.

Optimizes double transpose (A.T.T returns A).

Returns:

Transposed expression or original if already transposed twice.

Return type:

Expression

property T: Self#

Transpose property (shorthand for transpose()).

Returns:

Transposed expression.

Return type:

Expression

sum(dim=None)[source]#

Create a sum reduction operation.

Parameters:

dim – Dimension to sum over (None for all dimensions).

Returns:

Expression computing the sum.

Return type:

ReduceSum

__add__(other)[source]#

Add two expressions or an expression and a scalar.

Parameters:

other – Expression, float, or int to add.

Returns:

Sum of self and other (optimized).

Return type:

Expression

__radd__(other)[source]#

Add a scalar and an expression (reverse operation).

Parameters:

other – Float or int to add.

Returns:

Sum of other and self (optimized).

Return type:

Expression

__sub__(other)[source]#

Subtract an expression or scalar from this expression.

Parameters:

other – Expression, float, or int to subtract.

Returns:

Difference of self and other (optimized).

Return type:

Expression

__rsub__(other)[source]#

Subtract this expression from a scalar (reverse operation).

Parameters:

other – Float or int to subtract from.

Returns:

Difference of other and self (optimized).

Return type:

Expression

__neg__()[source]#

Negate this expression.

Returns:

Negation of self.

Return type:

Expression

__mul__(other)[source]#

Multiply this expression by another (elementwise).

Parameters:

other – Expression, float, or int to multiply.

Returns:

Elementwise product of self and other.

If multiplying a scalar by a sum, automatically distributes.

Return type:

Expression

__rmul__(other)[source]#

Multiply a scalar by this expression (reverse operation).

Parameters:

other – Float or int to multiply.

Returns:

Elementwise product of other and self.

If multiplying a scalar by a sum, automatically distributes.

Return type:

Expression

__truediv__(other)[source]#

Divide this expression by a scalar.

Parameters:

other – Float or int to divide by.

Returns:

Result of division (optimized).

If dividing a sum by a scalar, automatically distributes.

Return type:

Expression

__matmul__(other)[source]#

Matrix multiply this expression by another.

Parameters:

other – Expression, float, or int to multiply.

Returns:

Matrix product of self and other.

Return type:

Expression

__rmatmul__(other)[source]#

Matrix multiply a value by this expression (reverse operation).

Parameters:

other – Expression, float, or int to multiply.

Returns:

Matrix product of other and self.

Return type:

Expression

__pow__(exponent)[source]#

Raise this expression to a power (elementwise).

Parameters:

exponent – Power to raise to.

Returns:

Result of exponentiation.

Return type:

Expression

class rlaopt.expression.Variable(size_or_tensor, requires_grad=True, var_id=None, name=None, dtype=None, device=None)[source]#

Bases: Expression

Leaf optimization variable wrapping a torch.nn.Parameter.

Variable represents a trainable parameter in an optimization problem. It extends Expression to provide automatic differentiation, device management, and integration with PyTorch’s optimization ecosystem.

Variables register their parameters with meaningful names (not just ‘value’) to improve debugging and state_dict readability.

Parameters:
  • *size_or_tensor (int | tuple[int, ...] | Tensor) – Either a size tuple (e.g., (5,) or (5, 10)) or an existing torch.Tensor to wrap.

  • requires_grad (bool) – Whether the parameter should track gradients.

  • var_id (int | None) – Optional custom ID for the variable.

  • name (str | None) – Optional custom name for the variable.

  • dtype (dtype | None) – Data type for the variable tensor.

  • device (device | None) – Device to place the variable on.

Examples

>>> # Create from size
>>> x = Variable((5,), name='weights')
>>> x.value.shape
torch.Size([5])
>>> # Create matrix variable
>>> A = Variable((3, 4,), name='matrix')
>>> A.value.shape
torch.Size([3, 4])
>>> # Create from existing tensor
>>> data = torch.randn(10)
>>> y = Variable(data, name='initialized')
__init__(size_or_tensor, requires_grad=True, var_id=None, name=None, dtype=None, device=None)[source]#

Constructor method for Variable.

Parameters:
property value: Parameter#

Get the parameter value.

Returns the underlying torch.nn.Parameter using the variable’s name. This allows accessing parameters as x.value while storing them with meaningful names in the state_dict.

Returns:

The parameter tensor.

Return type:

torch.nn.Parameter

Examples

>>> x = Variable((5,), name='alpha')
>>> x.value.data = torch.ones(5)
>>> x.value
Parameter containing:
tensor([1., 1., 1., 1., 1.], requires_grad=True)
property id: int#

Get the unique identifier of the variable.

Returns:

The variable’s unique ID.

Return type:

int

property name: str#

Get the name of the variable.

Returns:

The variable’s name.

Return type:

str

__repr__()[source]#

Full representation of the Variable.

Returns:

Detailed string representation including all attributes.

Return type:

str

__str__()[source]#

Shortened representation of the Variable.

Returns:

Brief string representation.

Return type:

str

is_smooth()[source]#

Variables are smooth (identity function is differentiable).

Returns:

Always True.

Return type:

bool

forward()[source]#

Evaluate the variable (returns its current value).

Returns:

The parameter tensor.

Return type:

torch.Tensor

tree()[source]#

Return tree representation for Variable (leaf node).

Returns:

Leaf node with type ‘Variable’.

Return type:

ExprTree

is_affine()[source]#

Variables are affine expressions.

Returns:

Always True.

Return type:

bool

class rlaopt.expression.Constant(value)[source]#

Bases: Expression

Constant value expression.

Represents a constant (non-trainable) value in an expression tree. Constants are stored as buffers rather than parameters, so they don’t receive gradients and don’t appear in parameter optimization.

Parameters:

value (float | int | Tensor) – The constant value (float, int, or torch.Tensor).

_value#

The constant value stored as a buffer.

Examples

>>> c = Constant(3.14)
>>> c.forward()
tensor(3.1400)
>>> c2 = Constant(torch.ones(5))
>>> c2.forward().shape
torch.Size([5])
__init__(value)[source]#

Initialize a constant expression.

Parameters:

value (float | int | Tensor) – The constant value to store.

property value#

Get the constant value.

Returns:

The constant value.

Return type:

torch.Tensor

is_smooth()[source]#

Constants are smooth (trivially differentiable).

Returns:

Always True.

Return type:

bool

forward()[source]#

Evaluate the constant (returns itself).

Returns:

The constant value.

Return type:

torch.Tensor

tree()[source]#

Return tree representation for Constant (leaf node).

Returns:

Leaf node with type ‘Constant’.

Return type:

ExprTree

is_affine()[source]#

Constants are affine expressions.

Returns:

Always True.

Return type:

bool

__neg__()[source]#

Negate the constant (keeps it as a constant).

Returns:

Negated constant.

Return type:

Constant

class rlaopt.expression.ExprTree(node_type, *children, is_commutative=False)[source]#

Bases: object

Tree representation of expression structure.

Used for testing expression optimizations and visualizing expression structure. Each node has a type (class name) and zero or more children.

Parameters:
  • node_type (str)

  • children (Self)

  • is_commutative (bool)

__init__(node_type, *children, is_commutative=False)[source]#

Initialize an expression tree node.

Parameters:
  • node_type (str) – The expression class name.

  • *children (Self) – Child ExprTree nodes.

  • is_commutative (bool) – Whether this operation is commutative (e.g., addition, multiplication). For commutative operations, child order doesn’t matter for equality/hashing.

__eq__(other)[source]#

Check structural equality of two expression trees.

For commutative operations, children can be in any order. For other operations, order matters.

Parameters:

other – Another ExprTree to compare with.

Returns:

True if trees have the same structure and node types.

Return type:

bool

__hash__()[source]#

Hash the tree for use in sets and dictionaries.

For commutative operations, hash is order-independent.

Returns:

Hash of the tree structure.

Return type:

int

__repr__()[source]#

Return a code-like representation of the tree.

Returns:

String representation suitable for debugging.

Return type:

str

__str__()[source]#

Return a pretty-printed tree visualization.

Returns:

Human-readable tree with indentation and branches.

Return type:

str

depth()[source]#

Calculate the depth of the tree.

Returns:

Maximum depth from this node to any leaf.

Return type:

int

count_nodes()[source]#

Count the total number of nodes in the tree.

Returns:

Total number of nodes (including this one).

Return type:

int