Representation Learning
Representation learning describes how the model transforms raw input data into meaningful internal representations.
Subsections
Building blocks for representation learning
Component |
Explanation |
|---|---|
Node features |
Learned atom type embeddings h_i, basis for all calculations |
Edge features |
Messages m_ij, are constructed with learned node features |
Message passing / graph NN |
|
Update function |
Updates node features to allow information to propagate |
Readout / pooling |
|
Tool box
Component |
Explanation |
|---|---|
Radial basis |
RadialBasis is fixed at initialization, taking distances |
Directional basis |
DirectionalBasis e.g., Vector-based Spherical harmonics of rank |
Learned features |
Maps the static radial and directional bases to learned features via small MLPs |
Edge MLP |
Concatenates sender node (initially these are the embedding vectors), receiver node, and radial + angular features. Outputs a learned message embedding for each edge. |
Attention |
Simple attention on messages (sigmoid or softmax). Could be replaced by softmax per node if desired. |
Node update |
Sums (weighted) messages ( |
Multi-scale |
Optional use of 2-3 different edge MLPs allows different treatments of neighboring atoms based on distance. Can be implemented by calling this layer separately on different neighbor lists, then summing messages before the node MLP. |
Final MLP pass |
The final node vectors are passed through a feed-forward MLP for potential energy predictions. The overall energy is the sum of all atom energy contributions. |