Representation Learning

Representation learning describes how the model transforms raw input data into meaningful internal representations.

Building blocks for representation learning

Component

Explanation

Node features

Learned atom type embeddings h_i, basis for all calculations

Edge features

Messages m_ij, are constructed with learned node features

Message passing / graph NN

Aggregates neighbor information, possibly with learned weights

depending on distances/directions (attention layer) -> the network learns chemical interactions

Update function

Updates node features to allow information to propagate

Readout / pooling

Converts node embeddings to molecular energy.

Can be sum, mean, or learned aggregation

Tool box

Component

Explanation

Radial basis b_{ij}

RadialBasis is fixed at initialization, taking distances r_ij and mapping them to a higher-dimensional embedding.

Directional basis Y_{ij}

DirectionalBasis e.g., Vector-based Spherical harmonics of rank l (Y_{ij} = Y_l(r_{ij}))

Learned features

Maps the static radial and directional bases to learned features via small MLPs RY^{learned}_{ij} = MLP(b_{ij}, Y_{ij})

Edge MLP

Concatenates sender node (initially these are the embedding vectors), receiver node, and radial + angular features. Outputs a learned message embedding for each edge. m_{ij} = MLP(h_i, h_j, RY^{learned}_{ij})

Attention \alpha_{ij} (Optional)

Simple attention on messages (sigmoid or softmax). Could be replaced by softmax per node if desired.

Node update

Sums (weighted) messages (m^{'}_{ij} = \alpha_{ij}\cdot m_{ij}) from neighbors. Passes the result through a small MLP for the new node embedding to get updated node vectors. h^{n+1}_i = MLP(h^{n}_i, (m^{'}_ij)_j)

Multi-scale

Optional use of 2-3 different edge MLPs allows different treatments of neighboring atoms based on distance. Can be implemented by calling this layer separately on different neighbor lists, then summing messages before the node MLP.

Final MLP pass

The final node vectors are passed through a feed-forward MLP for potential energy predictions. The overall energy is the sum of all atom energy contributions. E = \sum_i E_i