Skip to content

Internals

The package is built around a generic multi-block ADMM interface that can serve as the basis for other ADMM-based algorithms.

The basic loop of any ADMM algorithm looks as follows:

  1. Initialize state
  2. Repeat until convergence or maximum iterations
    1. Update primal blocks in specified order
    2. Update dual variables (constraint multipliers)
    3. Evaluate augmented Lagrangian objective
    4. Check for convergence

ADMM base class

solrcmf.admm.ADMM

Base class for multi-block ADMM algorithms.

Subclasses implement _setup to construct the Context (blocks, constraints, and parameters) for a specific problem, and override score and transform for evaluation and reconstruction.

The iteration alternates between updating all primal blocks in ctx.block_order, updating all dual (constraint multiplier) variables, and evaluating the augmented Lagrangian objective. Convergence is declared when the absolute change in objective satisfies

|obj_old - obj| <= max(rel_tol * |obj_old|, abs_tol)

skipping the first iteration to avoid spurious convergence from the initialisation.

Attributes:

Name Type Description
objs_ list[float]

Objective value at each iteration.

gaps_ list[float]

Change in objective (obj_old - obj) at each iteration.

converged_ bool

Whether the convergence criterion was met.

objective_value_ float

Final objective value.

n_iter_ int

Number of iterations performed.

elapsed_process_time_ float

CPU time consumed by the iteration loop.

ctx_ Context

The context object (only when save_ctx=True).

Methods:

Name Description
__init__

Initialize a new instance of the ADMM algorithm.

fit

Run the ADMM iteration until convergence or max_iter.

score

Evaluate the fit quality on X.

transform

Return the low-rank reconstruction of X.

Source code in src/solrcmf/admm.py
class ADMM[
    BT: DataclassInstance,
    CT: DataclassInstance,
    PT: DataclassInstance,
](BaseEstimator, ABC):
    """Base class for multi-block ADMM algorithms.

    Subclasses implement `_setup` to construct the `Context` (blocks,
    constraints, and parameters) for a specific problem, and override
    `score` and `transform` for evaluation and reconstruction.

    The iteration alternates between updating all primal blocks in
    `ctx.block_order`, updating all dual (constraint multiplier) variables,
    and evaluating the augmented Lagrangian objective. Convergence is
    declared when the absolute change in objective satisfies

        |obj_old - obj| <= max(rel_tol * |obj_old|, abs_tol)

    skipping the first iteration to avoid spurious convergence from the
    initialisation.

    Attributes:
        objs_ (list[float]): Objective value at each iteration.
        gaps_ (list[float]): Change in objective (obj_old - obj) at each
            iteration.
        converged_ (bool): Whether the convergence criterion was met.
        objective_value_ (float): Final objective value.
        n_iter_ (int): Number of iterations performed.
        elapsed_process_time_ (float): CPU time consumed by the iteration loop.
        ctx_ (Context): The context object (only when save_ctx=True).

    """

    _parameter_constraints = {
        "max_iter": [Interval(Integral, 1, None, closed="left")],
        "abs_tol": [Interval(Real, 0, None, closed="neither")],
        "rel_tol": [Interval(Real, 0, None, closed="neither")],
        "save_ctx": ["boolean"],
    }

    def __init__(
        self,
        max_iter: int = 1000,
        abs_tol: float = 1e-6,
        rel_tol: float = 1e-6,
        *,
        save_ctx: bool = False,
    ):
        """Initialize a new instance of the ADMM algorithm.

        Args:
            max_iter: Maximum number of iterations
            abs_tol: Absolute convergence tolerance
            rel_tol: Relative convergence tolerance
            save_ctx: Whether or not context object should be saved upon
                      convergence.

        """
        self.max_iter = max_iter
        self.abs_tol = abs_tol
        self.rel_tol = rel_tol

        self.save_ctx = save_ctx

    @abstractmethod
    def _setup(self, X, **kwargs) -> Context[BT, CT, PT]:
        """Set up the estimation problem.

        Called after data is available.
        """
        raise NotImplementedError(
            f"_setup method on {self.__class__.__name__} not implemented"
        )

    def fit(self, X, y=None, **kwargs):
        """Run the ADMM iteration until convergence or max_iter.

        Calls `_setup` to build the context, then alternates between primal
        block updates and dual (constraint multiplier) updates. The augmented
        Lagrangian objective is evaluated after each full sweep and checked
        against the convergence criterion.
        """
        # Validate parameters; should check parameters of
        # derived classes as well
        self._validate_params()

        # Setup ADMM context
        ctx = self._setup(X, **kwargs)

        start_time = process_time()

        objs = []
        gaps = []

        converged = False
        obj_old = inf
        for i in range(self.max_iter):
            # Update variable blocks
            for name, idx in ctx.block_order:
                bgroup = getattr(ctx.blocks, name)
                update(bgroup[idx], ctx)

            # Update constraints
            for cnstrnt in fields(ctx.constraints):
                cgroup = getattr(ctx.constraints, cnstrnt.name)
                for c in cgroup.values():
                    update(c, ctx)

            obj = _objective(ctx)
            gap = obj_old - obj

            objs.append(obj)
            gaps.append(gap)

            if i > 0 and abs(gap) <= max(
                self.rel_tol * abs(obj_old), self.abs_tol
            ):
                converged = True
                break

            obj_old = obj

        end_time = process_time()

        self.objs_ = objs
        self.gaps_ = gaps
        self.converged_ = converged
        self.objective_value_ = objs[-1]
        self.n_iter_ = i + 1
        self.elapsed_process_time_ = end_time - start_time

        for k, v in self._extra_attrs(ctx).items():
            setattr(self, k, v)

        if self.save_ctx:
            self.ctx_ = ctx

        return self

    @abstractmethod
    def score(self, X, **kwargs):
        """Evaluate the fit quality on X."""
        pass

    @abstractmethod
    def transform(self, X, y=None, **kwargs):
        """Return the low-rank reconstruction of X."""
        pass

    def _extra_attrs(self, ctx: Context[BT, CT, PT]) -> dict[str, Any]:
        """Return additional attributes to set on the estimator after fitting.

        Subclasses override this to expose problem-specific fitted quantities
        (e.g. factor matrices, structure patterns) without touching `fit`.
        """
        return {}
__init__(max_iter=1000, abs_tol=1e-06, rel_tol=1e-06, *, save_ctx=False)

Initialize a new instance of the ADMM algorithm.

Parameters:

Name Type Description Default
max_iter int

Maximum number of iterations

1000
abs_tol float

Absolute convergence tolerance

1e-06
rel_tol float

Relative convergence tolerance

1e-06
save_ctx bool

Whether or not context object should be saved upon convergence.

False
Source code in src/solrcmf/admm.py
def __init__(
    self,
    max_iter: int = 1000,
    abs_tol: float = 1e-6,
    rel_tol: float = 1e-6,
    *,
    save_ctx: bool = False,
):
    """Initialize a new instance of the ADMM algorithm.

    Args:
        max_iter: Maximum number of iterations
        abs_tol: Absolute convergence tolerance
        rel_tol: Relative convergence tolerance
        save_ctx: Whether or not context object should be saved upon
                  convergence.

    """
    self.max_iter = max_iter
    self.abs_tol = abs_tol
    self.rel_tol = rel_tol

    self.save_ctx = save_ctx
fit(X, y=None, **kwargs)

Run the ADMM iteration until convergence or max_iter.

Calls _setup to build the context, then alternates between primal block updates and dual (constraint multiplier) updates. The augmented Lagrangian objective is evaluated after each full sweep and checked against the convergence criterion.

Source code in src/solrcmf/admm.py
def fit(self, X, y=None, **kwargs):
    """Run the ADMM iteration until convergence or max_iter.

    Calls `_setup` to build the context, then alternates between primal
    block updates and dual (constraint multiplier) updates. The augmented
    Lagrangian objective is evaluated after each full sweep and checked
    against the convergence criterion.
    """
    # Validate parameters; should check parameters of
    # derived classes as well
    self._validate_params()

    # Setup ADMM context
    ctx = self._setup(X, **kwargs)

    start_time = process_time()

    objs = []
    gaps = []

    converged = False
    obj_old = inf
    for i in range(self.max_iter):
        # Update variable blocks
        for name, idx in ctx.block_order:
            bgroup = getattr(ctx.blocks, name)
            update(bgroup[idx], ctx)

        # Update constraints
        for cnstrnt in fields(ctx.constraints):
            cgroup = getattr(ctx.constraints, cnstrnt.name)
            for c in cgroup.values():
                update(c, ctx)

        obj = _objective(ctx)
        gap = obj_old - obj

        objs.append(obj)
        gaps.append(gap)

        if i > 0 and abs(gap) <= max(
            self.rel_tol * abs(obj_old), self.abs_tol
        ):
            converged = True
            break

        obj_old = obj

    end_time = process_time()

    self.objs_ = objs
    self.gaps_ = gaps
    self.converged_ = converged
    self.objective_value_ = objs[-1]
    self.n_iter_ = i + 1
    self.elapsed_process_time_ = end_time - start_time

    for k, v in self._extra_attrs(ctx).items():
        setattr(self, k, v)

    if self.save_ctx:
        self.ctx_ = ctx

    return self
score(X, **kwargs) abstractmethod

Evaluate the fit quality on X.

Source code in src/solrcmf/admm.py
@abstractmethod
def score(self, X, **kwargs):
    """Evaluate the fit quality on X."""
    pass
transform(X, y=None, **kwargs) abstractmethod

Return the low-rank reconstruction of X.

Source code in src/solrcmf/admm.py
@abstractmethod
def transform(self, X, y=None, **kwargs):
    """Return the low-rank reconstruction of X."""
    pass

Building blocks

solrcmf.base.Block dataclass

A single block of ADMM primal variables.

Each block owns one array value and contributes to both the primal update (via update) and the objective function (via objective). It is identified by a name and an index idx. The update and objective singledispatch functions are registered separately for each concrete subclass.

Attributes:

Name Type Description
name str

Attribute name on the blocks dataclass (e.g. "z", "d", "v").

idx IdxT

Key under which this block is stored in the dict on that attr.

shape tuple[int, ...]

Expected shape of value.

value NDArray[float64]

The current iterate; initialised by the concrete update.

Source code in src/solrcmf/base.py
@dataclass
class Block[IdxT]:
    """A single block of ADMM primal variables.

    Each block owns one array `value` and contributes to both the primal
    update (via `update`) and the objective function (via `objective`).
    It is identified by a name and an index `idx`. The `update` and
    `objective` singledispatch functions are registered separately for
    each concrete subclass.

    Attributes:
        name: Attribute name on the blocks dataclass (e.g. "z", "d", "v").
        idx: Key under which this block is stored in the dict on that attr.
        shape: Expected shape of `value`.
        value: The current iterate; initialised by the concrete update.

    """

    name: str
    idx: IdxT
    shape: tuple[int, ...]
    value: NDArray[float64] = field(init=False, repr=False)

solrcmf.base.Constraint dataclass

A block of ADMM dual variables enforcing a multi-affine constraint.

value holds the dual multiplier. update performs the dual ascent step value += residual. objective computes the augmented Lagrangian penalty term. Both use the constraint singledispatch function to obtain the primal residual, which is cached in residual after each update to avoid recomputation in objective.

Attributes:

Name Type Description
residual NDArray[float64]

The most recently computed primal residual; set by update and consumed by objective within the same iteration.

Source code in src/solrcmf/base.py
@dataclass
class Constraint[IdxT](Block[IdxT]):
    """A block of ADMM dual variables enforcing a multi-affine constraint.

    `value` holds the dual multiplier. `update` performs the dual ascent
    step value += residual. `objective` computes the augmented Lagrangian
    penalty term. Both use the `constraint` singledispatch function to
    obtain the primal residual, which is cached in `residual` after each
    `update` to avoid recomputation in `objective`.

    Attributes:
        residual: The most recently computed primal residual; set by
            `update` and consumed by `objective` within the same iteration.

    """

    residual: NDArray[float64] = field(init=False, repr=False)

solrcmf.base.Context dataclass

Shared state passed to every block and constraint update.

Holds the primal blocks, dual constraints, algorithm parameters, observed data, and the ordered list of blocks to update each iteration.

Attributes:

Name Type Description
blocks BT

Dataclass holding all primal block dicts.

constraints CT

Dataclass holding all constraint (dual variable) dicts.

params PT

Algorithm parameters (rho, penalties, etc.).

data dict[ViewDesc, NDArray[float64]]

Observed data matrices keyed by view descriptor.

block_order list[tuple[str, BlockDesc]]

Ordered list of (name, idx) pairs defining the update sequence within each ADMM iteration.

Methods:

Name Description
add_block

Instantiate a primal block and register it in the update order.

add_constraint

Instantiate a constraint (dual variable) block.

Source code in src/solrcmf/base.py
@dataclass
class Context[
    BT: DataclassInstance,
    CT: DataclassInstance,
    PT: DataclassInstance,
]:
    """Shared state passed to every block and constraint update.

    Holds the primal blocks, dual constraints, algorithm parameters,
    observed data, and the ordered list of blocks to update each iteration.

    Attributes:
        blocks: Dataclass holding all primal block dicts.
        constraints: Dataclass holding all constraint (dual variable) dicts.
        params: Algorithm parameters (rho, penalties, etc.).
        data: Observed data matrices keyed by view descriptor.
        block_order: Ordered list of (name, idx) pairs defining the update
            sequence within each ADMM iteration.

    """

    blocks: BT
    constraints: CT
    params: PT
    data: dict[ViewDesc, NDArray[float64]] = field(default_factory=dict)
    block_order: list[tuple[str, BlockDesc]] = field(default_factory=list)

    def add_block(
        self,
        name: str,
        idx: BlockDesc,
        block_type: type[Block],
        shape: tuple[int, ...],
    ):
        """Instantiate a primal block and register it in the update order."""
        self.block_order.append((name, idx))
        getattr(self.blocks, name)[idx] = block_type(name, idx, shape)

    def add_constraint(
        self,
        name: str,
        idx: BlockDesc,
        constraint_type: type[Constraint],
        shape: tuple[int, ...],
    ):
        """Instantiate a constraint (dual variable) block."""
        getattr(self.constraints, name)[idx] = constraint_type(
            name, idx, shape
        )
add_block(name, idx, block_type, shape)

Instantiate a primal block and register it in the update order.

Source code in src/solrcmf/base.py
def add_block(
    self,
    name: str,
    idx: BlockDesc,
    block_type: type[Block],
    shape: tuple[int, ...],
):
    """Instantiate a primal block and register it in the update order."""
    self.block_order.append((name, idx))
    getattr(self.blocks, name)[idx] = block_type(name, idx, shape)
add_constraint(name, idx, constraint_type, shape)

Instantiate a constraint (dual variable) block.

Source code in src/solrcmf/base.py
def add_constraint(
    self,
    name: str,
    idx: BlockDesc,
    constraint_type: type[Constraint],
    shape: tuple[int, ...],
):
    """Instantiate a constraint (dual variable) block."""
    getattr(self.constraints, name)[idx] = constraint_type(
        name, idx, shape
    )

Single-dispatch functions

The following functions drive the main update logic in the ADMM algorithm. They are defined with the functools.singledispatch decorator and subclasses of solrcmf.base.Block need to register their polymorphic update and objective functions. Subclasses of solrcmf.base.Constraint need to register their polymorphic constraint function.

solrcmf.base.update(block, ctx)

Update the block variables.

Source code in src/solrcmf/base.py
@singledispatch
def update[
    BT: DataclassInstance,
    CT: DataclassInstance,
    PT: DataclassInstance,
](block: Block, ctx: Context[BT, CT, PT]):
    """Update the block variables."""
    raise NotImplementedError(f"update() not implemented for {type(block)}.")

solrcmf.base.objective(block, ctx)

Compute the contribution to the objective.

Source code in src/solrcmf/base.py
@singledispatch
def objective[
    BT: DataclassInstance,
    CT: DataclassInstance,
    PT: DataclassInstance,
](block: Block, ctx: Context[BT, CT, PT]) -> float:
    """Compute the contribution to the objective."""
    return 0.0

solrcmf.base.constraint(block, _ctx)

Return the lhs of a constraint f(x) = 0.

Source code in src/solrcmf/base.py
@singledispatch
def constraint[
    BT: DataclassInstance,
    CT: DataclassInstance,
    PT: DataclassInstance,
](block: Constraint, _ctx: Context[BT, CT, PT]) -> NDArray[float64]:
    """Return the lhs of a constraint f(x) = 0."""
    raise NotImplementedError(
        f"constraint() not implemented for {type(block)}."
    )