From 00c108ff24a4ad58ec67fc93609ba8e6cf7be7b8 Mon Sep 17 00:00:00 2001
From: Michael Simon <simon@kit.edu>
Date: Wed, 30 Oct 2024 08:14:02 +0100
Subject: [PATCH] NO_STORY add capability to filter home orgs on host basis

Also make the extra home orgs scriptable, which should help configuring
the discovery page very much.
---
 .../service/disco/DiscoveryCacheService.java  | 55 ++++++++++++++-----
 .../disco/UserProvisionerCachedEntry.java     |  2 +-
 .../disco/UserProvisionerComparator.java      | 14 +++++
 .../webreg/bean/disco/DiscoveryLoginBean.java | 31 ++++++++++-
 4 files changed, 85 insertions(+), 17 deletions(-)
 create mode 100644 bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerComparator.java

diff --git a/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/DiscoveryCacheService.java b/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/DiscoveryCacheService.java
index d22b3fd51..804198fc7 100644
--- a/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/DiscoveryCacheService.java
+++ b/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/DiscoveryCacheService.java
@@ -5,7 +5,6 @@ import static edu.kit.scc.webreg.dao.ops.RqlExpressions.equal;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Comparator;
 import java.util.Date;
 import java.util.List;
 import java.util.Set;
