From 0e4f5c9d11c82d2ba48e649a671e8778937b895c Mon Sep 17 00:00:00 2001 From: Marie Weiel <marie.weiel@kit.edu> Date: Mon, 16 Dec 2024 19:05:43 +0100 Subject: [PATCH] modify communication --- 3_ensembles/sheet_3.ipynb | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/3_ensembles/sheet_3.ipynb b/3_ensembles/sheet_3.ipynb index 92c6487..dadb819 100644 --- a/3_ensembles/sheet_3.ipynb +++ b/3_ensembles/sheet_3.ipynb @@ -539,8 +539,7 @@ " train_counts = n_train_samples_local * np.ones(\n", " size, dtype=int\n", " ) # Determine load-balanced counts and displacements.\n", - " for idx in range(remainder_train):\n", - " train_counts[idx] += 1\n", + " train_counts[:remainder_train] += 1\n", " train_displs = np.concatenate(\n", " (np.zeros(1, dtype=int), np.cumsum(train_counts, dtype=int)[:-1]), dtype=int\n", " )\n", @@ -570,7 +569,7 @@ " ]\n", "\n", " else:\n", - " train_counts = None\n", + " train_counts = np.empty(size, dtype=int)\n", " n_features = None\n", " n_test_samples = None\n", " send_buf_train_samples = None\n", @@ -581,7 +580,7 @@ " ## BROADCAST NUMBER OF TEST SAMPLES FROM ROOT TO ALL OTHERS.\n", " ## n_test_samples = ...\n", " ## BROADCAST TRAIN COUNTS FROM ROOT TO ALL OTHERS.\n", - " ## train_counts = ...\n", + " ## ...\n", " \n", " samples_train_local = np.empty((train_counts[rank], n_features), dtype=float)\n", " targets_train_local = np.empty((train_counts[rank],), dtype=float)\n", -- GitLab