PyTorch for Scientific Computing – Quantum Mechanics Example Part 4) Full Code Optimizations — 16000 times faster on a Titan V GPU



In this post I’ll go through the code optimizations that led to me making the following comment in Part 2) of this series,

“Yes!, from 2768 seconds down to 0.1683 seconds! That’s 16452 Times Faster!! Yes, really, that much. Some of the simple code optimizations will be surprising by how much difference they make.”

I’ll make good use of the batched tensor Cholesky decomposition presented in Part 3). The basic idea of transforming matrix operations into tensor operation will be used in other parts of the code. The big performance gains will be from eliminating loops as much as possible. The loop elimination will come from batched tensor operations and “broadcasting” operations across tensor data structures. This is all pretty easy to do in PyTorch and it makes a huge difference for performance.

This is the 4th post in this series on using PyTorch for scientific computing, by example, using the Quantum Mechanics problem I presented in the first post. Here’s the background posts.

Part 1) described the problem of minimizing the energy of a small atomic system using a bases set of correlated Gaussian functions. It has most of the math to describe the problem. “Doing Quantum Mechanics with a Machine Learning Framework: PyTorch and Correlated Gaussian Wavefunctions: Part 1) Introduction”

Part 2) presented a full working implementation of the problem. That code is a straight forward implementation of the math and not optimal for performance. “PyTorch for Scientific Computing – Quantum Mechanics Example Part 2) Program Before Code Optimizations”

Part 3) showcased code for what I consider the most incredible scaling performance I’ve ever seen. It used batched tensor operation on an NVIDIA Titan V to achieve constant-time scaling! “PyTorch for Scientific Computing – Quantum Mechanics Example Part 3) Code Optimizations – Batched Matrix Operations, Cholesky Decomposition and Inverse”


Code Optimization: From loops and matrix operations to batched tensor operations

I’ll go through some of the code optimization by doing “From-this — To-this” on blocks of code from the implementation in Part 2). That will be followed by the complete code in a form that could be cut and pasted into a Jupyter notebook if you want to play with it yourself.

Performance timing and profiling

I did some profiling and a lot of code block timing when I was optimizing the code. I’m not going to include all of that. However, I would like to mention a couple of things.

cProfile

cProfile is a useful profiling tool for Python. It’s easy to use and gives detailed information on operation counts and function calls (built-ins and your own). You can use it to get an overview of what the code is doing and an idea of where the execution time is going. Just call it on the function used for running the program, like,

import cProfile
...
...
cProfile.run('xprof = opt_energyrc()')

I recommend trying cProfile if you haven’t used it.

Printing timing blocks

Another thing useful for getting timing information is to wrap code blocks with “time()” calls. This is old-school but very effective to see how much time chunks of code are taking during a test run. I made a copy of the code with about 15 or so blocks that looked like,

# Timing
        start_time = time.time()        
        # get the Cholesky decomp of all Akl martices
        cholAKL = cholesky(AKL)
# End-Timing    
        print(" {:20} {:.6f} seconds ".format('cholAKL took', time.time() - start_time))

This was really useful. There were some code blocks that looked innocent but turned out to be using a lot of the overall time! Here’s an example;

A surprise code block that used a lot of run time

This is a real simple chunk of code. It just copies the upper triangle of H and S to the lower part. H and S are symmetric and I was only computing half the matrix. The whole matrices are needed for the energy term. It’s only executed one time before the energy is computed but look at the time it took! (on GPU)

# Timing
    start_time = time.time()    
    # complete lower triangle of H and S
    for i in range(0,nb):
        for j in range(i+1,nb):
            H[j,i] = H[i,j]
            S[j,i] = S[i,j]
# End-Timing    
    print(" {:20} {:.6f} seconds ".format('Complete H', time.time() - start_time))
Complete H and S, took 5.018198 seconds

Now the same thing but using the PyTorch triu() function to get rid of the loops,

# Timing
    start_time = time.time()    
    # complete lower triangle of H and S
    H = th.triu(H,1)+th.t(th.triu(H))
    S = th.triu(S,1)+th.t(th.triu(S))
# End-Timing    
Complete H and S, took 0.000117 seconds 42888 times faster!

That was the biggest surprise of all of the code optimizations that I did. This is on the GPU … conclusion avoid unnecessary loops on the GPU!

Restructuring the code (inner loop elimination)

The code in Part 2) was mostly contained in two functions, one that computed “matrix elements over basis functions” and another function that looped over that “matrix element” code to generate the larger matrices H and S that are used to compute the energy of a system. The function “matrix_elements()” will be rolled into the “energy()” function by eliminating some of the loops in energy().

