From 0e4f5c9d11c82d2ba48e649a671e8778937b895c Mon Sep 17 00:00:00 2001
From: Marie Weiel <marie.weiel@kit.edu>
Date: Mon, 16 Dec 2024 19:05:43 +0100
Subject: [PATCH] modify communication

---
 3_ensembles/sheet_3.ipynb | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/3_ensembles/sheet_3.ipynb b/3_ensembles/sheet_3.ipynb
index 92c6487..dadb819 100644
--- a/3_ensembles/sheet_3.ipynb
+++ b/3_ensembles/sheet_3.ipynb
@@ -539,8 +539,7 @@
     "        train_counts = n_train_samples_local * np.ones(\n",
     "            size, dtype=int\n",
     "        )  # Determine load-balanced counts and displacements.\n",
-    "        for idx in range(remainder_train):\n",
-    "            train_counts[idx] += 1\n",
+    "        train_counts[:remainder_train] += 1\n",
     "        train_displs = np.concatenate(\n",
     "            (np.zeros(1, dtype=int), np.cumsum(train_counts, dtype=int)[:-1]), dtype=int\n",
     "        )\n",
@@ -570,7 +569,7 @@
     "        ]\n",
     "\n",
     "    else:\n",
-    "        train_counts = None\n",
+    "        train_counts = np.empty(size, dtype=int)\n",
     "        n_features = None\n",
     "        n_test_samples = None\n",
     "        send_buf_train_samples = None\n",
@@ -581,7 +580,7 @@
     "    ## BROADCAST NUMBER OF TEST SAMPLES FROM ROOT TO ALL OTHERS.\n",
     "    ## n_test_samples = ...\n",
     "    ## BROADCAST TRAIN COUNTS FROM ROOT TO ALL OTHERS.\n",
-    "    ## train_counts = ...\n",
+    "    ## ...\n",
     "    \n",
     "    samples_train_local = np.empty((train_counts[rank], n_features), dtype=float)\n",
     "    targets_train_local = np.empty((train_counts[rank],), dtype=float)\n",
-- 
GitLab