I introduce a new family of symmetric-positive-definite sparse matrices (“Nightmare matrices”) \(A\) that defeats all known preconditioners, sparse direct solvers, and Krylov iterative methods for solving \(Ax = b\). I then analyze the performance of Krylov methods for these matrices on GPUs and show how the expander structure leads to uncoalesced memory access in cusparse::csrmv
, causing warp stalls and poor GPU utilization per iteration. Next I demonstrate how deflation can improve GPU utilization by adding work the hardware handles efficiently while still accelerating convergence. Finally, I outline how to compute a deflation subspace that uses the GPU effectively and support these claims with performance counter and profiler results.
These matrices are created in three phases. First, nonzeros are distributed randomly across the matrix (each with a value of 1) to match a prescribed number of nonzeros per row. The identity matrix is then added to ensure a nonzero diagonal, and the result is symmetrized. This has high probability of creating an expander graph, which breaks many fill-reducing orderings used for sparse direct solvers. In the second phase the nonzero values (currently all 1s) are replaced with numbers uniformly distributed in [-1,1], centering the eigenvalue distribution. Finally, I set \(A \leftarrow A^T A\). This increases the number of nonzeros, shrinks the smallest absolute eigenvalue, and concentrates the spectrum at the endpoints. Because the nonzero values are sampled randomly, the primary eigenvalue distribution centers near 0 and tapers off. This is a very poor (but not the absolute worst) distribution for unpreconditioned Krylov solvers. The worst case would be chebyshev points, but it’s not trivial to match an arbitrary spectrum to a fixed sparsity pattern.
I will investigate unpreconditioned Krylov methods in more detail later, but just to illustrate the poor performance I used Conjugate Gradients on a nightmare matrix with m=100000
rows on an Nvidia T4 GPU and it never really converged (note I used relative error here instead of relative residual)
it: 175593 relative error = 0.005438385018617882
real 3m52.127s
user 3m44.893s
sys 0m8.932s
Since unpreconditioned Krylov methods fail, we might consider preconditioners or sparse direct solvers, but these approaches also fall short
Expander graphs ensure that any two points in the underlying graph are connected by a short path. The graph can be quite sparse, yet in many important ways it behaves like a dense matrix. In particular, sparse direct methods produce significant fill-in regardless of the fill-reducing ordering, as illustrated below with a banded matrix as reference
These figures show that sparse direct methods effectively require storing a fully dense version of \(A\) even though \(A\) is extremely sparse. For matrices with 100000
rows this is impossible on a GPU with limited memory. We therefore turn to iterative methods and preconditioning, but those options mostly fail as well.
“Sparsifying” preconditioners here means approaches that do not use the full matrix—for example incomplete LU, algebraic multigrid, or additive Schwarz/Block Jacobi. For Nightmare matrices these approaches fail to approximate the underlying system in important ways. One could design expander matrices where they succeed; for instance, Block Jacobi could work if off-diagonal blocks have low rank, even when the underlying graph is an expander. Nightmare matrices are specifically crafted to avoid such properties. I omit numerical results here to keep the post manageable, though I may explore them later.
But in a nutshell here is what ends up happening in each case:
These techniques may offer some benefit, but not enough to outweigh the additional memory and computational burden they impose
Given the failure of sparse direct solvers, we turn to iterative methods. Most preconditioners offer little help, leaving unpreconditioned Krylov methods such as Conjugate Gradients. Here we encounter another awful property of Nightmare matrices: evaluating \(Ax\) for any \(x\) resembles pure random access, the worst possible memory-access pattern on the GPU.
Each step of Conjugate Gradient involves a single evaluation of \(Ax\), and this dominates the runtime. For a Nightmare matrix that work makes poor use of the GPU. I profiled a run with m=100000
and ~60
nonzeros per row using cupy which mostly calls into cusparse in this case. The table below from an nsys profile confirms this, with 70% of the time spent in csrmv
which is the cusparse sparse matrix-vector product for CSR matrices
Time (%) | Total Time (ns) | Instances | Avg (ns) | Med (ns) | Min (ns) | Max (ns) | StdDev (ns) | Name |
---|---|---|---|---|---|---|---|---|
70.5 | 525524967 | 1003 | 523953.1 | 482678.0 | 117693 | 757360 | 101194.5 | cusparse::csrmv_v3_kernel |
5.7 | 42286307 | 6999 | 6041.8 | 6431.0 | 3519 | 9440 | 1453.2 | cupy_multiply__float_float64_float64 |
4.7 | 35381985 | 5001 | 7075.0 | 6400.0 | 5375 | 11904 | 1741.4 | cub::DeviceReduceKernel |
4.4 | 32736008 | 6001 | 5455.1 | 5184.0 | 3809 | 87518 | 1639.6 | cupy_subtract__float64_float64_float64 |
3.4 | 25017723 | 5001 | 5002.5 | 4544.0 | 3711 | 8800 | 1205.1 | cupy_multiply__float64_float64_float64 |
2.2 | 16511519 | 5001 | 3301.6 | 3008.0 | 2527 | 5600 | 799.0 | cub::DeviceReduceSingleTileKernel |
1.6 | 11606399 | 1003 | 11571.7 | 11073.0 | 8544 | 16160 | 1566.6 | cusparse::csrmv_v3_partition_kernel |
1.1 | 8207993 | 1000 | 8208.0 | 8000.0 | 6880 | 10336 | 689.7 | cupy_add__float64_float64_float64 |
1.0 | 7301007 | 3001 | 2432.9 | 2176.0 | 1919 | 3680 | 544.3 | cupy_sqrt__float64_float64 |
The problem with csrmv
is that the matrix comes from an expander graph, so access is nearly random. This leads to significant uncoalesced memory traffic and warp stalls on the GPU, as shown by the performance counters below
Metric | Throughput % | Avg | Avg Warps per Cycle |
---|---|---|---|
Compute Warps in Flight | 21.0 | 43,237,130 | 14 |
Unallocated Warps in Active SMs | 20.0 | 40,177,614 | 13 |
Vertex/Tess/Geometry Warps in Flight | 0.0 | 0 | 0 |
Pixel Warps in Flight | 0.0 | 0 | 0 |
DRAM Read Bandwidth | 21.0 | — | — |
DRAM Write Bandwidth | 2.0 | — | — |
GR Active | 47.0 | — | — |
Async Compute in Flight | 45.0 | — | — |
Sync Compute in Flight | 0.0 | — | — |
The profile shows poor DRAM bandwidth (21% utilization) and few warps in flight because uncoalesced reads cause stalls. In the next section I show how to recover GPU utilization by doing more work per iteration, but choosing work the GPU executes efficiently.
Deflated methods seek an orthonormal matrix \(V\) that projects out components slowing Krylov convergence. The challenge is computing such a \(V\) and integrating it with a Krylov method, each of which handles \(V\) slightly differently. The scheme below is due to Kahl and Rittich
For ( A x = b ), let ( V \in \mathbb{R}^{n \times k} ) have orthonormal columns and define
\[ M = V^\top A V. \]
The deflated system is
\[ (I - A V M^{-1} V^\top)A x_h = (I - A V M^{-1} V^\top)b. \]
After solving for \( x_h \), the reconstruction is
\[ x = V M^{-1} V^\top b + \bigl( x_h - V M^{-1} V^\top A x_h \bigr). \]
I choose \(V\) to be the matrix of eigenvectors associated with the smallest-magnitude eigenvalues of \(A\) and postpone discussion of how to compute it until the next section. This adds work per iteration on top of evaluating \(Av\), but the extra work is productive for the following reasons:
and indeed we see point (1) plainly in the trace, with aggregates below. Almost 80% of the time goes to *gemv
-style level-2 BLAS, and the share spent on csrmv
drops from about 70% to roughly 9%. The csrmv
call is unchanged; each iteration now includes work that uses the GPU efficiently, so far fewer iterations are needed to reach the same error (quantified later, after the GPU efficiency discussion).
Time (%) | Total Time (ns) | Instances | Avg (ns) | Med (ns) | Min (ns) | Max (ns) | StdDev (ns) | Name |
---|---|---|---|---|---|---|---|---|
47.0 | 9334022559 | 5800 | 1609314.2 | 1609599.0 | 1601119 | 1617375 | 2385.3 | gemv2N_kernel |
30.9 | 6138250408 | 3867 | 1587341.7 | 1587103.0 | 1580799 | 1599136 | 2410.3 | gemv2T_kernel_val |
8.5 | 1683961456 | 3868 | 435357.1 | 435542.5 | 116638 | 754257 | 11635.0 | cusparse::csrmv_v3_kernel |
5.7 | 1134962046 | 1 | 1134962046 | 1134962046 | 1134962046 | 1134962046 | 0.0 | cusparse::load_balancing_kernel |
1.8 | 357572186 | 3867 | 92467.6 | 92510.0 | 74046 | 117502 | 3270.0 | trsv_ln_exec_up |
1.8 | 357424521 | 3867 | 92429.4 | 92350.0 | 74655 | 117278 | 3253.3 | trsv_ln_exec |
1.5 | 288257525 | 1 | 288257525 | 288257525 | 288257525 | 288257525 | 0.0 | volta_dgemm_64x64_tn |
0.7 | 144908215 | 17405 | 8325.7 | 2240.0 | 1248 | 11954831 | 91079.1 | cupy_copy |
0.4 | 80594963 | 15460 | 5213.1 | 4912.0 | 2976 | 87775 | 1761.2 | cupy_subtract |
Our utilization metrics improve markedly. DRAM bandwidth nearly triples, and many more warps are active because memory access patterns become largely coalesced.
Metric | Throughput % | Avg | Avg Warps per Cycle |
---|---|---|---|
Compute Warps in Flight | 71.0 | 142,502,227 | 46 |
Unallocated Warps in Active SMs | 18.0 | 36,210,158 | 12 |
Vertex/Tess/Geometry Warps in Flight | 0.0 | 0 | 0 |
Pixel Warps in Flight | 0.0 | 0 | 0 |
DRAM Read Bandwidth | 58.0 | — | — |
DRAM Write Bandwidth | 1.0 | — | — |
GR Active | 91.0 | — | — |
Async Compute in Flight | 91.0 | — | — |
Sync Compute in Flight | 0.0 | — | — |
And just for ease of comparison I include both runs - with and without deflation - below
Metric | Run 1 (Deflated CG, large deflation space) | Run 2 (CG on expander matrix, no deflation) |
---|---|---|
Compute Warps in Flight | 71% (142M, 46 warps/cycle) | 21% (43M, 14 warps/cycle) |
Unallocated Warps in Active SMs | 18% (36M, 12 warps/cycle) | 20% (40M, 13 warps/cycle) |
Vertex/Tess/Geometry Warps in Flight | 0% | 0% |
Pixel Warps in Flight | 0% | 0% |
DRAM Read Bandwidth | 58% | 21% |
DRAM Write Bandwidth | 1% | 2% |
GR Active | 91% | 47% |
Async Compute in Flight | 91% | 45% |
Sync Compute in Flight | 0% | 0% |
and finally the payoff that we actually care about: achieving much lower error in significantly less wall time:
Run | Iterations | Final Error | Runtime (s) | Iterations / s | log10(Error) / Iteration |
---|---|---|---|---|---|
k = 0 (no deflation) | 175,593 | 5.44e-03 | 232.1 | ~757 | -1.28e-04 |
k = 1024 (deflation) | 1,931 | 3.59e-06 | 28.2 | ~68.5 | -1.70e-03 |
The results from deflated conjugate gradient are impressive but of course require computing a suitable \( V \). This takes us back to the beginning in a way because now we need an eigensolver capable of producing the smallest-magnitude eigenvalues for \( A \), which is a nightmare matrix and most solution techniques are no-go. The gold standard for a Krylov eigensolver are arnoldi or lanczos iterations, but for these to adequately yield small-magnitude eigenvalues we have to proceed sequentially one vector at a time because these method maintains orthogonality with previous iterates. I have already established that evaluating \(Ax\) for single vectors uses the GPU poorly. The solution is to attempt a block method that jointly advances multiple eigenvectors simultaneously. This brings us back to coalesced accesses because evaluating \(AX\) for a matrix \(X\) especially if \( X \) is row-major will allow each nonzero of \( A \) to load in a significant amount of data.
To accomplish this I make use of the fact that symmetric matrices can be diagonalized with orthonormal eigenvectors \(W \)
\[ A = W \Lambda W^T \]
and this decomposition shows us that all polynomials \( P \) satisfy
\[ P(A) = W P(\Lambda) W^T \]
This means that if we choose \( P \) to approximate \( P(x) \approx \frac{1}{x} \) then we can magnify the small eigenvalues. Now this is basically restating the underlying principle of a Krylov solver, but we can statically compute a polynomial for this as well. This is the principle underlying Chebyshev iteration.
For an SPD \(A\) with \(\operatorname{spec}(A)\subset[\lambda_{\min},\lambda_{\max}]\), set \[ \mu = \tfrac{\lambda_{\max}+\lambda_{\min}}{2}, \qquad \nu = \tfrac{\lambda_{\max}-\lambda_{\min}}{2}. \]
Initialize \(x^{(0)}\) and \(r^{(0)} = b - A x^{(0)}\). Let \(\alpha_0 = \tfrac{1}{\mu},\; \beta_0 = 0\) and form \[ x^{(1)} = x^{(0)} + \alpha_0 r^{(0)}. \]
For \(k \ge 1\), \[ \beta_k = \left(\frac{\nu \,\alpha_{k-1}}{2}\right)^2, \qquad \alpha_k = \frac{1}{\mu - \dfrac{\beta_k}{\alpha_{k-1}}}, \] \[ x^{(k+1)} = x^{(k)} + \alpha_k r^{(k)} + \beta_k \big(x^{(k)} - x^{(k-1)}\big), \qquad r^{(k)} = b - A x^{(k)}. \]
The advantage of Chebyshev iteration in this context compared to just using another Krylov method is that it is simple to advance multiple vectors simultaneously because the polynomial coefficients they get are all exactly the same (as opposed to say Conjugate Gradient, which computes the coefficients through dot products of previous iterates, which will result in different coefficients for each vector).
Here is the pseudocode for computing \( V \) with all of the above observations in mind:
# Chebyshev–Rayleigh–Ritz for k smallest eigenpairs
# Assume a routine:
# Y = ChebyshevSolve(A, RHS, inner, λ_min, λ_max)
# which does 'inner' Chebyshev iterations to approximately solve A·Y = RHS.
Inputs:
A # symmetric SPD n×n operator (matvec/matmat supported)
k # number of eigenpairs
outer # outer RR iterations
inner # inner Chebyshev iterations
λ_min, λ_max # spectral bounds
Procedure:
V ← RandomOrthonormal(n, k) # e.g., QR of random matrix
Θ ← undefined
for t = 1..outer:
# 1) Block Chebyshev solve: approximate A^{-1}·V
W ← ChebyshevSolve(A, V, inner, λ_min, λ_max) # n×k
# 2) Orthonormalize subspace
Q, _ ← qr(W) # n×k with QᵀQ = I
# 3) Rayleigh–Ritz on span(Q)
T ← Qᵀ A Q # k×k
S, diag(Θ) ← eig(T) # Θ ascending
V ← Q S # n×k (orthonormal)
# (optional) stop if max_j ||A v_j − Θ_j v_j|| ≤ tol
end for
Outputs:
V # Ritz vectors (≈ k smallest eigenvectors)
Θ # Ritz values (≈ k smallest eigenvalues)
I also provide code in the appendices.
From the trace we can see significant time in cusparse::spmm (note the asterisk: I renamed this from the symbol in cusparse which had a lot of template magic in it) and subsequent calls to DGEMM which is level-3 blas, absolutely the best place to be on a GPU.
Time (%) | Total Time (ns) | Instances | Avg (ns) | Med (ns) | Min (ns) | Max (ns) | StdDev (ns) | Name |
---|---|---|---|---|---|---|---|---|
76.0 | 18422615475 | 8 | 2302826934.4 | 2287765279.5 | 2268521089 | 2368998982 | 37822090.5 | cusparse::spmm( * ) |
6.2 | 1508488502 | 102 | 14789103.0 | 3460175.0 | 48991 | 213567593 | 31727229.9 | volta_dgemm_64x64_tn |
4.7 | 1137174777 | 1 | 1137174777.0 | 1137174777.0 | 1137174777 | 1137174777 | 0.0 | volta_dgemm_128x64_tn |
4.0 | 966627174 | 44 | 21968799.4 | 6511025.5 | 1578352 | 164595260 | 34956297.7 | volta_dgemm_128x64_nn |
3.6 | 882372404 | 1 | 882372404.0 | 882372404.0 | 882372404 | 882372404 | 0.0 | volta_dgemm_128x64_tt |
2.1 | 516596067 | 28 | 18449859.5 | 14831409.5 | 14518870 | 39973379 | 8741168.7 | geqr2_gmem_domino |
0.5 | 110681488 | 1 | 110681488.0 | 110681488.0 | 110681488 | 110681488 | 0.0 | cupy_power__float64_float_float64 |
0.4 | 108940450 | 1 | 108940450.0 | 108940450.0 | 108940450 | 108940450 | 0.0 | gen_sequenced |
0.4 | 104990105 | 13 | 8076161.9 | 12932875.0 | 2720 | 16496130 | 7829506.4 | cupy_subtract__float64_float64_float64 |
Here we can see extremely high utilization of the GPU with significant number of warps active per cycle, fully loaded GR and SM, high memory bandwidth.
Metric | Description | Value | Avg | Avg Warps per Cycle |
---|---|---|---|---|
GR Active | Throughput % | 100.0 | — | — |
Compute in Flight | Async Compute in Flight | 100.0 | — | — |
Compute in Flight | Sync Compute in Flight | 0.0 | — | — |
SM Active | Throughput % | 100.0 | — | — |
SM Warp Occupancy | Compute Warps in Flight | 95.0 | 169,023,581 | 61 |
SM Warp Occupancy | Unallocated Warps in Active SMs | 5.0 | 9,502,742 | 3 |
DRAM Bandwidth | DRAM Read Bandwidth | 67.0 | — | — |
DRAM Bandwidth | DRAM Write Bandwidth | 1.0 | — | — |
For the deflated case with m=100000
I was able to compute 512
highly accurate eigenvectors (relative residuals ~1e-12) and the rest after this decayed in accuracy, as is typical with block methods.
I used a GCP instance for these measurements:
Resource | Specification |
---|---|
Machine type | n1-standard-8 (8 vCPUs, 30 GB RAM) |
CPU platform | Intel Haswell |
GPUs | 1 × NVIDIA T4 |
# Received help from GPT-5
# coding: utf-8
from __future__ import annotations
import cupy as cp
from typing import Callable, Optional
from cupyx.scipy.sparse.linalg import LinearOperator, minres
from cupyx.scipy.linalg import solve_triangular
import numpy as np
import cupy as cp
import cupyx.scipy.sparse as cpxs
def random_expander_like_matrix(n: int, nnz_per_row: int, seed: int | None = None, dtype=cp.float32):
"""
Create an n x n sparse CuPy matrix A with:
- nnz_per_row randomly sampled off-diagonal entries per row (no replacement)
- symmetrized as (A + A.T)
- all nonzero values set to -1.0
Notes
-----
- Random k-regular-ish adjacency from uniform sampling tends to give
expander-like graphs with high probability (not guaranteed).
- After symmetrization, degrees will be between k and 2k in general.
Parameters
----------
n : int
Matrix size.
nnz_per_row : int
Number of (directed) nonzeros per row before symmetrization.
seed : int | None
Seed for reproducibility (host-side sampling).
dtype : cupy dtype
Data type for the values (float32/float64 typically).
Returns
-------
cupyx.scipy.sparse.csr_matrix
Symmetric sparse matrix with values -1.0.
"""
if not (0 <= nnz_per_row < n):
raise ValueError("Require 0 <= nnz_per_row < n (must exclude diagonal).")
rng = np.random.default_rng(seed)
# Host-side index generation (fast and simple):
# Sample from 0..n-2, then shift indices >= row by +1 to skip the diagonal.
rows = np.repeat(np.arange(n), nnz_per_row)
# For each row i, choose nnz_per_row unique columns from {0..n-1} \ {i}
# Use vectorized sampling per row
samples = np.empty((n, nnz_per_row), dtype=np.int64)
base = np.arange(n - 1) # choices without the diagonal
for i in range(n):
cols = rng.choice(base, size=nnz_per_row, replace=False)
# Map chosen index j in [0..n-2] to actual column:
# If j >= i, shift by +1 to skip the diagonal i
cols = cols + (cols >= i)
samples[i, :] = cols
cols = samples.ravel()
# Move to GPU
rows_gpu = cp.asarray(rows, dtype=cp.int64)
cols_gpu = cp.asarray(cols, dtype=cp.int64)
# Initial COO with data -1.0
data_gpu = cp.full(rows_gpu.shape, -1.0, dtype=dtype)
A = cpxs.coo_matrix((data_gpu, (rows_gpu, cols_gpu)), shape=(n, n))
# Canonicalize duplicates in A (shouldn't exist per row, but just in case)
A = A.tocsr()
A.sum_duplicates()
# Symmetrize: B = A + A.T
B = A + A.T
# Remove any duplicates after sum and set all values to -1.0
B.sum_duplicates()
B.data[:] = cp.array(-1.0, dtype=dtype)
# Ensure strictly zero diagonal (should already be zero, but enforce)
diag = B.diagonal()
if cp.any(diag != 0):
# Zero diagonal by subtracting its diagonal
B = B - cpxs.diags(diag, offsets=0, shape=B.shape, dtype=B.dtype)
return B
def block_chebyshev(A: LinearOperator, B: cp.ndarray,
eig_min: float, eig_max: float, maxiter: int) -> cp.ndarray:
"""Block Chebyshev iteration (SPD A, eigenvalues in [eig_min, eig_max])."""
# Spectral interval parameters
theta = 0.5 * (eig_max + eig_min) # center
delta = 0.5 * (eig_max - eig_min) # half-width
sigma1 = theta / delta # > 1 for SPD
X = cp.zeros_like(B)
R = B - A.matmat(X)
# Saad Alg. 12.1
rho = 1.0 / sigma1
D = (1.0 / theta) * R
for _ in range(maxiter):
X = X + D
R = R - A.matmat(D)
rho_next = 1.0 / (2.0 * sigma1 - rho)
D = rho_next * rho * D + (2.0 * rho_next / delta) * R
rho = rho_next
return X
def power_cheb(A: LinearOperator, V: cp.ndarray,
eig_min: float, eig_max: float,
outer_iter: int = 10, inner_iter: int = 10):
"""
Block power+Chebyshev accelerator with Rayleigh–Ritz at each outer step.
Returns (w, V) where w are Ritz values (cp.ndarray shape (k,)) and V the Ritz vectors (n x k).
"""
V = cp.asarray(V)
for _ in range(outer_iter):
n, k = V.shape
# Chebyshev smoothing
V = block_chebyshev(A, V, eig_min, eig_max, maxiter=inner_iter)
# Orthonormalize (economy QR)
Q, _ = cp.linalg.qr(V, mode='reduced')
V = Q
# Rayleigh–Ritz
VAV = V.T @ A.matmat(V) # (k x k), symmetric (Hermitian if complex)
w, W = cp.linalg.eigh(VAV) # ascending order
V = V @ W # Ritz vectors in the big space
# Residuals: ||A V - V diag(w)||_2 / w (vectorized per column)
AV = A.matmat(V) # (n x k)
R = AV - V * w[cp.newaxis, :] # broadcast subtract
norms = cp.sqrt(cp.sum(cp.abs(R)**2, axis=0))
residuals = norms / cp.abs(w)
print(float(residuals.min().item()), float(residuals.max().item()))
return w, V
class DeflatedMinres:
"""
CuPy deflated MINRES with minimal copying.
Assumptions:
- A is symmetric (A.T == A).
- V has orthonormal columns (on GPU).
- A supports @ with (n,) and (n,k) arrays (dense or LinearOperator with matvec/matmat).
"""
def __init__(self, A: cp.ndarray | LinearOperator, V: cp.ndarray) -> None:
self.A = A
# Prefer Fortran order for tall skinny (better GEMM behavior with cuBLAS)
self.V = cp.asfortranarray(V) # (n,k)
n, k = self.V.shape
self.n, self.k = n, k
# Precompute AV = A @ V once (n x k), Fortran-order to match V
AV = self._A_matmat(self.V)
self.AV = cp.asfortranarray(AV)
# VAV = V^T @ (A @ V) (k x k), symmetric SPD if A is SPD on range(V)
VAV = self.V.T @ self.AV
# Cholesky factorization: VAV = L L^T
self.L = cp.linalg.cholesky(VAV) # lower-triangular
# ---- Work buffers to avoid temporaries ----
self._buf_Ax = cp.empty((n,), dtype=self.V.dtype, order='C') # Ax
self._buf_t = cp.empty((k,), dtype=self.V.dtype, order='C') # t = V^T Ax
self._buf_z = cp.empty((k,), dtype=self.V.dtype, order='C') # z = (VAV)^{-1} t
self._buf_y = cp.empty((n,), dtype=self.V.dtype, order='C') # y = Ax - AV z
# For reconstruct/b setup:
self._buf_tb = cp.empty((k,), dtype=self.V.dtype, order='C') # V^T b
self._buf_yb = cp.empty((k,), dtype=self.V.dtype, order='C') # y_b
self._buf_xl = cp.empty((n,), dtype=self.V.dtype, order='C') # xl = V y_b
# For reconstruct(xh):
self._buf_Axh = cp.empty((n,), dtype=self.V.dtype, order='C') # A xh
self._buf_txh = cp.empty((k,), dtype=self.V.dtype, order='C') # V^T A xh
self._buf_zxh = cp.empty((k,), dtype=self.V.dtype, order='C') # z_xh
self._buf_Vz = cp.empty((n,), dtype=self.V.dtype, order='C') # V z (for pi(xh))
# ---- Utility: apply A to vector or matrix without allocating new buffers ----
def _A_matvec(self, x: cp.ndarray, out: cp.ndarray | None = None) -> cp.ndarray:
# A may be ndarray or LinearOperator
y = self.A @ x
if out is None:
return y
cp.copyto(out, y)
return out
def _A_matmat(self, X: cp.ndarray) -> cp.ndarray:
# Returns a fresh array (matmul allocates), but we do this only when needed.
return self.A @ X
def _apply_deflated(self, x: cp.ndarray) -> cp.ndarray:
# 1) Ax -> _buf_Ax
self._A_matvec(x, out=self._buf_Ax)
# 2) t = V^T Ax -> _buf_t (gemv with out= to avoid a tmp)
cp.dot(self.V.T, self._buf_Ax, out=self._buf_t)
# 3) z = (VAV)^{-1} t via Cholesky (two triangular solves, no tmp kept)
y = solve_triangular(self.L, self._buf_t, lower=True, overwrite_b=False, check_finite=False)
z = solve_triangular(self.L.T, y, lower=False, overwrite_b=False, check_finite=False)
cp.copyto(self._buf_z, z)
# 4) y = Ax - AV @ z (RETURN A FRESH VECTOR — no shared buffer!)
y_out = self._buf_Ax.copy() # 1 new vector per matvec
y_out -= self.AV @ self._buf_z # gemv
return y_out
# rhs = (I - P*) b = b - A^T V (VAV)^{-1} V^T b; with A symmetric => A^T V = AV
def _deflated_rhs(self, b: cp.ndarray) -> cp.ndarray:
# tb = V^T b
self._buf_tb = self.V.T @ b
# y_b = (VAV)^{-1} tb
y = solve_triangular(self.L, self._buf_tb, lower=True, overwrite_b=False, check_finite=False)
y = solve_triangular(self.L.T, y, lower=False, overwrite_b=False, check_finite=False)
cp.copyto(self._buf_yb, y)
# rhs = b - AV @ y_b (allocates new vector we return)
rhs = b - (self.AV @ self._buf_yb)
return rhs
# reconstruct x from xh: xl = V y_b, xr = xh - V z_xh with z_xh = (VAV)^{-1} V^T A xh
def _reconstruct(self, b: cp.ndarray, xh: cp.ndarray) -> cp.ndarray:
# xl
cp.copyto(self._buf_xl, self.V @ self._buf_yb) # reuse y_b from _deflated_rhs
# z_xh solve:
self._A_matvec(xh, out=self._buf_Axh) # A xh
self._buf_txh = self.V.T @ self._buf_Axh # V^T A xh
y = solve_triangular(self.L, self._buf_txh, lower=True, overwrite_b=False, check_finite=False)
y = solve_triangular(self.L.T, y, lower=False, overwrite_b=False, check_finite=False)
cp.copyto(self._buf_zxh, y)
# xr = xh - V z_xh (store in-place into _buf_Vz then combine)
cp.copyto(self._buf_Vz, self.V @ self._buf_zxh)
x = xh - self._buf_Vz
x += self._buf_xl
return x
def solve(
self,
b: cp.ndarray,
*,
tol: float = 1e-12,
maxiter: Optional[int] = None,
callback: Optional[Callable[[cp.ndarray], None]] = None,
) -> tuple[cp.ndarray, int]:
n = b.shape[0]
dtype = getattr(self.A, "dtype", None) or b.dtype
def mv(x: cp.ndarray) -> cp.ndarray:
# Return a NEW array each call. Do NOT reuse a class buffer here.
return self._apply_deflated(x)
op = LinearOperator((n, n), matvec=mv, dtype=dtype)
# Deflated RHS (uses AV, L; minimal temporaries)
rhs = self._deflated_rhs(b)
def cb(xh: cp.ndarray) -> None:
if callback is not None:
callback(self._reconstruct(b, xh))
xh, info = minres(op, rhs, tol=tol, maxiter=maxiter, callback=cb)
x = self._reconstruct(b, xh)
return x, info