I’ll strip out some of the code and comments. The matrix_elements() function will be removed and rolled into energy() like this,

From-this: (loop over matrix terms)

def energy(x,n,nb,Mass,Charge,Sym,symc):
    ...
    # outer loop is over symmetry terms
    for k in range(0,nsym):

        for j in range(0,nb):
            for i in range(j,nb):

                vechLi = X[i,:]
                vechLj = X[j,:]

                matels = matrix_elements(n,vechLi,vechLj,Sym[:,:,k],Mass,Charge);

                S[i,j] += symc[k]*matels['skl'];
                T[i,j] += symc[k]*matels['tkl'];
                V[i,j] += symc[k]*matels['vkl'];

    H = T + V

    # The energy from the Rayleigh quotient

    cHc = c@H@c;
    cSc = c@S@c;
    eng = cHc/cSc;

    return eng

To-this: (compute matrices all-at-once, removing 2 inner loops)

def b_energyrc(x,n,nb,Mass,Qmat,Sym,symc):
    ...
    # outer loop is over symmetry terms, the matrices are summed over these sym terms
    for k in range(0,nsym):

        # MATRIX ELEMENTS

        # Overlap: (normalized)
        # Skl = 2^3n/2 (||Lk|| ||Ll||/|AKL|)^3/2
        SKL = 2**(n*1.5) * th.sqrt( th.pow(th.ger(detL, detL)/detAKL ,3) );

        # Kinetic energy
        #TKL = SKL*(6*th.trace(Mass@Ak@invAkl@Al)) = skl*(6*th.sum(Mass*(Ak@invAkl@Al)))

        Tmat = th.zeros_like(invAKL)
        Tmat = th.matmul(th.transpose(AK.repeat((nb,1,1,1)), 0,1), th.matmul(invAKL,AL))
        TKL = 6*SKL*th.sum(Mass*Tmat, dim=(-2,-1))

        # potential energy
        TWOoSqrtPI = 1.1283791670955126 # 2/sqrt(pi)

        VKL = TWOoSqrtPI*SKL*th.sum(RIJ*Qmat, dim=(-2,-1))

        # accumulate matrices over sym terms
        S = S + symc[k]*SKL
        T = T + symc[k]*TKL
        V = V + symc[k]*VKL

    # Hamiltonian
    H = T + V

    # compute Rayleigh quotient (it is the smallest energy eigen value when minimized over c)
    cHc = c@H@c;
    cSc = c@S@c;
    eng = cHc/cSc;

    return eng           

The big change here is that the matrix_element() code computed individual matrix elements for H and S (and T,V). The new code uses tensors to compute these matrices all-at-once. No more loops for that. I am keeping the loop over symmetry terms, ( that could be a good place to code up multi-GPU or multi-node parallelism — I used MPI for parallelism there in my old fortran versions of the code).

Optimized code snippets

We wont go over every change made in the code but we will look at a few code blocks to illustrate important ideas.

From-this: (individual arrays of parameters and formula terms)

for j in range(0,nb):
    for i in range(j,nb):

        vechLi = X[i,:]
        vechLj = X[j,:]

 --> to matrix_element()
        # build Lk and Ll

        Lk = vech2L(vechLk,n);
        Ll = vech2L(vechLl,n);

        # apply symmetry projection on Ll

        # th.t() is shorthand for th.transpose(X, 0,1)
        PLl = th.t(Sym) @ Ll;

        # build Ak, Al, Akl, invAkl

        Ak = [email protected](Lk);
        Al = [email protected](PLl);
        Akl = Ak+Al
        invAkl = th.inverse(Akl);

To-this: (the complete set of parameter structures and formula terms)

# generate tensor of lower triangular matrices from X
# these are the non-linear parameters of the basis set
L = th.zeros((nb,n,n), device=device, dtype=dtype)
L = bvech2L(X,nb,n)

Note the use of “bvech2L()” that is a batched version of that code. It generates all of the nb (number of basis functions) lower triangle matrices Lk and puts them in L an (nb x n x n) tensor — a stack of nb, n x n matrices.

# get the determinates for L |L| is the product of diag elements
detL = th.abs(th.prod(th.diagonal(L, offset=0, dim1=-1, dim2=-2),1))

Get determinants of all Lk in 1 tensor operation. Those are numpy like broadcasting operations specified by the “dim” attributes.

# create the tensor of matrix products of the L matrices AKL = L x Ltranspose
AK = th.matmul(L,th.transpose(L, 1, 2))

Take the matrix product of all Lk matrices at once to make the Ak matrices. the tensor AK has the same shape as L.

--> to k (symmetry) loop
    P = Sym[k,:,:]
    # symetry projection is applied only to "ket" this constructs AL
    AL = th.matmul(th.t(P), th.matmul(AK,P))

Multiply all of the matrices in AK by the symmetry term P at the same time to make AL. AL=P’AKP for all AK. This is numpy like broadcasting again.

    # Akl = Ak + Al
    AKL = th.zeros((nb,nb,n,n), device=device, dtype=dtype)
    #for i in range(nb):
    #    for j in range(nb):
    #        #AKL[i,j] =  th.add(AK[i], AL[j])
    #        AKL[i,j] =  AK[i] + AL[j]
    AKL = AL.repeat((nb,1,1,1)) + th.transpose(AK.repeat((nb,1,1,1)), 0,1)

This bit of code is a little tricky so I left my first tensor code attempt commented out since it is easier to see what is happening there. AKL is an “outer-sum” of all combinations of the elements in AK and AL. You can see that in the loops. The code I ended up with eliminates those loops by essentially repeating all of the AL and AK terms in the shape of nb “vectors of matrices” and then summing them to get the “outer sum”. The result is the tensor AKL of dimenson (nb x nb x n x n). This is “tiling” all of the n x n matrices as sums of each other. Think of AKL as an nb x nb matrix of n x n matrices! This tensor structure works seamlessly with my batched algorithms for Cholesky decomposition and inverse below. That “repeat” construct is a powerful idea to eliminate loops.

    # get the Cholesky decomp of all Akl martices
    cholAKL = cholesky(AKL)

    # get determinates of AKL from diags |Akl|= |Lk|**2
    detAKL = th.prod(th.diagonal(cholAKL, offset=0, dim1=-1, dim2=-2),-1)**2

    # compute inverses of lower tringular matrices in cholAKL
    invLKL = inverseL(cholAKL)

    # inverses Akl^-1 = Lkl' x Lkl
    invAKL = th.matmul(th.transpose(invLKL, dim0=-1, dim1=-2),invLKL)

Those functions above compute all of their respective terms all-at-once! For real job runs there can be on the order of a million terms i.e. (nb x nb) = (1000 x 1000) so this gives a huge speedup.

The big magic is that on the Titan V GPU, with batched tensor algorithms, those million terms are all computed in the same time it would take to compute 1!!!

The rest of the code exploits the ideas of batching tensor operations and doing “outer” products and sums over tensors to eliminate loops. That replaces a lot of code with highly efficient advanced linear/matrix algebra structures that load into memory efficiently and execute with full utilization of the hardware with very little branching overhead. This is how I was able to take a completely reasonable looking implementation of the code and make it 16000 times faster on the Titan V GPU!

Program speedup (16452 Times Faster)

Remember the “teaser” I gave in Part 2)? Here it is again. In the next section I will present the code that produced that huge speedup.

start_time = time.time()
xrestart = opt_energy(steps=1, num_basis=512)
print(" took {} seconds ".format(time.time() - start_time))
step:     0  f: -0.720436686340  gradNorm: 4.060212842
 took 2768.9312148094177 seconds

2768 seconds for one step is too long!

(Note: these results are from random starting points so they differ greatly. )

start_time = time.time()
xrestart = opt_energyrc(steps=1,num_basis=512)
print(" took {:.4f} seconds ".format(time.time() - start_time))
step:     0  f: 82.210051671516  gradNorm: 495.193096479
 took 0.1683 seconds

Yes!, from 2768 seconds down to 0.1683 seconds! That’s 16452 Times Faster!!


The Complete Optimized Code

I have been running this code in a Jupyter notebook. Below is the code from the notebook cells.

import numpy as np
import torch as th

import time
dtype = th.float64

gpuid = 0
device = th.device("cuda:"+ str(gpuid))
#device = th.device("cpu")

print("Execution device: ",device)
print("PyTorch version: ", th.__version__ )
print("CUDA available: ", th.cuda.is_available())
print("CUDA version: ", th.version.cuda)
print("CUDA device:", th.cuda.get_device_name(gpuid))
# Utility functions
def vech2L(v,n):
    count = 0
    L = th.zeros((n,n), device=device, dtype=dtype)
    for j in range(n):
        for i in range(j,n):
            L[i,j]=v[count]
            count = count + 1
    return L

