diff --git a/2_psrs_cdist/sheet_2.ipynb b/2_psrs_cdist/sheet_2.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..fe71c6109382e75c78e493358de44138f378bbd4 --- /dev/null +++ b/2_psrs_cdist/sheet_2.ipynb @@ -0,0 +1,550 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ea5b6890", + "metadata": {}, + "source": [ + "# Skalierbare Methoden der Künstlichen Intelligenz\n", + "Dr. Charlotte Debus (charlotte.debus@kit.edu) \n", + "Dr. Markus Götz (markus.goetz@kit.edu) \n", + "Dr. Marie Weiel (marie.weiel@kit.edu) \n", + "Dr. Kaleb Phipps (kaleb.phipps@kit.edu) \n", + "\n", + "## Übung 2 am 03.12.24: Parallel Sorting by Regular Sampling und Pairwise Distances\n", + "In der zweiten Übung beschäftigen wir uns mit dem \"Parallel Sorting by Regular Sampling\" (PSRS) Algorithmus (siehe Vorlesung vom 07.11.24) und der parallelen Berechnung paarweiser Distanzen (\"pairwise distances\", siehe Vorlesung vom 14.11.24). \n", + "\n", + "### Aufgabe 1\n", + "Untenstehend finden Sie eine parallele Implementierung eines Algorithmus zur Berechnung paarweiser Distanzen in `Python3`. Wir verwenden 50 000 Samples des [SUSY-Datensatzes](https://archive.ics.uci.edu/dataset/279/susy). Diese finden Sie in der HDF5-Datei `/pfs/work7/workspace/scratch/ku4408-VL-ScalableAI/data/SUSY_50k.h5` auf dem bwUniCluster. Der SUSY-Datensatz enthält insgesamt 5 000 000 Samples aus Monte-Carlo-Simulationen hochenergetischer Teilchenkollisionen. Jedes Sample hat 18 Features, bestehend aus kinematischen Eigenschaften, die typischerweise von Teilchendetektoren gemessen werden, sowie aus diesen Messungen abgeleiteten Größen. Führen Sie den Code auf einem, zwei, vier, acht sowie 16 CPU-basierten Knoten in den Partitionen \"single\" bzw. \"multiple\" des bwUniClusters aus. Untersuchen Sie das schwache sowie das starke Skalierungsverhalten des Algorithmus und stellen Sie diese grafisch dar, z.B. mit `matplotlib.pyplot` in `Python3`. \n", + "\n", + "**Zur Erinnerung (siehe auch Vorlesung vom 2.11.23):** Bei der starken Skalierung wird die Problemgröße konstant gehalten, während man die Anzahl der Prozesse erhöht, d.h. es wird untersucht, inwieweit sich ein Problem konstanter Größe durch Hinzunahme von mehr Rechenressourcen schneller lösen lässt. Bei der schwachen Skalierung wird die Problemgröße pro Prozess konstant gehalten, während man die Anzahl der Prozesse erhöht, d.h. es wird untersucht, inwieweit sich ein größeres Problem durch Hinzunahme von mehr Rechenressourcen in gleicher Zeit lösen lässt. Das bedeutet, dass Sie die Problemgröße zur Untersuchung des schwachen Skalierungsverhaltens proportional anpassen müssen! \n", + "\n", + "**Vorgehensweise (analog zum ersten Übungsblatt):**\n", + "- Laden Sie zunächst die benötigten Module auf dem bwUniCluster.\n", + "- Setzen Sie dann eine virtuelle Umgebung mit `Python` auf, in der Sie die benötigten Pakete installieren. An dieser Stelle können Sie auch Ihre virtuelle Umgebung vom letzten Übungsblatt nutzen.\n", + "- Erstellen Sie basierend auf untenstehendem Code ein `Python`-Skript, welches Sie mithilfe eines `bash`-Skripts über SLURM auf dem Cluster submittieren (siehe Übung vom 05.11.24). Nachfolgend finden Sie ein beispielhaftes Template für das Submit-Skript für einen Multi-Node-Job inklusive der benötigten Module. Wenn Sie eine andere Anzahl an Knoten verwenden möchten, müssen Sie die `#SBATCH`-Optionen entsprechend modifizieren. Weitere Informationen dazu finden Sie [hier](https://wiki.bwhpc.de/e/BwUniCluster_2.0_Slurm_common_Features)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd16f08b-584d-4913-86bb-f5bb6d4822b6", + "metadata": {}, + "outputs": [], + "source": [ + "#!/bin/bash\n", + "\n", + "#SBATCH --job-name=cdist # job name\n", + "#SBATCH --partition=multiple # queue for resource allocation\n", + "#SBATCH --nodes=2 # number of nodes to be used\n", + "#SBATCH --time=4:00 # wall-clock time limit\n", + "#SBATCH --mem=90000 # memory per node \n", + "#SBATCH --cpus-per-task=40 # number of CPUs required per MPI task\n", + "#SBATCH --ntasks-per-node=1 # maximum count of tasks per node\n", + "#SBATCH --mail-type=ALL # Notify user by email when certain event types occur.\n", + "\n", + "export IBV_FORK_SAFE=1\n", + "export VENVDIR=<path/to/your/venv/folder> # Export path to your virtual environment.\n", + "export PYDIR=<path/to/your/python/script> # Export path to directory containing Python script.\n", + "\n", + "# Set up modules.\n", + "module purge # Unload all currently loaded modules.\n", + "module load compiler/gnu/13.3 # Load required modules.\n", + "module load mpi/openmpi/4.1\n", + "module load devel/cuda/12.4\n", + "module load lib/hdf5/1.14.4-gnu-13.3-openmpi-4.1\n", + "\n", + "source ${VENVDIR}/bin/activate # Activate your virtual environment.\n", + "\n", + "mpirun --mca mpi_warn_on_fork 0 python -u ${PYDIR}/cdist.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e172b4f", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Parallel calculation of pairwise distances\"\"\"\n", + "import time\n", + "\n", + "import h5py\n", + "import torch\n", + "from mpi4py import MPI\n", + "\n", + "torch.set_default_dtype(torch.float32)\n", + "\n", + "\n", + "def dist(x: torch.Tensor, y: torch.Tensor, comm: MPI.Comm = MPI.COMM_WORLD) -> torch.Tensor:\n", + " \"\"\"\n", + " Calculate pairwise distances between all rows (samples, i.e., along axis 0) of two tensors x and y in parallel.\n", + "\n", + " The distance matrix is calculated tile-wise with ring communication between processes, each holding a piece of x\n", + " and/or y.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : torch.Tensor\n", + " First 2d tensor (of shape m/p x f). m is the total number of samples in x, distributed over p processors.\n", + " f is the number of features.\n", + " y : torch.Tensor\n", + " Second 2d tensor (of shape n/p x f). n is the total number of samples in x, distributed over p processors.\n", + " The number of features f must be the same as for x.\n", + " comm : MPI.Comm\n", + " Communicator to use. Default is ``MPI.COMM_WORLD``.\n", + " \"\"\"\n", + " # Check whether two input tensors are compatible.\n", + " if len(x.shape) != len(y.shape) != 2:\n", + " raise ValueError(\"Input tensors must be two-dimensional.\")\n", + " if x.shape[1] != y.shape[1]:\n", + " raise ValueError(f\"Input tensors must have the same number of features but {x.shape[1]} != {y.shape[1]}.\")\n", + "\n", + " size, rank = comm.size, comm.rank # Set up communication.\n", + "\n", + " if size == 1: # Use torch functionality in non-parallel case.\n", + " return torch.cdist(x, y)\n", + "\n", + " else: # Parallel case\n", + " # --- Setup and Matrix Initialization ---\n", + " mp, f = x.shape # Get number of samples in local chunk of x and number of features.\n", + " np = y.shape[0] # Get number of samples in local chunk of y.\n", + " \n", + " # Each process initializes a local matrix, `local_distances`, of shape `(mp, n)`, where `mp` is the local chunk\n", + " # size of `x`, and `n` is the total number of samples in `y`. Each rank thus calculates the distance matrix\n", + " # chunk of size `mp x n`, i.e., rank 0 has distances from its own local `x` to all other `y`'s.\n", + "\n", + " # Determine overall number of samples in y.\n", + " n = comm.allreduce(np, op=MPI.SUM)\n", + " print(f\"Overall number of samples is {n}.\")\n", + " # Initialize rank-local chunk of mp x n distance matrix with zeros.\n", + " local_distances = torch.zeros((mp, n))\n", + "\n", + " # --- Managing Chunks and Displacements ---\n", + " # Determine where to put each result in the rank-local distance matrix chunk.\n", + " # Determine number of samples (rows) in each rank-local y.\n", + " y_counts = torch.tensor(comm.allgather(torch.numel(y) // f), dtype=torch.int)\n", + " # Calculate corresponding displacements from counts to record the starting index of each chunk in y. Thus, each\n", + " # process can identify where in the result matrix it should write the distances.\n", + " y_displ = (0,) + tuple(torch.cumsum(y_counts, dim=0, dtype=torch.int)[:-1])\n", + " # Extract actual result columns in distance matrix chunk for each rank.\n", + " cols = (y_displ[rank], y_displ[rank + 1] if (rank + 1) != size else n)\n", + "\n", + " # --- Ring Communication Pattern ---\n", + " # Calculate distances in a \"ring\" pattern. Each process calculates distances for its local x chunk against its\n", + " # local y chunk (diagonal calculation). Then, through `size - 1` iterations, each process sends its y chunk to\n", + " # the next process in the \"ring\" while receiving a new y chunk from the previous process. This continues until\n", + " # each process has calculated distances between its x chunk and all chunks of y across all processes.\n", + " stationary = y\n", + " \n", + " # 0th iteration: Calculate diagonal of global distance matrix.\n", + " # Each process calculates distances for its local x chunk against its local y chunk.\n", + " print(f\"Rank [{rank}/{size}]: Calculate diagonal blocks...\")\n", + " local_distances[:, cols[0]: cols[1]] = torch.cdist(x, stationary)\n", + " \n", + " print(f\"Rank [{rank}/{size}]: Start tile-wise ring communication...\")\n", + " # Remaining `size-1` iterations: Send rank-local part of y to next process in circular fashion.\n", + " for iter_idx in range(1, size):\n", + "\n", + " receiver = (rank + iter_idx) % size # Determine receiving process.\n", + " sender = (rank - iter_idx) % size # Determine sending process.\n", + " # Determine columns of rank-local distance matrix chunk to write result to.\n", + " col1 = y_displ[sender]\n", + " col2 = y_displ[sender + 1] if sender != size - 1 else n\n", + " columns = (col1, col2)\n", + " # All but first `iter_idx` processes are first receiving, then sending.\n", + " if (rank // iter_idx) != 0:\n", + " stat = MPI.Status()\n", + " # Probe for incoming message containing the next chunk of y to consider.\n", + " comm.Probe(source=sender, tag=iter_idx, status=stat)\n", + " # Determine number of samples to receive (= overall number of floats to receive / number of features).\n", + " count = int(stat.Get_count(MPI.FLOAT) / f)\n", + " # Initialize tensor for incoming chunk of y with zeros.\n", + " moving = torch.zeros((count, f))\n", + " comm.Recv(moving, source=sender, tag=iter_idx)\n", + " # Send rank-local chunk of y to next process.\n", + " comm.Send(stationary, dest=receiver, tag=iter_idx)\n", + " # First `iter_idx` processes can now receive after sending.\n", + " if (rank // iter_idx) == 0:\n", + " stat = MPI.Status()\n", + " comm.Probe(source=sender, tag=iter_idx, status=stat)\n", + " count = int(stat.Get_count(MPI.FLOAT) / f)\n", + " moving = torch.zeros((count, f))\n", + " comm.Recv(moving, source=sender, tag=iter_idx)\n", + " # Calculate distances between stationary chunk of x and currently considered, moving chunk of y.\n", + " # Write result at correct position in distance matrix.\n", + " local_distances[:, columns[0]: columns[1]] = torch.cdist(x, moving)\n", + " print(f\"Rank [{rank}/{size}]: [DONE]\")\n", + "\n", + " return local_distances\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " comm = MPI.COMM_WORLD\n", + " rank, size = comm.rank, comm.size\n", + "\n", + " data_path = \"/pfs/work7/workspace/scratch/ku4408-VL-ScalableAI/data/SUSY_50k.h5\"\n", + " dataset = \"data\"\n", + "\n", + " if rank == 0:\n", + " print(\n", + " \"######################\\n\"\n", + " \"# Pairwise distances #\\n\"\n", + " \"######################\\n\"\n", + " f\"COMM_WORLD size is {size}.\\n\"\n", + " f\"Loading data... {data_path}[{dataset}]\"\n", + " )\n", + "\n", + " # Parallel data loader for SUSY data.\n", + " with h5py.File(data_path, \"r\") as handle:\n", + " chunk = int(handle[dataset].shape[0]/size)\n", + " if rank == size - 1:\n", + " data = torch.FloatTensor(handle[dataset][rank*chunk:])\n", + " else:\n", + " data = torch.FloatTensor(handle[dataset][rank*chunk:(rank+1)*chunk])\n", + "\n", + " print(f\"\\t[OK]\\nRank [{rank}/{size}]: Local data chunk has shape {list(data.shape)}...\")\n", + "\n", + " if rank == 0:\n", + " print(\"Start distance calculations...\")\n", + " # Calculate distances of all SUSY samples w.r.t. each other and measure runtime.\n", + " start = time.perf_counter()\n", + " distances = dist(data, data, comm)\n", + " local_runtime = time.perf_counter() - start\n", + " # Calculate process-averaged runtime.\n", + " average_runtime = comm.allreduce(local_runtime, op=MPI.SUM) / size\n", + " print(f\"Rank [{rank}/{size}]: Local distance matrix has shape {list(distances.shape)}.\")\n", + " if rank == 0:\n", + " print(f\"Process-averaged run time:\\t{average_runtime} s\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "6c7900f7", + "metadata": {}, + "source": [ + "### Aufgabe 2\n", + "Untenstehend finden Sie eine parallele Implementierung des PSRS-Algorithmus. Mithilfe dieses Codes sollen unterschiedliche Sequenzen von ganzen Zahlen sortiert werden. Diese finden Sie in der HDF5-Datei `/pfs/work7/workspace/scratch/ku4408-VL-ScalableAI/data/psrs_data.h5` auf dem bwUniCluster. Die Datei enthält fünf verschiedene Ganzzahl-Sequenzen als Datensätze `['duplicates_10', 'duplicates_5', 'many_duplicates', 'no_duplicates', 'triplicates']`, die jeweils einen unterschiedlichen Anteil an Duplikaten bzw. Triplikaten enthalten. Alle Sequenzen bestehen aus $10^9$ Elementen. \n", + "Führen Sie den Code mithilfe des untenstehenden Submit-Skripts für alle fünf Datensätze auf vier CPU-basierten Knoten in der Partition \"multiple\" des bwUniClusters aus. Die Datensätze können über ein Command-Line-Argument des Python-Skripts spezifiziert werden, z.B. `mpirun python psrs.py --dataset no_duplicates`. Messen und vergleichen Sie die Laufzeiten. Was fällt Ihnen auf?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "691eeb17", + "metadata": {}, + "outputs": [], + "source": [ + "#!/bin/bash\n", + "\n", + "#SBATCH --job-name=psrs # Job name\n", + "#SBATCH --partition=dev_multiple # Queue for the resource allocation.\n", + "#SBATCH --time=5:00 # Wall-clock time limit \n", + "#SBATCH --mem=90000 # Memory per node\n", + "#SBATCH --nodes=4 # Number of nodes to be used\n", + "#SBATCH --cpus-per-task=40 # Number of CPUs per MPI task\n", + "#SBATCH --ntasks-per-node=1 # Number of tasks per node\n", + "#SBATCH --mail-type=ALL # Notify user by email when certain event types occur.\n", + "\n", + "export IBV_FORK_SAFE=1\n", + "export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}\n", + "export VENVDIR=<path/to/your/venv/folder> # Export path to your virtual environment.\n", + "export PYDIR=<path/to/your/python/script> # Export path to directory containing Python script.\n", + "\n", + "# Set up modules.\n", + "module purge # Unload all currently loaded modules.\n", + "module load compiler/gnu/13.3 # Load required modules.\n", + "module load mpi/openmpi/4.1\n", + "module load devel/cuda/12.4\n", + "module load lib/hdf5/1.14.4-gnu-13.3-openmpi-4.1\n", + "\n", + "source ${VENVDIR}/bin/activate # Activate your virtual environment.\n", + "\n", + "mpirun --mca mpi_warn_on_fork 0 python ${PYDIR}/psrs.py --dataset no_duplicates # Specify dataset via command-line argument." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "529e07b3", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Parallel Sorting by Regular Sampling\"\"\"\n", + "import argparse\n", + "import time\n", + "\n", + "import h5py\n", + "import numpy\n", + "import torch\n", + "from mpi4py import MPI\n", + "\n", + "__mpi_type_mappings = {\n", + " torch.bool: MPI.BOOL,\n", + " torch.uint8: MPI.UNSIGNED_CHAR,\n", + " torch.int8: MPI.SIGNED_CHAR,\n", + " torch.int16: MPI.SHORT,\n", + " torch.int32: MPI.INT,\n", + " torch.int64: MPI.LONG,\n", + " torch.bfloat16: MPI.INT16_T,\n", + " torch.float16: MPI.INT16_T,\n", + " torch.float32: MPI.FLOAT,\n", + " torch.float64: MPI.DOUBLE,\n", + " torch.complex64: MPI.COMPLEX,\n", + " torch.complex128: MPI.DOUBLE_COMPLEX,\n", + "}\n", + "\n", + "\n", + "def sort(a: torch.Tensor, comm: MPI.Comm = MPI.COMM_WORLD) -> torch.Tensor:\n", + " \"\"\"\n", + " Sort a's elements along given dimension in ascending order by their value.\n", + "\n", + " The sorting is not stable which means that equal elements in the result may have different ordering than in\n", + " original array.\n", + "\n", + " Parameters\n", + " ----------\n", + " a : torch.Tensor\n", + " The 1D input array to be sorted.\n", + " comm : MPI.Comm\n", + " The communicator to use.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " Sorted local results.\n", + " \"\"\"\n", + " size, rank = comm.size, comm.rank\n", + "\n", + " if size == 1:\n", + " local_sorted, _ = torch.sort(a)\n", + " return local_sorted\n", + "\n", + " else:\n", + " ###########\n", + " # PHASE 1 #\n", + " ###########\n", + " # p is comm size, n is overall number of samples.\n", + " # Each rank sorts its local chunk and chooses p regular samples as representatives.\n", + " if rank == 0:\n", + " print(\n", + " \"###########\\n\"\n", + " \"# PHASE 1 #\\n\"\n", + " \"###########\"\n", + " )\n", + " local_sorted, local_indices = torch.sort(a)\n", + " print(f\"Rank {rank}/{size}: Local sorting done...[OK]\")\n", + "\n", + " n_local = torch.tensor(\n", + " torch.numel(local_sorted), dtype=torch.int\n", + " ) # Number of elements in local chunk.\n", + " print(f\"Rank {rank}/{size}: Number of elements in local chunk is {n_local}.\")\n", + " counts = torch.zeros(\n", + " size, dtype=torch.int\n", + " ) # Initialize array for local element numbers.\n", + " comm.Allgather([n_local, MPI.INT], [counts, MPI.INT])\n", + "\n", + " # Each rank chooses p regular samples.\n", + " # For this, separate sorted tensor into p+1 equal-length partitions.\n", + " # Regular samples have indices 1, w+1, 2w+1,...,(p−1)w+1\n", + " # where w=n/p^2 (here: `size` = p, `n_local` = overall number of samples/p).\n", + " partitions = [int(x * n_local / size) for x in range(0, size)]\n", + "\n", + " reg_samples_local = local_sorted[partitions]\n", + " assert len(partitions) == size\n", + " print(\n", + " f\"Rank {rank}/{size}: There are {len(partitions)} local regular samples: {reg_samples_local}\"\n", + " )\n", + "\n", + " # Root gathers regular samples.\n", + " num_regs_global = int(\n", + " comm.allreduce(torch.numel(reg_samples_local), op=MPI.SUM)\n", + " ) # Get overall number of regular samples.\n", + " if rank == 0:\n", + " print(f\"Overall number of regular samples is {num_regs_global}.\")\n", + " reg_samples_global = torch.zeros(num_regs_global, dtype=a.dtype)\n", + " comm.Gather(reg_samples_local, reg_samples_global, root=0)\n", + " if rank == 0:\n", + " print(\"On root: Regular samples gathered...[OK]\")\n", + "\n", + " ###########\n", + " # PHASE 2 #\n", + " ###########\n", + " # Root sorts gathered regular samples, chooses pivots, and shares them with other processes.\n", + " if rank == 0:\n", + " print(\n", + " \"###########\\n\"\n", + " \"# PHASE 2 #\\n\"\n", + " \"###########\"\n", + " )\n", + " global_pivots = torch.zeros((size - 1,), dtype=local_sorted.dtype)\n", + " if rank == 0:\n", + " sorted_regs_global, _ = torch.sort(reg_samples_global)\n", + " print(f\"On root: Regular samples are {sorted_regs_global}.\")\n", + " # Choose p-1 pivot indices (p = MPI size).\n", + " global_partitions = [\n", + " int(x * num_regs_global / size) for x in range(1, size)\n", + " ]\n", + " global_pivots = sorted_regs_global[global_partitions]\n", + " if len(global_partitions) == size - 1:\n", + " print(\n", + " f\"On root: There are {len(global_partitions)} global pivots: {global_pivots}\"\n", + " )\n", + " comm.Bcast(\n", + " global_pivots, root=0\n", + " ) # Broadcast copy of pivots to all processes from root.\n", + " if rank == 0:\n", + " print(\"Pivots broadcast to all processes...\")\n", + " ###########\n", + " # PHASE 3 #\n", + " ###########\n", + " if rank == 0:\n", + " print(\n", + " \"###########\\n\"\n", + " \"# PHASE 3 #\\n\"\n", + " \"###########\\n\"\n", + " \"Each processor forms p disjunct partitions of locally sorted elements using pivots as splits.\"\n", + " )\n", + " # Each processor forms p disjunct partitions of locally sorted elements using p-1 pivots as splits.\n", + " lt_partitions = torch.empty((size, local_sorted.shape[0]), dtype=torch.int64)\n", + " last = torch.zeros_like(local_sorted, dtype=torch.int64)\n", + " # Iterate over all pivots and store index of first pivot greater than element's value\n", + " if rank == 0:\n", + " print(\"Iterate over pivots to find index of first pivot > element's value.\")\n", + "\n", + " for idx, p in enumerate(global_pivots):\n", + " # torch.lt(input, other, *, out=None) computes `input < other` element-wise.\n", + " # Returns boolean tensor that is True where input is less than other and False elsewhere.\n", + " lt = torch.lt(local_sorted, p).int()\n", + " if idx > 0:\n", + " lt_partitions[idx] = lt - last\n", + " else:\n", + " lt_partitions[idx] = lt\n", + " last = lt\n", + " lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last\n", + "\n", + " # lt_partitions contains p elements, first encodes which elements in local_sorted are smaller than 1st\n", + " # (= smallest) pivot, second encodes which elements are larger than 1st and smaller than 2nd pivot, ..., last\n", + " # elements encodes which elements are larger than last ( = largest) pivot. Now set up matrix holding info how\n", + " # many values will be sent for each partition. Processor i keeps ith partitions and sends jth partition to\n", + " # processor j.\n", + " local_partitions = torch.sum(\n", + " lt_partitions, dim=1\n", + " ) # How many values will be sent where (local)?\n", + " print(\n", + " f\"Rank {rank}/{size}: Local # elements to be sent to other ranks (keep own section): {local_partitions}\"\n", + " )\n", + " partition_matrix = torch.zeros_like(\n", + " local_partitions\n", + " ) # How many values will be sent where (global)?\n", + " comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM)\n", + " if rank == 0:\n", + " print(\n", + " f\"Global # of elements on all ranks (partition matrix): {partition_matrix}\"\n", + " )\n", + " # Matrix holding info which value will be shipped where.\n", + " index_matrix = torch.empty_like(local_sorted, dtype=torch.int64)\n", + " # Loop over `lt_partitions` (binary encoding of which elements is in which partition formed by pivots).\n", + " for i, x in enumerate(lt_partitions):\n", + " index_matrix[x > 0] = i\n", + " # Elements in 0th partition (< first pivot) get 0, i.e., will be collected at rank 0, elements in 1st\n", + " # partition (> than first + < than second pivot) get 1, i.e., will be collected at rank 1,...\n", + " print(f\"Rank {rank}/{size}: Ship element to rank: {index_matrix}\")\n", + " send_counts_local = numpy.zeros(size, dtype=int)\n", + " for s in numpy.arange(size):\n", + " send_counts_local[s] = int((index_matrix == s).sum(dim=0))\n", + " send_displ = numpy.zeros(size, dtype=int)\n", + " send_displ[1:] = numpy.cumsum(send_counts_local, axis=0)[:-1]\n", + " send_counts_global = numpy.zeros((size, size), dtype=int)\n", + " comm.Allgather([send_counts_local, MPI.INT], [send_counts_global, MPI.INT])\n", + " recv_counts_global = numpy.transpose(send_counts_global)\n", + " recv_counts_local = recv_counts_global[rank]\n", + " recv_displ = numpy.zeros(size, dtype=int)\n", + " recv_displ[1:] = numpy.cumsum(recv_counts_local, axis=0)[:-1]\n", + " # Counts + displacements for Alltoallv are rank-specific!\n", + " # send_counts_local on rank i: Integer array, entry j specifies number of values to be sent to rank j.\n", + " # recv_counts_local on rank i: Integer array, entry j specifies number of values to be received from rank j.\n", + " val_buf = torch.zeros((partition_matrix[rank],), dtype=local_sorted.dtype)\n", + "\n", + " send_buf = [\n", + " MPI.memory.fromaddress(local_sorted.data_ptr(), 0),\n", + " (send_counts_local.tolist(), send_displ.tolist()),\n", + " __mpi_type_mappings[local_sorted.dtype],\n", + " ]\n", + " recv_buf = [\n", + " MPI.memory.fromaddress(val_buf.data_ptr(), 0),\n", + " (recv_counts_local.tolist(), recv_displ.tolist()),\n", + " __mpi_type_mappings[val_buf.dtype],\n", + " ]\n", + " comm.Alltoallv(send_buf, recv_buf)\n", + " result, _ = torch.sort(val_buf)\n", + " return result\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " data_path = \"/pfs/work7/workspace/scratch/ku4408-VL-ScalableAI/data/psrs_data.h5\"\n", + "\n", + " parser = argparse.ArgumentParser(prog=\"Parallel Sorting by Regular Samples\")\n", + " parser.add_argument(\n", + " \"--dataset\",\n", + " type=str,\n", + " default=\"duplicates_1\",\n", + " help=\"The dataset to be sorted.\",\n", + " )\n", + "\n", + " args = parser.parse_args()\n", + "\n", + " comm = MPI.COMM_WORLD\n", + " rank, size = comm.rank, comm.size\n", + "\n", + " with h5py.File(data_path, \"r\") as f:\n", + " chunk = int(f[args.dataset].shape[0] / size)\n", + " if rank == size - 1:\n", + " data = torch.tensor(f[args.dataset][rank * chunk:])\n", + " else:\n", + " data = torch.tensor(f[args.dataset][rank * chunk:(rank + 1) * chunk])\n", + "\n", + " if rank == 0:\n", + " print(\n", + " \"########\\n\"\n", + " \"# PSRS #\\n\"\n", + " \"########\"\n", + " )\n", + "\n", + " print(f\"Local data on rank {rank} = {data}\")\n", + "\n", + " if rank == 0:\n", + " print(\"Start sorting...\")\n", + " start = time.perf_counter()\n", + " result = sort(data)\n", + " elapsed_local = time.perf_counter() - start\n", + " elapsed_global = comm.allreduce(elapsed_local, op=MPI.SUM)\n", + " elapsed_global /= size\n", + " if rank == 0:\n", + " print(f\"Sorting done...\\nRank-averaged run time: {elapsed_global} s\")\n", + " print(f\"Sorted chunk on rank {rank}/{size}: {result}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}