Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • marie.weiel/scalableai2425
1 result
Show changes
Commits on Source (2)
......@@ -63,9 +63,10 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W
# --- Ring Communication Pattern ---
# Calculate distances in a "ring" pattern. Each process calculates distances for its local x chunk against its
# local y chunk (diagonal calculation). Then, through `size - 1` iterations, each process sends its y chunk to
# the next process in the "ring" while receiving a new y chunk from the previous process. This continues until
# each process has calculated distances between its x chunk and all chunks of y across all processes.
# local y chunk (diagonal calculation). Then, through `(size + 1) // 2` iterations, each process sends its y
# chunk to the next process in the "ring" while receiving a new y chunk from the previous process. This
# continues until each process has calculated distances between its x chunk and all chunks of y across all
# processes.
x_ = x
stationary = y
......@@ -76,7 +77,7 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W
local_distances[:, cols[0]: cols[1]] = d_ij
print(f"Rank [{rank}/{size}]: Start tile-wise ring communication...")
# Remaining `(size+1) // 2` iterations: Send rank-local part of y to next process in circular fashion.
# Remaining `(size + 1) // 2` iterations: Send rank-local part of y to next process in circular fashion.
# We can perform less iterations due to the symmetric nature of the metric.
for iter_idx in range(1, (size + 2) // 2):
print(f"Rank [{rank}/{size}]: Starting iteration {iter_idx}")
......