# batched vech2L input is V nb x n(n+1)/2 (adding 1 to diag for stability)
def bvech2L(V,nb,n):
    count = 0
    L = th.zeros((nb,n,n), device=device, dtype=dtype)
    for j in range(n):
        for i in range(j,n):
            L[...,i,j]=V[...,count]
            count = count + 1
    return L + th.tensor(th.eye(n), device=device, dtype=dtype)

# Batched Cholesky decomp
def cholesky(A):
    L = th.zeros_like(A)

    for i in range(A.shape[-1]):
        for j in range(i+1):
            s = 0.0
            for k in range(j):
                s = s + L[...,i,k].clone() * L[...,j,k].clone()

            L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
                      (1.0 / L[...,j,j].clone() * (A[...,i,j] - s))
    return L

# Batched inverse of lower triangular matrices
def inverseL(L):
    n = L.shape[-1]
    invL = th.zeros_like(L)
    for j in range(0,n):
        invL[...,j,j] = 1.0/L[...,j,j]
        for i in range(j+1,n):
            S = 0.0
            for k in range(i+1):
                S = S - L[...,i,k]*invL[...,k,j].clone()
            invL[...,i,j] = S/L[...,i,i]

    return invL
def b_energyrc(x,n,nb,Mass,Qmat,Sym,symc):

    nx = len(x);
    nn = int(n*(n+1)/2);
    nsym = len(symc);

    # extract linear coefs "eigen vector"
    c = x[-nb:];
    # reshape non-linear variables for easier indexing
    X = th.reshape(x[:nb*nn], (nb,nn))

    # generate tensor of lower triangular matrices from X
    # these are the non-linear parameters of the basis set
    L = th.zeros((nb,n,n), device=device, dtype=dtype)
    L = bvech2L(X,nb,n)

    # get the determinates for L |L| is the product of diag elements
    detL = th.abs(th.prod(th.diagonal(L, offset=0, dim1=-1, dim2=-2),1))

    # create the tensor of matrix products of the L matrices AKL = L x Ltranspose
    AK = th.matmul(L,th.transpose(L, 1, 2))


    # Initialize H T V and S matrices
    # H = T + V, we are solving (H-ES)c = 0 for E (energy)
    H = th.zeros((nb,nb), device=device, dtype=dtype);
    S = th.zeros((nb,nb), device=device, dtype=dtype);
    T = th.zeros((nb,nb), device=device, dtype=dtype);
    V = th.zeros((nb,nb), device=device, dtype=dtype);


    # outer loop is over symmetry terms, the matrices are summed over these sym terms
    for k in range(0,nsym):

        P = Sym[k,:,:]
        # symetry projection is applied only to "ket" this constructs AL
        AL = th.matmul(th.t(P), th.matmul(AK,P))

        # Akl = Ak + Al
        AKL = th.zeros((nb,nb,n,n), device=device, dtype=dtype)
        #for i in range(nb):
        #    for j in range(nb):
        #        #AKL[i,j] =  th.add(AK[i], AL[j])
        #        AKL[i,j] =  AK[i] + AL[j]
        AKL = AL.repeat((nb,1,1,1)) + th.transpose(AK.repeat((nb,1,1,1)), 0,1)

        # get the Cholesky decomp of all Akl martices
        cholAKL = cholesky(AKL)

        # get determinates of AKL from diags |Akl|= |Lk|**2
        detAKL = th.prod(th.diagonal(cholAKL, offset=0, dim1=-1, dim2=-2),-1)**2

        # compute inverses of lower tringular matrices in cholAKL
        invLKL = inverseL(cholAKL)

        # inverses Akl^-1 = Lkl' x Lkl
        invAKL = th.matmul(th.transpose(invLKL, dim0=-1, dim1=-2),invLKL)

        # get terms needed for potential energy V
        RIJ = th.zeros_like(invAKL, device=device, dtype=dtype);
        # 1/rij i~=j
        for j in range(0,n-1):
            for i in range(j+1,n):
                tmp2 = invAKL[...,i,i] + invAKL[...,j,j] - 2*invAKL[...,i,j];
                RIJ[...,i,j] = th.rsqrt(tmp2)

        # 1/rij i=j
        for i in range(0,n):
            RIJ[...,i,i] = th.rsqrt(invAKL[...,i,i])    

        # MATRIX ELEMENTS

        # Overlap: (normalized)
        # Skl = 2^3n/2 (||Lk|| ||Ll||/|AKL|)^3/2
        SKL = 2**(n*1.5) * th.sqrt( th.pow(th.ger(detL, detL)/detAKL ,3) );

        # Kinetic energy
        #TKL = SKL*(6*th.trace(Mass@Ak@invAkl@Al)) = skl*(6*th.sum(Mass*(Ak@invAkl@Al)))

        Tmat = th.zeros_like(invAKL)
        #for i in range(nb):
        #    for j in range(nb):
        #        Tmat[i,j] = (AK[i]@invAKL[i,j]@AL[j])
        Tmat = th.matmul(th.transpose(AK.repeat((nb,1,1,1)), 0,1), th.matmul(invAKL,AL))
        TKL = 6*SKL*th.sum(Mass*Tmat, dim=(-2,-1))

        # potential energy
        TWOoSqrtPI = 1.1283791670955126 # 2/sqrt(pi)

        VKL = TWOoSqrtPI*SKL*th.sum(RIJ*Qmat, dim=(-2,-1))

        # accumulate matrices over sym terms
        S = S + symc[k]*SKL
        T = T + symc[k]*TKL
        V = V + symc[k]*VKL

    # Hamiltonian
    H = T + V

    # complete lower triangle of H and S
    H = th.triu(H,1)+th.t(th.triu(H))
    S = th.triu(S,1)+th.t(th.triu(S))

    # compute Rayleigh quotent (it is the smallest energy eigen value when minimized over c)
    cHc = c@H@c;
    cSc = c@S@c;
    eng = cHc/cSc;

    return eng