@@ -141,24 +140,13 @@ public class DiscoveryCacheService implements Serializable {
 	}
 
 	public List<UserProvisionerCachedEntry> getExtraEntryList(List<ScriptEntity> filterScriptList) {
-		return filterAllEntries(filterScriptList, singleton.getExtraEntryList());
+		return filterExtraEntries(filterScriptList, singleton.getExtraEntryList(), singleton.getAllEntryList());
 	}
 
 	private List<UserProvisionerCachedEntry> filterAllEntries(List<ScriptEntity> filterScriptList,
 			List<UserProvisionerCachedEntry> entryList) {
 		if (filterScriptList != null && filterScriptList.size() > 0) {
-			Comparator<UserProvisionerCachedEntry> comparator = new Comparator<UserProvisionerCachedEntry>() {
-
-				@Override
-				public int compare(UserProvisionerCachedEntry e1, UserProvisionerCachedEntry e2) {
-					if (e1.getDisplayName() != null)
-						return e1.getDisplayName().compareTo(e2.getDisplayName());
-					else 
-						return 0;
-				}
-			};
-
-			Set<UserProvisionerCachedEntry> returnList = new TreeSet<>(comparator);
+			Set<UserProvisionerCachedEntry> returnList = new TreeSet<>(new UserProvisionerComparator());
 			for (ScriptEntity script : filterScriptList) {
 				returnList.addAll(filterEntries(script, entryList));
 			}
@@ -167,6 +155,18 @@ public class DiscoveryCacheService implements Serializable {
 			return entryList;
 	}
 
+	private List<UserProvisionerCachedEntry> filterExtraEntries(List<ScriptEntity> filterScriptList,
+			List<UserProvisionerCachedEntry> extraEntryList, List<UserProvisionerCachedEntry> allEntryList) {
+		if (filterScriptList != null && filterScriptList.size() > 0) {
+			Set<UserProvisionerCachedEntry> returnList = new TreeSet<>(new UserProvisionerComparator());
+			for (ScriptEntity script : filterScriptList) {
+				returnList.addAll(filterExtraEntries(script, extraEntryList, allEntryList));
+			}
+			return new ArrayList<>(returnList);
+		} else
+			return extraEntryList;
+	}
+
 	private List<UserProvisionerCachedEntry> filterEntries(ScriptEntity scriptEntity,
 			List<UserProvisionerCachedEntry> entryList) {
 		ScriptEngine engine = (new ScriptEngineManager()).getEngineByName(scriptEntity.getScriptEngine());
@@ -211,4 +211,31 @@ public class DiscoveryCacheService implements Serializable {
 			return entryList;
 		}
 	}
+	
+	private List<UserProvisionerCachedEntry> filterExtraEntries(ScriptEntity scriptEntity,
+			List<UserProvisionerCachedEntry> extraEntryList, List<UserProvisionerCachedEntry> allEntryList) {
+		ScriptEngine engine = (new ScriptEngineManager()).getEngineByName(scriptEntity.getScriptEngine());
+
+		if (engine == null) {
+			logger.warn("No engine set for script {}. Returning all IDPs", scriptEntity.getName());
+			return extraEntryList;
+		}
+
+		try {
+			engine.eval(scriptEntity.getScript());
+			Invocable invocable = (Invocable) engine;
+
+			try {
+				List<UserProvisionerCachedEntry> extraList = new ArrayList<>();
+				invocable.invokeFunction("filterExtra", extraEntryList, extraList, allEntryList, logger);
+				return extraList;
+			} catch (NoSuchMethodException e) {
+			}
+
+			return extraEntryList;
+		} catch (ScriptException e) {
+			logger.warn("Script execution failed.", e);
+			return extraEntryList;
+		}
+	}
 }
diff --git a/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerCachedEntry.java b/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerCachedEntry.java
index b02bf5945..aa29c6455 100644
--- a/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerCachedEntry.java
+++ b/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerCachedEntry.java
@@ -8,7 +8,7 @@ public class UserProvisionerCachedEntry implements Serializable {
 	
 	private Long id;
 	private String name;
-	// for backwards compatibility in filterIdp scipts
+	// for backwards compatibility in filterIdp scripts
 	private String entityId;
 	private String displayName;
 	private String orgName;
diff --git a/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerComparator.java b/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerComparator.java
new file mode 100644
index 000000000..34c88c9c0
--- /dev/null
+++ b/bwreg-service/src/main/java/edu/kit/scc/webreg/service/disco/UserProvisionerComparator.java
@@ -0,0 +1,14 @@
+package edu.kit.scc.webreg.service.disco;
+
+import java.util.Comparator;
+
+public class UserProvisionerComparator implements Comparator<UserProvisionerCachedEntry> {
+
+	@Override
+	public int compare(UserProvisionerCachedEntry e1, UserProvisionerCachedEntry e2) {
+		if (e1.getDisplayName() != null)
+			return e1.getDisplayName().compareTo(e2.getDisplayName());
+		else 
+			return 0;
+	}
+}
diff --git a/bwreg-webapp/src/main/java/edu/kit/scc/webreg/bean/disco/DiscoveryLoginBean.java b/bwreg-webapp/src/main/java/edu/kit/scc/webreg/bean/disco/DiscoveryLoginBean.java
index d3d97c82f..2d0e36dbd 100644
--- a/bwreg-webapp/src/main/java/edu/kit/scc/webreg/bean/disco/DiscoveryLoginBean.java
+++ b/bwreg-webapp/src/main/java/edu/kit/scc/webreg/bean/disco/DiscoveryLoginBean.java
@@ -16,6 +16,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
 
+import org.slf4j.Logger;
+
 import edu.kit.scc.webreg.bootstrap.ApplicationConfig;
 import edu.kit.scc.webreg.entity.SamlIdpConfigurationEntity;
 import edu.kit.scc.webreg.entity.SamlIdpMetadataEntity;
@@ -33,6 +35,7 @@ import edu.kit.scc.webreg.service.SamlIdpConfigurationService;
 import edu.kit.scc.webreg.service.SamlIdpMetadataService;
 import edu.kit.scc.webreg.service.SamlSpConfigurationService;
 import edu.kit.scc.webreg.service.SamlSpMetadataService;
+import edu.kit.scc.webreg.service.ScriptService;
 import edu.kit.scc.webreg.service.disco.DiscoveryCacheService;
 import edu.kit.scc.webreg.service.disco.UserProvisionerCachedEntry;
 import edu.kit.scc.webreg.service.identity.UserProvisionerService;
@@ -51,6 +54,7 @@ import jakarta.faces.view.ViewScoped;
 import jakarta.inject.Inject;
 import jakarta.inject.Named;
 import jakarta.servlet.http.Cookie;
+import jakarta.servlet.http.HttpServletRequest;
 
 @Named
 @ViewScoped
@@ -58,6 +62,9 @@ public class DiscoveryLoginBean implements Serializable {
 
 	private static final long serialVersionUID = 1L;
 
+	@Inject
+	private Logger logger;
+
 	@Inject
 	private SamlIdpMetadataService idpService;
 
@@ -103,6 +110,12 @@ public class DiscoveryLoginBean implements Serializable {
 	@Inject
 	private UserProvisionerService userProvisionerService;
 
+	@Inject
+	private ScriptService scriptService;
+
+	@Inject
+	private HttpServletRequest request;
+
 	// private Object selectedIdp;
 	private UserProvisionerCachedEntry selected;
 
@@ -211,6 +224,20 @@ public class DiscoveryLoginBean implements Serializable {
 				}
 			}
 
+			/*
+			 * filter home orgs based on hostname
+			 */
+			if (appConfig.getConfigValue(request.getServerName() + "_ds_filter") != null) {
+				ScriptEntity script = scriptService.findByAttr("name",
+						appConfig.getConfigValue(request.getServerName() + "_ds_filter"));
+				if (script != null) {
+					filterScriptList.add(script);
+				} else {
+					logger.warn("Script for filtering is set ({}), but missing",
+							appConfig.getConfigValue(request.getServerName() + "_ds_filter"));
+				}
+			}
+
 			Integer largeLimit = Integer
 					.parseInt(appConfig.getConfigValueOrDefault("discovery_large_list_threshold", "100"));
 			if (getAllList().size() > largeLimit)
@@ -355,9 +382,9 @@ public class DiscoveryLoginBean implements Serializable {
 	public Boolean getLargeList() {
 		return largeList;
 	}
-	
+
 	public void clearPanel() {
-	    this.selected = null;
+		this.selected = null;
 	}
 
 }
-- 
GitLab