Lessons Learned Implementing Sparse QR Factorization in Rust

November 23, 2021

I want to cover a little bit of work I did using Rust, in particular how to “make the borrow checker happy” in the presence of many references to a common object, a tree, some of which may be mutable references. Earlier this year I wrote this simple sparse QR factorization. I cover the mathematical basis for the algorithm in this youtube video so I won’t go into great mathematical detail here on that, but I will briefly outline the algorithm, how to build up the resulting tree in Rust, how to iterate up and down the tree, and a few issues I had using Rust to perform the resulting computations

How Nested Dissection builds up a Tree, and the basic QR factorization algorithm

Nested dissection of a graph builds up a tree of graph vertices where a node of the tree contains graph vertices which form a separator that separates the graph vertices of all its ancestors. I visualize nested dissection for a simple 2D Laplacian sparse matrix and its associated graph below

Dissection Level Tree Diagram Block Structure
0Base case
1
2
3
4

a sparse QR factorization may proceed in the following steps:

  1. Start at lowest level of tree
  2. QR Factorize all nodes of current level
  3. Apply householder reflectors (Q factors) up ancestors of each factorized node
  4. Move up one level of tree, and repeat from step 2 until at root node
  5. QR factorize root note - now the full factorization is complete

Encoding the nested dissection tree in Rust (what does and does not work)

Nested dissection produces a binary tree where each node of the tree (if it is not a leaf) contains a graph separator which separates the nodes of its children. If I were to implement this in a functional programming language like OCaml a very convenient way to represent this would be as a recursive datatype. The official Rust blog also suggests this (although admittedly a little while ago) a simple definition is given below

enum BinaryTree {
    Leaf(i32),
    Node(Box<BinaryTree>, i32, Box<BinaryTree>)
}

This however did not work for me. This appears to work when you can construct the entire tree in a single line of Rust. If you have to incrementally build the tree however the Node(t1,t2) part of the enum above will result in Rust borrowchecker trouble because it is necessary to add pointers to more children which is a mutating operation.

What I ended up doing was flattening the tree out so that its data was stored contiguously in a Vec. I complemented that data with “pointers” in two arrays: parents and children. I also encoded the tree levels in this tree for level scheduling. The Rust datatype for this looks as below:

pub struct DissectionTree{
     pub parents  : Vec<Option<usize>>, //Parent paths enabling iteration up node ancestors
     pub children : Vec<Option<(usize,usize)>>, //Flattened out tree enabling iteration down node descendants
     pub levels : Vec<Vec<usize>>, //Tree levels
     pub nodes : Vec<Vec<usize>> //Sparse matrix nodes associated with specified tree node
}

Pattern matching to iterate down descendants

To iterate down descendants of a given node I maintain a stack of visited nodes and pop them from the stack when I am done with them, adding their respective children back into the stack until I reach the tree’s leafs. the line while let Some(n)=stack.pop() is a particularly elegant use of pattern matching that makes very clean what would otherwise be a little messy in C++ for a similar data structure. This single line does the following things

  1. Checks that the stack was not empty
  2. Takes the popped value if it was not empty
  3. Drives loop until stack is empty

The next line if let Some((c1,c2)) = self.children[n]

is another elegant use of pattern matching which does three things

  1. Checks if n is a leaf node or not
  2. If it is not a leaf node then assign its two children c1,c2 to variables
  3. Drive the if condition
//Iterate down all descendents of `node`
let mut stack = vec![node];
while let Some(n)=stack.pop(){
   if let Some((c1,c2)) = self.children[n]{
    //Do something with children c1,c1
      stack.push(c1);
      stack.push(c2);
    }
  }

Enabling iterating up parents

Just having descendant information in a flattened tree does not accomplish the needs of the sparse QR algorithm. You also need the ability to iterate up ancestors (parents) of a given node.

Pattern matching to iterate up parents

The parents array here enables an elegant way to loop up all ancestors of a specified node. In the below code the pattern matching while let Some(p) = mp checks if mp is not the root node of the tree, if it is not then the while loop continues cycling until the root node is reached.

//Iterate up all ancestors of `node`
let mut mp=dtree.parents[node];
while let Some(p) = mp{
  //do stuff with ancestor `p` of `node`
  mp=dtree.parents[p];
}

Slicing and dicing contiguous arrays in Rust

Solving sparse linear systems involves addressing into big contiguous numeric arrays, something Rust sometimes can get in the way of doing because Rust will not allow you to have a read-only reference to such an array coexist with a read-write (mutable) reference. I have found that the best way to fix this is to make sure that all mutable references are very short-lived. Make a habit of putting them in their own scope to ensure that the mutability is contained in a small portion of code. I give a live example from the numeric factorization below, which sets up a call into LAPACK and uses many mutable references but the entire call is contained within curly-braces

//.... (code snipped)
//Now actually apply Q^T
{
  let side=b'L';
  let trans=b'T';
  let mut work = vec![F::zero();lwork as usize];
  let mut info = 0;
  F::xmqr(side,trans,nrows as i32,nrhs as i32,tau.len() as i32,qt.as_slice(),nrows as i32,tau.as_slice(),
  tmp.as_mut_slice(),nrows as i32,work.as_mut_slice(),lwork,&mut info);
  assert_eq!(info,0);
};
//.... (code snipped)

