PEANUT model architecture

Overview of the PEANUT model architecture and components. Let \(N_{atoms}\) denote the number of atoms in the system, \(N_{elem}\) be the number of different chemical elements in the dataset, \(D\) the dimension of the atom embedding vectors, \(B\) the number of radial basis functions, \(L\) the maximum degree of spherical harmonics, and \(T\) the number of message-passing iterations. For each atom \(i\), let \(\mathcal{N}(i)\) denote the set of neighboring atoms within the cutoff radius.

Input for the model

Cartesian coordinates

\[r = (r_i)_{i=1}^{N_{\mathit{atoms}}} \in \mathbb{R}^{3 \times N_{\mathit{atoms}}}\]

Atom types

\[Z = (Z_i)_{i=1}^{N_{\mathit{atoms}}} \in \{1, \dots, N_{\mathit{elem}}\}^{N_{\mathit{atoms}}}\]

Construct an embedding vector for each atom:

\[h_i^{(0)} = e(Z_i) \in \mathbb{R}^D\]

where \(e\) is a function mapping the element type \(Z_i\) of atom \(i\) to an initial embedding vector \(h_i^{(0)}\) of length \(D\). Distances between atom pairs are encoded as Euclidean distance:

\[d_{ij} = r_j - r_i, \quad r_{ij} = \lVert d_{ij} \rVert_2\]

Radial features are encoded using \(B\) Behler-Parinello symmetry functions with centers \(R_B \in \mathbb{R}_+\) and width parameters \(\eta_B \in \mathbb{R}_+\). Additionally, a cutoff function \(f_c\) is applied to ensure smoothness at the cutoff radius \(r_c\).

\[R_d(r_{ij}) = e^{-\eta_B (r_{ij} - R_B)^2} f_c(r_{ij})\]

Each distance \(r_{ij}\) is mapped to a \(B\)-dimensional radial feature vector:

\[\begin{split}R_{ij} = \begin{pmatrix} R_1(r_{ij}) \\ R_2(r_{ij}) \\ \vdots \\ R_B(r_{ij}) \end{pmatrix} \in \mathbb{R}^B\end{split}\]

Spherical features are encoded using spherical harmonics:

\[\hat{d}_{ij} = \frac{d_{ij}}{\lVert d_{ij} \rVert_2}\]
\[Y_l^m(\hat{d}_{ij}), \quad m=-l,\dots,l, \quad l=0,\dots,L\]

Total number of spherical harmonics per pair:

\[|Y| = \sum_{l=0}^{L} (2l+1) = (L+1)^2\]
\[Y_{ij} = \bigl(Y_l^m(\hat{d}_{ij})\bigr)_{l=0,\dots,L}^{m=-l,\dots,l} \in \mathbb{R}^{(L+1)^2}\]

Radial and spherical features are combined:

\[C_{ij} = R_{ij} \otimes Y_{ij} \in \mathbb{R}^{B \times (L+1)^2}\]
\[F_{ij} = \mathit{flatten}(C_{ij})\]

Steps per atom \(i\) and message-passing iteration \(t = 1, \dots, T\)

Message passing: Construct message for each neighbor \(j\) of atom \(i\) using edge features \(F_{ij}\) and sender node features \(h_j^{(t-1)}\):

\[m_{ij}^{(t)} = f_{\mathit{message}}(F_{ij}, h_j^{(t-1)}), \quad j \in \mathcal{N}(i)\]

Aggregation: Combine all incoming messages from all atoms \(j\) in the neighborhood \(\mathcal{N}(i)\) of atom \(i\) to update the node embedding \(h_i^{(t)}\) at iteration \(t\).

\[a_i^{(t)} = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} m_{ij}^{(t)}\]

Node update:

\[h_i^{(t)} = MLP(h_i^{(t-1)}, a_i^{(t)}) + h_i^{(t-1)}\]

Energy prediction

\[E_i = MLP_{\mathit{energy}}(h_i^{(T)}) \in \mathbb{R}\]

Total energy per batch:

\[E^{(b)} = \sum_{i \in b} E_i\]