From 7fdb96c2d5f7daacb948fd68ed0ba5a49fbb436c Mon Sep 17 00:00:00 2001 From: ulysse <ulysse.chosson@obspm.fr> Date: Tue, 21 Nov 2023 14:03:24 +0100 Subject: [PATCH] refoctor getting env var so that it can be reused in other projects --- exodam_utils/__init__.py | 2 +- exodam_utils/config/__init__.py | 2 +- exodam_utils/config/env_var.py | 94 ++++++++++++++++++------------ exodam_utils/connect.py | 6 +- tests/utils/config/conftest.py | 9 +++ tests/utils/config/test_env_var.py | 46 ++++++++++++--- 6 files changed, 109 insertions(+), 50 deletions(-) diff --git a/exodam_utils/__init__.py b/exodam_utils/__init__.py index b0c1ccf..45284e8 100644 --- a/exodam_utils/__init__.py +++ b/exodam_utils/__init__.py @@ -1,7 +1,7 @@ """Exodam utils module.""" # Local imports -from .config import get_conn_info_form_env # noqa: F401 +from .config import ExodamDbConnectionInfo # noqa: F401 from .connect import conn_manager, db_connect # noqa: F401 from .const import BASE_DIR, EXODICT_TYPE, REGEX_GS, REGEX_UNIT # noqa: F401 from .exception import ( # noqa: F401 diff --git a/exodam_utils/config/__init__.py b/exodam_utils/config/__init__.py index 2ef8268..4d55086 100644 --- a/exodam_utils/config/__init__.py +++ b/exodam_utils/config/__init__.py @@ -1,4 +1,4 @@ """Config module.""" # Local imports -from .env_var import DbConnectionInfo, get_conn_info_form_env # noqa: F401 +from .env_var import DbConnectionInfo, ExodamDbConnectionInfo # noqa: F401 diff --git a/exodam_utils/config/env_var.py b/exodam_utils/config/env_var.py index 4325a70..fce81ae 100644 --- a/exodam_utils/config/env_var.py +++ b/exodam_utils/config/env_var.py @@ -3,7 +3,6 @@ # Standard imports from functools import lru_cache from os import getenv -from typing import cast # Third party imports from dotenv import load_dotenv @@ -22,6 +21,7 @@ class DbConnectionInfo: host: Host of the database port: Port of the database. db_name: Database name. + _raise_msg: Raise messages for the deferred raise method. """ user: str @@ -29,47 +29,67 @@ class DbConnectionInfo: host: str port: str db_name: str - - def __init__(self) -> None: + _raise_msg: list[str] + + def _getenv_defer_fail(self, key: str) -> str: + """ + Get env variable of defer fail. + + Try to get given environment variable if is empty of doesn't exist, stock name + of the variable to raise an error after. + + Args: + key: Name of the env var. + + Returns: + The env var or None if is empty or doesn't exist. + """ + env_var = getenv(key) + + if env_var is None or not env_var: + self._raise_msg.append(key) + return "" + + return env_var + + def _raise_defered_getenv_failure(self) -> None: + """Raise deferred missing environment variable failure.""" + msg = f"Missing environment variable(s): {', '.join(self._raise_msg)}" + raise MissingEnvironmentVariableError(msg) + + # PLR0913 = Too many arguments to function call + def __init__( # noqa: PLR0913 + self, + user_env_var_name: str, + password_env_var_name: str, + host_env_var_name: str, + port_env_var_name: str, + database_env_var_name: str, + ) -> None: """Initialize a DbConnectionInfo instance.""" load_dotenv() + self._raise_msg = [] - local_db_user = getenv("LOCAL_DB_USER") - local_db_password = getenv("LOCAL_DB_PASSWORD") - local_db_host = getenv("LOCAL_DB_HOST") - local_db_port = getenv("LOCAL_DB_PORT") - local_db_database_name = getenv("LOCAL_DB_DATABASE_NAME") - - name2var = { - "LOCAL_DB_USER": local_db_user, - "LOCAL_DB_PASSWORD": local_db_password, - "LOCAL_DB_HOST": local_db_host, - "LOCAL_DB_PORT": local_db_port, - "LOCAL_DB_DATABASE_NAME": local_db_database_name, - } - - missing_vars = [ - name for name, val in name2var.items() if val is None or not val - ] + self.user = self._getenv_defer_fail(user_env_var_name) + self.password = self._getenv_defer_fail(password_env_var_name) + self.host = self._getenv_defer_fail(host_env_var_name) + self.port = self._getenv_defer_fail(port_env_var_name) + self.db_name = self._getenv_defer_fail(database_env_var_name) - if missing_vars: - var_msg = "variable" if len(missing_vars) == 1 else "variables" - msg = f"Missing environment {var_msg}: {', '.join(missing_vars)}" - raise MissingEnvironmentVariableError(msg) - - self.user = cast(str, local_db_user) - self.password = cast(str, local_db_password) - self.host = cast(str, local_db_host) - self.port = cast(str, local_db_port) - self.db_name = cast(str, local_db_database_name) + if self._raise_msg: + self._raise_defered_getenv_failure() @lru_cache(maxsize=1) -def get_conn_info_form_env() -> DbConnectionInfo: - """ - Get connection information from environment variables. +class ExodamDbConnectionInfo(DbConnectionInfo): + """Singleton for exodam database connection information.""" - Returns: - Connection information. - """ - return DbConnectionInfo() + def __init__(self) -> None: + """Initialize an ExodamDbConnectionInfo instance.""" + super().__init__( + "LOCAL_DB_USER", + "LOCAL_DB_PASSWORD", + "LOCAL_DB_HOST", + "LOCAL_DB_PORT", + "LOCAL_DB_DATABASE_NAME", + ) diff --git a/exodam_utils/connect.py b/exodam_utils/connect.py index 5aeb989..3c8edcd 100644 --- a/exodam_utils/connect.py +++ b/exodam_utils/connect.py @@ -10,10 +10,10 @@ from psycopg import Connection from py_linq_sql import connect as py_linq_sql_connect # Local imports -from .config import DbConnectionInfo, get_conn_info_form_env +from .config import ExodamDbConnectionInfo -def db_connect(conn_info: DbConnectionInfo) -> Connection: +def db_connect(conn_info: ExodamDbConnectionInfo) -> Connection: """ Connect to a database. @@ -47,7 +47,7 @@ def conn_manager(conn: Connection | None) -> Generator[Connection, None, None]: # No cover because it's use when we are not in tests. if not we_have_a_given_connection: # pragma: no cover - conn_info = get_conn_info_form_env() + conn_info = ExodamDbConnectionInfo() connection = db_connect(conn_info) yield cast(Connection, connection) diff --git a/tests/utils/config/conftest.py b/tests/utils/config/conftest.py index 183ec9e..0041206 100644 --- a/tests/utils/config/conftest.py +++ b/tests/utils/config/conftest.py @@ -8,6 +8,9 @@ from os import getenv # Third party imports from dotenv import load_dotenv +# First party imports +from exodam_utils import ExodamDbConnectionInfo + def _get_old_env_vars() -> dict[str, str]: load_dotenv() @@ -54,3 +57,9 @@ def empty_some_environment_vars(): yield _set_env_vars_with_old(old_env_vars) + + +@pytest.fixture(scope="function") +def ExodamDbConnectionInfo_lru_cache_clear(): + yield + ExodamDbConnectionInfo.cache_clear() diff --git a/tests/utils/config/test_env_var.py b/tests/utils/config/test_env_var.py index ae8d00d..f296e22 100644 --- a/tests/utils/config/test_env_var.py +++ b/tests/utils/config/test_env_var.py @@ -1,19 +1,49 @@ # Pytest imports import pytest +# Third party imports +from assertpy import assert_that, soft_assertions + # First party imports -from exodam_utils import MissingEnvironmentVariableError, get_conn_info_form_env +from exodam_utils import ExodamDbConnectionInfo, MissingEnvironmentVariableError from exodam_utils.config import DbConnectionInfo -def test_DbConnectionInfo_success(dummy_environment_vars): - assert DbConnectionInfo() +def test_DbConnectionInfo_success( + dummy_environment_vars, + ExodamDbConnectionInfo_lru_cache_clear, +): + assert DbConnectionInfo( + "LOCAL_DB_USER", + "LOCAL_DB_PASSWORD", + "LOCAL_DB_HOST", + "LOCAL_DB_PORT", + "LOCAL_DB_DATABASE_NAME", + ) -def test_DbConnectionInfo_missing_env_var(empty_some_environment_vars): - with pytest.raises(MissingEnvironmentVariableError): - assert DbConnectionInfo() +def test_ExodamDbConnectionInfo( + dummy_environment_vars, + ExodamDbConnectionInfo_lru_cache_clear, +): + conn_info = ExodamDbConnectionInfo() + with soft_assertions(): + assert_that(conn_info.user).is_equal_to("test_user") + assert_that(conn_info.password).is_equal_to("dummy") + assert_that(conn_info.host).is_equal_to("test_localhost") + assert_that(conn_info.port).is_equal_to("5555") + assert_that(conn_info.db_name).is_equal_to("test_exodam") -def test_get_conn_info_form_env(): - assert get_conn_info_form_env() +def test_DbConnectionInfo_missing_env_var( + empty_some_environment_vars, + ExodamDbConnectionInfo_lru_cache_clear, +): + with pytest.raises(MissingEnvironmentVariableError): + assert DbConnectionInfo( + "LOCAL_DB_USER", + "LOCAL_DB_PASSWORD", + "LOCAL_DB_HOST", + "LOCAL_DB_PORT", + "LOCAL_DB_DATABASE_NAME", + ) -- GitLab