From fd12e3468ca3ad34ce777e7f2ab7bd3d78c8feb1 Mon Sep 17 00:00:00 2001 From: shtopane Date: Wed, 11 Sep 2024 18:59:48 +0300 Subject: [PATCH 1/5] add another option for jax persistent cache; update tests --- econpizza/config.py | 25 ++++++++--- econpizza/testing/test_config.py | 76 +++++++++++++++++++++++--------- 2 files changed, 74 insertions(+), 27 deletions(-) diff --git a/econpizza/config.py b/econpizza/config.py index 1690ef2..bf301b3 100644 --- a/econpizza/config.py +++ b/econpizza/config.py @@ -7,8 +7,9 @@ class EconPizzaConfig(dict): def __init__(self, *args, **kwargs): super(EconPizzaConfig, self).__init__(*args, **kwargs) self._enable_persistent_cache = False - self._econpizza_cache_folder = "__econpizza_cache__" + self._enable_jax_persistent_cache = False self._jax_cache_folder = "__jax_cache__" + self._econpizza_cache_folder = "__econpizza_cache__" @property def enable_persistent_cache(self): @@ -19,6 +20,15 @@ def enable_persistent_cache(self, value): self._enable_persistent_cache = value self.setup_persistent_cache() + @property + def enable_jax_persistent_cache(self): + return self._enable_jax_persistent_cache + + @enable_jax_persistent_cache.setter + def enable_jax_persistent_cache(self, value): + self._enable_jax_persistent_cache = value + self.setup_persistent_cache_jax() + @property def jax_cache_folder(self): return self._jax_cache_folder @@ -49,8 +59,9 @@ def _create_cache_dir(self, folder_name: str): return folder_path def setup_persistent_cache(self): - """Create folders for JAX and EconPizza cache. - By default, they are created in callee working directory. + """Create folder the econpizza cache. + Exported functions via JAX export will be saved there. + By default, it is created in callee working directory. """ if self.enable_persistent_cache == True: if not os.path.exists(self.econpizza_cache_folder): @@ -59,8 +70,12 @@ def setup_persistent_cache(self): else: folder_path_pizza = self.econpizza_cache_folder - # Jax cache is enabled by the used via JAX API. In this case we should not set another folder - if jax.config.jax_compilation_cache_dir is None: + def setup_persistent_cache_jax(self): + """Setup JAX persistent cache. + By default, it is created in callee working directory. + """ + if self.enable_jax_persistent_cache == True: + if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder): folder_path_jax = self._create_cache_dir(self.jax_cache_folder) jax.config.update("jax_compilation_cache_dir", folder_path_jax) jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py index 8a91425..79de0c6 100644 --- a/econpizza/testing/test_config.py +++ b/econpizza/testing/test_config.py @@ -16,19 +16,20 @@ def ep_config_reset(): @pytest.fixture(scope="function", autouse=True) def os_getcwd_create(): - folder_path = "./config_working_dir" + test_cache_folder = os.path.abspath("config_working_dir") - if not os.path.exists(folder_path): - os.makedirs(folder_path) + if not os.path.exists(test_cache_folder): + os.makedirs(test_cache_folder) - with patch("os.getcwd", return_value="./config_working_dir"): + with patch("os.getcwd", return_value=test_cache_folder): yield - if os.path.exists(folder_path): - shutil.rmtree(folder_path) + if os.path.exists(test_cache_folder): + shutil.rmtree(test_cache_folder) def test_config_default_values(): assert ep.config.enable_persistent_cache == False + assert ep.config.enable_jax_persistent_cache == False assert ep.config.econpizza_cache_folder == "__econpizza_cache__" assert ep.config.jax_cache_folder == "__jax_cache__" @@ -38,10 +39,16 @@ def test_config_jax_default_values(): assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0 @patch("os.makedirs") -@patch("jax.config.update") -def test_config_enable_persistent_cache(mock_jax_update, mock_makedirs): +def test_config_enable_persistent_cache(setup_jax_mock, mock_makedirs): ep.config.enable_persistent_cache = True + assert setup_jax_mock.call_count == 0 mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) + + +@patch("os.makedirs") +@patch("jax.config.update") +def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): + ep.config.enable_jax_persistent_cache = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) @@ -49,49 +56,74 @@ def test_config_enable_persistent_cache(mock_jax_update, mock_makedirs): mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0) @patch("os.makedirs") -@patch("jax.config.update") -def test_config_set_econpizza_folder(mock_jax_update, mock_makedirs): +def test_config_set_econpizza_folder(mock_makedirs): ep.config.econpizza_cache_folder = "test1" ep.config.enable_persistent_cache = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) - mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) @patch("os.makedirs") @patch("jax.config.update") def test_config_set_jax_folder(mock_jax_update, mock_makedirs): ep.config.jax_cache_folder = "test1" - ep.config.enable_persistent_cache = True + ep.config.enable_jax_persistent_cache = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1")) @patch("jax.config.update") def test_config_jax_folder_set_from_outside(mock_jax_update): mock_jax_update("jax_compilation_cache_dir", "jax_from_outside") - ep.config.enable_persistent_cache = True + ep.config.enable_jax_persistent_cache = True mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside") +@patch("os.path.exists") @patch("os.makedirs") -@patch("jax.config.update") -def test_econpizza_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs): - ep.config.enable_persistent_cache = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) +def test_econpizza_cache_folder_not_created_second_time(mock_makedirs, mock_exists): + # Set to first return False when the folder is not created, then True when the folder is created + mock_exists.side_effect = [False, True] + # When called for the first time, a cache folder should be created(default is __econpizza_cache__) ep.config.enable_persistent_cache = True - # only jax config is updated mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) + # Now reset the mock so that the calls are 0 again. + mock_makedirs.reset_mock() + # The second time we should not create a folder + ep.config.enable_persistent_cache = True + mock_makedirs.assert_not_called() - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) - +@patch("os.path.exists") +@patch("os.makedirs") @patch("jax.config.update") -def test_config_enable_persistent_cache_called_after_model_load(mock_jax_update): +def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs, mock_exists): + # Set to first return False when the folder is not created, then True when the folder is created + mock_exists.side_effect = [False, True] + + # When called for the first time, a cache folder should be created(default is __jax_cache__) + ep.config.enable_jax_persistent_cache = True + mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) + assert mock_jax_update.call_count == 3 + # Now reset the mock so that the calls are 0 again. + mock_makedirs.reset_mock() + mock_jax_update.reset_mock() + # The second time we should not create a folder + ep.config.enable_jax_persistent_cache = True + mock_makedirs.assert_not_called() + assert mock_jax_update.call_count == 0 + +def test_config_enable_persistent_cache_called_after_model_load(): _ = ep.load(ep.examples.dsge) assert os.path.exists(ep.config.econpizza_cache_folder) == False ep.config.enable_persistent_cache = True assert os.path.exists(ep.config.econpizza_cache_folder) == True +def test_config_enable_jax_persistent_cache_called_after_model_load(): + _ = ep.load(ep.examples.dsge) + + assert os.path.exists(ep.config.jax_cache_folder) == False + ep.config.enable_jax_persistent_cache = True + assert os.path.exists(ep.config.jax_cache_folder) == True + From 4ee66dfb0ab28470852e6b8e44629ea75e3a6452 Mon Sep 17 00:00:00 2001 From: shtopane Date: Wed, 11 Sep 2024 20:49:25 +0300 Subject: [PATCH 2/5] fix test and add comment --- econpizza/testing/test_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py index 79de0c6..73b3ff6 100644 --- a/econpizza/testing/test_config.py +++ b/econpizza/testing/test_config.py @@ -1,3 +1,4 @@ +"""Tests for the config module. Delete any __econpizza__ or __jax_cache__ folders you might have in the current folder before running""" import pytest import jax from unittest.mock import patch @@ -39,9 +40,8 @@ def test_config_jax_default_values(): assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0 @patch("os.makedirs") -def test_config_enable_persistent_cache(setup_jax_mock, mock_makedirs): +def test_config_enable_persistent_cache(mock_makedirs): ep.config.enable_persistent_cache = True - assert setup_jax_mock.call_count == 0 mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) From 863ab572b82c66367d74531bf1f13b87d5d304b3 Mon Sep 17 00:00:00 2001 From: shtopane Date: Mon, 16 Sep 2024 11:44:44 +0300 Subject: [PATCH 3/5] use hasattr and setattr in update. Use a map of property:method to call setup functions. --- econpizza/config.py | 100 +++++++++---------------------- econpizza/testing/test_config.py | 34 +++++------ 2 files changed, 46 insertions(+), 88 deletions(-) diff --git a/econpizza/config.py b/econpizza/config.py index bf301b3..a049207 100644 --- a/econpizza/config.py +++ b/econpizza/config.py @@ -1,95 +1,53 @@ -"""Configuration object""" - import os import jax class EconPizzaConfig(dict): def __init__(self, *args, **kwargs): super(EconPizzaConfig, self).__init__(*args, **kwargs) - self._enable_persistent_cache = False - self._enable_jax_persistent_cache = False - self._jax_cache_folder = "__jax_cache__" - self._econpizza_cache_folder = "__econpizza_cache__" - - @property - def enable_persistent_cache(self): - return self._enable_persistent_cache - - @enable_persistent_cache.setter - def enable_persistent_cache(self, value): - self._enable_persistent_cache = value - self.setup_persistent_cache() - - @property - def enable_jax_persistent_cache(self): - return self._enable_jax_persistent_cache - - @enable_jax_persistent_cache.setter - def enable_jax_persistent_cache(self, value): - self._enable_jax_persistent_cache = value - self.setup_persistent_cache_jax() + self.__dict__ = self + self.enable_persistent_cache = False + self.enable_jax_persistent_cache = False + self.jax_cache_folder = "__jax_cache__" + self.econpizza_cache_folder = "__econpizza_cache__" + + self._setup_persistent_cache_map = { + # "enable_persistent_cache": self.setup_persistent_cache, + "enable_jax_persistent_cache": self.setup_persistent_cache_jax + } - @property - def jax_cache_folder(self): - return self._jax_cache_folder - - @jax_cache_folder.setter - def jax_cache_folder(self, value): - self._jax_cache_folder = value + def __setitem__(self, key, value): + return self.update(key, value) - @property - def econpizza_cache_folder(self): - return self._econpizza_cache_folder - - @econpizza_cache_folder.setter - def econpizza_cache_folder(self, value): - self._econpizza_cache_folder = value - def update(self, key, value): + """Updates the attribute, and if it's related to caching, calls the appropriate setup function.""" if hasattr(self, key): setattr(self, key, value) + if key in self._setup_persistent_cache_map and value: + self._setup_persistent_cache_map[key]() else: raise AttributeError(f"'EconPizzaConfig' object has no attribute '{key}'") - + def _create_cache_dir(self, folder_name: str): cwd = os.getcwd() folder_path = os.path.join(cwd, folder_name) os.makedirs(folder_path, exist_ok=True) - return folder_path def setup_persistent_cache(self): - """Create folder the econpizza cache. - Exported functions via JAX export will be saved there. - By default, it is created in callee working directory. - """ - if self.enable_persistent_cache == True: - if not os.path.exists(self.econpizza_cache_folder): - folder_path_pizza = self._create_cache_dir(self.econpizza_cache_folder) - self.econpizza_cache_folder = folder_path_pizza - else: - folder_path_pizza = self.econpizza_cache_folder + """Create econpizza cache folder. If caching is enabled, sets up the cache.""" + if not os.path.exists(self.econpizza_cache_folder): + folder_path_pizza = self._create_cache_dir(self.econpizza_cache_folder) + self.econpizza_cache_folder = folder_path_pizza + else: + folder_path_pizza = self.econpizza_cache_folder def setup_persistent_cache_jax(self): - """Setup JAX persistent cache. - By default, it is created in callee working directory. - """ - if self.enable_jax_persistent_cache == True: - if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder): - folder_path_jax = self._create_cache_dir(self.jax_cache_folder) - jax.config.update("jax_compilation_cache_dir", folder_path_jax) - jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) - jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - self.jax_cache_folder = folder_path_jax - - def __repr__(self): - properties = { - k.lstrip("_"): v for k, v in self.__dict__.items() if k.startswith("_") - } - return f"{properties}" - - def __str__(self): - return self.__repr__() - + """Setup JAX persistent cache if enabled.""" + if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder): + folder_path_jax = self._create_cache_dir(self.jax_cache_folder) + jax.config.update("jax_compilation_cache_dir", folder_path_jax) + jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) + jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + self.jax_cache_folder = folder_path_jax config = EconPizzaConfig() diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py index 73b3ff6..d245b16 100644 --- a/econpizza/testing/test_config.py +++ b/econpizza/testing/test_config.py @@ -29,8 +29,8 @@ def os_getcwd_create(): shutil.rmtree(test_cache_folder) def test_config_default_values(): - assert ep.config.enable_persistent_cache == False - assert ep.config.enable_jax_persistent_cache == False + assert ep.config["enable_persistent_cache"] == False + assert ep.config["enable_jax_persistent_cache"] == False assert ep.config.econpizza_cache_folder == "__econpizza_cache__" assert ep.config.jax_cache_folder == "__jax_cache__" @@ -40,15 +40,16 @@ def test_config_jax_default_values(): assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0 @patch("os.makedirs") +@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") def test_config_enable_persistent_cache(mock_makedirs): - ep.config.enable_persistent_cache = True + ep.config["enable_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) @patch("os.makedirs") @patch("jax.config.update") def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): - ep.config.enable_jax_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) @@ -56,9 +57,10 @@ def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0) @patch("os.makedirs") +@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") def test_config_set_econpizza_folder(mock_makedirs): ep.config.econpizza_cache_folder = "test1" - ep.config.enable_persistent_cache = True + ep.config["enable_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) @@ -66,29 +68,30 @@ def test_config_set_econpizza_folder(mock_makedirs): @patch("jax.config.update") def test_config_set_jax_folder(mock_jax_update, mock_makedirs): ep.config.jax_cache_folder = "test1" - ep.config.enable_jax_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1")) @patch("jax.config.update") def test_config_jax_folder_set_from_outside(mock_jax_update): mock_jax_update("jax_compilation_cache_dir", "jax_from_outside") - ep.config.enable_jax_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside") @patch("os.path.exists") @patch("os.makedirs") +@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") def test_econpizza_cache_folder_not_created_second_time(mock_makedirs, mock_exists): # Set to first return False when the folder is not created, then True when the folder is created mock_exists.side_effect = [False, True] # When called for the first time, a cache folder should be created(default is __econpizza_cache__) - ep.config.enable_persistent_cache = True + ep.config["enable_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) # Now reset the mock so that the calls are 0 again. mock_makedirs.reset_mock() # The second time we should not create a folder - ep.config.enable_persistent_cache = True + ep.config["enable_persistent_cache"] = True mock_makedirs.assert_not_called() @patch("os.path.exists") @@ -99,31 +102,28 @@ def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs mock_exists.side_effect = [False, True] # When called for the first time, a cache folder should be created(default is __jax_cache__) - ep.config.enable_jax_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) assert mock_jax_update.call_count == 3 # Now reset the mock so that the calls are 0 again. mock_makedirs.reset_mock() mock_jax_update.reset_mock() # The second time we should not create a folder - ep.config.enable_jax_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_not_called() assert mock_jax_update.call_count == 0 +@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") def test_config_enable_persistent_cache_called_after_model_load(): _ = ep.load(ep.examples.dsge) assert os.path.exists(ep.config.econpizza_cache_folder) == False - ep.config.enable_persistent_cache = True + ep.config["enable_persistent_cache"] = True assert os.path.exists(ep.config.econpizza_cache_folder) == True def test_config_enable_jax_persistent_cache_called_after_model_load(): _ = ep.load(ep.examples.dsge) assert os.path.exists(ep.config.jax_cache_folder) == False - ep.config.enable_jax_persistent_cache = True + ep.config["enable_jax_persistent_cache"] = True assert os.path.exists(ep.config.jax_cache_folder) == True - - - - From 789eba2d1dac838fc3c30fad98a6a5d3358d2044 Mon Sep 17 00:00:00 2001 From: shtopane Date: Mon, 16 Sep 2024 16:12:32 +0300 Subject: [PATCH 4/5] remove econpizza config --- econpizza/config.py | 11 --------- econpizza/testing/test_config.py | 41 -------------------------------- 2 files changed, 52 deletions(-) diff --git a/econpizza/config.py b/econpizza/config.py index a049207..287e785 100644 --- a/econpizza/config.py +++ b/econpizza/config.py @@ -5,13 +5,10 @@ class EconPizzaConfig(dict): def __init__(self, *args, **kwargs): super(EconPizzaConfig, self).__init__(*args, **kwargs) self.__dict__ = self - self.enable_persistent_cache = False self.enable_jax_persistent_cache = False self.jax_cache_folder = "__jax_cache__" - self.econpizza_cache_folder = "__econpizza_cache__" self._setup_persistent_cache_map = { - # "enable_persistent_cache": self.setup_persistent_cache, "enable_jax_persistent_cache": self.setup_persistent_cache_jax } @@ -33,14 +30,6 @@ def _create_cache_dir(self, folder_name: str): os.makedirs(folder_path, exist_ok=True) return folder_path - def setup_persistent_cache(self): - """Create econpizza cache folder. If caching is enabled, sets up the cache.""" - if not os.path.exists(self.econpizza_cache_folder): - folder_path_pizza = self._create_cache_dir(self.econpizza_cache_folder) - self.econpizza_cache_folder = folder_path_pizza - else: - folder_path_pizza = self.econpizza_cache_folder - def setup_persistent_cache_jax(self): """Setup JAX persistent cache if enabled.""" if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder): diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py index d245b16..0f9833b 100644 --- a/econpizza/testing/test_config.py +++ b/econpizza/testing/test_config.py @@ -29,9 +29,7 @@ def os_getcwd_create(): shutil.rmtree(test_cache_folder) def test_config_default_values(): - assert ep.config["enable_persistent_cache"] == False assert ep.config["enable_jax_persistent_cache"] == False - assert ep.config.econpizza_cache_folder == "__econpizza_cache__" assert ep.config.jax_cache_folder == "__jax_cache__" def test_config_jax_default_values(): @@ -39,13 +37,6 @@ def test_config_jax_default_values(): assert jax.config.values["jax_persistent_cache_min_entry_size_bytes"] == .0 assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0 -@patch("os.makedirs") -@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") -def test_config_enable_persistent_cache(mock_makedirs): - ep.config["enable_persistent_cache"] = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) - - @patch("os.makedirs") @patch("jax.config.update") def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): @@ -56,14 +47,6 @@ def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): mock_jax_update.assert_any_call("jax_persistent_cache_min_entry_size_bytes", -1) mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0) -@patch("os.makedirs") -@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") -def test_config_set_econpizza_folder(mock_makedirs): - ep.config.econpizza_cache_folder = "test1" - ep.config["enable_persistent_cache"] = True - - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True) - @patch("os.makedirs") @patch("jax.config.update") def test_config_set_jax_folder(mock_jax_update, mock_makedirs): @@ -78,22 +61,6 @@ def test_config_jax_folder_set_from_outside(mock_jax_update): ep.config["enable_jax_persistent_cache"] = True mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside") -@patch("os.path.exists") -@patch("os.makedirs") -@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") -def test_econpizza_cache_folder_not_created_second_time(mock_makedirs, mock_exists): - # Set to first return False when the folder is not created, then True when the folder is created - mock_exists.side_effect = [False, True] - - # When called for the first time, a cache folder should be created(default is __econpizza_cache__) - ep.config["enable_persistent_cache"] = True - mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True) - # Now reset the mock so that the calls are 0 again. - mock_makedirs.reset_mock() - # The second time we should not create a folder - ep.config["enable_persistent_cache"] = True - mock_makedirs.assert_not_called() - @patch("os.path.exists") @patch("os.makedirs") @patch("jax.config.update") @@ -113,14 +80,6 @@ def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs mock_makedirs.assert_not_called() assert mock_jax_update.call_count == 0 -@pytest.mark.skip(reason="Skipping until enable_persistent_cache gets exposed for end users") -def test_config_enable_persistent_cache_called_after_model_load(): - _ = ep.load(ep.examples.dsge) - - assert os.path.exists(ep.config.econpizza_cache_folder) == False - ep.config["enable_persistent_cache"] = True - assert os.path.exists(ep.config.econpizza_cache_folder) == True - def test_config_enable_jax_persistent_cache_called_after_model_load(): _ = ep.load(ep.examples.dsge) From 1dd31cf76d1647f1cc1dd4a66a183b0b2b405033 Mon Sep 17 00:00:00 2001 From: Krasimira Kirilova Date: Thu, 10 Oct 2024 10:54:51 +0000 Subject: [PATCH 5/5] Remove call to "jax_persistent_cache_min_entry_size_bytes" setting as the default value of 0 is the behavior we want. --- econpizza/config.py | 1 - econpizza/testing/test_config.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/econpizza/config.py b/econpizza/config.py index 287e785..432ea58 100644 --- a/econpizza/config.py +++ b/econpizza/config.py @@ -35,7 +35,6 @@ def setup_persistent_cache_jax(self): if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder): folder_path_jax = self._create_cache_dir(self.jax_cache_folder) jax.config.update("jax_compilation_cache_dir", folder_path_jax) - jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) self.jax_cache_folder = folder_path_jax diff --git a/econpizza/testing/test_config.py b/econpizza/testing/test_config.py index 0f9833b..d53496b 100644 --- a/econpizza/testing/test_config.py +++ b/econpizza/testing/test_config.py @@ -44,7 +44,6 @@ def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs): mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__")) - mock_jax_update.assert_any_call("jax_persistent_cache_min_entry_size_bytes", -1) mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0) @patch("os.makedirs") @@ -71,7 +70,7 @@ def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs # When called for the first time, a cache folder should be created(default is __jax_cache__) ep.config["enable_jax_persistent_cache"] = True mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True) - assert mock_jax_update.call_count == 3 + assert mock_jax_update.call_count == 2 # Now reset the mock so that the calls are 0 again. mock_makedirs.reset_mock() mock_jax_update.reset_mock()