diff --git a/econpizza/config.py b/econpizza/config.py index 1690ef2..432ea58 100644 --- a/econpizza/config.py +++ b/econpizza/config.py @@ -1,80 +1,41 @@ -"""Configuration object""" - import os import jax class EconPizzaConfig(dict): def __init__(self, *args, **kwargs): super(EconPizzaConfig, self).__init__(*args, **kwargs) - self._enable_persistent_cache = False - self._econpizza_cache_folder = "__econpizza_cache__" - self._jax_cache_folder = "__jax_cache__" - - @property - def enable_persistent_cache(self): - return self._enable_persistent_cache + self.__dict__ = self + self.enable_jax_persistent_cache = False + self.jax_cache_folder = "__jax_cache__" - @enable_persistent_cache.setter - def enable_persistent_cache(self, value): - self._enable_persistent_cache = value - self.setup_persistent_cache() + self._setup_persistent_cache_map = { + "enable_jax_persistent_cache": self.setup_persistent_cache_jax + } - @property - def jax_cache_folder(self): - return self._jax_cache_folder - - @jax_cache_folder.setter - def jax_cache_folder(self, value): - self._jax_cache_folder = value + def __setitem__(self, key, value): + return self.update(key, value) - @property - def econpizza_cache_folder(self): - return self._econpizza_cache_folder - - @econpizza_cache_folder.setter - def econpizza_cache_folder(self, value): - self._econpizza_cache_folder = value - def update(self, key, value): + """Updates the attribute, and if it's related to caching, calls the appropriate setup function.""" if hasattr(self, key): setattr(self, key, value) + if key in self._setup_persistent_cache_map and value: + self._setup_persistent_cache_map[key]() else: raise AttributeError(f"'EconPizzaConfig' object has no attribute '{key}'") - + def _create_cache_dir(self, folder_name: str): cwd = os.getcwd() folder_path = os.path.join(cwd, folder_name) os.makedirs(folder_path, exist_ok=True) - return folder_path - def setup_persistent_cache(self): - """Create folders for JAX and EconPizza cache. - By default, they are created in callee working directory. - """ - if self.enable_persistent_cache == True: - if not os.path.exists(self.econpizza_cache_folder): - folder_path_pizza = self._create_cache_dir(self.econpizza_cache_folder) - self.econpizza_cache_folder = folder_path_pizza - else: - folder_path_pizza = self.econpizza_cache_folder - - # Jax cache is enabled by the used via JAX API. In this case we should not set another folder - if jax.config.jax_compilation_cache_dir is None: - folder_path_jax = self._create_cache_dir(self.jax_cache_folder) - jax.config.update("jax_compilation_cache_dir", folder_path_jax) - jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) - jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - self.jax_cache_folder = folder_path_jax - - def __repr__(self): - properties = { - k.lstrip("_"): v for k, v in self.__dict__.items() if k.startswith("_") - } - return f"{properties}" - - def __str__(self): - return self.__repr__() - + def setup_persistent_cache_jax(self): + """Setup JAX persistent cache if enabled.""" + if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder): + folder_path_jax = self._create_cache_dir(self.jax_cache_folder) + jax.config.update("jax_compilation_cache_dir", folder_path_jax) + jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + self.jax_cache_folder = folder_path_jax config = EconPizzaConfig() diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py index 8a91425..d53496b 100644 --- a/econpizza/testing/test_config.py +++ b/econpizza/testing/test_config.py @@ -1,3 +1,4 @@ +"""Tests for the config module. Delete any __econpizza__ or __jax_cache__ folders you might have in the current folder before running""" import pytest import jax from unittest.mock import patch @@ -16,20 +17,19 @@ def ep_config_reset(): @pytest.fixture(scope="function", autouse=True) def os_getcwd_create(): - folder_path = "./config_working_dir" + test_cache_folder = os.path.abspath("config_working_dir") - if not os.path.exists(folder_path): - os.makedirs(folder_path) + if not os.path.exists(test_cache_folder): + os.makedirs(test_cache_folder) - with patch("os.getcwd", return_value="./config_working_dir"): + with patch("os.getcwd", return_value=test_cache_folder): yield - if os.path.exists(folder_path): - shutil.rmtree(folder_path) + if os.path.exists(test_cache_folder): + shutil.rmtree(test_cache_folder) def test_config_default_values(): - assert ep.config.enable_persistent_cache == False - assert ep.config.econpizza_cache_folder == "__econpizza_cache__" + assert ep.config["enable_jax_persistent_cache"] == False assert ep.config.jax_cache_folder == "__jax_cache__" def test_config_jax_default_values(): @@ -39,59 +39,49 @@ def test_config_jax_default_values(): @patch("os.makedirs") @patch("jax.config.update") -def test_config_enable_persistent_cache(mock_jax_update, mock_makedirs): - ep.config.enable_persistent_cache = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) +def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): + ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) - mock_jax_update.assert_any_call("jax_persistent_cache_min_entry_size_bytes", -1) mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0) -@patch("os.makedirs") -@patch("jax.config.update") -def test_config_set_econpizza_folder(mock_jax_update, mock_makedirs): - ep.config.econpizza_cache_folder = "test1" - ep.config.enable_persistent_cache = True - - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) - mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) - @patch("os.makedirs") @patch("jax.config.update") def test_config_set_jax_folder(mock_jax_update, mock_makedirs): ep.config.jax_cache_folder = "test1" - ep.config.enable_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1")) @patch("jax.config.update") def test_config_jax_folder_set_from_outside(mock_jax_update): mock_jax_update("jax_compilation_cache_dir", "jax_from_outside") - ep.config.enable_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside") +@patch("os.path.exists") @patch("os.makedirs") @patch("jax.config.update") -def test_econpizza_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs): - ep.config.enable_persistent_cache = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) - - ep.config.enable_persistent_cache = True - # only jax config is updated - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) - - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) - -@patch("jax.config.update") -def test_config_enable_persistent_cache_called_after_model_load(mock_jax_update): +def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs, mock_exists): + # Set to first return False when the folder is not created, then True when the folder is created + mock_exists.side_effect = [False, True] + + # When called for the first time, a cache folder should be created(default is __jax_cache__) + ep.config["enable_jax_persistent_cache"] = True + mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) + assert mock_jax_update.call_count == 2 + # Now reset the mock so that the calls are 0 again. + mock_makedirs.reset_mock() + mock_jax_update.reset_mock() + # The second time we should not create a folder + ep.config["enable_jax_persistent_cache"] = True + mock_makedirs.assert_not_called() + assert mock_jax_update.call_count == 0 + +def test_config_enable_jax_persistent_cache_called_after_model_load(): _ = ep.load(ep.examples.dsge) - assert os.path.exists(ep.config.econpizza_cache_folder) == False - ep.config.enable_persistent_cache = True - assert os.path.exists(ep.config.econpizza_cache_folder) == True - - - - + assert os.path.exists(ep.config.jax_cache_folder) == False + ep.config["enable_jax_persistent_cache"] = True + assert os.path.exists(ep.config.jax_cache_folder) == True