Batching `LinearOperator`: status and benchmarks

Hello all, I come bearing plots to share my work on batching scipy.sparse.linalg.LinearOperator and its potential use in the iterative linear solvers.

Status

The linked PR is ready for review, and includes:

  • batch support for the entire interface[1]
  • tests across various kinds of operators and different kinds of batching
  • a documentation overhaul
  • housekeeping to reject >2-D operators in consuming algorithms for now
  • substantial self-review to aid reviewers
  • green CI :christmas_tree:

I plan to let the PR sit for two to three weeks, before returning to hopefully make the final push to get it merged. Any reviews will be greatly appreciated!

Due diligence on un-batched performance

Firstly, I checked the overhead introduced for the un-batched case via the pre-existing benchmarks[2]:

The dense and spsolve lines don’t use the LinearOperator interface, so we can see that any additional overhead appears to disappear down to the level of noise by the time we get to n=2^8. The overhead for smaller problems seems sufficiently modest: the highest ratio comes from a <1ms delta, and the worst point there (tfqmr at n=2^7) is still just a delta of <20ms.

This seems very acceptable to me, although I don’t know exactly where the overhead is coming from, and why it seems to affect tfqmr more than the other solvers — anyone have a clue?

Batched cg performance

Now for the fun part: showing that this is useful. I have a rough implementation of batched conjugate gradient on the go at wip: N-D `LinearOperator` and `cg` by lucascolley · Pull Request #35 · lucascolley/scipy · GitHub which builds on top of the batched LinearOperator PR. This implementation supports batched LHS as well as the array API standard, but for now let me share some plots regarding performance with a single LHS and batched RHS.

The cg implementation is batched in the sense that it processes the entire input batch together, stepping all systems through each iteration simultaneously. Converged systems have to ‘go through the motions’ until every system has passed the convergence check. The advantage comes primarily from batched matvec calls.

n-D sparse.coo_array batched vs un-batched

For this benchmark, we use cg to solve (for \mathbf{x}) many instances of

\mathbf{b} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \qquad \mathbf{P}\mathbf{x} = \mathbf{b}

where \mathbf{P} is Gilbert Strang’s favourite matrix of shape (n^2, n^2).

In the non-batched case, we represent \mathbf{P} as a scipy.sparse.csr_array[3], and solve against a list of b vectors in a Python for loop, where the list contains batch_size many vectors.

In the batched case, we represent the batch of \mathbf{P} as a 3-D scipy.sparse.coo_array of shape (batch_size, n*n, n*n), and b as a batch of vectors with shape (batch_size, n*n). We pass these to a single call of cg.

I ran the un-batched case on top of main to account for any new overhead. Here are the results[4]:

We see performance improvements of up to 200x for large batches of small systems.

Dense NumPy batched vs un-batched

I also ran a benchmark with dense NumPy arrays. This time, for a matrix \mathbf{M} of shape (n^2, n^2) with normally distributed entries, we define a linear operator A by the function:

def matvec(x):
    return np.squeeze(M.mT @ (M @ x[..., np.newaxis]), axis=-1) + 1e-3 * x

We use cg to solve (for \mathbf{x}) many instances of

\mathbf{b} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \qquad \mathbf{A}\mathbf{x} = \mathbf{b}

We see similar results, perhaps performing notably better around n=8.

Overhead check

I also checked how much overhead the rough array API standard support implementation adds for the un-batched case:

The overhead is more prominent while we are down at the millisecond level, but it again falls to a seemingly acceptable level for sufficiently large matrices, so I don’t think this needs to worry us yet.

Further work

There is a lot of work left on the table, and I plan to spend the coming months investigating much of it, including:

  • benchmarking with pathological cases in the batch
  • benchmarking of batched LHS
  • benchmarking with pydata/sparse
  • benchmarking of array API standard support and on the GPU
  • work on the other iterative solvers in sparse.linalg
  • investigating more parallelism strategies (dividing batches in sub-batches?[5])
  • investigating possible algorithmic changes, like ‘zeroing-out’ of rows corresponding to converged systems

Hopefully the results so far show that this is something worth pursuing!


It has been a fun but challenging exercise wrangling this section of code which had come to need some TLC as the years went by. I look forward to continuing the effort soon!

With happy holiday wishes :christmas_tree::mx_claus: :snowman: ,
Lucas