def opt_energyrc(steps=1, num_basis=8, restart=True):

    #
    # Li BO setup
    #
    n=3;

    Mass = th.tensor([[0.5, 0.0, 0.0],
                     [0.0, 0.5, 0.0],
                     [0.0, 0.0, 0.5]], device=device, dtype=dtype);

    Charge = th.tensor([-3, 1, 1, -3, 1, -3], device=device, dtype=dtype);
    Charge = vech2L(Charge,n)

    # symmetry projection terms
    Sym = th.zeros((6,3,3), device=device, dtype=dtype)
    # (1)(2)(3)
    Sym[0,:,:] = th.tensor([[1,0,0],[0,1,0],[0,0,1]], device=device, dtype=dtype);
    # (12)
    Sym[1,:,:] = th.tensor([[0,1,0],[1,0,0],[0,0,1]], device=device, dtype=dtype);
    # (13)
    Sym[2,:,:] = th.tensor([[0,0,1],[0,1,0],[1,0,0]], device=device, dtype=dtype);
    # (23)
    Sym[3,:,:] = th.tensor([[1,0,0],[0,0,1],[0,1,0]], device=device, dtype=dtype);
    # (123)
    Sym[4,:,:] = th.tensor([[0,1,0],[0,0,1],[1,0,0]], device=device, dtype=dtype);
    # (132)
    Sym[5,:,:] = th.tensor([[0,0,1],[1,0,0],[0,1,0]], device=device, dtype=dtype);

    # coeff's
    symc = th.tensor([4.0,4.0,-2.0,-2.0,-2.0,-2.0], device=device, dtype=dtype);

    # Sample parameters should return energy of -7.3615
    xvechL=th.tensor([
         1.6210e+00, -2.1504e-01,  9.0755e-01,  9.7866e-01, -2.8418e-01,
        -3.5286e+00, -3.3045e+00, -4.5036e+00, -3.2116e-01, -7.1901e-02,
         1.5167e+00, -8.4489e-01, -2.1377e-01, -3.6127e-03, -5.3774e-03,
        -2.1263e+00, -2.5191e-01,  2.1235e+00, -2.1396e-01, -1.4084e-03,
        -1.0092e-02,  4.5349e+00,  9.4837e-03,  1.1225e+00, -2.1315e-01,
         5.8451e-02, -4.9410e-03,  5.0853e+00,  7.3332e-01,  5.0672e+00,
        -2.1589e-01, -6.8986e-03, -1.4310e-02,  1.5979e+00,  3.3946e-02,
        -8.7965e-01, -1.1121e+00, -2.1903e-03, -4.6925e-02,  2.1457e-01,
         3.3045e-03,  4.5120e+00, -2.1423e-01, -1.6493e-02, -2.3429e-03,
        -8.6715e-01, -6.7070e-02,  1.5998e+00
     ], device=device, dtype=dtype, requires_grad=False)

    evec = th.tensor([
      -6.0460e-02,  7.7708e-05, 1.6152e+00,  9.5443e-01,  
      1.1771e-01,  3.2196e+00,  9.6344e-01, 3.1398e+00
    ], device=device, dtype=dtype, requires_grad=False)


    # uncomment following lines to test above
    #nb=8
    #x1 = th.tensor(th.cat((xvechL,evec)), device=device, dtype=dtype, requires_grad=True)
    #energy = b_energyrc(x1,n,nb,Mass,Charge,Sym,symc)
    #print(energy) # should be -7.3615
    #return x1

    if restart:
        nb=num_basis
        x1 = xrestart
    else:
        # random start point
        nb=num_basis
        #th.manual_seed(42)
        x1 = th.empty(int(nb*n*(n+1)/2 + nb), device=device, dtype=dtype, requires_grad=True)
        th.nn.init.uniform_(x1, a=-0.8, b=0.8)

    # start from a restart value
    #x1 = xrestart
    #print(energy)
    #return x1

    # Do the Optimization (take your pick, but Rprop is great for this code)
    #optimizer = th.optim.LBFGS([x1])
    #optimizer = th.optim.Adadelta([x1], lr=1.0)
    #optimizer = th.optim.Adam([x1], lr=0.005)
    optimizer = th.optim.Rprop([x1], lr=0.0001, etas=(0.5, 1.2), step_sizes=(1e-07, 50))

    # this is not needed with Rprop but useful for adjusting learning rate for others
    #scheduler =   th.optim.lr_scheduler.ReduceLROnPlateau(optimizer,threshold=0.00001,cooldown=3, verbose=True,patience=2, factor=0.5)

    for i in range(steps):
        optimizer.zero_grad()
        loss = b_energyrc(x1,n,nb,Mass,Charge,Sym,symc)
        loss.backward()
        # a closure is only needed for LBFGS opt
        #def closure():
        #    return b_energyrc(x1,n,nb,Mass,Charge,Sym,symc)
        #optimizer.step(closure)
        optimizer.step()
        #scheduler.step(loss)

        if (i<20 or not i%100):print('step: {:5}  f: {:4.12f}  gradNorm: {:.9f}'.format(i, loss, th.norm(x1.grad)))
    # print last value
    print('step: {:5}  f: {:4.12f}  gradNorm: {:.9f}'.format(i, loss, th.norm(x1.grad)))
    return x1
