diff --git a/docs/api/models/pyhealth.models.AdaCare.rst b/docs/api/models/pyhealth.models.AdaCare.rst index 00aeaf4f0..4d988ba78 100644 --- a/docs/api/models/pyhealth.models.AdaCare.rst +++ b/docs/api/models/pyhealth.models.AdaCare.rst @@ -9,6 +9,11 @@ The separate callable AdaCareLayer and the complete AdaCare model. :show-inheritance: .. autoclass:: pyhealth.models.AdaCare + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.MultimodalAdaCare :members: :undoc-members: :show-inheritance: \ No newline at end of file diff --git a/docs/api/models/pyhealth.models.RETAIN.rst b/docs/api/models/pyhealth.models.RETAIN.rst index 88899363e..ac8bbad06 100644 --- a/docs/api/models/pyhealth.models.RETAIN.rst +++ b/docs/api/models/pyhealth.models.RETAIN.rst @@ -9,6 +9,11 @@ The separate callable RETAINLayer and the complete RETAIN model. :show-inheritance: .. autoclass:: pyhealth.models.RETAIN + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.MultimodalRETAIN :members: :undoc-members: :show-inheritance: \ No newline at end of file diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_adacare.py b/examples/drug_recommendation/drug_recommendation_mimic4_adacare.py new file mode 100644 index 000000000..9cddad01b --- /dev/null +++ b/examples/drug_recommendation/drug_recommendation_mimic4_adacare.py @@ -0,0 +1,121 @@ +""" +Example of using AdaCare for drug recommendation on MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data +2. Applying the DrugRecommendationMIMIC4 task +3. Creating a SampleDataset with nested sequence processors +4. Training an AdaCare model +""" + +import torch + +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import AdaCare +from pyhealth.tasks import DrugRecommendationMIMIC4 +from pyhealth.trainer import Trainer + +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + cache_dir="/shared/eng/pyhealth_agent/baselines", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], + ) + + # STEP 2: Apply drug recommendation task + sample_dataset = base_dataset.set_task( + DrugRecommendationMIMIC4(), + num_workers=4, + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f" Visit ID: {sample['visit_id']}") + print(f" Conditions (history): {len(sample['conditions'])} visits") + print(f" Procedures (history): {len(sample['procedures'])} visits") + print(f" Drugs history: {len(sample['drugs_hist'])} visits") + print(f" Target drugs: {len(sample['drugs'])} drugs") + print(f"\n First visit conditions: {sample['conditions'][0][:5]}...") + print(f" Target drugs sample: {sample['drugs'][:5]}...") + + # STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print("\nDataset split:") + print(f" Train: {len(train_dataset)} samples") + print(f" Validation: {len(val_dataset)} samples") + print(f" Test: {len(test_dataset)} samples") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) + + # STEP 4: Initialize AdaCare model + model = AdaCare( + dataset=sample_dataset, + embedding_dim=128, + hidden_dim=128, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params:,} parameters") + print(f"Feature keys: {model.feature_keys}") + print(f"Label key: {model.label_keys[0]}") + + # STEP 5: Train the model + trainer = Trainer( + model=model, + device="cuda:1", # or "cpu" + metrics=["pr_auc_samples", "f1_samples", "jaccard_samples"], + ) + + print("\nStarting training...") + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + monitor="pr_auc_samples", + optimizer_params={"lr": 1e-4}, + optimizer_class=torch.optim.AdamW, + ) + + # STEP 6: Evaluate on test set + print("\nEvaluating on test set...") + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Inspect model predictions + print("\nSample predictions:") + sample_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = model(**sample_batch) + + print(f" Batch size: {output['y_prob'].shape[0]}") + print(f" Number of drug classes: {output['y_prob'].shape[1]}") + print(" Predicted probabilities (first 5 drugs of first patient):") + print(f" {output['y_prob'][0, :5].cpu().numpy()}") + print(" True labels (first 5 drugs of first patient):") + print(f" {output['y_true'][0, :5].cpu().numpy()}") diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_adacare_optuna.py b/examples/drug_recommendation/drug_recommendation_mimic4_adacare_optuna.py new file mode 100644 index 000000000..ee618bfb7 --- /dev/null +++ b/examples/drug_recommendation/drug_recommendation_mimic4_adacare_optuna.py @@ -0,0 +1,201 @@ +""" +Optuna hyperparameter tuning for AdaCare on drug recommendation with MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data and applying the DrugRecommendationMIMIC4 task +2. Defining an Optuna objective that tunes AdaCare-specific hyperparameters +3. Running 10 Optuna trials to find the best configuration +4. Training a final model with the best hyperparameters + +Tuned hyperparameters: + - embedding_dim: embedding size for code tokens + - hidden_dim: GRU hidden state size inside AdaCare + - lr: learning rate for AdamW + - weight_decay: L2 regularization coefficient for AdamW + +Note: + AdaCare.__init__ forwards **kwargs to BaseModel.__init__, which only + accepts `dataset`. Layer-specific parameters (kernel_size, kernel_num, + r_v, r_c, activation, dropout) must not be passed to AdaCare() directly. + The tunable surface here covers the explicit named parameters + (embedding_dim, hidden_dim) and the optimizer settings. +""" + +import torch +import optuna + +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import AdaCare +from pyhealth.tasks import DrugRecommendationMIMIC4 +from pyhealth.trainer import Trainer + +if __name__ == "__main__": + # --------------------------------------------------------------------------- + # STEP 1: Load MIMIC-IV base dataset + # --------------------------------------------------------------------------- + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + cache_dir="/shared/eng/pyhealth_agent/baselines", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], + ) + + # STEP 2: Apply drug recommendation task + sample_dataset = base_dataset.set_task( + DrugRecommendationMIMIC4(), + num_workers=4, + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # STEP 3: Split dataset (fixed split so all trials see the same data) + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print(f"\nDataset split — Train: {len(train_dataset)} " + f"Val: {len(val_dataset)} Test: {len(test_dataset)}") + + # --------------------------------------------------------------------------- + # STEP 4: Define Optuna objective + # --------------------------------------------------------------------------- + DEVICE = "cuda:3" # or "cpu" + TUNE_EPOCHS = 10 # lightweight training per trial + N_TRIALS = 10 + + def objective(trial: optuna.Trial) -> float: + """Suggest hyperparameters and return val pr_auc_samples.""" + + # --- Suggest hyperparameters --------------------------------------- + embedding_dim = trial.suggest_categorical( + "embedding_dim", [64, 128, 256] + ) + hidden_dim = trial.suggest_categorical("hidden_dim", [64, 128, 256]) + lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) + weight_decay = trial.suggest_float( + "weight_decay", 1e-6, 1e-2, log=True + ) + batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) + + # --- Build dataloaders --------------------------------------------- + train_loader = get_dataloader( + train_dataset, batch_size=batch_size, shuffle=True + ) + val_loader = get_dataloader( + val_dataset, batch_size=batch_size, shuffle=False + ) + + # --- Build model --------------------------------------------------- + model = AdaCare( + dataset=sample_dataset, + embedding_dim=embedding_dim, + hidden_dim=hidden_dim, + ) + + # --- Train --------------------------------------------------------- + trainer = Trainer( + model=model, + device=DEVICE, + metrics=["pr_auc_samples"], + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=TUNE_EPOCHS, + monitor="pr_auc_samples", + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": lr}, + weight_decay=weight_decay, + ) + + # --- Evaluate on validation set ------------------------------------ + scores = trainer.evaluate(val_loader) + return scores["pr_auc_samples"] + + # --------------------------------------------------------------------------- + # STEP 5: Run Optuna study + # --------------------------------------------------------------------------- + print( + f"\nStarting Optuna search: " + f"{N_TRIALS} trials, {TUNE_EPOCHS} epochs each..." + ) + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=N_TRIALS) + + best_params = study.best_params + print("\nBest hyperparameters found:") + for k, v in best_params.items(): + print(f" {k}: {v}") + print(f"Best validation pr_auc_samples: {study.best_value:.4f}") + + # --------------------------------------------------------------------------- + # STEP 6: Train final model with best hyperparameters + # --------------------------------------------------------------------------- + print("\nTraining final model with best hyperparameters...") + + train_loader = get_dataloader( + train_dataset, batch_size=best_params["batch_size"], shuffle=True + ) + val_loader = get_dataloader( + val_dataset, batch_size=best_params["batch_size"], shuffle=False + ) + test_loader = get_dataloader( + test_dataset, batch_size=best_params["batch_size"], shuffle=False + ) + + final_model = AdaCare( + dataset=sample_dataset, + embedding_dim=best_params["embedding_dim"], + hidden_dim=best_params["hidden_dim"], + ) + + num_params = sum(p.numel() for p in final_model.parameters()) + print(f"Final model: {num_params:,} parameters") + + final_trainer = Trainer( + model=final_model, + device=DEVICE, + metrics=["pr_auc_samples", "f1_samples", "jaccard_samples"], + ) + final_trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + monitor="pr_auc_samples", + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": best_params["lr"]}, + weight_decay=best_params["weight_decay"], + ) + + # STEP 7: Evaluate on test set + print("\nEvaluating on test set...") + results = final_trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 8: Inspect model predictions + print("\nSample predictions:") + sample_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = final_model(**sample_batch) + + print(f" Batch size: {output['y_prob'].shape[0]}") + print(f" Number of drug classes: {output['y_prob'].shape[1]}") + print(" Predicted probabilities (first 5 drugs of first patient):") + print(f" {output['y_prob'][0, :5].cpu().numpy()}") + print(" True labels (first 5 drugs of first patient):") + print(f" {output['y_true'][0, :5].cpu().numpy()}") diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_multimodal_retain.py b/examples/drug_recommendation/drug_recommendation_mimic4_multimodal_retain.py new file mode 100644 index 000000000..a47cce1f0 --- /dev/null +++ b/examples/drug_recommendation/drug_recommendation_mimic4_multimodal_retain.py @@ -0,0 +1,210 @@ +""" +Drug Recommendation on MIMIC-IV with MultimodalRETAIN + +This example demonstrates how to use the MultimodalRETAIN model with mixed +input modalities for drug recommendation on MIMIC-IV. + +The MultimodalRETAIN model can handle: +- Sequential features (visit histories with diagnoses, procedures) → RETAIN processing + with reverse time attention mechanism +- Non-sequential features (demographics, static measurements) → Direct embedding + +This example shows: +1. Loading MIMIC-IV data with mixed feature types +2. Applying a drug recommendation task +3. Training a MultimodalRETAIN model with both sequential and non-sequential inputs +4. Evaluating the model performance +5. Comparing to vanilla RETAIN (sequential only) +""" + +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.models import MultimodalRETAIN +from pyhealth.tasks import DrugRecommendationMIMIC4 +from pyhealth.trainer import Trainer + + +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + print("=" * 60) + print("STEP 1: Loading MIMIC-IV Dataset") + print("=" * 60) + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + dev=True, # Use development mode for faster testing + num_workers=4, + ) + base_dataset.stats() + + # STEP 2: Apply drug recommendation task with multimodal features + print("\n" + "=" * 60) + print("STEP 2: Setting Drug Recommendation Task") + print("=" * 60) + + # Use the DrugRecommendationMIMIC4 task + # This task creates visit-level nested sequences from diagnoses/procedures + # and recommends drugs for the current visit + task = DrugRecommendationMIMIC4() + sample_dataset = base_dataset.set_task( + task, + num_workers=4, + ) + + print(f"\nTotal samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + if len(sample_dataset) > 0: + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + for key in sample_dataset.input_schema.keys(): + if key in sample: + if isinstance(sample[key], (list, tuple)): + if sample[key] and isinstance(sample[key][0], (list, tuple)): + print(f" {key}: {len(sample[key])} visits") + else: + print(f" {key}: length {len(sample[key])}") + else: + print(f" {key}: {type(sample[key])}") + # Show drugs key from output + if 'drugs' in sample: + print(f" drugs (target): {len(sample['drugs'])} prescriptions") + + # STEP 3: Split dataset + print("\n" + "=" * 60) + print("STEP 3: Splitting Dataset") + print("=" * 60) + + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print(f"Train samples: {len(train_dataset)}") + print(f"Val samples: {len(val_dataset)}") + print(f"Test samples: {len(test_dataset)}") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) + + # STEP 4: Initialize MultimodalRETAIN model + print("\n" + "=" * 60) + print("STEP 4: Initializing MultimodalRETAIN Model") + print("=" * 60) + + model = MultimodalRETAIN( + dataset=sample_dataset, + embedding_dim=128, + dropout=0.5, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"Model initialized with {num_params:,} parameters") + + # Print feature classification + print(f"\nSequential features (RETAIN processing): {model.sequential_features}") + print(f"Non-sequential features (direct embedding): {model.non_sequential_features}") + + # Calculate expected embedding dimensions + total_dim = len(model.feature_keys) * model.embedding_dim + print(f"\nPatient representation dimension: {total_dim}") + + # STEP 5: Train the model + print("\n" + "=" * 60) + print("STEP 5: Training Model") + print("=" * 60) + + trainer = Trainer( + model=model, + device="cuda:0", # Change to "cpu" if no GPU available + metrics=["pr_auc_samples", "roc_auc_samples", "jaccard_samples", "f1_samples"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=10, + monitor="jaccard_samples", + optimizer_params={"lr": 1e-3}, + ) + + # STEP 6: Evaluate on test set + print("\n" + "=" * 60) + print("STEP 6: Evaluating on Test Set") + print("=" * 60) + + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Demonstrate model predictions + print("\n" + "=" * 60) + print("STEP 7: Sample Predictions") + print("=" * 60) + + import torch + + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print(f"\nBatch size: {output['y_prob'].shape[0]}") + print(f"Output shape: {output['y_prob'].shape}") + print(f"(batch_size, num_drug_types)") + + # Show first patient predictions + print(f"\nFirst patient top-5 drug recommendations:") + first_patient_probs = output['y_prob'][0] + top5_drugs = torch.topk(first_patient_probs, k=min(5, len(first_patient_probs))) + for i, (drug_idx, prob) in enumerate(zip(top5_drugs.indices, top5_drugs.values)): + print(f" {i+1}. Drug index {drug_idx.item()}: probability {prob.item():.4f}") + + # Show ground truth for first patient + print(f"\nFirst patient ground truth drugs:") + first_patient_true = output['y_true'][0] + true_drug_indices = torch.where(first_patient_true > 0)[0] + print(f" Number of prescribed drugs: {len(true_drug_indices)}") + if len(true_drug_indices) > 0: + print(f" Drug indices: {true_drug_indices.tolist()[:10]}...") + + # STEP 8: Compare with vanilla RETAIN (if applicable) + print("\n" + "=" * 60) + print("STEP 8: Model Architecture Comparison") + print("=" * 60) + + print("\nMultimodalRETAIN vs. Vanilla RETAIN:") + print(" Vanilla RETAIN:") + print(" - Only handles sequential (visit-level) features") + print(" - Processes all features through reverse time attention") + print(" ") + print(" MultimodalRETAIN:") + print(" - Handles both sequential and non-sequential features") + print(f" - Sequential features ({len(model.sequential_features)}): " + f"{model.sequential_features}") + print(f" - Non-sequential features ({len(model.non_sequential_features)}): " + f"{model.non_sequential_features}") + print(" - More flexible for heterogeneous EHR data") + + # Summary + print("\n" + "=" * 60) + print("SUMMARY: MultimodalRETAIN Training Complete") + print("=" * 60) + print(f"Model: MultimodalRETAIN") + print(f"Dataset: MIMIC-IV") + print(f"Task: Drug Recommendation") + print(f"Sequential features: {len(model.sequential_features)}") + print(f"Non-sequential features: {len(model.non_sequential_features)}") + print(f"Best validation Jaccard: {results.get('jaccard_samples', 0):.4f}") + print("\nRETAIN advantages:") + print(" - Reverse time attention for interpretability") + print(" - Visit-level attention weights (alpha)") + print(" - Variable-level attention weights (beta)") + print(" - Multimodal extension allows richer feature sets") + print("=" * 60) + diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_retain.py b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py index 496edc7e3..d39c46042 100644 --- a/examples/drug_recommendation/drug_recommendation_mimic4_retain.py +++ b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py @@ -19,100 +19,102 @@ from pyhealth.tasks import DrugRecommendationMIMIC4 from pyhealth.trainer import Trainer -# STEP 1: Load MIMIC-IV base dataset -base_dataset = MIMIC4Dataset( - ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "prescriptions", - ], -) - -# STEP 2: Apply drug recommendation task -sample_dataset = base_dataset.set_task( - DrugRecommendationMIMIC4(), - num_workers=4, -) - -print(f"Total samples: {len(sample_dataset)}") -print(f"Input schema: {sample_dataset.input_schema}") -print(f"Output schema: {sample_dataset.output_schema}") - -# Inspect a sample -sample = sample_dataset.samples[0] -print("\nSample structure:") -print(f" Patient ID: {sample['patient_id']}") -print(f" Visit ID: {sample['visit_id']}") -print(f" Conditions (history): {len(sample['conditions'])} visits") -print(f" Procedures (history): {len(sample['procedures'])} visits") -print(f" Drugs history: {len(sample['drugs_hist'])} visits") -print(f" Target drugs: {len(sample['drugs'])} drugs") -print(f"\n First visit conditions: {sample['conditions'][0][:5]}...") -print(f" Target drugs sample: {sample['drugs'][:5]}...") - -# STEP 3: Split dataset -train_dataset, val_dataset, test_dataset = split_by_patient( - sample_dataset, [0.8, 0.1, 0.1] -) - -print("\nDataset split:") -print(f" Train: {len(train_dataset)} samples") -print(f" Validation: {len(val_dataset)} samples") -print(f" Test: {len(test_dataset)} samples") - -# Create dataloaders -train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) -val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) -test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) - -# STEP 4: Initialize RETAIN model -model = RETAIN( - dataset=sample_dataset, - embedding_dim=128, - dropout=0.5, -) - -num_params = sum(p.numel() for p in model.parameters()) -print(f"\nModel initialized with {num_params:,} parameters") -print(f"Feature keys: {model.feature_keys}") -print(f"Label key: {model.label_key}") - -# STEP 5: Train the model -trainer = Trainer( - model=model, - device="cuda:4", # or "cpu" - metrics=["pr_auc_samples", "f1_samples", "jaccard_samples"], -) - -print("\nStarting training...") -trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=50, - monitor="pr_auc_samples", - optimizer_params={"lr": 1e-3}, -) - -# STEP 6: Evaluate on test set -print("\nEvaluating on test set...") -results = trainer.evaluate(test_loader) -print("\nTest Results:") -for metric, value in results.items(): - print(f" {metric}: {value:.4f}") - -# STEP 7: Inspect model predictions -print("\nSample predictions:") -sample_batch = next(iter(test_loader)) - -with torch.no_grad(): - output = model(**sample_batch) - -print(f" Batch size: {output['y_prob'].shape[0]}") -print(f" Number of drug classes: {output['y_prob'].shape[1]}") -print(" Predicted probabilities (first 5 drugs of first patient):") -print(f" {output['y_prob'][0, :5].cpu().numpy()}") -print(" True labels (first 5 drugs of first patient):") -print(f" {output['y_true'][0, :5].cpu().numpy()}") +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + cache_dir="/shared/eng/pyhealth_agent/baselines", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], + ) + + # STEP 2: Apply drug recommendation task + sample_dataset = base_dataset.set_task( + DrugRecommendationMIMIC4(), + num_workers=4, + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f" Visit ID: {sample['visit_id']}") + print(f" Conditions (history): {len(sample['conditions'])} visits") + print(f" Procedures (history): {len(sample['procedures'])} visits") + print(f" Drugs history: {len(sample['drugs_hist'])} visits") + print(f" Target drugs: {len(sample['drugs'])} drugs") + print(f"\n First visit conditions: {sample['conditions'][0][:5]}...") + print(f" Target drugs sample: {sample['drugs'][:5]}...") + + # STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print("\nDataset split:") + print(f" Train: {len(train_dataset)} samples") + print(f" Validation: {len(val_dataset)} samples") + print(f" Test: {len(test_dataset)} samples") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) + + # STEP 4: Initialize RETAIN model + model = RETAIN( + dataset=sample_dataset, + embedding_dim=128, + dropout=0.5, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params:,} parameters") + print(f"Feature keys: {model.feature_keys}") + print(f"Label key: {model.label_key}") + + # STEP 5: Train the model + trainer = Trainer( + model=model, + device="cuda:4", # or "cpu" + metrics=["pr_auc_samples", "f1_samples", "jaccard_samples"], + ) + + print("\nStarting training...") + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + monitor="pr_auc_samples", + optimizer_params={"lr": 1e-3}, + ) + + # STEP 6: Evaluate on test set + print("\nEvaluating on test set...") + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Inspect model predictions + print("\nSample predictions:") + sample_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = model(**sample_batch) + + print(f" Batch size: {output['y_prob'].shape[0]}") + print(f" Number of drug classes: {output['y_prob'].shape[1]}") + print(" Predicted probabilities (first 5 drugs of first patient):") + print(f" {output['y_prob'][0, :5].cpu().numpy()}") + print(" True labels (first 5 drugs of first patient):") + print(f" {output['y_true'][0, :5].cpu().numpy()}") diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_rnn.py b/examples/drug_recommendation/drug_recommendation_mimic4_rnn.py new file mode 100644 index 000000000..5f8f38e08 --- /dev/null +++ b/examples/drug_recommendation/drug_recommendation_mimic4_rnn.py @@ -0,0 +1,123 @@ +""" +Example of using RNN for drug recommendation on MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data +2. Applying the DrugRecommendationMIMIC4 task +3. Creating a SampleDataset with nested sequence processors +4. Training an RNN model +""" + +import torch + +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import RNN +from pyhealth.tasks import DrugRecommendationMIMIC4 +from pyhealth.trainer import Trainer + +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + cache_dir="/shared/eng/pyhealth_agent/baselines", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], + ) + + # STEP 2: Apply drug recommendation task + sample_dataset = base_dataset.set_task( + DrugRecommendationMIMIC4(), + num_workers=4, + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f" Visit ID: {sample['visit_id']}") + print(f" Conditions (history): {len(sample['conditions'])} visits") + print(f" Procedures (history): {len(sample['procedures'])} visits") + print(f" Drugs history: {len(sample['drugs_hist'])} visits") + print(f" Target drugs: {len(sample['drugs'])} drugs") + print(f"\n First visit conditions: {sample['conditions'][0][:5]}...") + print(f" Target drugs sample: {sample['drugs'][:5]}...") + + # STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print("\nDataset split:") + print(f" Train: {len(train_dataset)} samples") + print(f" Validation: {len(val_dataset)} samples") + print(f" Test: {len(test_dataset)} samples") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) + + # STEP 4: Initialize RNN model + model = RNN( + dataset=sample_dataset, + embedding_dim=128, + hidden_dim=128, + rnn_type="GRU", + dropout=0.5, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params:,} parameters") + print(f"Feature keys: {model.feature_keys}") + print(f"Label key: {model.label_key}") + + # STEP 5: Train the model + trainer = Trainer( + model=model, + device="cuda:0", # or "cpu" + metrics=["pr_auc_samples", "f1_samples", "jaccard_samples"], + ) + + print("\nStarting training...") + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + monitor="pr_auc_samples", + optimizer_params={"lr": 1e-4}, + optimizer_class=torch.optim.AdamW, + ) + + # STEP 6: Evaluate on test set + print("\nEvaluating on test set...") + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Inspect model predictions + print("\nSample predictions:") + sample_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = model(**sample_batch) + + print(f" Batch size: {output['y_prob'].shape[0]}") + print(f" Number of drug classes: {output['y_prob'].shape[1]}") + print(" Predicted probabilities (first 5 drugs of first patient):") + print(f" {output['y_prob'][0, :5].cpu().numpy()}") + print(" True labels (first 5 drugs of first patient):") + print(f" {output['y_true'][0, :5].cpu().numpy()}") diff --git a/examples/drug_recommendation/drug_recommendation_mimic4_rnn_optuna.py b/examples/drug_recommendation/drug_recommendation_mimic4_rnn_optuna.py new file mode 100644 index 000000000..4c3fb2a9f --- /dev/null +++ b/examples/drug_recommendation/drug_recommendation_mimic4_rnn_optuna.py @@ -0,0 +1,205 @@ +""" +Optuna hyperparameter tuning for RNN on drug recommendation with MIMIC-IV. + +This example demonstrates: +1. Loading MIMIC-IV data and applying the DrugRecommendationMIMIC4 task +2. Defining an Optuna objective that tunes RNN-specific hyperparameters +3. Running 10 Optuna trials to find the best configuration +4. Training a final model with the best hyperparameters + +Tuned hyperparameters: + - embedding_dim: embedding size for code tokens + - hidden_dim: GRU/LSTM/RNN hidden state size + - rnn_type: recurrent cell type (GRU, LSTM, RNN) + - num_layers: number of stacked recurrent layers + - dropout: dropout rate applied before each recurrent layer + - lr: learning rate for AdamW + - weight_decay: L2 regularization coefficient for AdamW +""" + +import torch +import optuna + +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import RNN +from pyhealth.tasks import DrugRecommendationMIMIC4 +from pyhealth.trainer import Trainer + +if __name__ == "__main__": + # --------------------------------------------------------------------------- + # STEP 1: Load MIMIC-IV base dataset + # --------------------------------------------------------------------------- + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + cache_dir="/shared/eng/pyhealth_agent/baselines", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], + ) + + # STEP 2: Apply drug recommendation task + sample_dataset = base_dataset.set_task( + DrugRecommendationMIMIC4(), + num_workers=4, + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # STEP 3: Split dataset (fixed split so all trials see the same data) + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print(f"\nDataset split — Train: {len(train_dataset)} " + f"Val: {len(val_dataset)} Test: {len(test_dataset)}") + + # --------------------------------------------------------------------------- + # STEP 4: Define Optuna objective + # --------------------------------------------------------------------------- + DEVICE = "cuda:2" # or "cpu" + TUNE_EPOCHS = 10 # lightweight training per trial + N_TRIALS = 10 + + def objective(trial: optuna.Trial) -> float: + """Return validation pr_auc_samples for a sampled RNN configuration.""" + + # --- Suggest hyperparameters ------------------------------------------- + embedding_dim = trial.suggest_categorical( + "embedding_dim", [64, 128, 256] + ) + hidden_dim = trial.suggest_categorical("hidden_dim", [64, 128, 256]) + rnn_type = trial.suggest_categorical( + "rnn_type", ["GRU", "LSTM", "RNN"] + ) + num_layers = trial.suggest_int("num_layers", 1, 3) + dropout = trial.suggest_float("dropout", 0.1, 0.7) + lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) + weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True) + batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) + + # --- Build dataloaders ------------------------------------------------- + train_loader = get_dataloader( + train_dataset, batch_size=batch_size, shuffle=True + ) + val_loader = get_dataloader( + val_dataset, batch_size=batch_size, shuffle=False + ) + + # --- Build model ------------------------------------------------------- + model = RNN( + dataset=sample_dataset, + embedding_dim=embedding_dim, + hidden_dim=hidden_dim, + rnn_type=rnn_type, + num_layers=num_layers, + dropout=dropout, + ) + + # --- Train ------------------------------------------------------------- + trainer = Trainer( + model=model, + device=DEVICE, + metrics=["pr_auc_samples"], + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=TUNE_EPOCHS, + monitor="pr_auc_samples", + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": lr}, + weight_decay=weight_decay, + ) + + # --- Evaluate on validation set ---------------------------------------- + scores = trainer.evaluate(val_loader) + return scores["pr_auc_samples"] + + # --------------------------------------------------------------------------- + # STEP 5: Run Optuna study + # --------------------------------------------------------------------------- + print( + f"\nStarting Optuna search ({N_TRIALS} trials, {TUNE_EPOCHS} epochs each)..." + ) + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=N_TRIALS) + + best_params = study.best_params + print("\nBest hyperparameters found:") + for k, v in best_params.items(): + print(f" {k}: {v}") + print(f"Best validation pr_auc_samples: {study.best_value:.4f}") + + # --------------------------------------------------------------------------- + # STEP 6: Train final model with best hyperparameters + # --------------------------------------------------------------------------- + print("\nTraining final model with best hyperparameters...") + + train_loader = get_dataloader( + train_dataset, batch_size=best_params["batch_size"], shuffle=True + ) + val_loader = get_dataloader( + val_dataset, batch_size=best_params["batch_size"], shuffle=False + ) + test_loader = get_dataloader( + test_dataset, batch_size=best_params["batch_size"], shuffle=False + ) + + final_model = RNN( + dataset=sample_dataset, + embedding_dim=best_params["embedding_dim"], + hidden_dim=best_params["hidden_dim"], + rnn_type=best_params["rnn_type"], + num_layers=best_params["num_layers"], + dropout=best_params["dropout"], + ) + + num_params = sum(p.numel() for p in final_model.parameters()) + print(f"Final model: {num_params:,} parameters") + + final_trainer = Trainer( + model=final_model, + device=DEVICE, + metrics=["pr_auc_samples", "f1_samples", "jaccard_samples"], + ) + final_trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + monitor="pr_auc_samples", + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": best_params["lr"]}, + weight_decay=best_params["weight_decay"], + ) + + # STEP 7: Evaluate on test set + print("\nEvaluating on test set...") + results = final_trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 8: Inspect model predictions + print("\nSample predictions:") + sample_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = final_model(**sample_batch) + + print(f" Batch size: {output['y_prob'].shape[0]}") + print(f" Number of drug classes: {output['y_prob'].shape[1]}") + print(" Predicted probabilities (first 5 drugs of first patient):") + print(f" {output['y_prob'][0, :5].cpu().numpy()}") + print(" True labels (first 5 drugs of first patient):") + print(f" {output['y_true'][0, :5].cpu().numpy()}") diff --git a/examples/mortality_prediction/mortality_mimic4_multimodal_adacare.py b/examples/mortality_prediction/mortality_mimic4_multimodal_adacare.py new file mode 100644 index 000000000..8e973412b --- /dev/null +++ b/examples/mortality_prediction/mortality_mimic4_multimodal_adacare.py @@ -0,0 +1,197 @@ +""" +Mortality Prediction on MIMIC-IV with MultimodalAdaCare + +This example demonstrates how to use the MultimodalAdaCare model with mixed +input modalities for in-hospital mortality prediction on MIMIC-IV. + +The MultimodalAdaCare model can handle: +- Sequential features (diagnoses, procedures) → AdaCare processing with scale-adaptive + feature extraction and recalibration +- Non-sequential features (demographics, static measurements) → Direct embedding + +This example shows: +1. Loading MIMIC-IV data with mixed feature types +2. Applying a mortality prediction task +3. Training a MultimodalAdaCare model with both sequential and non-sequential inputs +4. Evaluating the model performance +5. Analyzing feature importance (from AdaCareLayer for sequential features) +""" + +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.models import MultimodalAdaCare +from pyhealth.tasks import InHospitalMortalityMIMIC4 +from pyhealth.trainer import Trainer + + +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + print("=" * 60) + print("STEP 1: Loading MIMIC-IV Dataset") + print("=" * 60) + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=["diagnoses_icd", "procedures_icd"], + dev=True, # Use development mode for faster testing + num_workers=4, + ) + base_dataset.stats() + + # STEP 2: Apply mortality prediction task with multimodal features + print("\n" + "=" * 60) + print("STEP 2: Setting Mortality Prediction Task") + print("=" * 60) + + # Use the InHospitalMortalityMIMIC4 task + # This task will create sequential features from diagnoses and procedures + task = InHospitalMortalityMIMIC4() + sample_dataset = base_dataset.set_task( + task, + num_workers=4, + ) + + print(f"\nTotal samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + if len(sample_dataset) > 0: + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + for key in sample_dataset.input_schema.keys(): + if key in sample: + if isinstance(sample[key], (list, tuple)): + print(f" {key}: length {len(sample[key])}") + else: + print(f" {key}: {type(sample[key])}") + print(f" Mortality: {sample.get('mortality', 'N/A')}") + + # STEP 3: Split dataset + print("\n" + "=" * 60) + print("STEP 3: Splitting Dataset") + print("=" * 60) + + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print(f"Train samples: {len(train_dataset)}") + print(f"Val samples: {len(val_dataset)}") + print(f"Test samples: {len(test_dataset)}") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) + + # STEP 4: Initialize MultimodalAdaCare model + print("\n" + "=" * 60) + print("STEP 4: Initializing MultimodalAdaCare Model") + print("=" * 60) + + model = MultimodalAdaCare( + dataset=sample_dataset, + embedding_dim=128, + hidden_dim=128, + kernel_size=2, + kernel_num=64, + r_v=4, + r_c=4, + activation="sigmoid", + rnn_type="gru", + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"Model initialized with {num_params:,} parameters") + + # Print feature classification + print(f"\nSequential features (AdaCare processing): {model.sequential_features}") + print(f"Non-sequential features (direct embedding): {model.non_sequential_features}") + + # Calculate expected embedding dimensions + seq_dim = len(model.sequential_features) * model.hidden_dim + non_seq_dim = len(model.non_sequential_features) * model.embedding_dim + total_dim = seq_dim + non_seq_dim + print(f"\nPatient representation dimension:") + print(f" Sequential contribution: {seq_dim}") + print(f" Non-sequential contribution: {non_seq_dim}") + print(f" Total: {total_dim}") + + # STEP 5: Train the model + print("\n" + "=" * 60) + print("STEP 5: Training Model") + print("=" * 60) + + trainer = Trainer( + model=model, + device="cuda:0", # Change to "cpu" if no GPU available + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=10, + monitor="roc_auc", + optimizer_params={"lr": 1e-3}, + ) + + # STEP 6: Evaluate on test set + print("\n" + "=" * 60) + print("STEP 6: Evaluating on Test Set") + print("=" * 60) + + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Demonstrate model predictions and feature importance + print("\n" + "=" * 60) + print("STEP 7: Sample Predictions and Feature Importance") + print("=" * 60) + + import torch + + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print(f"\nBatch size: {output['y_prob'].shape[0]}") + print(f"First 10 predicted probabilities:") + for i, (prob, true_label) in enumerate( + zip(output['y_prob'][:10], output['y_true'][:10]) + ): + print(f" Sample {i+1}: prob={prob.item():.4f}, true={int(true_label.item())}") + + # Display feature importance information + print(f"\nFeature Importance outputs:") + print(f" Number of sequential features with importance: {len(output['feature_importance'])}") + print(f" Number of sequential features with conv importance: {len(output['conv_feature_importance'])}") + + if len(output['feature_importance']) > 0: + for i, feat_key in enumerate(model.sequential_features): + feat_imp = output['feature_importance'][i] + conv_imp = output['conv_feature_importance'][i] + print(f"\n Feature '{feat_key}':") + print(f" Input feature importance shape: {feat_imp.shape}") + print(f" Conv feature importance shape: {conv_imp.shape}") + + # Summary + print("\n" + "=" * 60) + print("SUMMARY: MultimodalAdaCare Training Complete") + print("=" * 60) + print(f"Model: MultimodalAdaCare") + print(f"Dataset: MIMIC-IV") + print(f"Task: In-Hospital Mortality Prediction") + print(f"Sequential features: {len(model.sequential_features)}") + print(f"Non-sequential features: {len(model.non_sequential_features)}") + print(f"Best validation ROC-AUC: {max(results.get('roc_auc', 0), 0):.4f}") + print("\nAdaCare provides interpretability through:") + print(" - Input feature importance (original features)") + print(" - Convolutional feature importance (scale-adaptive features)") + print("=" * 60) + diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 778a6f362..1614b6f7b 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -15,6 +15,7 @@ import multiprocessing.queues import shutil +from filelock import FileLock import litdata from litdata.streaming.item_loader import ParquetLoader from litdata.processing.data_processor import in_notebook @@ -27,7 +28,11 @@ import requests from tqdm import tqdm import dask.dataframe as dd -from dask.distributed import Client as DaskClient, LocalCluster as DaskCluster, progress as dask_progress +from dask.distributed import ( + Client as DaskClient, + LocalCluster as DaskCluster, + progress as dask_progress, +) import narwhals as nw import itertools import numpy as np @@ -46,6 +51,7 @@ # Remove LitData version check to avoid unnecessary warnings os.environ["LITDATA_DISABLE_VERSION_CHECK"] = "1" + def is_url(path: str) -> bool: """URL detection.""" result = urlparse(path) @@ -124,6 +130,7 @@ def _litdata_merge(cache_dir: Path) -> None: cache_dir (Path): The cache directory containing LitData binary writer files. """ from litdata.streaming.writer import _INDEX_FILENAME + files = os.listdir(cache_dir) # Return if the index already exists @@ -134,13 +141,19 @@ def _litdata_merge(cache_dir: Path) -> None: # Return if there are no index files to merge if len(index_files) == 0: - raise ValueError("There are zero samples in the dataset, please check the task and processors.") + raise ValueError( + "There are zero samples in the dataset, please check the task and processors." + ) - BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files)) + BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge( + num_workers=len(index_files) + ) class _ProgressContext: - def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs): + def __init__( + self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs + ): """ :param queue: An existing queue (e.g., from multiprocessing). If provided, this class acts as a passthrough. @@ -167,8 +180,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.progress: self.progress.close() + _task_transform_progress: multiprocessing.queues.Queue | None = None + def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: """ Initializer for worker processes to set up a global queue. @@ -179,7 +194,10 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: global _task_transform_progress _task_transform_progress = queue -def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: + +def _task_transform_fn( + args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path], +) -> None: """ Worker function to apply task transformation on a chunk of patients. @@ -191,14 +209,16 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P global_event_df (pl.LazyFrame): The global event dataframe. output_dir (Path): The output directory to save results. """ - BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. + BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. worker_id, task, patient_ids, global_event_df, output_dir = args total_patients = len(list(patient_ids)) - logger.info(f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})") + logger.info( + f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})" + ) with ( set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)), - _ProgressContext(_task_transform_progress, total=total_patients) as progress + _ProgressContext(_task_transform_progress, total=total_patients) as progress, ): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") @@ -208,8 +228,8 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P complete = 0 patients = ( global_event_df.filter(pl.col("patient_id").is_in(batch)) - .collect(engine="streaming") - .partition_by("patient_id", as_dict=True) + .collect(engine="streaming") + .partition_by("patient_id", as_dict=True) ) for patient_id, patient_df in patients.items(): patient_id = patient_id[0] # Extract string from single-element list @@ -223,8 +243,10 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P logger.info(f"Worker {worker_id} finished processing patients.") + _proc_transform_progress: multiprocessing.queues.Queue | None = None + def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: """ Initializer for worker processes to set up a global queue. @@ -235,6 +257,7 @@ def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: global _proc_transform_progress _proc_transform_progress = queue + def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: """ Worker function to apply processors on a chunk of samples. @@ -250,11 +273,13 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: BATCH_SIZE = 128 worker_id, task_df, start_idx, end_idx, output_dir = args total_samples = end_idx - start_idx - logger.info(f"Worker {worker_id} started processing {total_samples} samples. ({start_idx} to {end_idx})") + logger.info( + f"Worker {worker_id} started processing {total_samples} samples. ({start_idx} to {end_idx})" + ) with ( set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)), - _ProgressContext(_proc_transform_progress, total=total_samples) as progress + _ProgressContext(_proc_transform_progress, total=total_samples) as progress, ): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") @@ -400,9 +425,7 @@ def clean_tmpdir(self) -> None: if tmp_dir.exists(): shutil.rmtree(tmp_dir) - def _scan_csv_tsv_gz( - self, source_path: str - ) -> dd.DataFrame: + def _scan_csv_tsv_gz(self, source_path: str) -> dd.DataFrame: """Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame. If the cached Parquet file does not exist, it converts the source CSV/TSV file @@ -501,7 +524,9 @@ def _event_transform(self, output_dir: Path) -> None: handle.result() # type: ignore except Exception as e: if output_dir.exists(): - logger.error(f"Error during caching, removing incomplete file {output_dir}") + logger.error( + f"Error during caching, removing incomplete file {output_dir}" + ) shutil.rmtree(output_dir) raise e finally: @@ -515,14 +540,16 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ - self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore + self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" cache_valid = ret_path.is_dir() and any(ret_path.glob("*.parquet")) if not cache_valid: if ret_path.exists(): - logger.warning(f"Incomplete parquet cache at {ret_path} (directory exists but contains no parquet files). Removing and rebuilding.") + logger.warning( + f"Incomplete parquet cache at {ret_path} (directory exists but contains no parquet files). Removing and rebuilding." + ) shutil.rmtree(ret_path) logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) @@ -722,10 +749,14 @@ def default_task(self) -> Optional[BaseTask]: """ return None - def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> None: + def _task_transform( + self, task: BaseTask, output_dir: Path, num_workers: int + ) -> None: self._main_guard(self._task_transform.__name__) - logger.info(f"Applying task transformations on data with {num_workers} workers...") + logger.info( + f"Applying task transformations on data with {num_workers} workers..." + ) global_event_df = task.pre_filter(self.global_event_df) patient_ids = ( global_event_df.select("patient_id") @@ -737,34 +768,52 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> ) if in_notebook(): - logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + logger.info( + "Detected Jupyter notebook environment, setting num_workers to 1" + ) num_workers = 1 - num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers + num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers # This ensures worker's polars threads are limited to avoid oversubscription, # which can lead to additional 75% speedup when num_workers is large. threads_per_worker = max(1, (os.cpu_count() or 1) // num_workers) try: - with set_env(POLARS_MAX_THREADS=str(threads_per_worker), DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)): + with set_env( + POLARS_MAX_THREADS=str(threads_per_worker), + DATA_OPTIMIZER_NUM_WORKERS=str(num_workers), + ): if num_workers == 1: logger.info("Single worker mode, processing sequentially") - _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) + _task_transform_fn( + (0, task, patient_ids, global_event_df, output_dir) + ) _litdata_merge(output_dir) return # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary ctx = multiprocessing.get_context("spawn") queue = ctx.Queue() - args_list = [( - worker_id, - task, - pids, - global_event_df, - output_dir, - ) for worker_id, pids in enumerate(itertools.batched(patient_ids, len(patient_ids) // num_workers + 1))] - with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: - result = pool.map_async(_task_transform_fn, args_list) # type: ignore + args_list = [ + ( + worker_id, + task, + pids, + global_event_df, + output_dir, + ) + for worker_id, pids in enumerate( + itertools.batched( + patient_ids, len(patient_ids) // num_workers + 1 + ) + ) + ] + with ctx.Pool( + processes=num_workers, + initializer=_task_transform_init, + initargs=(queue,), + ) as pool: + result = pool.map_async(_task_transform_fn, args_list) # type: ignore with tqdm(total=len(patient_ids)) as progress: while not result.ready(): try: @@ -775,26 +824,32 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> # remaining items while not queue.empty(): progress.update(queue.get()) - result.get() # ensure exceptions are raised + result.get() # ensure exceptions are raised _litdata_merge(output_dir) logger.info(f"Task transformation completed and saved to {output_dir}") except Exception as e: - logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}") + logger.error( + f"Error during task transformation, cleaning up output directory: {output_dir}" + ) shutil.rmtree(output_dir) raise e - def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> None: + def _proc_transform( + self, task_df: Path, output_dir: Path, num_workers: int + ) -> None: self._main_guard(self._proc_transform.__name__) logger.info(f"Applying processors on data with {num_workers} workers...") num_samples = len(litdata.StreamingDataset(str(task_df))) if in_notebook(): - logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + logger.info( + "Detected Jupyter notebook environment, setting num_workers to 1" + ) num_workers = 1 - num_workers = min(num_workers, num_samples) # Avoid spawning empty workers + num_workers = min(num_workers, num_samples) # Avoid spawning empty workers try: with set_env(DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)): if num_workers == 1: @@ -805,16 +860,25 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> ctx = multiprocessing.get_context("spawn") queue = ctx.Queue() - linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) - args_list = [( - worker_id, - task_df, - start, - end, - output_dir, - ) for worker_id, (start, end) in enumerate(linspace)] - with ctx.Pool(processes=num_workers, initializer=_proc_transform_init, initargs=(queue,)) as pool: - result = pool.map_async(_proc_transform_fn, args_list) # type: ignore + linspace = more_itertools.sliding_window( + np.linspace(0, num_samples, num_workers + 1, dtype=int), 2 + ) + args_list = [ + ( + worker_id, + task_df, + start, + end, + output_dir, + ) + for worker_id, (start, end) in enumerate(linspace) + ] + with ctx.Pool( + processes=num_workers, + initializer=_proc_transform_init, + initargs=(queue,), + ) as pool: + result = pool.map_async(_proc_transform_fn, args_list) # type: ignore with tqdm(total=num_samples) as progress: while not result.ready(): try: @@ -825,10 +889,12 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> # remaining items while not queue.empty(): progress.update(queue.get()) - result.get() # ensure exceptions are raised + result.get() # ensure exceptions are raised _litdata_merge(output_dir) - logger.info(f"Processor transformation completed and saved to {output_dir}") + logger.info( + f"Processor transformation completed and saved to {output_dir}" + ) except Exception as e: logger.error(f"Error during processor transformation.") shutil.rmtree(output_dir) @@ -888,10 +954,14 @@ def set_task( "output_schema": task.output_schema, }, sort_keys=True, - default=str + default=str, ) - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + cache_dir = ( + self.cache_dir + / "tasks" + / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + ) cache_dir.mkdir(parents=True, exist_ok=True) proc_params = json.dumps( @@ -914,54 +984,93 @@ def set_task( ), }, sort_keys=True, - default=str + default=str, ) task_df_path = Path(cache_dir) / "task_df.ld" - samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" + samples_path = ( + Path(cache_dir) + / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" + ) logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}") task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) - if not (samples_path / "index.json").exists(): - # Check if index.json exists to verify cache integrity, this - # is the standard file for litdata.StreamingDataset - if not (task_df_path / "index.json").exists(): - self._task_transform( - task, - task_df_path, - num_workers, - ) - else: - logger.info(f"Found cached task dataframe at {task_df_path}, skipping task transformation.") + def _is_valid_litdata_cache(path: Path) -> bool: + """Return True only if index.json exists and at least one .bin chunk file + is present alongside it. Checking for .bin files is O(1) and avoids + reading/parsing index.json which can be large.""" + return (path / "index.json").exists() and any(path.glob("*.bin")) - # Build processors and fit on the dataset - logger.info(f"Fitting processors on the dataset...") - dataset = litdata.StreamingDataset( - str(task_df_path), - transform=lambda x: pickle.loads(x["sample"]), - ) - builder = SampleBuilder( - input_schema=task.input_schema, # type: ignore - output_schema=task.output_schema, # type: ignore - input_processors=input_processors, - output_processors=output_processors, - ) - builder.fit(dataset) - builder.save(str(samples_path / "schema.pkl")) - - # Apply processors and save final samples to cache_dir - logger.info(f"Processing samples and saving to {samples_path}...") - self._proc_transform( - task_df_path, - samples_path, - num_workers, + def _invalidate_cache(path: Path, label: str) -> None: + """Remove a corrupt or empty litdata cache directory and recreate it.""" + logger.warning( + f"Corrupt or empty {label} cache at {path}; removing and rebuilding." ) - logger.info(f"Cached processed samples to {samples_path}") + shutil.rmtree(path, ignore_errors=True) + path.mkdir(parents=True, exist_ok=True) + + # Fast path: cache already valid, no lock needed (reads are always safe). + # Slow path: acquire a per-cache-dir file lock so that concurrent processes + # (e.g. parallel hparam jobs) don't race to build the same litdata cache. + # The double-checked pattern inside the lock means the winner builds it + # once; all others wait, re-check, and skip. + if not _is_valid_litdata_cache(samples_path): + lock_path = Path(cache_dir) / "build.lock" + with FileLock(str(lock_path), timeout=7200): + # Re-check inside the lock — another process may have built it + # while we were waiting. + if _is_valid_litdata_cache(samples_path): + logger.info( + f"Found cached processed samples at {samples_path} (built by another process)." + ) + else: + if (samples_path / "index.json").exists(): + _invalidate_cache(samples_path, "samples") + + # Check if index.json exists and is non-empty to verify cache integrity + if not _is_valid_litdata_cache(task_df_path): + if (task_df_path / "index.json").exists(): + _invalidate_cache(task_df_path, "task_df") + self._task_transform( + task, + task_df_path, + num_workers, + ) + else: + logger.info( + f"Found cached task dataframe at {task_df_path}, skipping task transformation." + ) + + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") + dataset = litdata.StreamingDataset( + str(task_df_path), + transform=lambda x: pickle.loads(x["sample"]), + ) + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(dataset) + builder.save(str(samples_path / "schema.pkl")) + + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {samples_path}...") + self._proc_transform( + task_df_path, + samples_path, + num_workers, + ) + logger.info(f"Cached processed samples to {samples_path}") else: - logger.info(f"Found cached processed samples at {samples_path}, skipping processing.") + logger.info( + f"Found cached processed samples at {samples_path}, skipping processing." + ) return SampleDataset( path=str(samples_path), diff --git a/pyhealth/datasets/configs/mimic4_ehr.yaml b/pyhealth/datasets/configs/mimic4_ehr.yaml index 3dfb0e5e5..84c570bb9 100644 --- a/pyhealth/datasets/configs/mimic4_ehr.yaml +++ b/pyhealth/datasets/configs/mimic4_ehr.yaml @@ -113,6 +113,7 @@ tables: patient_id: "subject_id" timestamp: "chartdate" attributes: + - "hadm_id" - "hcpcs_cd" - "seq_num" - "short_description" diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..147fd8f4b 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,4 +1,4 @@ -from .adacare import AdaCare, AdaCareLayer +from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare from .agent import Agent, AgentLayer from .base_model import BaseModel from .biot import BIOT @@ -18,7 +18,7 @@ from .micron import MICRON, MICRONLayer from .mlp import MLP from .molerec import MoleRec, MoleRecLayer -from .retain import RETAIN, RETAINLayer +from .retain import MultimodalRETAIN, RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer from .safedrug import SafeDrug, SafeDrugLayer from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py index 5e4ee806b..ab8602b77 100644 --- a/pyhealth/models/adacare.py +++ b/pyhealth/models/adacare.py @@ -5,10 +5,14 @@ from pyhealth.datasets import SampleDataset from pyhealth.processors import ( - SequenceProcessor, - NestedSequenceProcessor, - NestedFloatsProcessor, DeepNestedFloatsProcessor, + DeepNestedSequenceProcessor, + MultiHotProcessor, + NestedFloatsProcessor, + NestedSequenceProcessor, + SequenceProcessor, + TensorProcessor, + TimeseriesProcessor, ) from .base_model import BaseModel from .embedding import EmbeddingModel @@ -362,10 +366,7 @@ def __init__( hidden_dim: int = 128, **kwargs, ): - super().__init__( - dataset=dataset, - **kwargs, - ) + super().__init__(dataset=dataset) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim @@ -374,7 +375,7 @@ def __init__( raise ValueError("input_dim is automatically determined") assert len(self.label_keys) == 1, "Only one label key is supported" - + # Use EmbeddingModel for unified embedding handling self.embedding_model = EmbeddingModel(dataset, embedding_dim) # AdaCare layers for each feature @@ -390,11 +391,12 @@ def __init__( NestedSequenceProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, + TimeseriesProcessor, ), ): raise ValueError( """AdaCare only supports SequenceProcessor, NestedSequenceProcessor, - NestedFloatsProcessor, DeepNestedFloatsProcessor.""" + NestedFloatsProcessor, DeepNestedFloatsProcessor, TimeseriesProcessor.""" ) self.adacare[feature_key] = AdaCareLayer( @@ -429,19 +431,22 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: embedded, masks = self.embedding_model(kwargs, output_mask=True) feature_importance = [] conv_feature_importance = [] - + for _, feature_key in enumerate(self.feature_keys): embeds = embedded[feature_key] mask = masks[feature_key] processor = self.dataset.input_processors[feature_key] - + if embeds.dim() == 3: - if isinstance(processor, NestedFloatsProcessor): + if isinstance(processor, (NestedFloatsProcessor, TimeseriesProcessor)): + # Both produce [batch, seq_len, num_features] masks — reduce to [batch, seq_len] mask = torch.any(mask, dim=2) elif isinstance(processor, SequenceProcessor): - pass + pass # mask already [batch, seq_len] else: - raise ValueError(f"Expected NestedFloatsProcessor or SequenceProcessor for 3D input, got {type(processor)}") + raise ValueError( + f"Expected NestedFloatsProcessor, TimeseriesProcessor, or SequenceProcessor for 3D input, got {type(processor)}" + ) elif embeds.dim() == 4: if isinstance(processor, NestedSequenceProcessor): embeds = torch.sum(embeds, dim=2) @@ -450,10 +455,14 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: embeds = torch.sum(embeds, dim=2) mask = torch.any(mask, dim=(2, 3)) else: - raise ValueError(f"Expected NestedSequenceProcessor or DeepNestedFloatsProcessor for 4D input, got {type(processor)}") + raise ValueError( + f"Expected NestedSequenceProcessor or DeepNestedFloatsProcessor for 4D input, got {type(processor)}" + ) else: - raise NotImplementedError(f"Unsupported input dimension {feature_key}: {embeds.dim()} for AdaCare") - + raise NotImplementedError( + f"Unsupported input dimension {feature_key}: {embeds.dim()} for AdaCare" + ) + embeds, _, inputatt, convatt = self.adacare[feature_key](embeds, mask) feature_importance.append(inputatt) conv_feature_importance.append(convatt) @@ -476,4 +485,259 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: } if kwargs.get("embed", False): results["embed"] = patient_emb - return results \ No newline at end of file + return results + + +class MultimodalAdaCare(BaseModel): + """Multimodal AdaCare model for mixed sequential and non-sequential features. + + This model extends AdaCare to support mixed input modalities: + - Sequential features (sequences, timeseries) go through AdaCareLayer + - Non-sequential features (multi-hot, tensor) bypass AdaCareLayer, use embeddings + + The model automatically classifies input features based on their processor types: + - Sequential processors (apply AdaCareLayer): SequenceProcessor, + NestedSequenceProcessor, DeepNestedSequenceProcessor, NestedFloatsProcessor, + DeepNestedFloatsProcessor, TimeseriesProcessor + - Non-sequential processors (embeddings only): MultiHotProcessor, TensorProcessor + + For sequential features, the model: + 1. Embeds the input using EmbeddingModel + 2. Applies AdaCareLayer with scale-adaptive feature extraction and recalibration + 3. Extracts the patient representation + + For non-sequential features, the model: + 1. Embeds the input using EmbeddingModel + 2. Applies mean pooling if needed to reduce to 2D + 3. Uses the embedding directly + + All feature representations are concatenated and passed through a final + fully connected layer for predictions. Feature importance outputs from + AdaCareLayer are preserved for sequential features. + + Args: + dataset (SampleDataset): the dataset to train the model. It is used to query + certain information such as the set of all tokens and processor types. + embedding_dim (int): the embedding dimension. Default is 128. + hidden_dim (int): the hidden dimension for AdaCare layers. Default is 128. + **kwargs: other parameters for the AdaCareLayer (e.g., kernel_size, kernel_num, + r_v, r_c, activation, rnn_type, dropout). + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": ["cond-33", "cond-86"], # sequential + ... "demographics": ["asian", "male"], # multi-hot + ... "vitals": [120.0, 80.0, 98.6], # tensor + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-1", + ... "conditions": ["cond-12", "cond-52"], # sequential + ... "demographics": ["white", "female"], # multi-hot + ... "vitals": [110.0, 75.0, 98.2], # tensor + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={ + ... "conditions": "sequence", + ... "demographics": "multi_hot", + ... "vitals": "tensor", + ... }, + ... output_schema={"label": "binary"}, + ... dataset_name="test" + ... ) + >>> + >>> from pyhealth.datasets import get_dataloader + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> + >>> model = MultimodalAdaCare(dataset=dataset, hidden_dim=64) + >>> + >>> data_batch = next(iter(train_loader)) + >>> + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...), + 'feature_importance': [...], + 'conv_feature_importance': [...] + } + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + hidden_dim: int = 128, + **kwargs, + ): + super(MultimodalAdaCare, self).__init__(dataset=dataset) + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + + # validate kwargs + if "input_dim" in kwargs: + raise ValueError("input_dim is determined by embedding_dim") + + assert len(self.label_keys) == 1, "Only one label key is supported" + + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + + # Classify features as sequential or non-sequential + self.sequential_features = [] + self.non_sequential_features = [] + + self.adacare = nn.ModuleDict() + for feature_key in self.feature_keys: + processor = dataset.input_processors[feature_key] + if self._is_sequential_processor(processor): + self.sequential_features.append(feature_key) + self.adacare[feature_key] = AdaCareLayer( + input_dim=embedding_dim, hidden_dim=hidden_dim, **kwargs + ) + else: + self.non_sequential_features.append(feature_key) + + # Calculate final concatenated dimension + final_dim = ( + len(self.sequential_features) * hidden_dim + + len(self.non_sequential_features) * embedding_dim + ) + output_size = self.get_output_size() + self.fc = nn.Linear(final_dim, output_size) + + def _is_sequential_processor(self, processor) -> bool: + """Check if processor represents sequential data. + + Sequential processors are those that benefit from AdaCare processing, + including sequences of codes and timeseries data. + + Note: + StageNetProcessor and StageNetTensorProcessor are excluded as they + are specialized for the StageNet model architecture and should be + treated as non-sequential for standard AdaCare processing. + + Args: + processor: The processor instance to check. + + Returns: + bool: True if processor is sequential, False otherwise. + """ + return isinstance( + processor, + ( + SequenceProcessor, + NestedSequenceProcessor, + DeepNestedSequenceProcessor, + NestedFloatsProcessor, + DeepNestedFloatsProcessor, + TimeseriesProcessor, + ), + ) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation handling mixed modalities. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + Dict[str, torch.Tensor]: A dictionary with the following keys: + - loss: a scalar tensor representing the loss. + - y_prob: a tensor representing the predicted probabilities. + - y_true: a tensor representing the true labels. + - logit: a tensor representing the logits. + - feature_importance: list of tensors representing input + feature importance for sequential features. + - conv_feature_importance: list of tensors representing + convolutional feature importance for sequential features. + - embed (optional): a tensor representing the patient + embeddings if requested. + """ + patient_emb = [] + embedded, masks = self.embedding_model(kwargs, output_mask=True) + feature_importance = [] + conv_feature_importance = [] + + # Process sequential features through AdaCare + for feature_key in self.sequential_features: + embeds = embedded[feature_key] + mask = masks[feature_key] + processor = self.dataset.input_processors[feature_key] + + # Handle different dimensions + if embeds.dim() == 3: + if isinstance(processor, (NestedFloatsProcessor, TimeseriesProcessor)): + # Both produce [batch, seq_len, num_features] masks — reduce to [batch, seq_len] + mask = torch.any(mask, dim=2) + elif isinstance(processor, SequenceProcessor): + pass # mask already [batch, seq_len] + else: + raise ValueError( + f"Expected NestedFloatsProcessor, TimeseriesProcessor, or " + f"SequenceProcessor for 3D input, got {type(processor)}" + ) + elif embeds.dim() == 4: + if isinstance(processor, NestedSequenceProcessor): + embeds = torch.sum(embeds, dim=2) + mask = torch.any(mask, dim=2) + elif isinstance(processor, DeepNestedFloatsProcessor): + embeds = torch.sum(embeds, dim=2) + mask = torch.any(mask, dim=(2, 3)) + else: + raise ValueError( + f"Expected NestedSequenceProcessor or " + f"DeepNestedFloatsProcessor for 4D input, " + f"got {type(processor)}" + ) + else: + raise NotImplementedError( + f"Unsupported input dimension {feature_key}: " + f"{embeds.dim()} for AdaCare" + ) + + # Apply AdaCare layer + embeds, _, inputatt, convatt = self.adacare[feature_key](embeds, mask) + feature_importance.append(inputatt) + conv_feature_importance.append(convatt) + patient_emb.append(embeds) + + # Process non-sequential features (use embeddings directly) + for feature_key in self.non_sequential_features: + x = embedded[feature_key] + # If multi-dimensional, aggregate (mean pooling) + while x.dim() > 2: + x = x.mean(dim=1) + patient_emb.append(x) + + # Concatenate all representations + patient_emb = torch.cat(patient_emb, dim=1) + # (patient, label_size) + logits = self.fc(patient_emb) + + # Calculate loss and predictions + y_true = kwargs[self.label_keys[0]].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + "feature_importance": feature_importance, + "conv_feature_importance": conv_feature_importance, + } + if kwargs.get("embed", False): + results["embed"] = patient_emb + return results diff --git a/pyhealth/models/embedding.py b/pyhealth/models/embedding.py index 89b3190d1..83a3a78c0 100644 --- a/pyhealth/models/embedding.py +++ b/pyhealth/models/embedding.py @@ -109,6 +109,7 @@ def init_embedding_with_pretrained( return loaded + class EmbeddingModel(BaseModel): """ EmbeddingModel is responsible for creating embedding layers for different types of input data. @@ -159,14 +160,16 @@ def __init__( SequenceProcessor, StageNetProcessor, NestedSequenceProcessor, - DeepNestedSequenceProcessor + DeepNestedSequenceProcessor, ), ): vocab_size = len(processor.code_vocab) # For NestedSequenceProcessor and DeepNestedSequenceProcessor, don't use padding_idx # because empty visits/groups need non-zero embeddings. - if isinstance(processor, (NestedSequenceProcessor, DeepNestedSequenceProcessor)): + if isinstance( + processor, (NestedSequenceProcessor, DeepNestedSequenceProcessor) + ): self.embedding_layers[field_name] = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, @@ -231,22 +234,29 @@ def __init__( self.embedding_layers[field_name] = nn.Linear( in_features=num_categories, out_features=embedding_dim ) - + # Smart Processor (Token-based) -> Transformers elif hasattr(processor, "is_token") and processor.is_token(): try: from transformers import AutoModel except ImportError: - raise ImportError("Please install `transformers` to use token-based processors.") - + raise ImportError( + "Please install `transformers` to use token-based processors." + ) + # Load the model - self.embedding_layers[field_name] = AutoModel.from_pretrained(processor.tokenizer_model) - + self.embedding_layers[field_name] = AutoModel.from_pretrained( + processor.tokenizer_model + ) + # Check if we need projection - if self.embedding_layers[field_name].config.hidden_size != self.embedding_dim: + if ( + self.embedding_layers[field_name].config.hidden_size + != self.embedding_dim + ): self.embedding_layers[f"{field_name}_proj"] = nn.Linear( - self.embedding_layers[field_name].config.hidden_size, - self.embedding_dim + self.embedding_layers[field_name].config.hidden_size, + self.embedding_dim, ) else: @@ -255,91 +265,95 @@ def __init__( field_name, ) - def forward(self, - inputs: Dict[str, torch.Tensor], - masks: Dict[str, torch.Tensor] = None, - output_mask: bool = False - ) -> Dict[str, torch.Tensor] | tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: - + def forward( + self, + inputs: Dict[str, torch.Tensor], + masks: Dict[str, torch.Tensor] = None, + output_mask: bool = False, + ) -> ( + Dict[str, torch.Tensor] + | tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] + ): + embedded: Dict[str, torch.Tensor] = {} out_masks: Dict[str, torch.Tensor] = {} if output_mask else None - + for field_name, tensor in inputs.items(): processor = self.dataset.input_processors.get(field_name, None) - + if field_name not in self.embedding_layers: # No embedding layer -> passthrough embedded[field_name] = tensor continue - + # Check if it's a transformer model layer = self.embedding_layers[field_name] - + # Check for transformers.PreTrainedModel (but without importing if possible, use class name check) # or check if it has 'config' attribute - if hasattr(layer, "config") and hasattr(layer, "forward"): + if hasattr(layer, "config") and hasattr(layer, "forward"): # It's likely a transformer - tensor = tensor.to(self.device).long() # Ensure LongTensor for IDs - + tensor = tensor.to(self.device).long() # Ensure LongTensor for IDs + mask = None if masks is not None and field_name in masks: mask = masks[field_name].to(self.device) - + # Handle 3D input (Batch, Num_Notes, Seq_Len) - is_3d = (inputs[field_name].dim() == 3) - + is_3d = inputs[field_name].dim() == 3 + if is_3d: - b, n, l = inputs[field_name].shape - tensor = tensor.view(b * n, l) - if mask is not None: - mask = mask.view(b * n, l) - + b, n, l = inputs[field_name].shape + tensor = tensor.view(b * n, l) + if mask is not None: + mask = mask.view(b * n, l) + # Forward pass through transformer output = layer(input_ids=tensor, attention_mask=mask) - x = output.last_hidden_state # (Batch, Seq, Hidden) - + x = output.last_hidden_state # (Batch, Seq, Hidden) + if is_3d: - # If we had 3D input, we MUST pool the sequence dim (L) to get one vector per note - # Resulting shape: (B, N, H) - - # Pool L dim -> (B*N, H) using CLS token (index 0) - x = x[:, 0, :] - - # Check projections - if f"{field_name}_proj" in self.embedding_layers: + # If we had 3D input, we MUST pool the sequence dim (L) to get one vector per note + # Resulting shape: (B, N, H) + + # Pool L dim -> (B*N, H) using CLS token (index 0) + x = x[:, 0, :] + + # Check projections + if f"{field_name}_proj" in self.embedding_layers: x = self.embedding_layers[f"{field_name}_proj"](x) - - x = x.view(b, n, -1) - - else: - # 2D input (Batch, Seq) -> (Batch, Seq, Hidden) - # No pooling, treating as sequence of tokens (word embeddings) - if f"{field_name}_proj" in self.embedding_layers: + + x = x.view(b, n, -1) + + else: + # 2D input (Batch, Seq) -> (Batch, Seq, Hidden) + # No pooling, treating as sequence of tokens (word embeddings) + if f"{field_name}_proj" in self.embedding_layers: x = self.embedding_layers[f"{field_name}_proj"](x) - + embedded[field_name] = x - - else: + + else: # Standard layers tensor = tensor.to(self.device) embedded[field_name] = layer(tensor) - + if output_mask: # Generate a mask for this field # For transformers, we might already have a mask, or use pad token if masks is not None and field_name in masks: - out_masks[field_name] = masks[field_name].to(self.device) + out_masks[field_name] = masks[field_name].to(self.device) elif hasattr(processor, "code_vocab"): pad_idx = processor.code_vocab.get("", 0) - out_masks[field_name] = (tensor != pad_idx) + out_masks[field_name] = tensor != pad_idx else: - # Default mask generation (e.g. for simple linear layers where 0 might be padding?) - # Be careful changing this behavior. + # Default mask generation (e.g. for simple linear layers where 0 might be padding?) + # Be careful changing this behavior. # Previous code: # masks[field_name] = (tensor != pad_idx) -> where pad_idx was 0 default pad_idx = 0 - out_masks[field_name] = (tensor != pad_idx) - + out_masks[field_name] = tensor != pad_idx + if output_mask: return embedded, out_masks else: diff --git a/pyhealth/models/retain.py b/pyhealth/models/retain.py index b272995ec..e871c88d0 100644 --- a/pyhealth/models/retain.py +++ b/pyhealth/models/retain.py @@ -6,6 +6,16 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel +from pyhealth.processors import ( + DeepNestedFloatsProcessor, + DeepNestedSequenceProcessor, + MultiHotProcessor, + NestedFloatsProcessor, + NestedSequenceProcessor, + SequenceProcessor, + TensorProcessor, + TimeseriesProcessor, +) from .embedding import EmbeddingModel @@ -56,23 +66,27 @@ def reverse_x(input, lengths): reversed_input[i, :length] = input[i, :length].flip(dims=[0]) return reversed_input - def compute_alpha(self, rx, lengths): + def compute_alpha(self, rx, lengths, total_length: int): """Computes alpha attention.""" rx = rnn_utils.pack_padded_sequence( rx, lengths, batch_first=True, enforce_sorted=False ) g, _ = self.alpha_gru(rx) - g, _ = rnn_utils.pad_packed_sequence(g, batch_first=True) + g, _ = rnn_utils.pad_packed_sequence( + g, batch_first=True, total_length=total_length + ) attn_alpha = torch.softmax(self.alpha_li(g), dim=1) return attn_alpha - def compute_beta(self, rx, lengths): + def compute_beta(self, rx, lengths, total_length: int): """Computes beta attention.""" rx = rnn_utils.pack_padded_sequence( rx, lengths, batch_first=True, enforce_sorted=False ) h, _ = self.beta_gru(rx) - h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True) + h, _ = rnn_utils.pad_packed_sequence( + h, batch_first=True, total_length=total_length + ) attn_beta = torch.tanh(self.beta_li(h)) return attn_beta @@ -95,15 +109,17 @@ def forward( # rnn will only apply dropout between layers x = self.dropout_layer(x) batch_size = x.size(0) + total_length = x.size(1) # capture before packing so pad_packed restores it if mask is None: lengths = torch.full( - size=(batch_size,), fill_value=x.size(1), dtype=torch.int64 + size=(batch_size,), fill_value=total_length, dtype=torch.int64 ) else: lengths = torch.sum(mask.int(), dim=-1).cpu() + lengths = lengths.clamp(min=1) # prevent zero-length crash in GRU rx = self.reverse_x(x, lengths) - attn_alpha = self.compute_alpha(rx, lengths) - attn_beta = self.compute_beta(rx, lengths) + attn_alpha = self.compute_alpha(rx, lengths, total_length) + attn_beta = self.compute_beta(rx, lengths, total_length) c = attn_alpha * attn_beta * x # (patient, sequence len, feature_size) c = torch.sum(c, dim=1) # (patient, feature_size) return c @@ -201,15 +217,11 @@ def __init__( # Create RETAIN layers for each feature self.retain = nn.ModuleDict() for feature_key in self.feature_keys: - self.retain[feature_key] = RETAINLayer( - feature_size=embedding_dim, **kwargs - ) + self.retain[feature_key] = RETAINLayer(feature_size=embedding_dim, **kwargs) output_size = self.get_output_size() num_features = len(self.feature_keys) - self.fc = nn.Linear( - num_features * self.embedding_dim, output_size - ) + self.fc = nn.Linear(num_features * self.embedding_dim, output_size) def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. @@ -228,30 +240,29 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """ patient_emb = [] embedded = self.embedding_model(kwargs) - + for feature_key in self.feature_keys: x = embedded[feature_key] - + # Handle different input dimensions # Case 1: 4D tensor from NestedSequenceProcessor # (batch, visits, events, embedding_dim) # Need to sum across events to get (batch, visits, embedding_dim) if len(x.shape) == 4: x = torch.sum(x, dim=2) # Sum across events within visit - + # Case 2: 3D tensor from SequenceProcessor or after summing # (batch, seq_len, embedding_dim) - already correct format elif len(x.shape) == 3: pass # Already correct format - + # Case 3: 2D tensor - shouldn't happen for RETAIN but handle it elif len(x.shape) == 2: x = x.unsqueeze(1) # Add seq dim: (batch, 1, embedding_dim) - + else: raise ValueError( - f"Unexpected tensor shape {x.shape} for feature " - f"{feature_key}" + f"Unexpected tensor shape {x.shape} for feature " f"{feature_key}" ) # Create mask: non-padding entries are valid @@ -331,3 +342,233 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # try loss backward ret["loss"].backward() + + +class MultimodalRETAIN(BaseModel): + """Multimodal RETAIN model for mixed sequential and non-sequential features. + + This model extends RETAIN to support mixed input modalities: + - Sequential features (sequences, timeseries) go through RETAINLayer + - Non-sequential features (multi-hot, tensor) bypass RETAIN, use embeddings directly + + The model automatically classifies input features based on their processor types: + - Sequential processors (apply RETAINLayer): SequenceProcessor, + NestedSequenceProcessor, DeepNestedSequenceProcessor, NestedFloatsProcessor, + DeepNestedFloatsProcessor, TimeseriesProcessor + - Non-sequential processors (embeddings only): MultiHotProcessor, TensorProcessor + + For sequential features, the model: + 1. Embeds the input using EmbeddingModel + 2. Applies RETAINLayer with reverse time attention mechanism + 3. Extracts the patient representation + + For non-sequential features, the model: + 1. Embeds the input using EmbeddingModel + 2. Applies mean pooling if needed to reduce to 2D + 3. Uses the embedding directly + + All feature representations are concatenated and passed through a final + fully connected layer for predictions. + + Args: + dataset (SampleDataset): the dataset to train the model. It is used to query + certain information such as the set of all tokens and processor types. + embedding_dim (int): the embedding dimension. Default is 128. + **kwargs: other parameters for the RETAIN layer (e.g., dropout). + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": [["A", "B"], ["C"]], # nested sequence + ... "demographics": ["asian", "male"], # multi-hot + ... "vitals": [110.0, 75.0, 98.2], # tensor + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-1", + ... "conditions": [["D"], ["E", "F"]], # nested sequence + ... "demographics": ["white", "female"], # multi-hot + ... "vitals": [120.0, 80.0, 98.6], # tensor + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={ + ... "conditions": "nested_sequence", + ... "demographics": "multi_hot", + ... "vitals": "tensor", + ... }, + ... output_schema={"label": "binary"}, + ... dataset_name="test" + ... ) + >>> + >>> from pyhealth.datasets import get_dataloader + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> + >>> model = MultimodalRETAIN(dataset=dataset) + >>> + >>> data_batch = next(iter(train_loader)) + >>> + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...) + } + """ + + def __init__(self, dataset: SampleDataset, embedding_dim: int = 128, **kwargs): + super(MultimodalRETAIN, self).__init__(dataset=dataset) + self.embedding_dim = embedding_dim + + # validate kwargs for RETAIN layer + if "feature_size" in kwargs: + raise ValueError("feature_size is determined by embedding_dim") + + assert len(self.label_keys) == 1, "Only one label key is supported" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + + # Classify features as sequential or non-sequential + self.sequential_features = [] + self.non_sequential_features = [] + + self.retain = nn.ModuleDict() + for feature_key in self.feature_keys: + processor = dataset.input_processors[feature_key] + if self._is_sequential_processor(processor): + self.sequential_features.append(feature_key) + # Create RETAIN layer for this feature + self.retain[feature_key] = RETAINLayer( + feature_size=embedding_dim, **kwargs + ) + else: + self.non_sequential_features.append(feature_key) + + # Calculate final concatenated dimension + final_dim = ( + len(self.sequential_features) * embedding_dim + + len(self.non_sequential_features) * embedding_dim + ) + output_size = self.get_output_size() + self.fc = nn.Linear(final_dim, output_size) + + def _is_sequential_processor(self, processor) -> bool: + """Check if processor represents sequential data. + + Sequential processors are those that benefit from RETAIN processing, + including sequences of codes and timeseries data. + + Note: + StageNetProcessor and StageNetTensorProcessor are excluded as they + are specialized for the StageNet model architecture and should be + treated as non-sequential for standard RETAIN processing. + + Args: + processor: The processor instance to check. + + Returns: + bool: True if processor is sequential, False otherwise. + """ + return isinstance( + processor, + ( + SequenceProcessor, + NestedSequenceProcessor, + DeepNestedSequenceProcessor, + NestedFloatsProcessor, + DeepNestedFloatsProcessor, + TimeseriesProcessor, + ), + ) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation handling mixed modalities. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + Dict[str, torch.Tensor]: A dictionary with the following keys: + - loss: a scalar tensor representing the loss. + - y_prob: a tensor representing the predicted probabilities. + - y_true: a tensor representing the true labels. + - logit: a tensor representing the logits. + - embed (optional): a tensor representing the patient + embeddings if requested. + """ + patient_emb = [] + embedded, emb_masks = self.embedding_model(kwargs, output_mask=True) + + # Process sequential features through RETAIN + for feature_key in self.sequential_features: + x = embedded[feature_key] + + # Handle different input dimensions + # Case 1: 4D tensor from NestedSequenceProcessor + # (batch, visits, events, embedding_dim) + # Need to sum across events to get (batch, visits, embedding_dim) + if x.dim() == 4: + x = torch.sum(x, dim=2) # Sum across events within visit + + # Case 2: 3D tensor from SequenceProcessor or after summing + # (batch, seq_len, embedding_dim) - already correct format + elif x.dim() == 3: + pass # Already correct format + + # Case 3: 2D tensor - shouldn't happen for RETAIN but handle it + elif x.dim() == 2: + x = x.unsqueeze(1) # Add seq dim: (batch, 1, embedding_dim) + + else: + raise ValueError( + f"Unexpected tensor shape {x.shape} for feature {feature_key}" + ) + + # Use mask from EmbeddingModel (derived from original unembedded tensor) + mask = emb_masks.get(feature_key) + if mask is not None: + # Ensure 2D (batch, seq_len) — reduce any extra dims + while mask.dim() > 2: + mask = mask.any(dim=-1) + mask = mask.float() + x = self.retain[feature_key](x, mask) + patient_emb.append(x) + + # Process non-sequential features (use embeddings directly) + for feature_key in self.non_sequential_features: + x = embedded[feature_key] + # If multi-dimensional, aggregate (mean pooling) + while x.dim() > 2: + x = x.mean(dim=1) + patient_emb.append(x) + + # Concatenate all representations + patient_emb = torch.cat(patient_emb, dim=1) + # (patient, label_size) + logits = self.fc(patient_emb) + + # Calculate loss and predictions + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + if kwargs.get("embed", False): + results["embed"] = patient_emb + return results diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index bc01d166c..b87d5ce09 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -109,11 +109,16 @@ def forward( ) else: lengths = torch.sum(mask.int(), dim=-1).cpu() + # Ensure tensor is contiguous for cuDNN compatibility + x = x.contiguous() x = rnn_utils.pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) outputs, _ = self.rnn(x) outputs, _ = rnn_utils.pad_packed_sequence(outputs, batch_first=True) + # Ensure outputs are contiguous after unpacking + outputs = outputs.contiguous() + if not self.bidirectional: last_outputs = outputs[torch.arange(batch_size), (lengths - 1), :] return outputs, last_outputs @@ -122,7 +127,8 @@ def forward( f_last_outputs = outputs[torch.arange(batch_size), (lengths - 1), 0, :] b_last_outputs = outputs[:, 0, 1, :] last_outputs = torch.cat([f_last_outputs, b_last_outputs], dim=-1) - outputs = outputs.view(batch_size, outputs.shape[1], -1) + # Ensure view result is contiguous for cuDNN + outputs = outputs.view(batch_size, outputs.shape[1], -1).contiguous() last_outputs = self.down_projection(last_outputs) outputs = self.down_projection(outputs) return outputs, last_outputs @@ -270,22 +276,28 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: for feature_key in self.feature_keys: x = embedded[feature_key] - # Use abs() before sum to catch edge cases where embeddings sum to 0 - # @TODO bug with 0 embedding sum can still persist if the embedding is all 0s but the mask is not all 0s. - # despite being valid values (e.g., [1.0, -1.0]) - - # If we have an explicit mask, use it - if feature_key in masks: - mask = masks[feature_key].to(self.device).int() - # Token-level mask (B, N_notes, L): reduce to note-level (B, N_notes) - # by checking whether each note has at least one valid token. - # This is needed when TupleTimeTextProcessor returns 3D token masks that - # EmbeddingModel has already pooled down to (B, N_notes, H). - if mask.dim() == 3: - mask = (mask.sum(dim=-1) > 0).int() # (B, N_notes) + + x_dim_orig = x.dim() + if x_dim_orig == 4: + # nested_sequence: (B, num_visits, num_codes, D) + # @TODO: sum-pooling across codes is a simple baseline. May need to investigate better embeddings for nested codes. + x = x.sum(dim=2) # (B, num_visits, D) + if feature_key in masks: + mask = (masks[feature_key].to(self.device).sum(dim=-1) > 0).int() # (B, V) + else: + mask = (torch.abs(x).sum(dim=-1) != 0).int() + elif x_dim_orig == 2: + x = x.unsqueeze(1) + mask = None else: - mask = (torch.abs(x).sum(dim=-1) != 0).int() - + # 3D: already (B, T, D) + if feature_key in masks: + mask = masks[feature_key].to(self.device).int() + if mask.dim() == 3: + mask = (mask.sum(dim=-1) > 0).int() + else: + mask = (torch.abs(x).sum(dim=-1) != 0).int() + _, x = self.rnn[feature_key](x, mask) patient_emb.append(x) @@ -502,6 +514,22 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: for feature_key in self.sequential_features: x = embedded[feature_key] m = mask[feature_key] + + x_dim_orig = x.dim() + # Pool across events if needed + if x_dim_orig == 4: + b, v, e, d = x.shape + x = x.view(b, v * e, d) + elif x_dim_orig == 2: + x = x.unsqueeze(1) + + if x_dim_orig == 4 and m.dim() == 3: + m = m.view(b, v * e) + elif m.dim() == 3: + m = (m.sum(dim=-1) > 0).int() + elif m.dim() == 1: + m = m.unsqueeze(1) + _, last_hidden = self.rnn[feature_key](x, m) patient_emb.append(last_hidden) diff --git a/tests/core/test_adacare.py b/tests/core/test_adacare.py index daf60ecee..2e0ba7daf 100644 --- a/tests/core/test_adacare.py +++ b/tests/core/test_adacare.py @@ -2,7 +2,7 @@ import torch from pyhealth.datasets import create_sample_dataset, get_dataloader -from pyhealth.models import AdaCare +from pyhealth.models import AdaCare, MultimodalAdaCare class TestAdaCare(unittest.TestCase): @@ -159,5 +159,145 @@ def test_model_with_embedding(self): self.assertEqual(ret["embed"].shape[1], expected_embed_dim) +class TestMultimodalAdaCare(unittest.TestCase): + """Test cases for the MultimodalAdaCare model.""" + + def setUp(self): + """Set up test data and model with mixed feature types.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86"], # sequential + "procedures": ["proc-12", "proc-45"], # sequential + "demographics": ["asian", "male"], # multi-hot + "vitals": [120.0, 80.0, 98.6], # tensor + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-12", "cond-52"], # sequential + "procedures": ["proc-23"], # sequential + "demographics": ["white", "female"], # multi-hot + "vitals": [110.0, 75.0, 98.2], # tensor + "label": 0, + }, + ] + + self.input_schema = { + "conditions": "sequence", # sequential + "procedures": "sequence", # sequential + "demographics": "multi_hot", # non-sequential + "vitals": "tensor", # non-sequential + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = MultimodalAdaCare(dataset=self.dataset, hidden_dim=64) + + def test_model_initialization(self): + """Test that MultimodalAdaCare initializes correctly.""" + self.assertIsInstance(self.model, MultimodalAdaCare) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(self.model.hidden_dim, 64) + self.assertEqual(len(self.model.feature_keys), 4) + + self.assertIn("conditions", self.model.sequential_features) + self.assertIn("procedures", self.model.sequential_features) + self.assertIn("demographics", self.model.non_sequential_features) + self.assertIn("vitals", self.model.non_sequential_features) + + self.assertIn("conditions", self.model.adacare) + self.assertIn("procedures", self.model.adacare) + self.assertNotIn("demographics", self.model.adacare) + self.assertNotIn("vitals", self.model.adacare) + + def test_model_forward(self): + """Test that MultimodalAdaCare forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertIn("feature_importance", ret) + self.assertIn("conv_feature_importance", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + # 2 sequential features produce importance outputs + self.assertEqual(len(ret["feature_importance"]), 2) + self.assertEqual(len(ret["conv_feature_importance"]), 2) + + def test_model_backward(self): + """Test that MultimodalAdaCare backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_output_shapes(self): + """Test that output shapes are correct for multimodal inputs.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertEqual(ret["y_prob"].shape, (2, 1)) + self.assertEqual(ret["y_true"].shape, (2, 1)) + self.assertEqual(ret["logit"].shape, (2, 1)) + + def test_model_with_embedding(self): + """Test that MultimodalAdaCare returns embeddings when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) + + expected_embed_dim = ( + len(self.model.sequential_features) * self.model.hidden_dim + + len(self.model.non_sequential_features) * self.model.embedding_dim + ) + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_loss_is_finite(self): + """Test that the loss is finite.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + if __name__ == "__main__": unittest.main() diff --git a/tests/core/test_multimodal_adacare.py b/tests/core/test_multimodal_adacare.py new file mode 100644 index 000000000..e1718556d --- /dev/null +++ b/tests/core/test_multimodal_adacare.py @@ -0,0 +1,289 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MultimodalAdaCare + + +class TestMultimodalAdaCare(unittest.TestCase): + """Test cases for the MultimodalAdaCare model.""" + + def setUp(self): + """Set up test data and model with mixed feature types.""" + # Samples with mixed sequential and non-sequential features + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80"], # sequential + "procedures": ["proc-12", "proc-45"], # sequential + "demographics": ["asian", "male", "smoker"], # multi-hot + "vitals": [120.0, 80.0, 98.6, 16.0], # tensor + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-12", "cond-52"], # sequential + "procedures": ["proc-23"], # sequential + "demographics": ["white", "female"], # multi-hot + "vitals": [110.0, 75.0, 98.2, 18.0], # tensor + "label": 0, + }, + ] + + # Define input and output schemas with mixed types + self.input_schema = { + "conditions": "sequence", # sequential + "procedures": "sequence", # sequential + "demographics": "multi_hot", # non-sequential + "vitals": "tensor", # non-sequential + } + self.output_schema = {"label": "binary"} + + # Create dataset + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + # Create model + self.model = MultimodalAdaCare(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the MultimodalAdaCare model initializes correctly.""" + self.assertIsInstance(self.model, MultimodalAdaCare) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(self.model.hidden_dim, 128) + self.assertEqual(len(self.model.feature_keys), 4) + + # Check that features are correctly classified + self.assertIn("conditions", self.model.sequential_features) + self.assertIn("procedures", self.model.sequential_features) + self.assertIn("demographics", self.model.non_sequential_features) + self.assertIn("vitals", self.model.non_sequential_features) + + # Check that AdaCare layers are only created for sequential features + self.assertIn("conditions", self.model.adacare) + self.assertIn("procedures", self.model.adacare) + self.assertNotIn("demographics", self.model.adacare) + self.assertNotIn("vitals", self.model.adacare) + + def test_model_forward(self): + """Test that the MultimodalAdaCare model forward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check output structure + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertIn("feature_importance", ret) + self.assertIn("conv_feature_importance", ret) + + # Check tensor shapes + self.assertEqual(ret["y_prob"].shape[0], 2) # batch size + self.assertEqual(ret["y_true"].shape[0], 2) # batch size + self.assertEqual(ret["logit"].shape[0], 2) # batch size + + # Check that loss is a scalar + self.assertEqual(ret["loss"].dim(), 0) + + # Check feature importance outputs + self.assertEqual(len(ret["feature_importance"]), 2) # 2 sequential features + self.assertEqual(len(ret["conv_feature_importance"]), 2) # 2 sequential + + def test_model_backward(self): + """Test that the MultimodalAdaCare model backward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + ret = self.model(**data_batch) + + # Backward pass + ret["loss"].backward() + + # Check that at least one parameter has gradients + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_model_with_embedding(self): + """Test that the MultimodalAdaCare model returns embeddings when requested.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check that embeddings are returned + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + + # Check embedding dimension + # 2 sequential features * hidden_dim + 2 non-sequential features * embedding_dim + expected_embed_dim = ( + len(self.model.sequential_features) * self.model.hidden_dim + + len(self.model.non_sequential_features) * self.model.embedding_dim + ) + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_custom_hyperparameters(self): + """Test MultimodalAdaCare model with custom hyperparameters.""" + model = MultimodalAdaCare( + dataset=self.dataset, + embedding_dim=64, + hidden_dim=32, + kernel_size=2, + kernel_num=32, + r_v=2, + r_c=2, + activation="sparsemax", + rnn_type="lstm", + dropout=0.3, + ) + + self.assertEqual(model.embedding_dim, 64) + self.assertEqual(model.hidden_dim, 32) + + # Test forward pass + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("feature_importance", ret) + + def test_only_sequential_features(self): + """Test MultimodalAdaCare with only sequential features.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86"], + "procedures": ["proc-12", "proc-45"], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-12"], + "procedures": ["proc-23"], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"conditions": "sequence", "procedures": "sequence"}, + output_schema={"label": "binary"}, + dataset_name="test_seq_only", + ) + + model = MultimodalAdaCare(dataset=dataset, hidden_dim=64) + + # Check that all features are sequential + self.assertEqual(len(model.sequential_features), 2) + self.assertEqual(len(model.non_sequential_features), 0) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("feature_importance", ret) + + def test_only_non_sequential_features(self): + """Test MultimodalAdaCare with only non-sequential features.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "demographics": ["asian", "male", "smoker"], + "vitals": [120.0, 80.0, 98.6, 16.0], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "demographics": ["white", "female"], + "vitals": [110.0, 75.0, 98.2, 18.0], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"demographics": "multi_hot", "vitals": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test_non_seq_only", + ) + + model = MultimodalAdaCare(dataset=dataset, hidden_dim=64) + + # Check that all features are non-sequential + self.assertEqual(len(model.sequential_features), 0) + self.assertEqual(len(model.non_sequential_features), 2) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + # No feature importance when no sequential features + self.assertEqual(len(ret["feature_importance"]), 0) + self.assertEqual(len(ret["conv_feature_importance"]), 0) + + def test_sequential_processor_classification(self): + """Test that _is_sequential_processor correctly identifies processor types.""" + from pyhealth.processors import ( + MultiHotProcessor, + SequenceProcessor, + TensorProcessor, + ) + + # Test with actual processor instances + seq_proc = SequenceProcessor() + self.assertTrue(self.model._is_sequential_processor(seq_proc)) + + # Create simple multi-hot processor + multihot_proc = MultiHotProcessor() + self.assertFalse(self.model._is_sequential_processor(multihot_proc)) + + # Tensor processor + tensor_proc = TensorProcessor() + self.assertFalse(self.model._is_sequential_processor(tensor_proc)) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests/core/test_multimodal_retain.py b/tests/core/test_multimodal_retain.py new file mode 100644 index 000000000..2b7b1c99d --- /dev/null +++ b/tests/core/test_multimodal_retain.py @@ -0,0 +1,312 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MultimodalRETAIN + + +class TestMultimodalRETAIN(unittest.TestCase): + """Test cases for the MultimodalRETAIN model.""" + + def setUp(self): + """Set up test data and model with mixed feature types.""" + # Samples with mixed sequential and non-sequential features + # RETAIN typically works with visit-level nested sequences + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": [["A", "B"], ["C"]], # nested sequence + "procedures": [["P1"], ["P2", "P3"]], # nested sequence + "demographics": ["asian", "male"], # multi-hot + "vitals": [120.0, 80.0, 98.6], # tensor + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": [["D"], ["E", "F"]], # nested sequence + "procedures": [["P4"]], # nested sequence + "demographics": ["white", "female"], # multi-hot + "vitals": [110.0, 75.0, 98.2], # tensor + "label": 0, + }, + ] + + # Define input and output schemas with mixed types + self.input_schema = { + "conditions": "nested_sequence", # sequential + "procedures": "nested_sequence", # sequential + "demographics": "multi_hot", # non-sequential + "vitals": "tensor", # non-sequential + } + self.output_schema = {"label": "binary"} + + # Create dataset + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + # Create model + self.model = MultimodalRETAIN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the MultimodalRETAIN model initializes correctly.""" + self.assertIsInstance(self.model, MultimodalRETAIN) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(len(self.model.feature_keys), 4) + + # Check that features are correctly classified + self.assertIn("conditions", self.model.sequential_features) + self.assertIn("procedures", self.model.sequential_features) + self.assertIn("demographics", self.model.non_sequential_features) + self.assertIn("vitals", self.model.non_sequential_features) + + # Check that RETAIN layers are only created for sequential features + self.assertIn("conditions", self.model.retain) + self.assertIn("procedures", self.model.retain) + self.assertNotIn("demographics", self.model.retain) + self.assertNotIn("vitals", self.model.retain) + + def test_model_forward(self): + """Test that the MultimodalRETAIN model forward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check output structure + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + # Check tensor shapes + self.assertEqual(ret["y_prob"].shape[0], 2) # batch size + self.assertEqual(ret["y_true"].shape[0], 2) # batch size + self.assertEqual(ret["logit"].shape[0], 2) # batch size + + # Check that loss is a scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the MultimodalRETAIN model backward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + ret = self.model(**data_batch) + + # Backward pass + ret["loss"].backward() + + # Check that at least one parameter has gradients + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_model_with_embedding(self): + """Test that the MultimodalRETAIN model returns embeddings when requested.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check that embeddings are returned + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + + # Check embedding dimension + # All features contribute embedding_dim + expected_embed_dim = len(self.model.feature_keys) * self.model.embedding_dim + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_custom_hyperparameters(self): + """Test MultimodalRETAIN model with custom hyperparameters.""" + model = MultimodalRETAIN( + dataset=self.dataset, + embedding_dim=64, + dropout=0.3, + ) + + self.assertEqual(model.embedding_dim, 64) + + # Test forward pass + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_only_sequential_features(self): + """Test MultimodalRETAIN with only sequential features.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": [["A", "B"], ["C"]], + "procedures": [["P1"], ["P2"]], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": [["D"], ["E"]], + "procedures": [["P3"]], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={ + "conditions": "nested_sequence", + "procedures": "nested_sequence" + }, + output_schema={"label": "binary"}, + dataset_name="test_seq_only", + ) + + model = MultimodalRETAIN(dataset=dataset) + + # Check that all features are sequential + self.assertEqual(len(model.sequential_features), 2) + self.assertEqual(len(model.non_sequential_features), 0) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_only_non_sequential_features(self): + """Test MultimodalRETAIN with only non-sequential features.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "demographics": ["asian", "male"], + "vitals": [120.0, 80.0, 98.6], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "demographics": ["white", "female"], + "vitals": [110.0, 75.0, 98.2], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"demographics": "multi_hot", "vitals": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test_non_seq_only", + ) + + model = MultimodalRETAIN(dataset=dataset) + + # Check that all features are non-sequential + self.assertEqual(len(model.sequential_features), 0) + self.assertEqual(len(model.non_sequential_features), 2) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_sequential_processor_classification(self): + """Test that _is_sequential_processor correctly identifies processor types.""" + from pyhealth.processors import ( + MultiHotProcessor, + NestedSequenceProcessor, + TensorProcessor, + ) + + # Test with actual processor instances + nested_seq_proc = NestedSequenceProcessor() + self.assertTrue(self.model._is_sequential_processor(nested_seq_proc)) + + # Create simple multi-hot processor + multihot_proc = MultiHotProcessor() + self.assertFalse(self.model._is_sequential_processor(multihot_proc)) + + # Tensor processor + tensor_proc = TensorProcessor() + self.assertFalse(self.model._is_sequential_processor(tensor_proc)) + + def test_with_simple_sequences(self): + """Test MultimodalRETAIN with simple (non-nested) sequences.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ["A", "B", "C"], # simple sequence + "demographics": ["asian", "male"], # multi-hot + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "codes": ["D", "E"], # simple sequence + "demographics": ["white", "female"], # multi-hot + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"codes": "sequence", "demographics": "multi_hot"}, + output_schema={"label": "binary"}, + dataset_name="test_simple_seq", + ) + + model = MultimodalRETAIN(dataset=dataset) + + # Check that codes is sequential + self.assertIn("codes", model.sequential_features) + self.assertIn("demographics", model.non_sequential_features) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests/core/test_retain.py b/tests/core/test_retain.py new file mode 100644 index 000000000..78cc94e61 --- /dev/null +++ b/tests/core/test_retain.py @@ -0,0 +1,420 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import RETAIN, MultimodalRETAIN + + +class TestRETAIN(unittest.TestCase): + """Test cases for the RETAIN model.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": [["A", "B"], ["C", "D", "E"]], + "procedures": [["P1"], ["P2", "P3"]], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": [["F"], ["G", "H"]], + "procedures": [["P4", "P5"], ["P6"]], + "label": 0, + }, + ] + + self.input_schema = { + "conditions": "nested_sequence", + "procedures": "nested_sequence", + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = RETAIN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the RETAIN model initializes correctly.""" + self.assertIsInstance(self.model, RETAIN) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(len(self.model.feature_keys), 2) + self.assertIn("conditions", self.model.feature_keys) + self.assertIn("procedures", self.model.feature_keys) + self.assertEqual(self.model.label_keys[0], "label") + + def test_forward_input_format(self): + """Test that the dataloader provides tensor inputs.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + self.assertIsInstance(data_batch["conditions"], torch.Tensor) + self.assertIsInstance(data_batch["procedures"], torch.Tensor) + self.assertIsInstance(data_batch["label"], torch.Tensor) + + def test_model_forward(self): + """Test that the RETAIN model forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the RETAIN model backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + + def test_loss_is_finite(self): + """Test that the loss is finite.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + def test_output_shapes(self): + """Test that output shapes are correct.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + batch_size = 2 + num_labels = 1 # binary classification + + self.assertEqual(ret["y_prob"].shape, (batch_size, num_labels)) + self.assertEqual(ret["y_true"].shape, (batch_size, num_labels)) + self.assertEqual(ret["logit"].shape, (batch_size, num_labels)) + + def test_model_with_embedding(self): + """Test that the RETAIN model returns embeddings when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + data_batch["embed"] = True + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + expected_embed_dim = len(self.model.feature_keys) * self.model.embedding_dim + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + +class TestMultimodalRETAIN(unittest.TestCase): + """Test cases for the MultimodalRETAIN model.""" + + def setUp(self): + """Set up test data and model with mixed feature types.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": [["A", "B"], ["C"]], # nested sequence (sequential) + "procedures": [["P1"], ["P2", "P3"]], # nested sequence (sequential) + "demographics": ["asian", "male"], # multi-hot (non-sequential) + "vitals": [120.0, 80.0, 98.6], # tensor (non-sequential) + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": [["D"], ["E", "F"]], # nested sequence (sequential) + "procedures": [["P4"]], # nested sequence (sequential) + "demographics": ["white", "female"], # multi-hot (non-sequential) + "vitals": [110.0, 75.0, 98.2], # tensor (non-sequential) + "label": 0, + }, + ] + + self.input_schema = { + "conditions": "nested_sequence", # sequential + "procedures": "nested_sequence", # sequential + "demographics": "multi_hot", # non-sequential + "vitals": "tensor", # non-sequential + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = MultimodalRETAIN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the MultimodalRETAIN model initializes correctly.""" + self.assertIsInstance(self.model, MultimodalRETAIN) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(len(self.model.feature_keys), 4) + + # Check that features are correctly classified + self.assertIn("conditions", self.model.sequential_features) + self.assertIn("procedures", self.model.sequential_features) + self.assertIn("demographics", self.model.non_sequential_features) + self.assertIn("vitals", self.model.non_sequential_features) + + # Check that RETAIN layers are only created for sequential features + self.assertIn("conditions", self.model.retain) + self.assertIn("procedures", self.model.retain) + self.assertNotIn("demographics", self.model.retain) + self.assertNotIn("vitals", self.model.retain) + + def test_model_forward(self): + """Test that the MultimodalRETAIN forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the MultimodalRETAIN backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + + def test_model_with_embedding(self): + """Test that the MultimodalRETAIN model returns embeddings when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + + # All features contribute embedding_dim + expected_embed_dim = len(self.model.feature_keys) * self.model.embedding_dim + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_output_shapes(self): + """Test that output shapes are correct for multimodal inputs.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertEqual(ret["y_prob"].shape, (2, 1)) + self.assertEqual(ret["y_true"].shape, (2, 1)) + self.assertEqual(ret["logit"].shape, (2, 1)) + + def test_loss_is_finite(self): + """Test that the loss is finite.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + def test_custom_hyperparameters(self): + """Test MultimodalRETAIN with custom hyperparameters.""" + model = MultimodalRETAIN( + dataset=self.dataset, + embedding_dim=64, + dropout=0.3, + ) + + self.assertEqual(model.embedding_dim, 64) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_only_sequential_features(self): + """Test MultimodalRETAIN with only sequential features.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": [["A", "B"], ["C"]], + "procedures": [["P1"], ["P2"]], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": [["D"], ["E"]], + "procedures": [["P3"]], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"conditions": "nested_sequence", "procedures": "nested_sequence"}, + output_schema={"label": "binary"}, + dataset_name="test_seq_only", + ) + + model = MultimodalRETAIN(dataset=dataset) + self.assertEqual(len(model.sequential_features), 2) + self.assertEqual(len(model.non_sequential_features), 0) + + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_only_non_sequential_features(self): + """Test MultimodalRETAIN with only non-sequential features.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "demographics": ["asian", "male"], + "vitals": [120.0, 80.0, 98.6], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "demographics": ["white", "female"], + "vitals": [110.0, 75.0, 98.2], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"demographics": "multi_hot", "vitals": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test_non_seq_only", + ) + + model = MultimodalRETAIN(dataset=dataset) + self.assertEqual(len(model.sequential_features), 0) + self.assertEqual(len(model.non_sequential_features), 2) + + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_sequential_processor_classification(self): + """Test that _is_sequential_processor correctly identifies processor types.""" + from pyhealth.processors import ( + MultiHotProcessor, + NestedSequenceProcessor, + TensorProcessor, + ) + + nested_seq_proc = NestedSequenceProcessor() + self.assertTrue(self.model._is_sequential_processor(nested_seq_proc)) + + multihot_proc = MultiHotProcessor() + self.assertFalse(self.model._is_sequential_processor(multihot_proc)) + + tensor_proc = TensorProcessor() + self.assertFalse(self.model._is_sequential_processor(tensor_proc)) + + def test_with_simple_sequences(self): + """Test MultimodalRETAIN with simple (non-nested) sequences mixed with non-sequential.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ["A", "B", "C"], # simple sequence + "demographics": ["asian", "male"], # multi-hot + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "codes": ["D", "E"], # simple sequence + "demographics": ["white", "female"], # multi-hot + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"codes": "sequence", "demographics": "multi_hot"}, + output_schema={"label": "binary"}, + dataset_name="test_simple_seq", + ) + + model = MultimodalRETAIN(dataset=dataset) + self.assertIn("codes", model.sequential_features) + self.assertIn("demographics", model.non_sequential_features) + + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + +if __name__ == "__main__": + unittest.main()