P.S. for anyone wondering how on earth I’ve found the time to work on this, I am planning to write up my Master’s dissertation based on this work.


  1. with one caveat related to the behaviour of an undocumented __rmul__ ↩︎

  2. Results generated with pixi run bench -t sparse_linalg_solve.Bench --compare base-bench-23836 on bench · lucascolley/scipy@20a8769 · GitHub and bench · lucascolley/scipy@eb46853 · GitHub. Plot generated with the script at 14/12/2025 Batched LinearOperator overhead plot code · GitHub. ↩︎

  3. This was pre-existing in the benchmarks, I assume it is expected to be the fastest? In any case, I also benchmarked this against using COO arrays in the un-batched case and found very similar results. ↩︎

  4. Results generated with pixi run bench -t sparse_linalg_solve.BatchedCG --compare base-bench-23836 on WIP: benchmarking · lucascolley/scipy@bc57a7c · GitHub and bench · lucascolley/scipy@eb46853 · GitHub. Plot generated with the script at … ↩︎

  5. Performance Portable Batched Sparse Linear Solvers | IEEE Journals & Magazine | IEEE Xplore takes that approach of dishing out mini-batches to ‘teams’, which seems to work harmoniously with the parallelism infrastructure available in Kokkos. ↩︎

2 Likes

Hi all, following approval from @mdhaber, I plan to merge ENH/DEP: sparse.linalg.LinearOperator: n-D batch support by lucascolley · Pull Request #23836 · scipy/scipy · GitHub soon if there is no further feedback. Any further reviews would be greatly appreciated, including just thoughts on individual elements of the change, such as the deprecation, documentation, or new tests.

Copying below the top-post of the PR:

What does this implement/fix?

Support for n-dimensional batches of linear operators in sparse.linalg.LinearOperator.

Deprecation

This PR introduces a FutureWarning for calling {r}matvec on column vectors. That API does not extend cleanly to batch dimensions, but identical behaviour can be achieved (and extended to batch dimensions) via {r}matmat.

Future work

Immediate follow-up (me): ENH: sparse.linalg.LinearOperator: empty batch support · Issue #24562 · scipy/scipy · GitHub

What I plan to work on:

  • array API standard support (including take another look at pydata/sparse)
  • supporting these batched operators in the iterative solvers of this submodule

Potential follow-ups that might be nice (but I don’t have plans to work on) based on comments here and TODOs in the code:

Let’s say I’ll merge on Wednesday if nobody shouts. Happy to wait longer if somebody wants more time to review!

Hi Lucas,

There’s a sparse WG meeting next Monday, so perhaps we should discuss it there before merge. Alternatively, just ensure that @perimosocordiae and @dschult had a proper look.

Hi Stefan, there isn’t really anything specific to sparse arrays for this PR — the linear operator interface in sparse.linalg is pretty separate and abstracts over both sparse and dense arrays.

I would however definitely appreciate thoughts from CJ and Dan regarding the follow-up to add array API standard support and use this in the iterative solvers, which can be viewed at WIP: ENH: sparse.linalg.cg: batched and array API CG solver by lucascolley · Pull Request #24450 · scipy/scipy · GitHub. Before that PR, we can use np for pretty much everything due to duck-typing, only occasionally special-casing sparse arrays. Now that we are generalising to xp, we should figure out the best way to call out to functions in the scipy.sparse and np namespaces at the right time, without needing to carry around too much special-casing.

Hi all, update: we merged the batch support, and the follow-up to add array API standard support now has core dev approval at ENH: sparse.linalg.LinearOperator: array API standard support by lucascolley · Pull Request #24627 · scipy/scipy · GitHub. I’d like to merge it in the next few days, but please do shout if anyone would like more time to review!

The PR adds a new xp parameter to LinearOperator.__new__ and a new _xp attribute to the LinearOperator class.

The one notable backwards-incompatible change is that subclasses of LinearOperator are now expected to call super().__init__ or otherwise set self._xp. This matches what we have done over in scipy.spatial:

In [1]: import scipy, numpy as np

In [2]: class MyRot(scipy.spatial.transform.Rotation):
   ...:     def __init__(self, quat):
   ...:         print("hello :)")
   ...:         self._quat = self.from_quat(quat)
   ...:

In [4]: R = MyRot(np.arange(4))
hello :)

In [5]: R.apply(np.arange(3))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 1
----> 1 R.apply(np.arange(3))

File ~/ghq/github.com/scipy/scipy/build-install/usr/lib/python3.14/site-packages/scipy/spatial/transform/_rotation.py:1679, in Rotation.apply(self, vectors, inverse)
   1552 @xp_capabilities(
   1553     skip_backends=[
   1554         ("dask.array", "missing linalg.cross/det functions and .mT attribute"),
   (...)   1557 )
   1558 def apply(self, vectors: ArrayLike, inverse: bool = False) -> Array:
   1559     """Apply this rotation to a set of vectors.
   1560
   1561     If the original frame rotates to the final frame by this rotation, then
   (...)   1677
   1678     """
-> 1679     vectors = self._xp.asarray(
   1680         vectors, device=xp_device(self._quat), dtype=self._quat.dtype
   1681     )
   1682     single_vector = vectors.ndim == 1
   1683     # Numpy optimization: The Cython backend typing requires us to have fixed
   1684     # dimensions, so for the Numpy case we always broadcast the vector to 2D.

AttributeError: 'MyRot' object has no attribute '_xp'

Cheers,
Lucas

1 Like