start_time = time.time()
for i in range(10):
    print("Optimization restart: {}".format(i))
    xrestart = opt_energyrc(steps=5000,num_basis=512, restart=Fasle)
print(" took {:.4f} seconds ".format(time.time() - start_time))

Some notes on the code and the science...

This is "research" quality code. I wrote the first code implementation as an exercise to play with PyTorch. I did all of this in my "spare time" over a few weeks. It is based on science I did many years ago as a graduate student at Washington State University. I later spent several enjoyable years doing research with more advanced forms of the basis set at the University of Arizona. I had a reference implementation in MATLAB to validate against. I have old fortran code too, but haven't tried to recompile it. It was really fun to do this, then and now!

I learned a lot about PyTorch while writing this code and even more about coding for GPU. I'm still just stunned at how well some of these code optimization worked on the Titan V.

I have been running my jobs with FP64 on the Titan V (which is fantastic for that!). I made a slight modification to the code to stabilize it enough to use FP32 single precision but that is really not enough precision for this kind of work. The energy lowering with FP32 stalls out sooner than I would like. The code is capable of achieving nano-hartree energy accuracy so double precision is really a must.

The memory management in the code could be improved. I'm sure I'm using up more memory than I need to. It looks like there are places where I could release some memory manually or wrap code in function definitions so the garbage collector could remove it. I just haven't spent the time to do all that.

The job input setup in the code above is for Lithium atom in the Born-Oppenheimer approximation i.e. infinite nuclear mass. (that's why the "Mass" input matrix is diagonal.) I did this so I can compare against published results. The code can actually be used to compute optimized wavefunctions using actual isotopic masses. That was part of the motivation for the original research since there is very little code available to do that. It's possible to use this code for exotic systems like matter--anti-matter (e+ e+ e- e-) i.e. 2 positrons and 2 electrons. Yes, that is bound, even thought it has a short life-time. Because of the explicit symmetry projection, the code is only usable for very small systems of particles. I really don't expect anyone to use this code. [I do know a couple of people that might use it though.] If I feel motivated and have enough time I may write code for a better form of the basis functions that I did the math for years ago but never implemented. We'll see!

Happy computing! --dbk