Skip to content

Add model training with mini-batches#59

Draft
JMorado wants to merge 8 commits intochemle:mainfrom
JMorado:feature_batching
Draft

Add model training with mini-batches#59
JMorado wants to merge 8 commits intochemle:mainfrom
JMorado:feature_batching

Conversation

@JMorado
Copy link
Contributor

@JMorado JMorado commented Jul 23, 2025

This PR adds support for training EMLE models using mini-batches. As training datasets continue to grow in size, this feature becomes essential to avoid memory issues. Unlike with the QM7 dataset, we can no longer fit everything into memory.

The implementation introduces three flags: --use-minibatch, which enables/disables mini-batch training; --batch-size, which specifies the size of each mini-batch, and --shuffle, which shuffles the training data. By default, training still uses the original full-batch optimization.

@lohedges
Copy link
Contributor

Thanks for this. Ignore the test failures. This is because sqm is currently completely broken with recent versions of ambertools. (There are glibc issues, so the package will likely need to be rebuilt.)

@JMorado
Copy link
Contributor Author

JMorado commented Jul 23, 2025

I've added a proof-of-concept implementation to perform the IVM and AEV calculations and to make the training step of valence widths "lazy", i.e. such that batches of masked AEVs are written to disk and loaded on the fly as needed. This is necessary because it is otherwise impossible to load large datasets into memory (the training does not go past the AEV computation, and it's impossible to store the aev_mols tensor in memory). I've been testing this on a dataset with ca. 0.5 M configurations, and it seems like a viable solution so far, although not the most performant. I'm keen to improve the implementation, so any suggestions are welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants