Skip to content
Snippets Groups Projects
Commit 50207328 authored by Kaleb Phipps's avatar Kaleb Phipps
Browse files

fix comment

parent 8fc0881d
Branches
No related tags found
No related merge requests found
...@@ -63,9 +63,10 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W ...@@ -63,9 +63,10 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W
# --- Ring Communication Pattern --- # --- Ring Communication Pattern ---
# Calculate distances in a "ring" pattern. Each process calculates distances for its local x chunk against its # Calculate distances in a "ring" pattern. Each process calculates distances for its local x chunk against its
# local y chunk (diagonal calculation). Then, through `size - 1` iterations, each process sends its y chunk to # local y chunk (diagonal calculation). Then, through `(size + 1) // 2` iterations, each process sends its y
# the next process in the "ring" while receiving a new y chunk from the previous process. This continues until # chunk to the next process in the "ring" while receiving a new y chunk from the previous process. This
# each process has calculated distances between its x chunk and all chunks of y across all processes. # continues until each process has calculated distances between its x chunk and all chunks of y across all
# processes.
x_ = x x_ = x
stationary = y stationary = y
...@@ -76,7 +77,7 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W ...@@ -76,7 +77,7 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W
local_distances[:, cols[0]: cols[1]] = d_ij local_distances[:, cols[0]: cols[1]] = d_ij
print(f"Rank [{rank}/{size}]: Start tile-wise ring communication...") print(f"Rank [{rank}/{size}]: Start tile-wise ring communication...")
# Remaining `(size+1) // 2` iterations: Send rank-local part of y to next process in circular fashion. # Remaining `(size + 1) // 2` iterations: Send rank-local part of y to next process in circular fashion.
# We can perform less iterations due to the symmetric nature of the metric. # We can perform less iterations due to the symmetric nature of the metric.
for iter_idx in range(1, (size + 2) // 2): for iter_idx in range(1, (size + 2) // 2):
print(f"Rank [{rank}/{size}]: Starting iteration {iter_idx}") print(f"Rank [{rank}/{size}]: Starting iteration {iter_idx}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment