From 600187e6eddd239dde9aa93f584d82effafd9bf4 Mon Sep 17 00:00:00 2001
From: Julian Keck <julian.keck9@kit.edu>
Date: Fri, 2 Aug 2024 14:24:01 +0200
Subject: [PATCH] UPD properly handle session timeouts

---
 api/__init__.py    | 33 +++++++++++++++++++++++++++++----
 model/settings.py  |  3 +++
 model/wapi/cntl.py |  8 ++++----
 util/wapi_util.py  | 17 +++++++++++++++--
 4 files changed, 51 insertions(+), 10 deletions(-)

diff --git a/api/__init__.py b/api/__init__.py
index 42a9f16..de85627 100644
--- a/api/__init__.py
+++ b/api/__init__.py
@@ -17,18 +17,43 @@ def get_db(request: Request):
     return database
 
 
+def set_connection_parameters(db_conn):
+    # client-connection-defaults setzen, falls fuer unseren hostnamen definiert:
+    get_stmt_bindings = dict({'bnd_fqdn': settings.api_client_name.rstrip('.') + '.'})
+    get_stmt = """
+    select pg_set.name, ccp.cfg_param_value
+      from netadmin.dns_ntree n
+        JOIN netadmin.cntl_clnt_conn_params ccp ON ccp.dns_ntree_key_nr = n.key_nr
+        JOIN pg_catalog.pg_settings pg_set ON pg_set.name = ccp.cfg_param_name
+      where n.fqdn = %(bnd_fqdn)s
+    """
+    set_stmt = """select pg_catalog.set_config(%(bnd_k)s, %(bnd_v)s, %(bnd_is_local)s)"""
+    # apply param setting transaction-local (only for current TA)?
+    set_ta_local = False
+
+    #
+    # eigener transaction block
+    # ta-context: implizites commit nach block-exit, implizites cursor-schliessen
+    with db_conn.transaction():
+        with db_conn.cursor() as csr:
+            csr.execute(get_stmt, get_stmt_bindings)
+            ccp_rows = csr.fetchall()
+            for ccp_row in ccp_rows:
+                (p_name, p_value) = ccp_row
+                set_stmt_bindings = dict({'bnd_k': p_name, 'bnd_v': p_value, 'bnd_is_local': set_ta_local})
+                csr.execute(set_stmt, set_stmt_bindings)
+    return
+
+
 def get_conn(request: Request):
     if hasattr(request.state, 'db_conn'):
         return request.state.db_conn
 
     db_connection = get_db(request).connect()
-    # db_connection = db.connect()
     query = "SET search_path TO netadmin, public;"
     db.execute(db_connection, query)
-    db.execute(db_connection, "SET idle_in_transaction_session_timeout TO '60s';")
-    db.execute(db_connection, "SET idle_session_timeout TO '60s';")
-    db.execute(db_connection, "SET statement_timeout TO '60s';")
     db_connection.commit()
+    set_connection_parameters(db_connection)
     request.state.db_conn = db_connection
 
     # Register callbacks to ensure transactions are properly committed after request is over
diff --git a/model/settings.py b/model/settings.py
index d3a62af..365c9e5 100644
--- a/model/settings.py
+++ b/model/settings.py
@@ -1,3 +1,4 @@
+import socket
 from typing import Optional
 
 from pydantic import Extra
@@ -61,6 +62,8 @@ class Settings(BaseSettings):
 
     null_login_name: Optional[str] = None
 
+    api_client_name: str = socket.getfqdn()
+
     # Load settings from .env
     model_config = SettingsConfigDict(env_file=".env", extra=Extra.allow)
 
diff --git a/model/wapi/cntl.py b/model/wapi/cntl.py
index 68599a3..4a18679 100644
--- a/model/wapi/cntl.py
+++ b/model/wapi/cntl.py
@@ -52,7 +52,8 @@ class Mgr(BaseModel):
 
         descr = 'Session-Token ({})'.format(datetime.datetime.now())
 
-        with get_cursor(conn) as cursor:
+        cursor = get_cursor(conn)
+        with conn.transaction():
             cursor.execute("SELECT out_token_text, out_token_gpk "
                            "FROM cntl.create_sess_auth(%(login_name)s, %(descr)s, %(stay_logged_in)s)",
                            {
@@ -61,7 +62,6 @@ class Mgr(BaseModel):
                                'descr': descr
                            })
             res = cursor.fetchone()
-        conn.commit()
 
         return APIToken(gpk=res['out_token_gpk'],
                         token=res['out_token_text'],
@@ -70,11 +70,11 @@ class Mgr(BaseModel):
 
     @staticmethod
     def check_token(conn, token: str | APIToken):
-        with get_cursor(conn) as cursor:
+        cursor = get_cursor(conn)
+        with conn.transaction() as t:
             if isinstance(token, APIToken):
                 token = token.token
             cursor.execute("select * from cntl.validate_jwt(in_token:=%(token)s)", {'token': token})
-            conn.rollback()
 
             result = cursor.fetchone()
             if result is None:
diff --git a/util/wapi_util.py b/util/wapi_util.py
index c791ac6..bdbada7 100644
--- a/util/wapi_util.py
+++ b/util/wapi_util.py
@@ -1,6 +1,8 @@
 import json
 import logging
 
+import psycopg
+
 from model.settings import settings
 from util.util import get_cursor
 
@@ -29,9 +31,18 @@ def execute_wapi_function(conn, request: list[dict[str, str | dict]], user: str
         if 'idx' not in stmt.keys():
             stmt['idx'] = stmt['name'].replace('.', '_')
 
+    # if any statement is not a list statement, we need to use a higher isolation level
+    # this is just a heuristic, but it should work in most cases and it only creates false positives.
+    maybe_contains_dml = any(not stmt['name'].endswith('.list') for stmt in request)
+
     if user is None and superuser:
         user = settings.wapi_netvs_superuser
 
+    old_isolation_level = conn.isolation_level
+    conn.isolation_level = psycopg.IsolationLevel.READ_COMMITTED
+    if maybe_contains_dml:
+        conn.isolation_level = psycopg.IsolationLevel.SERIALIZABLE
+
     query = """
                select * from wapi_4_1.ta_handler(
                  in_login_name => %(login_name)s,
@@ -44,17 +55,19 @@ def execute_wapi_function(conn, request: list[dict[str, str | dict]], user: str
             """
     result = None
     try:
-        with get_cursor(conn) as cursor:
+        cursor = get_cursor(conn)
+        with conn.transaction():
             cursor.execute(query, {
                 'login_name': user,
                 'rq': json.dumps(request),
                 'dry_mode': dry_mode,
             })
             result = cursor.fetchall()
-        conn.commit()
     except Exception as e:
         logger.error('Error in executing statement:\n\n{stmt}\n\n'.format(stmt=json.dumps(request, indent=2)))
         raise e
+    finally:
+        conn.isolation_level = old_isolation_level
 
     if result is None:
         return {}
-- 
GitLab