diff --git a/CMakeLists.txt b/CMakeLists.txt index 1125042..1b8f04c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,6 @@ list(APPEND "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ) find_package(CUDAToolkit REQUIRED) -find_package(MPI 3.0 REQUIRED COMPONENTS CXX) find_package(Torch 2.6 REQUIRED CONFIG) # Also, torch_python! @@ -79,6 +78,7 @@ find_library(TORCH_PYTHON_LIBRARY find_library(TORCH_PYTHON_LIBRARY torch_python REQUIRED) if (DGRAPH_ENABLE_NVSHMEM) + find_package(MPI 3.0 REQUIRED COMPONENTS CXX) find_package(NVSHMEM 2.5 REQUIRED MODULE) endif () diff --git a/DGraph/Communicator.py b/DGraph/Communicator.py index a9d2f93..d7355c6 100644 --- a/DGraph/Communicator.py +++ b/DGraph/Communicator.py @@ -16,6 +16,7 @@ from DGraph.distributed.nccl import NCCLBackendEngine from DGraph.CommunicatorBase import CommunicatorBase +from typing import Tuple, Optional SUPPORTED_BACKENDS = ["nccl", "mpi", "nvshmem"] @@ -95,6 +96,13 @@ def get_local_tensor( return masked_tensor + def alloc_buffer( + self, size: Tuple[int, ...], dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Allocate a buffer suitable for this backend's communication model. + Default: torch.empty. NVSHMEM overrides with symmetric allocation.""" + return self.__backend_engine.allocate_buffer(size, dtype, device) + def scatter(self, *args, **kwargs) -> torch.Tensor: self.__check_init() return self.__backend_engine.scatter(*args, **kwargs) @@ -103,6 +111,22 @@ def gather(self, *args, **kwargs) -> torch.Tensor: self.__check_init() return self.__backend_engine.gather(*args, **kwargs) + def put( + self, + send_buffer: torch.Tensor, + recv_buffer: torch.Tensor, + send_offsets: torch.Tensor, + recv_offsets: torch.Tensor, + remote_offsets: Optional[torch.Tensor] = None, + ) -> None: + return self.__backend_engine.put( + send_buffer, + recv_buffer, + send_offsets, + recv_offsets, + remote_offsets=remote_offsets, + ) + def barrier(self) -> None: self.__check_init() self.__backend_engine.barrier() diff --git a/DGraph/data/ogbn_datasets.py b/DGraph/data/ogbn_datasets.py index deb1383..2416d82 100644 --- a/DGraph/data/ogbn_datasets.py +++ b/DGraph/data/ogbn_datasets.py @@ -18,9 +18,9 @@ from ogb.nodeproppred import NodePropPredDataset from DGraph.data.graph import DistributedGraph from DGraph.data.graph import get_round_robin_node_rank_map -import numpy as np +from DGraph.data.preprocess import process_homogenous_data import os -import torch.distributed as dist + SUPPORTED_DATASETS = [ "ogbn-arxiv", @@ -37,117 +37,6 @@ } -def node_renumbering(node_rank_placement) -> Tuple[torch.Tensor, torch.Tensor]: - """The nodes are renumbered based on the rank mappings so the node features and - numbers are contiguous.""" - - contiguous_rank_mapping, renumbered_nodes = torch.sort(node_rank_placement) - return renumbered_nodes, contiguous_rank_mapping - - -def edge_renumbering( - edge_indices, renumbered_nodes, vertex_mapping, edge_features=None -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - src_indices = edge_indices[0, :] - dst_indices = edge_indices[1, :] - src_indices = renumbered_nodes[src_indices] - dst_indices = renumbered_nodes[dst_indices] - - edge_src_rank_mapping = vertex_mapping[src_indices] - edge_dest_rank_mapping = vertex_mapping[dst_indices] - - sorted_src_rank_mapping, sorted_indices = torch.sort(edge_src_rank_mapping) - dst_indices = dst_indices[sorted_indices] - src_indices = src_indices[sorted_indices] - - sorted_dest_rank_mapping = edge_dest_rank_mapping[sorted_indices] - - if edge_features is not None: - # Sort the edge features based on the sorted indices - edge_features = edge_features[sorted_indices] - - return ( - torch.stack([src_indices, dst_indices], dim=0), - sorted_src_rank_mapping, - sorted_dest_rank_mapping, - edge_features, - ) - - -def process_homogenous_data( - graph_data, - labels, - rank: int, - world_Size: int, - split_idx: dict, - node_rank_placement: torch.Tensor, - *args, - **kwargs, -) -> DistributedGraph: - """For processing homogenous graph with node features, edge index and labels""" - assert "node_feat" in graph_data, "Node features not found" - assert "edge_index" in graph_data, "Edge index not found" - assert "num_nodes" in graph_data, "Number of nodes not found" - assert graph_data["edge_feat"] is None, "Edge features not supported" - - node_features = torch.Tensor(graph_data["node_feat"]).float() - edge_index = torch.Tensor(graph_data["edge_index"]).long() - num_nodes = graph_data["num_nodes"] - labels = torch.Tensor(labels).long() - # For bidirectional graphs the number of edges are double counted - num_edges = edge_index.shape[1] - - assert node_rank_placement.shape[0] == num_nodes, "Node mapping mismatch" - assert "train" in split_idx, "Train mask not found" - assert "valid" in split_idx, "Validation mask not found" - assert "test" in split_idx, "Test mask not found" - - train_nodes = torch.from_numpy(split_idx["train"]) - valid_nodes = torch.from_numpy(split_idx["valid"]) - test_nodes = torch.from_numpy(split_idx["test"]) - - # Renumber the nodes and edges to make them contiguous - renumbered_nodes, contiguous_rank_mapping = node_renumbering(node_rank_placement) - node_features = node_features[renumbered_nodes] - - # Sanity check to make sure we placed the nodes in the correct spots - - assert torch.all(node_rank_placement[renumbered_nodes] == contiguous_rank_mapping) - - # First renumber the edges - # Then we calculate the location of the source and destination vertex of each edge - # based on the rank mapping - # Then we sort the edges based on the source vertex rank mapping - # When determining the location of the edge, we use the rank of the source vertex - # as the location of the edge - - edge_index, edge_rank_mapping, edge_dest_rank_mapping, _ = edge_renumbering( - edge_index, renumbered_nodes, contiguous_rank_mapping, edge_features=None - ) - - train_nodes = renumbered_nodes[train_nodes] - valid_nodes = renumbered_nodes[valid_nodes] - test_nodes = renumbered_nodes[test_nodes] - - labels = labels[renumbered_nodes] - - graph_obj = DistributedGraph( - node_features=node_features, - edge_index=edge_index, - num_nodes=num_nodes, - num_edges=num_edges, - node_loc=contiguous_rank_mapping.long(), - edge_loc=edge_rank_mapping.long(), - edge_dest_rank_mapping=edge_dest_rank_mapping.long(), - world_size=world_Size, - labels=labels, - train_mask=train_nodes, - val_mask=valid_nodes, - test_mask=test_nodes, - ) - return graph_obj - - class DistributedOGBWrapper(Dataset): def __init__( self, @@ -211,7 +100,9 @@ def __init__( else: if node_rank_placement is None: if self._rank == 0: - print(f"Node rank placement not provided, generating a round robin placement") + print( + f"Node rank placement not provided, generating a round robin placement" + ) node_rank_placement = get_round_robin_node_rank_map( graph_data["num_nodes"], self._world_size ) diff --git a/DGraph/data/preprocess.py b/DGraph/data/preprocess.py new file mode 100644 index 0000000..71707ea --- /dev/null +++ b/DGraph/data/preprocess.py @@ -0,0 +1,114 @@ +import torch +from typing import Optional, Tuple +from DGraph.data.graph import DistributedGraph + + +def node_renumbering(node_rank_placement) -> Tuple[torch.Tensor, torch.Tensor]: + """The nodes are renumbered based on the rank mappings so the node features and + numbers are contiguous.""" + + contiguous_rank_mapping, renumbered_nodes = torch.sort(node_rank_placement) + return renumbered_nodes, contiguous_rank_mapping + + +def edge_renumbering( + edge_indices, renumbered_nodes, vertex_mapping, edge_features=None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + src_indices = edge_indices[0, :] + dst_indices = edge_indices[1, :] + src_indices = renumbered_nodes[src_indices] + dst_indices = renumbered_nodes[dst_indices] + + edge_src_rank_mapping = vertex_mapping[src_indices] + edge_dest_rank_mapping = vertex_mapping[dst_indices] + + sorted_src_rank_mapping, sorted_indices = torch.sort(edge_src_rank_mapping) + dst_indices = dst_indices[sorted_indices] + src_indices = src_indices[sorted_indices] + + sorted_dest_rank_mapping = edge_dest_rank_mapping[sorted_indices] + + if edge_features is not None: + # Sort the edge features based on the sorted indices + edge_features = edge_features[sorted_indices] + + return ( + torch.stack([src_indices, dst_indices], dim=0), + sorted_src_rank_mapping, + sorted_dest_rank_mapping, + edge_features, + ) + + +def process_homogenous_data( + graph_data, + labels, + rank: int, + world_Size: int, + split_idx: dict, + node_rank_placement: torch.Tensor, + *args, + **kwargs, +) -> DistributedGraph: + """For processing homogenous graph with node features, edge index and labels""" + assert "node_feat" in graph_data, "Node features not found" + assert "edge_index" in graph_data, "Edge index not found" + assert "num_nodes" in graph_data, "Number of nodes not found" + assert graph_data["edge_feat"] is None, "Edge features not supported" + + node_features = torch.Tensor(graph_data["node_feat"]).float() + edge_index = torch.Tensor(graph_data["edge_index"]).long() + num_nodes = graph_data["num_nodes"] + labels = torch.Tensor(labels).long() + # For bidirectional graphs the number of edges are double counted + num_edges = edge_index.shape[1] + + assert node_rank_placement.shape[0] == num_nodes, "Node mapping mismatch" + assert "train" in split_idx, "Train mask not found" + assert "valid" in split_idx, "Validation mask not found" + assert "test" in split_idx, "Test mask not found" + + train_nodes = torch.from_numpy(split_idx["train"]) + valid_nodes = torch.from_numpy(split_idx["valid"]) + test_nodes = torch.from_numpy(split_idx["test"]) + + # Renumber the nodes and edges to make them contiguous + renumbered_nodes, contiguous_rank_mapping = node_renumbering(node_rank_placement) + node_features = node_features[renumbered_nodes] + + # Sanity check to make sure we placed the nodes in the correct spots + + assert torch.all(node_rank_placement[renumbered_nodes] == contiguous_rank_mapping) + + # First renumber the edges + # Then we calculate the location of the source and destination vertex of each edge + # based on the rank mapping + # Then we sort the edges based on the source vertex rank mapping + # When determining the location of the edge, we use the rank of the source vertex + # as the location of the edge + + edge_index, edge_rank_mapping, edge_dest_rank_mapping, _ = edge_renumbering( + edge_index, renumbered_nodes, contiguous_rank_mapping, edge_features=None + ) + + train_nodes = renumbered_nodes[train_nodes] + valid_nodes = renumbered_nodes[valid_nodes] + test_nodes = renumbered_nodes[test_nodes] + + labels = labels[renumbered_nodes] + + graph_obj = DistributedGraph( + node_features=node_features, + edge_index=edge_index, + num_nodes=num_nodes, + num_edges=num_edges, + node_loc=contiguous_rank_mapping.long(), + edge_loc=edge_rank_mapping.long(), + edge_dest_rank_mapping=edge_dest_rank_mapping.long(), + world_size=world_Size, + labels=labels, + train_mask=train_nodes, + val_mask=valid_nodes, + test_mask=test_nodes, + ) + return graph_obj diff --git a/DGraph/distributed/Engine.py b/DGraph/distributed/Engine.py index 547aada..c0a129c 100644 --- a/DGraph/distributed/Engine.py +++ b/DGraph/distributed/Engine.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) import torch -from typing import Optional, Union +from typing import Optional, Union, Tuple class BackendEngine(object): @@ -64,6 +64,41 @@ def gather( ) -> torch.Tensor: raise NotImplementedError + def put( + self, + send_buffer: torch.Tensor, + recv_buffer: torch.Tensor, + send_offsets: torch.Tensor, + recv_offsets: torch.Tensor, + remote_offsets: Optional[torch.Tensor] = None, + ) -> None: + """ + Exchange data between all ranks. + + Chunks send_buffer by send_offsets, delivers each chunk to the + corresponding rank's recv_buffer. Must be synchronous: when this + method returns, recv_buffer is fully populated and safe to read. + + Two-sided backends ignore remote_offsets. + One-sided backends use remote_offsets[i] as the write position + into rank i's recv_buffer. + """ + raise NotImplementedError + + def allocate_buffer( + self, + size: Tuple[int, ...], + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """ + Allocate a communication buffer. + + Default: torch.empty. One-sided backends override this to + return symmetric / registered memory. + """ + return torch.empty(size, dtype=dtype, device=device) + def finalize(self) -> None: raise NotImplementedError diff --git a/DGraph/distributed/__init__.py b/DGraph/distributed/__init__.py index aaee1db..81ed157 100644 --- a/DGraph/distributed/__init__.py +++ b/DGraph/distributed/__init__.py @@ -17,4 +17,18 @@ Modules exported by this package: - `Engine`: The DGraph communication engine used by the Communicator. - `BackendEngine`: The abstract DGraph communication engine used by the Communicator. +- `HaloExchange`: Halo exchange class for communicating remote vertices +- `CommunicationPattern`: Dataclass for holding communication pattern information """ +from DGraph.distributed.haloExchange import HaloExchange, DGraphMessagePassing +from DGraph.distributed.commInfo import ( + CommunicationPattern, + build_communication_pattern, +) + +__all__ = [ + "HaloExchange", + "DGraphMessagePassing", + "CommunicationPattern", + "build_communication_pattern", +] diff --git a/DGraph/distributed/commInfo.py b/DGraph/distributed/commInfo.py new file mode 100644 index 0000000..0d5bd7e --- /dev/null +++ b/DGraph/distributed/commInfo.py @@ -0,0 +1,207 @@ +import torch +from dataclasses import dataclass +from typing import Optional +import torch.distributed as dist + + +@dataclass +class CommunicationPattern: + # --- Identity --- + rank: int + world_size: int + + # --- Vertex Counts --- + num_local_vertices: int + num_halo_vertices: int + + # --- Local Subgraph --- + local_edge_list: torch.Tensor # [num_local_edges, 2] + + # --- Send Indexing --- + send_local_idx: torch.Tensor # [total_sends] + send_offset: torch.Tensor # [world_size + 1] + + # --- Receive Indexing --- + recv_offset: torch.Tensor # [world_size + 1] + + # --- Communication Map --- + comm_map: torch.Tensor # [world_size, world_size] + + # --- One-Sided RMA Offsets --- + put_forward_remote_offset: torch.Tensor # [world_size] + put_backward_remote_offset: torch.Tensor # [world_size] + + +def compute_local_vertices(partitioning: torch.Tensor, rank: int) -> torch.Tensor: + """Returns local_vertices_global: [num_local] global IDs owned by this rank""" + + return torch.where(partitioning == rank)[0] + + +def compute_halo_vertices( + edge_list: torch.Tensor, + src_partitioning: torch.Tensor, + rank: int, + dst_partitioning: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Computes halo vertices. Supports both homogeneous and bipartite/heterogeneous relations. + """ + # Fallback for homogeneous graphs + if dst_partitioning is None: + dst_partitioning = src_partitioning + + src_rank = src_partitioning[edge_list[:, 0]] + dst_rank = dst_partitioning[edge_list[:, 1]] + + # Cross-rank mask: source is local, destination is remote + cross_mask = (src_rank == rank) & (dst_rank != rank) + + # Return unique destination vertex IDs from those edges + return torch.unique(edge_list[cross_mask, 1]) + + +def compute_local_edge_list( + global_edge_list: torch.Tensor, # [E, 2] + partitioning: torch.Tensor, # [V] + local_vertices_global: torch.Tensor, # [num_local] + halo_vertices_global: torch.Tensor, # [num_halo] + rank: int, +) -> torch.Tensor: + num_local = local_vertices_global.size(0) + num_halo = halo_vertices_global.size(0) + num_global = partitioning.size(0) + + # Filter edges owned by this rank + local_edge_mask = partitioning[global_edge_list[:, 0]] == rank + local_edges_global = global_edge_list[local_edge_mask] + + # Build inverse map: global_id -> local_idx via scatter + g2l = torch.empty(num_global, dtype=torch.long, device=global_edge_list.device) + g2l.scatter_(0, local_vertices_global, torch.arange(num_local, device=g2l.device)) + g2l.scatter_( + 0, + halo_vertices_global, + torch.arange(num_local, num_local + num_halo, device=g2l.device), + ) + + # Remap to local numbering + local_edge_list = g2l[local_edges_global] + + return local_edge_list + + +def compute_boundary_vertices( + edge_list: torch.Tensor, + src_partitioning: torch.Tensor, + src_local_vertices_global: torch.Tensor, + rank: int, + num_ranks: int, + dst_partitioning: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Computes boundary vertices and CSR send offsets. + """ + # Fallback for homogeneous graphs + if dst_partitioning is None: + dst_partitioning = src_partitioning + + # 1. Filter cross-rank edges + src_rank = src_partitioning[edge_list[:, 0]] + dst_rank = dst_partitioning[edge_list[:, 1]] + cross_mask = (src_rank == rank) & (dst_rank != rank) + cross_edges = edge_list[cross_mask] + + # 2. Deduplicate (src, dst_rank) pairs + src_global = cross_edges[:, 0] + target_ranks = dst_partitioning[cross_edges[:, 1]] + + # v_src_total acts as V for homogeneous, or V_src for heterogeneous + v_src_total = src_partitioning.size(0) + encoded = target_ranks * v_src_total + src_global + unique_encoded = torch.unique(encoded) + + target_ranks_unique = unique_encoded // v_src_total + src_global_unique = unique_encoded % v_src_total + + # 3. Sort by target rank + sort_idx = torch.argsort(target_ranks_unique) + target_ranks_sorted = target_ranks_unique[sort_idx] + src_global_sorted = src_global_unique[sort_idx] + + # 4. Remap to local indices + num_local = src_local_vertices_global.size(0) + g2l = torch.empty(v_src_total, dtype=torch.long, device=edge_list.device) + g2l[src_local_vertices_global] = torch.arange(num_local, device=edge_list.device) + send_local_idx = g2l[src_global_sorted] + + # 5. Build CSR offsets + send_offset = torch.zeros(num_ranks + 1, dtype=torch.long, device=edge_list.device) + send_offset.scatter_add_( + 0, target_ranks_sorted + 1, torch.ones_like(target_ranks_sorted) + ) + send_offset = send_offset.cumsum(0) + + return send_local_idx, send_offset + + +def compute_comm_map(send_offset, world_size) -> torch.Tensor: + """All-gathers send counts to build comm_map: [world_size, world_size]""" + send_counts = send_offset[1:] - send_offset[:-1] + comm_map_list = [torch.zeros(world_size).long().cuda() for _ in range(world_size)] + dist.all_gather(comm_map_list, send_counts.cuda()) + comm_map = torch.stack(comm_map_list) + return comm_map + + +def compute_recv_offsets(comm_map, rank) -> tuple[torch.Tensor, torch.Tensor]: + """Returns (recv_offset, recv_backward_offset) from comm_map""" + recv_counts = comm_map[:, rank] + recv_offset = torch.zeros(comm_map.size(0) + 1, dtype=torch.long) + recv_offset[1:] = recv_counts.cpu().cumsum(0) + + recv_backward_offset = comm_map[:rank, :].sum(0) + return recv_offset, recv_backward_offset + + +def build_communication_pattern( + global_edge_list: torch.Tensor, + partitioning: torch.Tensor, + rank: int, + world_size: int, +) -> CommunicationPattern: + """ + + Args: + global_edge_list (torch.Tensor)): A tensor of shape [E, 2] + partitioning (torch.Tensor): A tensor of shape [V] + rank (int): Rank of this process + world_size (int): Total number of processes + + Returns: + CommunicationPattern + """ + local_verts = compute_local_vertices(partitioning, rank) + halo_verts = compute_halo_vertices(global_edge_list, partitioning, rank) + local_edges = compute_local_edge_list( + global_edge_list, partitioning, local_verts, halo_verts, rank + ) + send_idx, send_off = compute_boundary_vertices( + global_edge_list, partitioning, local_verts, rank, world_size + ) + comm = compute_comm_map(send_off, world_size) + recv_off, recv_back_off = compute_recv_offsets(comm, rank) + + return CommunicationPattern( + rank=rank, + world_size=world_size, + num_local_vertices=local_verts.size(0), + num_halo_vertices=halo_verts.size(0), + comm_map=comm, + send_local_idx=send_idx, + send_offset=send_off, + recv_offset=recv_off, + local_edge_list=local_edges, + put_forward_remote_offset=comm[:rank, :].sum(0), + put_backward_remote_offset=comm[:, :rank].sum(1), + ) diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index e4f58bc..65805e7 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -263,6 +263,23 @@ namespace Local } }; + // Add specialization + template <> + struct FloatAtomicAddOp + { + __device__ __forceinline__ void operator()(float4 *cur_addr, const float4 new_val) + { + // Cast the float4 pointer to a standard float pointer + float *addr_as_float = reinterpret_cast(cur_addr); + + // Atomically add each component individually + atomicAdd(&addr_as_float[0], new_val.x); + atomicAdd(&addr_as_float[1], new_val.y); + atomicAdd(&addr_as_float[2], new_val.z); + atomicAdd(&addr_as_float[3], new_val.w); + } + }; + template struct FloatSetOp { diff --git a/DGraph/distributed/haloExchange.py b/DGraph/distributed/haloExchange.py new file mode 100644 index 0000000..d8125cf --- /dev/null +++ b/DGraph/distributed/haloExchange.py @@ -0,0 +1,223 @@ +import torch +from typing import Optional +import torch.nn as nn +from DGraph import Communicator +from torch.autograd import Function +from DGraph.distributed.commInfo import CommunicationPattern + + +class HaloExchangeImpl(Function): + """Backend-agnostic autograd function for halo vertex feature exchange. + + Performs the inter-rank communication portion of a halo exchange. The + gather step (indexing ``x_local`` by ``send_local_idx``) is intentionally + kept *outside* this class so that PyTorch's built-in autograd handles + gradient accumulation for vertices sent to multiple ranks via + ``scatter_add_`` automatically. + + Forward: + Allocates a receive buffer, then calls ``comm.put`` to deliver each + per-rank slice of ``send_buffer`` into the corresponding slot of the + remote rank's receive buffer. Uses ``comm_pattern.put_forward_remote_offset`` + as the one-sided write offset (ignored by two-sided backends). + + Backward: + Transposes the forward communication: ``send_offset`` and ``recv_offset`` + swap roles, and ``put_backward_remote_offset`` is used as the write + offset. Returns ``grad_send_buffer``; gradients for ``comm`` and + ``comm_pattern`` are ``None`` (non-differentiable). + + See Also: + ``HaloExchange`` for the user-facing wrapper. + ``DGraphMessagePassing`` for an end-to-end usage example. + ``DGraph/distributed/HaloExchangeDocument.md`` for full design details. + """ + + @staticmethod + def forward( + ctx, send_buffer, comm, comm_pattern: CommunicationPattern + ) -> torch.Tensor: + # Allocate the receive buffer + feature_dim = send_buffer.shape[1] if send_buffer.ndim == 2 else 1 + total_recv = int(comm_pattern.recv_offset[-1].item()) + + ctx.comm_pattern = comm_pattern + ctx.feature_dim = feature_dim + ctx.comm = comm + recv_buffer = comm.alloc_buffer( + (total_recv, feature_dim) if send_buffer.ndim == 2 else (total_recv,), + dtype=send_buffer.dtype, + device=send_buffer.device, + ) + # TODO: complete + send_offsets = comm_pattern.send_offset + recv_offsets = comm_pattern.recv_offset + put_forward_remote_offset = comm_pattern.put_forward_remote_offset + comm.put( + send_buffer, + recv_buffer, + send_offsets, + recv_offsets, + remote_offsets=put_forward_remote_offset, + ) + + return recv_buffer + + @staticmethod + def backward(ctx, grad_recv_buffer): + total_sent = ctx.comm_pattern.send_offset[-1].item() + feature_dim = ctx.feature_dim + comm = ctx.comm + + grad_input_tensor = comm.alloc_buffer( + (total_sent, feature_dim) if grad_recv_buffer.ndim == 2 else (total_sent,), + dtype=grad_recv_buffer.dtype, + device=grad_recv_buffer.device, + ) + send_offsets = ctx.comm_pattern.send_offset + recv_offsets = ctx.comm_pattern.recv_offset + put_backward_remote_offset = ctx.comm_pattern.put_backward_remote_offset + comm.put( + grad_recv_buffer, + grad_input_tensor, + recv_offsets, + send_offsets, + remote_offsets=put_backward_remote_offset, + ) + + return grad_input_tensor, None, None + + +class HaloExchange: + """Exchanges halo vertex features between ranks for distributed GNN computation. + + Wraps the full gather → communicate → return pipeline behind a single + callable. The result is autograd-compatible: gradients flow back through + the communication and through the gather indexing step automatically. + + Usage:: + + exchanger = HaloExchange(communicator) + halo_features = exchanger(local_node_features, comm_pattern) + # halo_features: [num_halo, F] — features of remote neighbour vertices + + The returned ``halo_features`` can be concatenated with ``local_node_features`` + to form the augmented subgraph consumed by a GNN layer. See + ``DGraphMessagePassing`` for a complete example. + + Args: + comm (Communicator): Initialised communicator backed by any supported + engine (NCCL, MPI, NVSHMEM). + + See Also: + ``DGraphMessagePassing`` — end-to-end message-passing wrapper. + ``commInfo.build_communication_pattern`` — how to build ``comm_pattern``. + ``DGraph/distributed/HaloExchangeDocument.md`` — full design details. + """ + + def __init__(self, comm: Communicator): + self.comm = comm + + def __call__( + self, x_local: torch.Tensor, comm_pattern: CommunicationPattern + ) -> torch.Tensor: + """Exchange boundary vertex features with neighbouring ranks. + + Args: + x_local (torch.Tensor): Local node feature matrix of shape + ``[num_local, F]``. + comm_pattern (CommunicationPattern): Precomputed communication + pattern for this rank (see ``commInfo.build_communication_pattern``). + + Returns: + torch.Tensor: Halo node features of shape ``[num_halo, F]``, + one row per remote neighbour vertex ordered by + ``comm_pattern.recv_offset``. + """ + send_buffer = x_local[comm_pattern.send_local_idx] + recv_buffer = HaloExchangeImpl.apply(send_buffer, self.comm, comm_pattern) + return recv_buffer # type: ignore + + +class DGraphMessagePassing(nn.Module): + """Distributed GNN message-passing layer wrapper. + + Combines halo exchange with a user-supplied message-passing layer into a + single ``nn.Module``. This is the recommended way to build a distributed + GNN layer with DGraph. + + The forward pass performs three steps: + + 1. **Halo exchange** — fetch features of remote neighbour vertices from + other ranks via ``HaloExchange``. + 2. **Augment** — concatenate local and halo features into a single tensor + indexed by ``comm_pattern.local_edge_list``. + 3. **Message passing** — run the wrapped ``message_passing_layer`` on the + augmented local subgraph. + + The ``message_passing_layer`` must accept the following positional arguments:: + + output = message_passing_layer( + node_features, # [num_local + num_halo, F] + edge_index, # [num_local_edges, 2] (local integer indices) + edge_features, # [num_local_edges, E] or None + ) + + and must return updated features only for the *local* vertices (i.e. + ``output`` has shape ``[num_local, F_out]``). + + Usage:: + + conv = MyGNNConv(in_channels=64, out_channels=64) + layer = DGraphMessagePassing(exchanger=HaloExchange(comm), message_passing_layer=conv) + + # Inside the training loop (all tensors already on the correct device): + updated = layer(local_node_features, comm_pattern, local_edge_features) + + Args: + exchanger (HaloExchange): Initialised ``HaloExchange`` instance. + message_passing_layer (nn.Module): Any GNN conv / message-passing + module that matches the calling convention above. + + See Also: + ``HaloExchange`` — the underlying exchange primitive. + ``commInfo.build_communication_pattern`` — how to build ``comm_pattern``. + """ + + def __init__(self, exchanger: HaloExchange, message_passing_layer: nn.Module): + super(DGraphMessagePassing, self).__init__() + self.message_passing_layer = message_passing_layer + self.exchanger = exchanger + + def forward( + self, + local_node_features: torch.Tensor, + comm_pattern: CommunicationPattern, + local_edge_features: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run halo exchange then message passing on the augmented local subgraph. + + Args: + local_node_features (torch.Tensor): Feature matrix for locally-owned + vertices, shape ``[num_local, F]``. + comm_pattern (CommunicationPattern): Precomputed communication pattern + for this rank. + local_edge_features (Optional[torch.Tensor]): Edge feature matrix for + local edges, shape ``[num_local_edges, E]``, or ``None`` if the + underlying layer does not use edge features. + + Returns: + torch.Tensor: Updated feature matrix for locally-owned vertices, + shape ``[num_local, F_out]``, where ``F_out`` is determined by + ``message_passing_layer``. + """ + + halo_node_features = self.exchanger(local_node_features, comm_pattern) + + local_subgraph = torch.cat([local_node_features, halo_node_features], dim=0) + + local_updated_node_features = self.message_passing_layer( + local_subgraph, comm_pattern.local_edge_list, local_edge_features + ) + + return local_updated_node_features diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index 2e02a82..f69c1a5 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -26,7 +26,7 @@ from torch.autograd import Function from DGraph.utils import largest_split -from typing import overload +from typing import overload, List TIMINGS = {"Gather_Index_Forward": [], "Gather_Forward_Local": []} @@ -254,6 +254,57 @@ def destroy(self) -> None: # dist.destroy_process_group() NCCLBackendEngine._is_initialized = False + def _get_splits(self, send_offset, recv_offset) -> tuple[List[int], List[int]]: + """ + Return (send_splits, recv_splits) as plain Python lists of ints. + + send_splits[i] = number of *vertices* this rank sends to rank i. + recv_splits[i] = number of *vertices* this rank receives from rank i. + + These are in vertex units; the caller must multiply by feature_dim + before passing to all_to_all_single. + """ + + send_off = send_offset + recv_off = recv_offset + + send_splits = (send_off[1:] - send_off[:-1]).tolist() + recv_splits = (recv_off[1:] - recv_off[:-1]).tolist() + + return (send_splits, recv_splits) + + @staticmethod + def _scale_splits(splits: List[int], factor: int) -> List[int]: + """Multiply each split count by a scalar (feature dimension).""" + return [s * factor for s in splits] + + def put( + self, + send_buffer: torch.Tensor, + recv_buffer: torch.Tensor, + send_offsets: torch.Tensor, + recv_offsets: torch.Tensor, + remote_offsets: torch.Tensor | None = None, + ) -> None: + _ = remote_offsets # remote_offsets not needed in 2-sided semantices + + send_splits, recv_splits = self._get_splits( + send_offset=send_offsets, recv_offset=recv_offsets + ) + feature_dim = send_buffer.shape[1] if send_buffer.ndim == 2 else 1 + + # all_to_all_single operates on flat views; split sizes are in + # element counts, so scale vertex counts by feature_dim. + send_flat = send_buffer.contiguous().view(-1) + recv_flat = recv_buffer.contiguous().view(-1) + + dist.all_to_all_single( + output=recv_flat, + input=send_flat, + output_split_sizes=self._scale_splits(recv_splits, feature_dim), + input_split_sizes=self._scale_splits(send_splits, feature_dim), + ) + def finalize(self) -> None: if NCCLBackendEngine._is_initialized: dist.barrier() diff --git a/DGraph/utils/TimingReport.py b/DGraph/utils/TimingReport.py index 2f8258f..48b94a8 100644 --- a/DGraph/utils/TimingReport.py +++ b/DGraph/utils/TimingReport.py @@ -1,6 +1,5 @@ import torch -import torch.distributed as dist -from DGraph.Communicator import Communicator +from typing import Optional class TimingReport: @@ -10,13 +9,30 @@ class TimingReport: _communicator = None _is_initialized = False - def __init__( - self, - ): - pass + def __init__(self, name: Optional[str] = None): + """ + Initialize the instance, optionally with a name for use as a context manager. + """ + self.name = name + + def __enter__(self): + """Start the timer when entering the 'with' block.""" + if self.name is None: + raise ValueError( + "A name must be provided to use TimingReport as a context manager." + ) + self.start(self.name) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop the timer when exiting the 'with' block.""" + if self.name is not None: + self.stop(self.name) + # Returning False ensures any exceptions raised inside the 'with' block are not suppressed. + return False @staticmethod - def init(communicator: Communicator): + def init(communicator): """Initialize the TimingReport with a communicator.""" if TimingReport._is_initialized: raise RuntimeError("TimingReport is already initialized.") diff --git a/experiments/OGB/GCN.py b/experiments/OGB/GCN.py index d14bcdd..2b3cab4 100644 --- a/experiments/OGB/GCN.py +++ b/experiments/OGB/GCN.py @@ -13,20 +13,58 @@ # SPDX-License-Identifier: (Apache-2.0) import torch import torch.nn as nn -import torch.distributed as dist + +from DGraph.distributed import HaloExchange, CommunicationPattern from DGraph.utils.TimingReport import TimingReport +from DGraph import Communicator +import torch.distributed as dist +import sys -class ConvLayer(nn.Module): - def __init__(self, in_channels, out_channels): - super(ConvLayer, self).__init__() - self.conv = nn.Linear(in_channels, out_channels) +import torch +import torch.nn as nn + + +class GraphConvLayer(nn.Module): + def __init__(self, message_dim, out_channels): + super(GraphConvLayer, self).__init__() + self.conv = nn.Linear(message_dim, out_channels) self.act = nn.ReLU(inplace=True) - def forward(self, x): - x = self.conv(x) - x = self.act(x) - return x + def forward(self, x, edge_index, num_local_nodes, edge_features=None): + source_vertices = edge_index[:, 0] + target_vertices = edge_index[:, 1] + + assert (source_vertices < num_local_nodes).all(), ( + f"Graph routing error: Found source_vertices >= num_local_nodes ({num_local_nodes}). " + "Boundary nodes must only act as targets (x_j) in this aggregation scheme!" + ) + + x_i = x[source_vertices, :] + x_j = x[target_vertices, :] + + if edge_features is not None: + x_ij = torch.cat([x_i, x_j, edge_features], dim=1) + else: + x_ij = torch.cat([x_i, x_j], dim=1) + + m_ij = self.conv(x_ij) + m_ij = self.act(m_ij) + + out_channels = m_ij.size(1) + + # 1. Allocate ONLY for the local nodes (the sources we are aggregating to) + out = torch.zeros(num_local_nodes, out_channels, dtype=x.dtype, device=x.device) + + # 2. Scatter messages back to the SOURCE vertices + scatter_index = ( + source_vertices.unsqueeze(-1).expand(-1, out_channels).to(x.device) + ) + + # 3. Perform the aggregation + out = out.scatter_add(0, scatter_index, m_ij) + + return out class CommAwareGCN(nn.Module): @@ -35,61 +73,46 @@ class CommAwareGCN(nn.Module): but good enough for the purpose of testing. """ - def __init__(self, in_channels, hidden_dims, num_classes, comm): + def __init__( + self, + in_channels: int, + hidden_dims: int, + num_classes: int, + halo_exchanger: HaloExchange, + comm: Communicator, + ): super(CommAwareGCN, self).__init__() + self.halo_exchanger = halo_exchanger + + self.conv1 = GraphConvLayer(2 * in_channels, hidden_dims) + + self.conv2 = GraphConvLayer(2 * hidden_dims, hidden_dims) - self.conv1 = ConvLayer(in_channels, hidden_dims) - self.conv2 = ConvLayer(hidden_dims, hidden_dims) self.fc = nn.Linear(hidden_dims, num_classes) self.softmax = nn.Softmax(dim=1) self.comm = comm def forward( - self, - node_features, - edge_index, - rank_mapping, - gather_cache=None, - scatter_cache=None, + self, local_node_features: torch.Tensor, comm_pattern: CommunicationPattern ): - num_local_nodes = node_features.size(1) - _src_indices = edge_index[:, 0, :] - _dst_indices = edge_index[:, 1, :] - TimingReport.start("pre-processing") - _src_rank_mappings = torch.cat( - [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 - ) - _dst_rank_mappings = torch.cat( - [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 - ) - TimingReport.stop("pre-processing") - TimingReport.start("Gather_1") - x = self.comm.gather( - node_features, _dst_indices, _dst_rank_mappings, cache=gather_cache - ) - TimingReport.stop("Gather_1") - TimingReport.start("Conv_1") - x = self.conv1(x) - TimingReport.stop("Conv_1") - TimingReport.start("Scatter_1") - x = self.comm.scatter( - x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache - ) - TimingReport.stop("Scatter_1") - TimingReport.start("Gather_2") - x = self.comm.gather(x, _dst_indices, _dst_rank_mappings, cache=gather_cache) - TimingReport.stop("Gather_2") - TimingReport.start("Conv_2") - x = self.conv2(x) - TimingReport.stop("Conv_2") - TimingReport.start("Scatter_2") - x = self.comm.scatter( - x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache - ) - TimingReport.stop("Scatter_2") - TimingReport.start("Final_FC") - x = self.fc(x) - TimingReport.stop("Final_FC") + num_local_nodes = local_node_features.shape[0] + + with TimingReport("feature-exchange-1"): + boundary_features = self.halo_exchanger(local_node_features, comm_pattern) + + with TimingReport("process-1"): + x = torch.cat([local_node_features, boundary_features], dim=0) + x = self.conv1(x, comm_pattern.local_edge_list, num_local_nodes) + + with TimingReport("feature-exchange-2"): + boundary_features = self.halo_exchanger(x, comm_pattern) + + with TimingReport("process-2"): + x = torch.cat([x, boundary_features], dim=0) + x = self.conv2(x, comm_pattern.local_edge_list, num_local_nodes) + + with TimingReport("final-fc"): + x = self.fc(x) # x = self.softmax(x) return x diff --git a/experiments/OGB/main.py b/experiments/OGB/main.py index 8deb383..5535c16 100644 --- a/experiments/OGB/main.py +++ b/experiments/OGB/main.py @@ -11,312 +11,214 @@ # https://github.com/LBANN and https://github.com/LLNL/LBANN. # # SPDX-License-Identifier: (Apache-2.0) -import sys -from time import perf_counter +""" +Distributed GCN benchmark on OGB node-property-prediction datasets. +""" +import os +import json from typing import Optional -from DGraph.data.datasets import DistributedOGBWrapper -from DGraph.Communicator import CommunicatorBase, Communicator -from DGraph.distributed.nccl._nccl_cache import ( - NCCLGatherCacheGenerator, - NCCLScatterCacheGenerator, -) import fire +import numpy as np import torch -import torch.optim as optim import torch.distributed as dist +import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP + +from DGraph.Communicator import Communicator +from DGraph.distributed import HaloExchange from DGraph.utils.TimingReport import TimingReport + from GCN import CommAwareGCN as GCN +from ogb_comm_dataset import DGraphOGBDataset from utils import ( + calculate_accuracy, + cleanup, dist_print_ephemeral, make_experiment_log, - write_experiment_log, - cleanup, - visualize_trajectories, safe_create_dir, - calculate_accuracy, + visualize_trajectories, + write_experiment_log, ) -import numpy as np -import os -import json -class SingleProcessDummyCommunicator(CommunicatorBase): - def __init__(self): - super().__init__() - self._rank = 0 - self._world_size = 1 - self._is_initialized = True - self.backend = "single" - - def get_rank(self): - return self._rank - - def get_world_size(self): - return self._world_size - - def scatter( - self, - tensor: torch.Tensor, - src: torch.Tensor, - rank_mappings, - num_local_nodes, - **kwargs, - ): - # TODO: Wrap this in the datawrapper class - src = src.unsqueeze(-1).expand(1, -1, tensor.shape[-1]) - out = torch.zeros(1, num_local_nodes, tensor.shape[-1]).to(tensor.device) - out.scatter_add(1, src, tensor) - return out - - def gather(self, tensor, dst, rank_mappings, **kwargs): - # TODO: Wrap this in the datawrapper class - dst = dst.unsqueeze(-1).expand(1, -1, tensor.shape[-1]) - out = torch.gather(tensor, 1, dst) - return out - - def __str__(self) -> str: - return self.backend - - def rank_cuda_device(self): - device = torch.cuda.current_device() - return device - - def barrier(self): - # No-op for single process - pass +# --------------------------------------------------------------------------- +# Training / evaluation loop +# --------------------------------------------------------------------------- def _run_experiment( - dataset, - comm, + dataset: DGraphOGBDataset, + comm: Communicator, lr: float, epochs: int, log_prefix: str, - in_dim: int = 128, - hidden_dims: int = 128, + hidden_dims: int = 256, num_classes: int = 40, - use_cache: bool = False, - dset_name: str = "arxiv", + device: str | torch.device = "cuda", + rank: int = 0, + local_rank: int = 0, ): - local_rank = comm.get_rank() % torch.cuda.device_count() - print(f"Rank: {local_rank} Local Rank: {local_rank}") - torch.cuda.set_device(local_rank) - device = torch.cuda.current_device() - model = GCN( - in_channels=in_dim, hidden_dims=hidden_dims, num_classes=num_classes, comm=comm - ) - rank = comm.get_rank() - model = model.to(device) + """Run one full training + validation + test experiment. - model = ( - DDP(model, device_ids=[local_rank], output_device=local_rank) - if comm.get_world_size() > 1 - else model - ) - optimizer = optim.Adam(model.parameters(), lr=lr) + Args: + dataset: Loaded ``DGraphOGBDataset`` for this rank. + node_rank_placement: [V] global vertex→rank assignment tensor. + comm: Initialised communicator. + lr: Adam learning rate. + epochs: Number of training epochs. + log_prefix: Path prefix for all output log files. + hidden_dims: Hidden layer width for the GCN. + num_classes: Number of output classes. - stream = torch.cuda.Stream() + Returns: + Tuple of (training_loss, validation_loss, validation_accuracy) numpy arrays, + each of length ``epochs``. + """ + + # ---- Extract local data from the dataset -------------------------------- + comm_pattern = dataset.comm_pattern + + local_node_features, local_labels, comm_pattern = dataset[0] + local_node_features = local_node_features.to(device) + local_labels = local_labels.to(device) - node_features, edge_indices, rank_mappings, labels = dataset[0] + in_dim = local_node_features.shape[1] - node_features = node_features.to(device).unsqueeze(0) - edge_indices = edge_indices.to(device).unsqueeze(0) - labels = labels.to(device).unsqueeze(0) - rank_mappings = rank_mappings + local_masks = dataset.get_masks() + train_mask = local_masks["train_mask"].to(device) + val_mask = local_masks["val_mask"].to(device) + test_mask = local_masks["test_mask"].to(device) if rank == 0: - print("*" * 80) - for i in range(comm.get_world_size()): - if i == rank: - print(f"Rank: {rank} Mapping: {rank_mappings.shape}") - print(f"Rank: {rank} Node Features: {node_features.shape}") - print(f"Rank: {rank} Edge Indices: {edge_indices.shape}") + print( + f"Dataset loaded — " + f"local nodes: {local_node_features.shape[0]}, " + f"in_dim: {in_dim}, " + f"num_classes: {num_classes}" + ) - comm.barrier() + # ---- Model setup -------------------------------------------------------- + halo_exchanger = HaloExchange(comm) + model = GCN( + in_channels=in_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + halo_exchanger=halo_exchanger, + comm=comm, + ).to(device) + + if comm.get_world_size() > 1: + model = DDP(model, device_ids=[local_rank], output_device=local_rank) + + optimizer = optim.Adam(model.parameters(), lr=lr) criterion = torch.nn.CrossEntropyLoss() - train_mask = dataset.graph_obj.get_local_mask("train", rank) - validation_mask = dataset.graph_obj.get_local_mask("val", rank) + stream = torch.cuda.Stream() + + # ---- Training loop ------------------------------------------------------ training_loss_scores = [] validation_loss_scores = [] validation_accuracy_scores = [] + training_times = [] - world_size = comm.get_world_size() - - print(f"Rank: {rank} training_mask: {train_mask.shape}") - print(f"Rank: {rank} validation_mask: {validation_mask.shape}") - - gather_cache = None - scatter_cache = None - - if use_cache: - print(f"Rank: {rank} Using Cache. Generating Cache") - start_time = perf_counter() - src_indices = edge_indices[:, 0, :] - dst_indices = edge_indices[:, 1, :] - - # This says where the edges are located - edge_placement = rank_mappings[0] - - cache_prefix = f"cache/{dset_name}" - scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt" - gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt" - - if os.path.exists(gather_cache_file): - gather_cache = torch.load(gather_cache_file, weights_only=False) - - if os.path.exists(scatter_cache_file): - scatter_cache = torch.load(scatter_cache_file, weight_only=False) - - # These say where the source and destination nodes are located - edge_src_placement = rank_mappings[ - 0 - ] # Redundant but making explicit for clarity - edge_dest_placement = rank_mappings[1] - - num_input_rows = node_features.size(1) - local_num_edges = (edge_placement == rank).sum().item() - - if gather_cache is None: - gather_cache = NCCLGatherCacheGenerator( - dst_indices, - edge_placement, - edge_dest_placement, - num_input_rows, - rank, - world_size, - ) - with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: - torch.save(gather_cache, f) - - if scatter_cache is None: - nodes_per_rank = dataset.graph_obj.get_nodes_per_rank() - - scatter_cache = NCCLScatterCacheGenerator( - src_indices, - edge_placement, - edge_src_placement, - nodes_per_rank[rank], - rank, - world_size, - ) - with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: - torch.save(scatter_cache, f) - - # Sanity checks for the cache - for key, value in gather_cache.gather_send_local_placement.items(): - assert value.max().item() < num_input_rows - assert key < world_size - assert key != rank - assert value.shape[0] == gather_cache.gather_send_comm_vector[key] - - for key, value in gather_cache.gather_recv_local_placement.items(): - assert value.max().item() < local_num_edges - assert key < world_size - assert key != rank - assert value.shape[0] == gather_cache.gather_recv_comm_vector[key] - - for rank, value in scatter_cache.gather_send_local_placement.items(): - assert value.max().item() < local_num_edges - assert rank < world_size - assert rank != rank - assert value.shape[0] == scatter_cache.gather_send_comm_vector - - for rank, value in scatter_cache.gather_recv_local_placement.items(): - assert value.max().item() < num_input_rows - assert rank < world_size - assert rank != rank - assert value.shape[0] == scatter_cache.gather_recv_comm_vector - end_time = perf_counter() - print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s") - - # with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: - # torch.save(gather_cache, f) - # with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: - # torch.save(scatter_cache, f) - # print(f"Rank: {rank} Cache Generated") + make_experiment_log(f"{log_prefix}_training_loss.log", rank) + make_experiment_log(f"{log_prefix}_validation_loss.log", rank) + make_experiment_log(f"{log_prefix}_validation_accuracy.log", rank) - training_times = [] - for i in range(epochs): + for epoch in range(epochs): + model.train() comm.barrier() torch.cuda.synchronize() + start_time = torch.cuda.Event(enable_timing=True) end_time = torch.cuda.Event(enable_timing=True) start_time.record(stream) + optimizer.zero_grad() - _output = model( - node_features, edge_indices, rank_mappings, gather_cache, scatter_cache - ) - # Must flatten along the batch dimension for the loss function - output = _output[:, train_mask].view(-1, num_classes) - gt = labels[:, train_mask].view(-1) - loss = criterion(output, gt) - loss.backward() - dist_print_ephemeral(f"Epoch {i} \t Loss: {loss.item()}", rank) + + with TimingReport("forward"): + output = model(local_node_features, comm_pattern) + + train_output = output[train_mask] + train_labels = local_labels[train_mask] + loss = criterion(train_output, train_labels.reshape(-1)) + + with TimingReport("backward"): + loss.backward() + optimizer.step() comm.barrier() end_time.record(stream) torch.cuda.synchronize() - training_times.append(start_time.elapsed_time(end_time)) + + elapsed_ms = start_time.elapsed_time(end_time) + training_times.append(elapsed_ms) training_loss_scores.append(loss.item()) + + dist_print_ephemeral( + f"Epoch {epoch:4d} | loss: {loss.item():.4f} | {elapsed_ms:.1f} ms", + rank, + ) write_experiment_log(str(loss.item()), f"{log_prefix}_training_loss.log", rank) + # ---- Validation ----------------------------------------------------- model.eval() with torch.no_grad(): - validation_preds = _output[:, validation_mask].view(-1, num_classes) - label_validation = labels[:, validation_mask].view(-1) - validation_score = criterion( - validation_preds, - label_validation, - ) - write_experiment_log( - str(validation_score.item()), f"{log_prefix}_validation_loss.log", rank - ) - - validation_loss_scores.append(validation_score.item()) - - val_pred = torch.log_softmax(validation_preds, dim=1) - accuracy = calculate_accuracy(val_pred, label_validation) - validation_accuracy_scores.append(accuracy) - write_experiment_log( - f"Validation Accuracy: {accuracy:.2f}", - f"{log_prefix}_validation_accuracy.log", - rank, - ) - model.train() + val_output = output[val_mask] + val_labels = local_labels[val_mask] + val_loss = criterion(val_output, val_labels.reshape(-1)) + val_preds = torch.log_softmax(val_output, dim=1) + val_accuracy = calculate_accuracy(val_preds, val_labels) + + validation_loss_scores.append(val_loss.item()) + validation_accuracy_scores.append(val_accuracy) + write_experiment_log( + str(val_loss.item()), f"{log_prefix}_validation_loss.log", rank + ) + write_experiment_log( + f"Validation Accuracy: {val_accuracy:.2f}", + f"{log_prefix}_validation_accuracy.log", + rank, + ) torch.cuda.synchronize() + # ---- Test evaluation ---------------------------------------------------- model.eval() - with torch.no_grad(): - test_idx = dataset.graph_obj.get_local_mask("test", rank) - test_labels = labels[:, test_idx].view(-1) - test_preds = model(node_features, edge_indices, rank_mappings)[:, test_idx] - test_preds = test_preds.view(-1, num_classes) - test_loss = criterion(test_preds, test_labels) - test_preds = torch.log_softmax(test_preds, dim=1) - test_accuracy = calculate_accuracy(test_preds, test_labels) - test_log_file = f"{log_prefix}_test_results.log" - write_experiment_log( - "loss,accuracy", - test_log_file, - rank, + test_output = model(local_node_features, comm_pattern) + test_preds = test_output[test_mask] + test_labels = local_labels[test_mask] + test_loss = criterion(test_preds, test_labels.reshape(-1)) + test_preds_log = torch.log_softmax(test_preds, dim=1) + test_accuracy = calculate_accuracy(test_preds_log, test_labels) + + test_log_file = f"{log_prefix}_test_results.log" + make_experiment_log(test_log_file, rank) + write_experiment_log("loss,accuracy", test_log_file, rank) + write_experiment_log(f"{test_loss.item()},{test_accuracy}", test_log_file, rank) + + if rank == 0: + print( + f"\nTest | loss: {test_loss.item():.4f} | accuracy: {test_accuracy:.2f}%" ) - write_experiment_log(f"{test_loss.item()},{test_accuracy}", test_log_file, rank) + # ---- Timing summary ----------------------------------------------------- make_experiment_log(f"{log_prefix}_training_times.log", rank) - make_experiment_log(f"{log_prefix}_runtime_experiment.log", rank) + for t in training_times: + write_experiment_log(str(t), f"{log_prefix}_training_times.log", rank) - for times in training_times: - write_experiment_log(str(times), f"{log_prefix}_training_times.log", rank) - - average_time = np.mean(training_times[1:]) - log_str = f"Average time per epoch: {average_time:.4f} ms" - write_experiment_log(log_str, f"{log_prefix}_runtime_experiment.log", rank) + average_time = ( + np.mean(training_times[1:]) if len(training_times) > 1 else training_times[0] + ) + make_experiment_log(f"{log_prefix}_runtime_experiment.log", rank) + write_experiment_log( + f"Average time per epoch (excl. first): {average_time:.4f} ms", + f"{log_prefix}_runtime_experiment.log", + rank, + ) return ( np.array(training_loss_scores), @@ -325,103 +227,136 @@ def _run_experiment( ) +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + def main( - backend: str = "single", + backend: str = "nccl", dataset: str = "arxiv", - epochs: int = 3, + epochs: int = 10, lr: float = 0.001, runs: int = 1, + hidden_dims: int = 256, log_dir: str = "logs", node_rank_placement_file: Optional[str] = None, - use_cache: bool = False, + root_dir: Optional[str] = None, ): - _communicator = backend.lower() - dset_name = dataset - assert _communicator.lower() in [ - "single", + """Distributed GCN benchmark on OGB node-property-prediction datasets. + + Args: + backend: Communication backend — one of ``nccl``, ``mpi``, ``nvshmem``. + dataset: OGB dataset name — one of ``arxiv``, ``products``. + epochs: Number of training epochs per run. + lr: Adam learning rate. + runs: Number of independent runs (for mean/std reporting). + hidden_dims: Hidden layer width for the GCN. + log_dir: Directory to write log files and plots. + node_rank_placement_file: Path to a ``.pt`` file containing a [V] + int64 tensor mapping each global vertex to its assigned rank. + Required for all distributed backends. + """ + assert backend.lower() in ( "nccl", - "nvshmem", "mpi", - ], "Invalid backend" - - in_dims = {"arxiv": 128, "products": 100} - - assert dataset in ["arxiv", "products"], "Invalid dataset" - - node_rank_placement = None - if _communicator.lower() == "single": - # Dummy communicator for single process testing - comm = SingleProcessDummyCommunicator() - - else: - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - comm = Communicator.init_process_group(_communicator) - - # Must pass the node rank placement file the first time - if node_rank_placement_file is not None: - assert os.path.exists( - node_rank_placement_file - ), "Node rank placement file not found" - node_rank_placement = torch.load( - node_rank_placement_file, weights_only=False - ) + "nvshmem", + ), f"Unsupported backend '{backend}'. Choose from: nccl, mpi, nvshmem." + assert dataset in ( + "arxiv", + "products", + ), f"Unsupported dataset '{dataset}'. Choose from: arxiv, products." + + num_classes = {"arxiv": 40, "products": 47}[dataset] + + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + comm = Communicator.init_process_group(backend.lower()) + + rank = comm.get_rank() + world_size = comm.get_world_size() + + rank = comm.get_rank() + local_rank = rank % torch.cuda.device_count() + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + TimingReport.init(comm) - safe_create_dir(log_dir, comm.get_rank()) - training_dataset = DistributedOGBWrapper( - f"ogbn-{dataset}", - comm, + safe_create_dir(log_dir, rank) + + # ---- Node rank placement ------------------------------------------------ + assert node_rank_placement_file is not None, ( + "--node_rank_placement_file is required. " + "Generate one with preprocess.py before running this script." + ) + assert os.path.exists( + node_rank_placement_file + ), f"Node rank placement file not found: {node_rank_placement_file}" + node_rank_placement = torch.load(node_rank_placement_file, weights_only=False) + + # ---- Dataset ------------------------------------------------------------ + training_dataset = DGraphOGBDataset( + dname=f"ogbn-{dataset}", + comm=comm, node_rank_placement=node_rank_placement, - force_reprocess=True, + root_dir=root_dir, ) - num_classes = training_dataset.num_classes - - training_trajectores = np.zeros((runs, epochs)) - validation_trajectores = np.zeros((runs, epochs)) + # ---- Runs --------------------------------------------------------------- + training_trajectories = np.zeros((runs, epochs)) + validation_trajectories = np.zeros((runs, epochs)) validation_accuracies = np.zeros((runs, epochs)) - world_size = comm.get_world_size() - for i in range(runs): - log_prefix = f"{log_dir}/{dataset}_{world_size}_cache={use_cache}_run_{i}" - training_traj, val_traj, val_accuracy = _run_experiment( - training_dataset, - comm, - lr, - epochs, - log_prefix, - use_cache=use_cache, + + for run in range(runs): + if rank == 0: + print(f"\n{'='*60}") + print(f"Run {run + 1}/{runs}") + print(f"{'='*60}") + + log_prefix = f"{log_dir}/{dataset}_world{world_size}_run{run}" + train_loss, val_loss, val_acc = _run_experiment( + dataset=training_dataset, + comm=comm, + lr=lr, + epochs=epochs, + log_prefix=log_prefix, + hidden_dims=hidden_dims, num_classes=num_classes, - dset_name=dset_name, - in_dim=in_dims[dset_name], + device=device, + rank=rank, + local_rank=local_rank, ) - training_trajectores[i] = training_traj - validation_trajectores[i] = val_traj - validation_accuracies[i] = val_accuracy + training_trajectories[run] = train_loss + validation_trajectories[run] = val_loss + validation_accuracies[run] = val_acc + # ---- Timing report ------------------------------------------------------ write_experiment_log( json.dumps(TimingReport._timers), - f"{log_dir}/{dset_name}_timing_report_world_size_{world_size}_cache_{use_cache}.json", - comm.get_rank(), + f"{log_dir}/{dataset}_timing_report_world{world_size}.json", + rank, ) + # ---- Plots -------------------------------------------------------------- visualize_trajectories( - training_trajectores, + training_trajectories, "Training Loss", f"{log_dir}/training_loss.png", - comm.get_rank(), + rank, ) visualize_trajectories( - validation_trajectores, + validation_trajectories, "Validation Loss", f"{log_dir}/validation_loss.png", - comm.get_rank(), + rank, ) visualize_trajectories( validation_accuracies, "Validation Accuracy", f"{log_dir}/validation_accuracy.png", - comm.get_rank(), + rank, ) + cleanup() diff --git a/experiments/OGB/ogb_comm_dataset.py b/experiments/OGB/ogb_comm_dataset.py new file mode 100644 index 0000000..f739a49 --- /dev/null +++ b/experiments/OGB/ogb_comm_dataset.py @@ -0,0 +1,174 @@ +from typing import Optional, Tuple +import torch +from torch.utils.data import Dataset +from ogb.nodeproppred import NodePropPredDataset +from DGraph.Communicator import CommunicatorBase +from DGraph.distributed import CommunicationPattern, build_communication_pattern +from DGraph.data.graph import get_round_robin_node_rank_map + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_local_split_masks( + node_rank_placement: torch.Tensor, + split_idx: dict, + rank: int, +) -> dict[str, torch.Tensor]: + """Convert global OGB split indices into boolean masks over *local* nodes. + + Args: + node_rank_placement: [V] tensor mapping each global vertex to its rank. + split_idx: dict with keys 'train', 'valid', 'test', each a 1-D tensor + of global node indices (as returned by ogb's ``get_idx_split``). + rank: this process's rank. + + Returns: + Dict with keys 'train', 'valid', 'test', each a boolean tensor of + shape [num_local] that is True for local nodes belonging to that split. + """ + V = node_rank_placement.shape[0] + local_node_global_ids = torch.where(node_rank_placement == rank)[0] + + masks = {} + for split_name, global_ids in split_idx.items(): + global_mask = torch.zeros(V, dtype=torch.bool) + global_mask[global_ids] = True + masks[split_name] = global_mask[local_node_global_ids] + return masks + + +def generate_communication_pattern( + edge_index: torch.Tensor, + node_rank_placement: torch.Tensor, + rank: int, + world_size: int, +) -> CommunicationPattern: + comm_pattern = build_communication_pattern( + edge_index, node_rank_placement, rank, world_size + ) + return comm_pattern + + +class DGraphOGBDataset(Dataset): + def __init__( + self, + dname: str, + comm: CommunicatorBase, + node_rank_placement: Optional[torch.Tensor] = None, + root_dir: Optional[str] = None, + *args, + **kwargs, + ) -> None: + """ + Args: + dname (str): Name of the dataset + comm (CommunicatorBase): Communicator object + node_rank_placement (torch.Tensor): Node rank placement, where node_rank_placement[i] is the rank of the node i + *args: + **kwargs: + + """ + super().__init__() + self.comm_object = comm + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + + comm.barrier() + + if self.rank == 0: + # Load dataset on rank 0 first + self.dataset = NodePropPredDataset( + name=dname, root=root_dir if root_dir else "dataset" + ) + + comm.barrier() + + # Load dataset on all other ranks + if self.rank != 0: + self.dataset = NodePropPredDataset( + name=dname, root=root_dir if root_dir else "dataset" + ) + + comm.barrier() + + graph_data, labels = self.dataset[0] + split_idx = self.dataset.get_idx_split() + + num_nodes = graph_data["num_nodes"] + node_features = torch.from_numpy(graph_data["node_feat"]).float() + edge_index = torch.from_numpy(graph_data["edge_index"]).long().T + labels = torch.from_numpy(labels).long() + + if node_rank_placement is None: + node_rank_placement = get_round_robin_node_rank_map( + num_nodes, self.world_size + ) + + self.comm_pattern = generate_communication_pattern( + edge_index, node_rank_placement, self.rank, self.world_size + ) + + local_nodes = node_rank_placement == self.rank + local_node_features = node_features[local_nodes, :] + local_labels = labels[local_nodes] + self.local_node_features = local_node_features + self.local_labels = local_labels + + rank = comm.get_rank() + assert split_idx is not None + + local_masks = _build_local_split_masks(node_rank_placement, split_idx, rank) + self.train_mask = local_masks["train"] + self.val_mask = local_masks["valid"] + self.test_mask = local_masks["test"] + + def get_masks(self): + local_masks = { + "train_mask": self.train_mask, + "val_mask": self.val_mask, + "test_mask": self.test_mask, + } + return local_masks + + def __len__(self) -> int: + return 1 + + def __getitem__( + self, index + ) -> Tuple[torch.Tensor, torch.Tensor, CommunicationPattern]: + return ( + self.local_node_features, + self.local_labels, + self.comm_pattern, + ) + + +if __name__ == "__main__": + from DGraph.Communicator import Communicator + + comm = Communicator.init_process_group("nccl") + + rank = comm.get_rank() + local_rank = rank % torch.cuda.device_count() + world_size = comm.get_world_size() + torch.cuda.set_device(local_rank) + + node_rank_placement = torch.load( + f"/p/vast1/zaman2/matrix/DGraph/experiments/OGB/ogbn-arxiv-mappings/ogbn-arxiv_vertex_rank_mapping_{world_size}.pt" + ) + dataset = DGraphOGBDataset( + dname="ogbn-arxiv", comm=comm, node_rank_placement=node_rank_placement + ) + + data, labels, comm_pattern = dataset[0] + if rank == 0: + import os + + print(comm_pattern.comm_map) + file_path = os.path.abspath(__file__) + # Get the directory containing the current script + file_dir = os.path.dirname(file_path) + print(f"Saving to {file_dir}/comm_map_{world_size}.pt") + torch.save(comm_pattern.comm_map, f"{file_dir}/comm_map_{world_size}.pt") diff --git a/tests/test_comm_info.py b/tests/test_comm_info.py new file mode 100644 index 0000000..06ebe43 --- /dev/null +++ b/tests/test_comm_info.py @@ -0,0 +1,689 @@ +# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +""" +Tests for DGraph.distributed.commInfo + +Single-process tests (no dist required): + Run with: python -m pytest tests/test_comm_info.py + +Distributed tests (require 2 GPUs): + Run with: torchrun --nnodes 1 --nproc-per-node 2 -m pytest tests/test_comm_info.py + +Test graphs +----------- +Homogeneous (4 vertices, 2 ranks): + + Vertices : 0, 1, 2, 3 + Partitioning: [0, 0, 1, 1] (rank 0 → {0,1}, rank 1 → {2,3}) + Edges (undirected → stored as directed pairs): + 0↔1 local on rank 0 + 0↔2 cross rank + 1↔3 cross rank + 2↔3 local on rank 1 + + Rank 0 expected: + local verts : [0, 1] + halo verts : [2, 3] + local edges : (0,1),(1,0),(0,2),(1,3) in local numbering + send_local_idx : [0, 1] (verts 0 and 1 sent to rank 1) + send_offset : [0, 0, 2] + recv_offset : [0, 0, 2] + + Rank 1 expected: + local verts : [2, 3] + halo verts : [0, 1] + local edges : (0,1),(1,0),(0,2),(1,3) in local numbering (2→0,3→1,0→2,1→3) + send_local_idx : [0, 1] (local indices of verts 2 and 3, sent to rank 0) + send_offset : [0, 2, 2] + recv_offset : [0, 2, 2] + + comm_map = [[0, 2], + [2, 0]] + +Heterogeneous (V_src=3, V_dst=4, 2 ranks): + + src_partitioning = [0, 0, 1] (rank 0 → src{0,1}, rank 1 → src{2}) + dst_partitioning = [0, 0, 1, 1] (rank 0 → dst{0,1}, rank 1 → dst{2,3}) + Edges (src_class → dst_class): + (0,0), (0,2), (1,1), (1,3), (2,0), (2,2) + + Rank 0: + halo dst verts : [2, 3] (cross edges (0,2),(1,3)) + boundary src verts: {0,1} → send_local_idx=[0,1], send_offset=[0,0,2] + Rank 1: + halo dst verts : [0] (cross edge (2,0)) + boundary src verts: {2} → send_local_idx=[0], send_offset=[0,1,1] +""" + +import pytest +import torch +import torch.distributed as dist + +from DGraph.distributed.commInfo import ( + CommunicationPattern, + compute_local_vertices, + compute_halo_vertices, + compute_local_edge_list, + compute_boundary_vertices, + compute_comm_map, + compute_recv_offsets, + build_communication_pattern, +) + +# --------------------------------------------------------------------------- +# Shared graph tensors (CPU, used by single-process tests) +# --------------------------------------------------------------------------- + +# fmt: off +HOMO_EDGE_LIST = torch.tensor([ + [0, 1], [1, 0], # local on rank 0 + [0, 2], [2, 0], # cross + [1, 3], [3, 1], # cross + [2, 3], [3, 2], # local on rank 1 +], dtype=torch.long) + +HOMO_PARTITIONING = torch.tensor([0, 0, 1, 1], dtype=torch.long) + +HETERO_EDGE_LIST = torch.tensor([ + [0, 0], [0, 2], + [1, 1], [1, 3], + [2, 0], [2, 2], +], dtype=torch.long) + +HETERO_SRC_PARTITIONING = torch.tensor([0, 0, 1], dtype=torch.long) +HETERO_DST_PARTITIONING = torch.tensor([0, 0, 1, 1], dtype=torch.long) + +# comm_map for the homogeneous graph (known analytically) +HOMO_COMM_MAP = torch.tensor([[0., 2.], [2., 0.]]) +# fmt: on + + +# =========================================================================== +# compute_local_vertices +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected", + [ + (0, torch.tensor([0, 1])), + (1, torch.tensor([2, 3])), + ], +) +def test_compute_local_vertices_correct_ids(rank, expected): + result = compute_local_vertices(HOMO_PARTITIONING, rank) + assert torch.equal(result, expected) + + +def test_compute_local_vertices_is_1d(): + result = compute_local_vertices(HOMO_PARTITIONING, rank=0) + assert result.ndim == 1, "Result must be a 1-D tensor" + + +def test_compute_local_vertices_covers_all_ranks(): + all_local = torch.cat( + [compute_local_vertices(HOMO_PARTITIONING, r) for r in range(2)] + ).sort()[0] + all_verts = torch.arange(HOMO_PARTITIONING.size(0)) + assert torch.equal(all_local, all_verts), "Union of local verts must cover all vertices" + + +# =========================================================================== +# compute_halo_vertices — homogeneous +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected_halo", + [ + (0, torch.tensor([2, 3])), + (1, torch.tensor([0, 1])), + ], +) +def test_compute_halo_vertices_homogeneous(rank, expected_halo): + result = compute_halo_vertices(HOMO_EDGE_LIST, HOMO_PARTITIONING, rank) + assert result.ndim == 1 + assert torch.equal(result, expected_halo) + + +def test_compute_halo_vertices_no_cross_edges_returns_empty(): + edge_list = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) + partitioning = torch.tensor([0, 0], dtype=torch.long) + result = compute_halo_vertices(edge_list, partitioning, rank=0) + assert result.numel() == 0 + + +def test_compute_halo_vertices_unique(): + # Multiple edges to the same remote vertex should deduplicate + edge_list = torch.tensor([[0, 2], [0, 2], [1, 2]], dtype=torch.long) + partitioning = torch.tensor([0, 0, 1, 1], dtype=torch.long) + result = compute_halo_vertices(edge_list, partitioning, rank=0) + assert result.tolist() == [2], f"Expected [2], got {result.tolist()}" + + +def test_compute_halo_vertices_dst_none_equals_src(): + """Passing dst_partitioning=None must be identical to passing src_partitioning.""" + r0_implicit = compute_halo_vertices(HOMO_EDGE_LIST, HOMO_PARTITIONING, rank=0) + r0_explicit = compute_halo_vertices( + HOMO_EDGE_LIST, HOMO_PARTITIONING, rank=0, dst_partitioning=HOMO_PARTITIONING + ) + assert torch.equal(r0_implicit, r0_explicit) + + +# =========================================================================== +# compute_halo_vertices — heterogeneous +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected_halo", + [ + (0, torch.tensor([2, 3])), # cross edges (0,2),(1,3) + (1, torch.tensor([0])), # cross edge (2,0) + ], +) +def test_compute_halo_vertices_heterogeneous(rank, expected_halo): + result = compute_halo_vertices( + HETERO_EDGE_LIST, + HETERO_SRC_PARTITIONING, + rank, + dst_partitioning=HETERO_DST_PARTITIONING, + ) + assert result.ndim == 1 + assert torch.equal(result, expected_halo) + + +def test_compute_halo_vertices_hetero_disjoint_from_local_dst(): + """Halo dst vertices must not be owned by this rank.""" + for rank in [0, 1]: + halo = compute_halo_vertices( + HETERO_EDGE_LIST, + HETERO_SRC_PARTITIONING, + rank, + dst_partitioning=HETERO_DST_PARTITIONING, + ) + halo_ranks = HETERO_DST_PARTITIONING[halo] + assert (halo_ranks != rank).all(), ( + f"Rank {rank}: halo contains a locally-owned dst vertex" + ) + + +# =========================================================================== +# compute_local_edge_list +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected_edges", + [ + # Rank 0: edges with src ∈ {0,1}; g2l: 0→0,1→1,2→2,3→3 + (0, torch.tensor([[0, 1], [1, 0], [0, 2], [1, 3]])), + # Rank 1: edges with src ∈ {2,3}; g2l: 2→0,3→1,0→2,1→3 + (1, torch.tensor([[0, 2], [0, 1], [1, 3], [1, 0]])), + ], +) +def test_compute_local_edge_list_correct_remapping(rank, expected_edges): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + halo_verts = compute_halo_vertices(HOMO_EDGE_LIST, HOMO_PARTITIONING, rank) + result = compute_local_edge_list( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, halo_verts, rank + ) + # Order of rows may differ; compare as sets of edge tuples + result_set = set(map(tuple, result.tolist())) + expected_set = set(map(tuple, expected_edges.tolist())) + assert result_set == expected_set, ( + f"Rank {rank}: edge sets differ.\nGot: {result_set}\nExpected: {expected_set}" + ) + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_compute_local_edge_list_source_always_local(rank): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + halo_verts = compute_halo_vertices(HOMO_EDGE_LIST, HOMO_PARTITIONING, rank) + result = compute_local_edge_list( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, halo_verts, rank + ) + num_local = local_verts.size(0) + assert (result[:, 0] < num_local).all(), "All source indices must be in [0, num_local)" + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_compute_local_edge_list_all_indices_in_bounds(rank): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + halo_verts = compute_halo_vertices(HOMO_EDGE_LIST, HOMO_PARTITIONING, rank) + result = compute_local_edge_list( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, halo_verts, rank + ) + total = local_verts.size(0) + halo_verts.size(0) + assert (result >= 0).all() + assert (result < total).all() + + +# =========================================================================== +# compute_boundary_vertices — homogeneous +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected_send_offset", + [ + (0, torch.tensor([0, 0, 2])), # nothing to rank 0 (self), 2 to rank 1 + (1, torch.tensor([0, 2, 2])), # 2 to rank 0, nothing to rank 1 (self) + ], +) +def test_compute_boundary_vertices_send_offset_homogeneous(rank, expected_send_offset): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + _, send_offset = compute_boundary_vertices( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, rank, num_ranks=2 + ) + assert torch.equal(send_offset, expected_send_offset), ( + f"Rank {rank}: send_offset {send_offset.tolist()} != {expected_send_offset.tolist()}" + ) + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_compute_boundary_vertices_send_idx_are_local_indices(rank): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + send_local_idx, _ = compute_boundary_vertices( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, rank, num_ranks=2 + ) + num_local = local_verts.size(0) + assert (send_local_idx >= 0).all() + assert (send_local_idx < num_local).all() + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_compute_boundary_vertices_unique_per_dest_rank(rank): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + send_local_idx, send_offset = compute_boundary_vertices( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, rank, num_ranks=2 + ) + for r in range(2): + segment = send_local_idx[send_offset[r] : send_offset[r + 1]] + assert segment.unique().size(0) == segment.size(0), ( + f"Rank {rank}: duplicate send indices for dest rank {r}" + ) + + +def test_compute_boundary_vertices_self_send_is_zero(): + """The segment for this rank's own index in send_offset must always be empty.""" + for rank in [0, 1]: + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank) + _, send_offset = compute_boundary_vertices( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, rank, num_ranks=2 + ) + assert send_offset[rank + 1] == send_offset[rank], ( + f"Rank {rank}: non-zero self-send segment" + ) + + +def test_compute_boundary_vertices_no_cross_edges_empty(): + edge_list = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) + partitioning = torch.tensor([0, 0], dtype=torch.long) + local_verts = torch.tensor([0, 1]) + send_local_idx, send_offset = compute_boundary_vertices( + edge_list, partitioning, local_verts, rank=0, num_ranks=2 + ) + assert send_local_idx.numel() == 0 + assert torch.equal(send_offset, torch.zeros(3, dtype=torch.long)) + + +def test_compute_boundary_vertices_duplicate_edges_deduplicated(): + """A vertex connected by multiple edges to the same remote rank is sent only once.""" + edge_list = torch.tensor([[0, 2], [0, 2], [0, 3]], dtype=torch.long) + partitioning = torch.tensor([0, 0, 1, 1], dtype=torch.long) + local_verts = torch.tensor([0, 1]) + send_local_idx, send_offset = compute_boundary_vertices( + edge_list, partitioning, local_verts, rank=0, num_ranks=2 + ) + segment = send_local_idx[send_offset[1] : send_offset[2]] + assert segment.unique().size(0) == segment.size(0) + # vertex 0 should appear exactly once for rank 1 + assert (segment == 0).sum().item() == 1 + + +def test_compute_boundary_vertices_dst_none_equals_src(): + local_verts = compute_local_vertices(HOMO_PARTITIONING, rank=0) + idx_implicit, off_implicit = compute_boundary_vertices( + HOMO_EDGE_LIST, HOMO_PARTITIONING, local_verts, rank=0, num_ranks=2 + ) + idx_explicit, off_explicit = compute_boundary_vertices( + HOMO_EDGE_LIST, + HOMO_PARTITIONING, + local_verts, + rank=0, + num_ranks=2, + dst_partitioning=HOMO_PARTITIONING, + ) + assert torch.equal(idx_implicit, idx_explicit) + assert torch.equal(off_implicit, off_explicit) + + +# =========================================================================== +# compute_boundary_vertices — heterogeneous +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected_send_offset, expected_num_sends", + [ + (0, torch.tensor([0, 0, 2]), 2), # send src{0,1} to rank 1 + (1, torch.tensor([0, 1, 1]), 1), # send src{2} to rank 0 + ], +) +def test_compute_boundary_vertices_heterogeneous( + rank, expected_send_offset, expected_num_sends +): + local_verts = compute_local_vertices(HETERO_SRC_PARTITIONING, rank) + send_local_idx, send_offset = compute_boundary_vertices( + HETERO_EDGE_LIST, + HETERO_SRC_PARTITIONING, + local_verts, + rank, + num_ranks=2, + dst_partitioning=HETERO_DST_PARTITIONING, + ) + assert torch.equal(send_offset, expected_send_offset) + assert send_local_idx.numel() == expected_num_sends + num_local = local_verts.size(0) + assert (send_local_idx >= 0).all() + assert (send_local_idx < num_local).all() + + +# =========================================================================== +# compute_recv_offsets (no dist required — pure tensor arithmetic) +# =========================================================================== + + +@pytest.mark.parametrize( + "rank, expected_recv_offset, expected_recv_bwd", + [ + (0, torch.tensor([0, 0, 2]), torch.tensor([0., 0.])), + (1, torch.tensor([0, 2, 2]), torch.tensor([0., 2.])), + ], +) +def test_compute_recv_offsets(rank, expected_recv_offset, expected_recv_bwd): + recv_offset, recv_bwd = compute_recv_offsets(HOMO_COMM_MAP, rank) + assert torch.equal(recv_offset, expected_recv_offset) + assert torch.equal(recv_bwd, expected_recv_bwd) + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_compute_recv_offsets_total_matches_comm_map_col(rank): + recv_offset, _ = compute_recv_offsets(HOMO_COMM_MAP, rank) + expected_total = int(HOMO_COMM_MAP[:, rank].sum().item()) + assert recv_offset[-1].item() == expected_total + + +def test_compute_recv_offsets_is_non_decreasing(): + for rank in [0, 1]: + recv_offset, _ = compute_recv_offsets(HOMO_COMM_MAP, rank) + assert (recv_offset[1:] >= recv_offset[:-1]).all() + + +# =========================================================================== +# Distributed fixture +# (tests below require: torchrun --nnodes 1 --nproc-per-node 2) +# =========================================================================== + + +@pytest.fixture(scope="module") +def dist_setup(): + """Initialize NCCL process group and set the per-rank CUDA device.""" + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + yield rank, world_size, device + + +def _homo_tensors(device): + return HOMO_EDGE_LIST.to(device), HOMO_PARTITIONING.to(device) + + +def _hetero_tensors(device): + return ( + HETERO_EDGE_LIST.to(device), + HETERO_SRC_PARTITIONING.to(device), + HETERO_DST_PARTITIONING.to(device), + ) + + +# =========================================================================== +# compute_comm_map (distributed) +# =========================================================================== + + +def test_compute_comm_map_correct_values(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + + local_verts = compute_local_vertices(partitioning, rank) + _, send_off = compute_boundary_vertices( + edge_list, partitioning, local_verts, rank, world_size + ) + comm_map = compute_comm_map(send_off, world_size) + + expected = torch.tensor([[0., 2.], [2., 0.]], device=device) + assert torch.equal(comm_map, expected), ( + f"Rank {rank}: comm_map {comm_map.tolist()} != {expected.tolist()}" + ) + + +def test_compute_comm_map_row_matches_local_send_counts(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + + local_verts = compute_local_vertices(partitioning, rank) + _, send_off = compute_boundary_vertices( + edge_list, partitioning, local_verts, rank, world_size + ) + comm_map = compute_comm_map(send_off, world_size) + + send_counts = (send_off[1:] - send_off[:-1]).float() + assert torch.equal(comm_map[rank], send_counts), ( + f"Rank {rank}: comm_map row {comm_map[rank].tolist()} != send_counts {send_counts.tolist()}" + ) + + +def test_compute_comm_map_diagonal_zero(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + + local_verts = compute_local_vertices(partitioning, rank) + _, send_off = compute_boundary_vertices( + edge_list, partitioning, local_verts, rank, world_size + ) + comm_map = compute_comm_map(send_off, world_size) + + assert comm_map[rank, rank].item() == 0.0, "Self-send entry must be zero" + + +# =========================================================================== +# build_communication_pattern (distributed, homogeneous) +# =========================================================================== + + +def test_build_communication_pattern_vertex_counts(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + assert cp.rank == rank + assert cp.world_size == world_size + assert cp.num_local_vertices == 2, f"Rank {rank}: expected 2 local verts" + assert cp.num_halo_vertices == 2, f"Rank {rank}: expected 2 halo verts" + + +def test_build_communication_pattern_send_offset(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + expected = {0: torch.tensor([0, 0, 2]), 1: torch.tensor([0, 2, 2])} + assert torch.equal(cp.send_offset.cpu(), expected[rank]) + + +def test_build_communication_pattern_recv_offset(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + expected = {0: torch.tensor([0, 0, 2]), 1: torch.tensor([0, 2, 2])} + assert torch.equal(cp.recv_offset.cpu(), expected[rank]) + + +def test_build_communication_pattern_put_forward_remote_offset(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + # put_forward_remote_offset[i] = number of vertices lower-ranked ranks + # collectively send to rank i — tells this rank where to write in rank i's + # recv buffer. + expected = cp.comm_map[:rank, :].sum(0) + assert torch.equal(cp.put_forward_remote_offset, expected) + + +def test_build_communication_pattern_put_backward_remote_offset(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + expected = cp.comm_map[:, :rank].sum(1) + assert torch.equal(cp.put_backward_remote_offset, expected) + + +# =========================================================================== +# Full invariants (from CommInfoDesignDocument.md) +# =========================================================================== + + +def test_communication_pattern_invariants(dist_setup): + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + # Local edge list bounds + assert (cp.local_edge_list >= 0).all() + assert (cp.local_edge_list < cp.num_local_vertices + cp.num_halo_vertices).all() + + # Source vertices are always local + assert (cp.local_edge_list[:, 0] < cp.num_local_vertices).all() + + # Send indices are in the local range + assert (cp.send_local_idx >= 0).all() + assert (cp.send_local_idx < cp.num_local_vertices).all() + + # send_offset is non-decreasing + assert (cp.send_offset[1:] >= cp.send_offset[:-1]).all() + + # comm_map row sum == total sends + assert cp.comm_map[rank].long().sum().item() == cp.send_offset[-1].item(), ( + "comm_map row sum must equal total sends" + ) + + # comm_map column sum == total recvs + assert cp.comm_map[:, rank].long().sum().item() == cp.recv_offset[-1].item(), ( + "comm_map col sum must equal total recvs" + ) + + # Self-send is zero + assert cp.comm_map[rank, rank].item() == 0.0 + + # Uniqueness within each destination rank segment + for r in range(world_size): + segment = cp.send_local_idx[cp.send_offset[r] : cp.send_offset[r + 1]] + assert segment.unique().size(0) == segment.size(0), ( + f"Rank {rank}: duplicate send indices in segment for dest rank {r}" + ) + + # put offset derivations + assert torch.equal(cp.put_forward_remote_offset, cp.comm_map[:rank, :].sum(0)) + assert torch.equal(cp.put_backward_remote_offset, cp.comm_map[:, :rank].sum(1)) + + +def test_communication_pattern_send_recv_symmetry(dist_setup): + """ + For any pair (A, B): comm_map[A, B] == comm_map[B, A] holds for undirected + graphs (our test graph is undirected). Verify via reconstructing comm_map + from a fresh all-gather and comparing it to the stored one. + """ + rank, world_size, device = dist_setup + edge_list, partitioning = _homo_tensors(device) + cp = build_communication_pattern(edge_list, partitioning, rank, world_size) + + send_counts = (cp.send_offset[1:] - cp.send_offset[:-1]).float().to(device) + gathered = [torch.zeros(world_size, device=device) for _ in range(world_size)] + dist.all_gather(gathered, send_counts) + reconstructed = torch.stack(gathered) + + assert torch.equal(reconstructed, cp.comm_map), ( + f"Rank {rank}: reconstructed comm_map differs from stored comm_map" + ) + + +# =========================================================================== +# Heterogeneous (distributed) — test compute_halo_vertices + +# compute_boundary_vertices with separate partitionings, then build a +# CommunicationPattern manually. +# =========================================================================== + + +def test_hetero_halo_and_boundary_distributed(dist_setup): + rank, world_size, device = dist_setup + edge_list, src_part, dst_part = _hetero_tensors(device) + + src_local_verts = compute_local_vertices(src_part, rank) + halo_verts = compute_halo_vertices( + edge_list, src_part, rank, dst_partitioning=dst_part + ) + send_local_idx, send_off = compute_boundary_vertices( + edge_list, src_part, src_local_verts, rank, world_size, + dst_partitioning=dst_part, + ) + comm = compute_comm_map(send_off, world_size) + recv_off, _ = compute_recv_offsets(comm, rank) + + expected_local_verts = {0: 2, 1: 1} + expected_halo_verts = {0: 2, 1: 1} + expected_send_offset = { + 0: torch.tensor([0, 0, 2]), + 1: torch.tensor([0, 1, 1]), + } + expected_recv_offset = { + 0: torch.tensor([0, 0, 1]), # rank 0 receives 1 vert from rank 1 + 1: torch.tensor([0, 2, 2]), # rank 1 receives 2 verts from rank 0 + } + expected_comm_map = torch.tensor([[0., 2.], [1., 0.]], device=device) + + assert src_local_verts.size(0) == expected_local_verts[rank] + assert halo_verts.size(0) == expected_halo_verts[rank] + assert torch.equal(send_off.cpu(), expected_send_offset[rank]) + assert torch.equal(recv_off.cpu(), expected_recv_offset[rank]) + assert torch.equal(comm, expected_comm_map), ( + f"Rank {rank}: comm_map {comm.tolist()} != {expected_comm_map.tolist()}" + ) + + # Halo vertices must be owned by a remote rank + halo_ranks = dst_part[halo_verts] + assert (halo_ranks != rank).all() + + # send indices must be in [0, num_local_src) + assert (send_local_idx >= 0).all() + assert (send_local_idx < src_local_verts.size(0)).all() + + # No self-send + assert comm[rank, rank].item() == 0.0