Skip to content

Bin-Cao/PRDNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PRDNet: Pseudo-particle Ray Diffraction Network

License: MIT ICLR 2026

Screenshot 2026-01-31 at 12 01 13

PRDNet is a physics-informed graph neural network for crystal property prediction that combines:

  • Graph Neural Networks for crystal structure representation
  • Pseudo-particle Ray Diffraction physics integration
  • Multi-head Attention mechanisms

Quick Start

1. Installation

# Clone repository
git clone https://github.com/Bin-Cao/PRDNet.git
cd PRDNet

# Install dependencies
pip install -r requirements.txt

# Verify installation
python -c "import prdnet; print('PRDNet installed successfully!')"

2. Prepare Data

Your data should be in ASE database format. Here's how to create one:

from ase.db import connect
from ase import Atoms
from ase.build import bulk

# Create database
db = connect("my_data.db")

# Example: Add some crystal structures
# Silicon crystal
si = bulk('Si', 'diamond', a=5.43)
db.write(si, formation_energy=-5.42, band_gap=1.12)

# NaCl crystal
nacl = bulk('NaCl', 'rocksalt', a=5.64)
db.write(nacl, formation_energy=-8.23, band_gap=8.5)

print(f"Database created with {len(db)} structures")

Using existing datasets:

3. Train Model

Option A: Edit trainer.py (Recommended)

  1. Open trainer.py and modify the database paths in the main() function:
config = create_trainer_config(
    train_db_path="your_train.db",      # ← Change this
    val_db_path="your_val.db",          # ← Change this
    test_db_path="your_test.db",        # ← Change this
    target_property="formation_energy", # ← Change if needed
    # ... other settings
)
  1. Run training:
python trainer.py

Option B: Python API

from trainer import PrdnetTrainer, create_trainer_config

config = create_trainer_config(
    train_db_path="my_data.db",
    target_property="formation_energy",
    epochs=100,
    batch_size=32
)

trainer = PrdnetTrainer(config)
results = trainer.train()

Configuration

Key Parameters

Edit these parameters in trainer.py:

config = create_trainer_config(
    # Data paths
    train_db_path="path/to/train.db",
    val_db_path="path/to/val.db",      # Optional
    target_property="formation_energy", # Property to predict

    # Training settings
    epochs=500,
    batch_size=32,
    learning_rate=0.0005,

    # Model architecture
    model_config={
        "conv_layers": 6,           # Graph convolution layers
        "node_features": 256,       # Node embedding dimension
        "use_diffraction": True,    # Enable physics integration
        "diffraction_max_hkl": 5,   # Miller index range
        "node_layer_head": 8,       # Attention heads
    }
)

Advanced Usage

Distributed Training

# Multi-GPU training
torchrun --nproc_per_node=4 trainer.py

# Custom parameters
torchrun --nproc_per_node=4 trainer.py \
    --epochs 500 \
    --batch_size 96

Data Caching

PRDNet automatically caches preprocessed data for faster training:

config = create_trainer_config(
    cache_dir="./prdnet_cache",  # Cache directory
    use_cache=True,              # Enable caching
    force_cache_rebuild=False    # Rebuild cache if needed
)

Monitoring with WandB

# Install WandB (optional)
pip install wandb
wandb login

# Training with logging
python trainer.py  # Metrics automatically logged

Supported Properties

PRDNET can predict various materials properties:

  • Formation energy (formation_energy)
  • Band gap (band_gap)
  • Bulk modulus (bulk_modulus)
  • Shear modulus (shear_modulus)
  • Custom properties (any numeric property in your database)

Datasets

  • Materials Project - Formation energies, band gaps, elastic properties
  • JARVIS-DFT - Comprehensive DFT calculations
  • Custom databases - ASE database format

Troubleshooting

Common Issues

CUDA out of memory:

# Reduce batch size and model size
batch_size=16
model_config={"node_features": 128, "conv_layers": 4}

Database format errors:

# Check your database
python -c "from ase.db import connect; print(len(connect('your_data.db')))"

Missing dependencies:

# Install PyTorch first
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Then install other dependencies
pip install -r requirements.txt

License

MIT License - see LICENSE file for details.

Citation

@article{cao2025beyond,
  title={Beyond Structure: Invariant Crystal Property Prediction with Pseudo-Particle Ray Diffraction},
  author={Cao, Bin and Liu, Yang and Zhang, Longhan and Wu, Yifan and Li, Zhixun and Luo, Yuyu and Cheng, Hong and Ren, Yang and Zhang, Tong-Yi},
  booktitle={The Fourteenth International Conference on Learning Representations},
  year={2026}
}

Contributing

We welcome contributions! Please see CONTRIBUTING.md for guidelines.

Contact

About

[ICLR 2026] The official implement of PRDNet for Crystal Property Prediction

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages