diff --git a/2_psrs_cdist/solutions/A1/cdist_symmetric.py b/2_psrs_cdist/solutions/A1/cdist_symmetric.py index 1c606153e642bb7adaac54b74f8753c9c2317a3d..af24da8a7ad6c5d952b17b0f373e5ef3019a7551 100644 --- a/2_psrs_cdist/solutions/A1/cdist_symmetric.py +++ b/2_psrs_cdist/solutions/A1/cdist_symmetric.py @@ -63,9 +63,10 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W # --- Ring Communication Pattern --- # Calculate distances in a "ring" pattern. Each process calculates distances for its local x chunk against its - # local y chunk (diagonal calculation). Then, through `size - 1` iterations, each process sends its y chunk to - # the next process in the "ring" while receiving a new y chunk from the previous process. This continues until - # each process has calculated distances between its x chunk and all chunks of y across all processes. + # local y chunk (diagonal calculation). Then, through `(size + 1) // 2` iterations, each process sends its y + # chunk to the next process in the "ring" while receiving a new y chunk from the previous process. This + # continues until each process has calculated distances between its x chunk and all chunks of y across all + # processes. x_ = x stationary = y @@ -76,7 +77,7 @@ def dist_symmetric(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_W local_distances[:, cols[0]: cols[1]] = d_ij print(f"Rank [{rank}/{size}]: Start tile-wise ring communication...") - # Remaining `(size+1) // 2` iterations: Send rank-local part of y to next process in circular fashion. + # Remaining `(size + 1) // 2` iterations: Send rank-local part of y to next process in circular fashion. # We can perform less iterations due to the symmetric nature of the metric. for iter_idx in range(1, (size + 2) // 2): print(f"Rank [{rank}/{size}]: Starting iteration {iter_idx}")