Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ list(APPEND
"${CMAKE_CURRENT_SOURCE_DIR}/cmake"
)
find_package(CUDAToolkit REQUIRED)
find_package(MPI 3.0 REQUIRED COMPONENTS CXX)
find_package(Torch 2.6 REQUIRED CONFIG)

# Also, torch_python!
Expand All @@ -79,6 +78,7 @@ find_library(TORCH_PYTHON_LIBRARY
find_library(TORCH_PYTHON_LIBRARY torch_python REQUIRED)

if (DGRAPH_ENABLE_NVSHMEM)
find_package(MPI 3.0 REQUIRED COMPONENTS CXX)
find_package(NVSHMEM 2.5 REQUIRED MODULE)
endif ()

Expand Down
24 changes: 24 additions & 0 deletions DGraph/Communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from DGraph.distributed.nccl import NCCLBackendEngine

from DGraph.CommunicatorBase import CommunicatorBase
from typing import Tuple, Optional

SUPPORTED_BACKENDS = ["nccl", "mpi", "nvshmem"]

Expand Down Expand Up @@ -95,6 +96,13 @@ def get_local_tensor(

return masked_tensor

def alloc_buffer(
self, size: Tuple[int, ...], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""Allocate a buffer suitable for this backend's communication model.
Default: torch.empty. NVSHMEM overrides with symmetric allocation."""
return self.__backend_engine.allocate_buffer(size, dtype, device)

def scatter(self, *args, **kwargs) -> torch.Tensor:
self.__check_init()
return self.__backend_engine.scatter(*args, **kwargs)
Expand All @@ -103,6 +111,22 @@ def gather(self, *args, **kwargs) -> torch.Tensor:
self.__check_init()
return self.__backend_engine.gather(*args, **kwargs)

def put(
self,
send_buffer: torch.Tensor,
recv_buffer: torch.Tensor,
send_offsets: torch.Tensor,
recv_offsets: torch.Tensor,
remote_offsets: Optional[torch.Tensor] = None,
) -> None:
return self.__backend_engine.put(
send_buffer,
recv_buffer,
send_offsets,
recv_offsets,
remote_offsets=remote_offsets,
)

def barrier(self) -> None:
self.__check_init()
self.__backend_engine.barrier()
Expand Down
119 changes: 5 additions & 114 deletions DGraph/data/ogbn_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from ogb.nodeproppred import NodePropPredDataset
from DGraph.data.graph import DistributedGraph
from DGraph.data.graph import get_round_robin_node_rank_map
import numpy as np
from DGraph.data.preprocess import process_homogenous_data
import os
import torch.distributed as dist


SUPPORTED_DATASETS = [
"ogbn-arxiv",
Expand All @@ -37,117 +37,6 @@
}


def node_renumbering(node_rank_placement) -> Tuple[torch.Tensor, torch.Tensor]:
"""The nodes are renumbered based on the rank mappings so the node features and
numbers are contiguous."""

contiguous_rank_mapping, renumbered_nodes = torch.sort(node_rank_placement)
return renumbered_nodes, contiguous_rank_mapping


def edge_renumbering(
edge_indices, renumbered_nodes, vertex_mapping, edge_features=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
src_indices = edge_indices[0, :]
dst_indices = edge_indices[1, :]
src_indices = renumbered_nodes[src_indices]
dst_indices = renumbered_nodes[dst_indices]

edge_src_rank_mapping = vertex_mapping[src_indices]
edge_dest_rank_mapping = vertex_mapping[dst_indices]

sorted_src_rank_mapping, sorted_indices = torch.sort(edge_src_rank_mapping)
dst_indices = dst_indices[sorted_indices]
src_indices = src_indices[sorted_indices]

sorted_dest_rank_mapping = edge_dest_rank_mapping[sorted_indices]

if edge_features is not None:
# Sort the edge features based on the sorted indices
edge_features = edge_features[sorted_indices]

return (
torch.stack([src_indices, dst_indices], dim=0),
sorted_src_rank_mapping,
sorted_dest_rank_mapping,
edge_features,
)


def process_homogenous_data(
graph_data,
labels,
rank: int,
world_Size: int,
split_idx: dict,
node_rank_placement: torch.Tensor,
*args,
**kwargs,
) -> DistributedGraph:
"""For processing homogenous graph with node features, edge index and labels"""
assert "node_feat" in graph_data, "Node features not found"
assert "edge_index" in graph_data, "Edge index not found"
assert "num_nodes" in graph_data, "Number of nodes not found"
assert graph_data["edge_feat"] is None, "Edge features not supported"

node_features = torch.Tensor(graph_data["node_feat"]).float()
edge_index = torch.Tensor(graph_data["edge_index"]).long()
num_nodes = graph_data["num_nodes"]
labels = torch.Tensor(labels).long()
# For bidirectional graphs the number of edges are double counted
num_edges = edge_index.shape[1]

assert node_rank_placement.shape[0] == num_nodes, "Node mapping mismatch"
assert "train" in split_idx, "Train mask not found"
assert "valid" in split_idx, "Validation mask not found"
assert "test" in split_idx, "Test mask not found"

train_nodes = torch.from_numpy(split_idx["train"])
valid_nodes = torch.from_numpy(split_idx["valid"])
test_nodes = torch.from_numpy(split_idx["test"])

# Renumber the nodes and edges to make them contiguous
renumbered_nodes, contiguous_rank_mapping = node_renumbering(node_rank_placement)
node_features = node_features[renumbered_nodes]

# Sanity check to make sure we placed the nodes in the correct spots

assert torch.all(node_rank_placement[renumbered_nodes] == contiguous_rank_mapping)

# First renumber the edges
# Then we calculate the location of the source and destination vertex of each edge
# based on the rank mapping
# Then we sort the edges based on the source vertex rank mapping
# When determining the location of the edge, we use the rank of the source vertex
# as the location of the edge

edge_index, edge_rank_mapping, edge_dest_rank_mapping, _ = edge_renumbering(
edge_index, renumbered_nodes, contiguous_rank_mapping, edge_features=None
)

train_nodes = renumbered_nodes[train_nodes]
valid_nodes = renumbered_nodes[valid_nodes]
test_nodes = renumbered_nodes[test_nodes]

labels = labels[renumbered_nodes]

graph_obj = DistributedGraph(
node_features=node_features,
edge_index=edge_index,
num_nodes=num_nodes,
num_edges=num_edges,
node_loc=contiguous_rank_mapping.long(),
edge_loc=edge_rank_mapping.long(),
edge_dest_rank_mapping=edge_dest_rank_mapping.long(),
world_size=world_Size,
labels=labels,
train_mask=train_nodes,
val_mask=valid_nodes,
test_mask=test_nodes,
)
return graph_obj


class DistributedOGBWrapper(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -211,7 +100,9 @@ def __init__(
else:
if node_rank_placement is None:
if self._rank == 0:
print(f"Node rank placement not provided, generating a round robin placement")
print(
f"Node rank placement not provided, generating a round robin placement"
)
node_rank_placement = get_round_robin_node_rank_map(
graph_data["num_nodes"], self._world_size
)
Expand Down
114 changes: 114 additions & 0 deletions DGraph/data/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
from typing import Optional, Tuple
from DGraph.data.graph import DistributedGraph


def node_renumbering(node_rank_placement) -> Tuple[torch.Tensor, torch.Tensor]:
"""The nodes are renumbered based on the rank mappings so the node features and
numbers are contiguous."""

contiguous_rank_mapping, renumbered_nodes = torch.sort(node_rank_placement)
return renumbered_nodes, contiguous_rank_mapping


def edge_renumbering(
edge_indices, renumbered_nodes, vertex_mapping, edge_features=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
src_indices = edge_indices[0, :]
dst_indices = edge_indices[1, :]
src_indices = renumbered_nodes[src_indices]
dst_indices = renumbered_nodes[dst_indices]

edge_src_rank_mapping = vertex_mapping[src_indices]
edge_dest_rank_mapping = vertex_mapping[dst_indices]

sorted_src_rank_mapping, sorted_indices = torch.sort(edge_src_rank_mapping)
dst_indices = dst_indices[sorted_indices]
src_indices = src_indices[sorted_indices]

sorted_dest_rank_mapping = edge_dest_rank_mapping[sorted_indices]

if edge_features is not None:
# Sort the edge features based on the sorted indices
edge_features = edge_features[sorted_indices]

return (
torch.stack([src_indices, dst_indices], dim=0),
sorted_src_rank_mapping,
sorted_dest_rank_mapping,
edge_features,
)


def process_homogenous_data(
graph_data,
labels,
rank: int,
world_Size: int,
split_idx: dict,
node_rank_placement: torch.Tensor,
*args,
**kwargs,
) -> DistributedGraph:
"""For processing homogenous graph with node features, edge index and labels"""
assert "node_feat" in graph_data, "Node features not found"
assert "edge_index" in graph_data, "Edge index not found"
assert "num_nodes" in graph_data, "Number of nodes not found"
assert graph_data["edge_feat"] is None, "Edge features not supported"

node_features = torch.Tensor(graph_data["node_feat"]).float()
edge_index = torch.Tensor(graph_data["edge_index"]).long()
num_nodes = graph_data["num_nodes"]
labels = torch.Tensor(labels).long()
# For bidirectional graphs the number of edges are double counted
num_edges = edge_index.shape[1]

assert node_rank_placement.shape[0] == num_nodes, "Node mapping mismatch"
assert "train" in split_idx, "Train mask not found"
assert "valid" in split_idx, "Validation mask not found"
assert "test" in split_idx, "Test mask not found"

train_nodes = torch.from_numpy(split_idx["train"])
valid_nodes = torch.from_numpy(split_idx["valid"])
test_nodes = torch.from_numpy(split_idx["test"])

# Renumber the nodes and edges to make them contiguous
renumbered_nodes, contiguous_rank_mapping = node_renumbering(node_rank_placement)
node_features = node_features[renumbered_nodes]

# Sanity check to make sure we placed the nodes in the correct spots

assert torch.all(node_rank_placement[renumbered_nodes] == contiguous_rank_mapping)

# First renumber the edges
# Then we calculate the location of the source and destination vertex of each edge
# based on the rank mapping
# Then we sort the edges based on the source vertex rank mapping
# When determining the location of the edge, we use the rank of the source vertex
# as the location of the edge

edge_index, edge_rank_mapping, edge_dest_rank_mapping, _ = edge_renumbering(
edge_index, renumbered_nodes, contiguous_rank_mapping, edge_features=None
)

train_nodes = renumbered_nodes[train_nodes]
valid_nodes = renumbered_nodes[valid_nodes]
test_nodes = renumbered_nodes[test_nodes]

labels = labels[renumbered_nodes]

graph_obj = DistributedGraph(
node_features=node_features,
edge_index=edge_index,
num_nodes=num_nodes,
num_edges=num_edges,
node_loc=contiguous_rank_mapping.long(),
edge_loc=edge_rank_mapping.long(),
edge_dest_rank_mapping=edge_dest_rank_mapping.long(),
world_size=world_Size,
labels=labels,
train_mask=train_nodes,
val_mask=valid_nodes,
test_mask=test_nodes,
)
return graph_obj
37 changes: 36 additions & 1 deletion DGraph/distributed/Engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#
# SPDX-License-Identifier: (Apache-2.0)
import torch
from typing import Optional, Union
from typing import Optional, Union, Tuple


class BackendEngine(object):
Expand Down Expand Up @@ -64,6 +64,41 @@ def gather(
) -> torch.Tensor:
raise NotImplementedError

def put(
self,
send_buffer: torch.Tensor,
recv_buffer: torch.Tensor,
send_offsets: torch.Tensor,
recv_offsets: torch.Tensor,
remote_offsets: Optional[torch.Tensor] = None,
) -> None:
"""
Exchange data between all ranks.

Chunks send_buffer by send_offsets, delivers each chunk to the
corresponding rank's recv_buffer. Must be synchronous: when this
method returns, recv_buffer is fully populated and safe to read.

Two-sided backends ignore remote_offsets.
One-sided backends use remote_offsets[i] as the write position
into rank i's recv_buffer.
"""
raise NotImplementedError

def allocate_buffer(
self,
size: Tuple[int, ...],
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
Allocate a communication buffer.

Default: torch.empty. One-sided backends override this to
return symmetric / registered memory.
"""
return torch.empty(size, dtype=dtype, device=device)

def finalize(self) -> None:
raise NotImplementedError

Expand Down
14 changes: 14 additions & 0 deletions DGraph/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,18 @@
Modules exported by this package:
- `Engine`: The DGraph communication engine used by the Communicator.
- `BackendEngine`: The abstract DGraph communication engine used by the Communicator.
- `HaloExchange`: Halo exchange class for communicating remote vertices
- `CommunicationPattern`: Dataclass for holding communication pattern information
"""
from DGraph.distributed.haloExchange import HaloExchange, DGraphMessagePassing
from DGraph.distributed.commInfo import (
CommunicationPattern,
build_communication_pattern,
)

__all__ = [
"HaloExchange",
"DGraphMessagePassing",
"CommunicationPattern",
"build_communication_pattern",
]
Loading