If this does not work, for example if you must access two portions of an array simultaneously with mutable references but you do not wish to make a copy - this results in two references to the same array both capable of reading and writing, Rust does not allow this. Fortunately there are some helpful functions which enable this pattern safely for example various “split” methods. Also ndarray crate contains many methods for this common need in numerical linear algebra.

Dealing with complex versus real types and “specializing”

Currently (as of my writing this blog post) Rust realy isn’t a major tool in technical computing and the current state of real versus complex datatypes kind of shows this. f32,f64 implement many important methods that are common for real numbers (like absolute value abs) but complex<f32>,complex<f64> do not. This has made it hard to write fully type-generic functions over these types. Fortunately this does appear to be improving over time, and many functions can be reasonably implemented if you include the Num numeric trait in the type parameter as follows:

//Type parameter specifies that `F` implements `Num` trait
impl <F : Num+Copy> CSCSparse<F>{ 
  //...Implementation code snipped out
}

This covers most numeric use cases but does miss a few features of real numbers that complex numbers also satisfy (such as the absolute value situation I already mentioned). Personally I think Complex<f32>,Complex<f64> should both implement the Float trait, or there should be another trait that has virtually everything Float has which real and complex numbers implement.

Wrapping real,complex versions of BLAS,LAPACK

I wrote a bare-minimum wrapper library to call into BLAS and LAPACK in a type-generic way (at least for the four key numeric types)

pub trait Lapack{
    type F;
    fn no_nans(xs : &[Self::F])->bool;
    //TODO: `a` actually gets written here, but restored after xmqr completes. This reference needs to be a mutable reference.
    fn xmqr(side : u8,trans : u8,m : i32,n : i32,k : i32,a : &[Self::F],lda : i32,tau : &[Self::F],c : &mut [Self::F],ldc : i32,work : &mut [Self::F],lwork : i32,info : &mut i32);
    fn xgeqrf(m: i32,n: i32,a: &mut [Self::F],lda: i32,tau: &mut [Self::F],work: &mut [Self::F],lwork: i32,info: &mut i32);
    fn xtrtrs(uplo: u8,trans: u8,diag: u8,n: i32,nrhs: i32,a: &[Self::F],lda: i32,b: &mut [Self::F],ldb: i32,info: &mut i32);
    fn xgemm(transa : u8,transb : u8,m : i32,n : i32,k : i32,alpha : Self::F,a : &[Self::F],lda : i32,b : &[Self::F], ldb : i32,beta : Self::F,c : &mut [Self::F], ldc : i32);
}

For example I was able to implement this trait for all four key datatypes f32,f64,Complex<f32>,Complex<f64> by following the pattern in the impl block below repeated for each datatype

impl Lapack for f32{
    type F=f32;
    fn no_nans(xs : &[Self::F])->bool{
        !xs.into_iter().map(|x|x.is_nan()).fold(false,|acc,x| {x||acc})
    }
    fn xmqr(side : u8,trans : u8,m : i32,n : i32,k : i32,a : &[Self::F],lda : i32,tau : &[Self::F],c : &mut [Self::F],ldc : i32,work : &mut [Self::F],lwork : i32,info : &mut i32){
        unsafe{sormqr(side,trans,m,n,k,a,lda,tau,c,ldc,work,lwork,info);}
    }
    fn xgeqrf(m: i32,n: i32,a: &mut [Self::F],lda: i32,tau: &mut [Self::F],work: &mut [Self::F],lwork: i32,info: &mut i32){
        unsafe{sgeqrf(m,n,a,lda,tau,work,lwork,info);}
    }
    fn xtrtrs(uplo: u8,trans: u8,diag: u8,n: i32,nrhs: i32,a: &[Self::F],lda: i32,b: &mut [Self::F],ldb: i32,info: &mut i32){
        unsafe{strtrs(uplo,trans,diag,n,nrhs,a,lda,b,ldb,info);}
    }
    fn xgemm(transa : u8,transb : u8,m : i32,n : i32,k : i32,alpha : Self::F,a : &[Self::F],lda : i32,b : &[Self::F], ldb : i32,beta : Self::F,c : &mut [Self::F], ldc : i32){
        unsafe{sgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc)}
    }
}

Conclusions and Observations

While initially the borrow checker can seem to be prohibitively risk averse for numerical code, I have found that the various workarounds are perfectly adequate and in fact result in better code that is less likey to suffer from data races - making adding parallelism later less painful. Keeping track of mutable versus non-mutable references in C++ can get to be overwhelming and results in a lot of bugs, but Rust effectively removes that mental effort from the job entirely and I find I can rely on Rust’s requirements heavily here.

I also found the pattern matching capability to very significantly clean up my tree iteration code compared to similar code in C++. I was able to effectively test whether a node was a leaf or not, get its descendants if not, and remove it from the iteration stack all in a single very intelligible line of Rust.

The main feature lacking was numeric support, which is still not bad and does appear to be improving.

Overall I found Rust to be a very elegant programming language for this nontrivial computational task and I look forward to see it improve even further. I would really like to see Rust be able to hook into upcoming HPC ecosystems being developed by hardware vendors like OneAPI level 0,ROCM. I’m also excited about the GCC frontend to Rust because I believe when multiple teams implement the same language they uncover ambiguities and keep each other honest.