Skip to content

Commit

Permalink
Merge pull request #5 from shtopane/kk/detach-config-jax-econpizza
Browse files Browse the repository at this point in the history
Update config to split `jax` and `econpizza` cache
  • Loading branch information
shtopane authored Oct 10, 2024
2 parents a276464 + 1dd31cf commit 60130ee
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 101 deletions.
77 changes: 19 additions & 58 deletions econpizza/config.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,41 @@
"""Configuration object"""

import os
import jax

class EconPizzaConfig(dict):
def __init__(self, *args, **kwargs):
super(EconPizzaConfig, self).__init__(*args, **kwargs)
self._enable_persistent_cache = False
self._econpizza_cache_folder = "__econpizza_cache__"
self._jax_cache_folder = "__jax_cache__"

@property
def enable_persistent_cache(self):
return self._enable_persistent_cache
self.__dict__ = self
self.enable_jax_persistent_cache = False
self.jax_cache_folder = "__jax_cache__"

@enable_persistent_cache.setter
def enable_persistent_cache(self, value):
self._enable_persistent_cache = value
self.setup_persistent_cache()
self._setup_persistent_cache_map = {
"enable_jax_persistent_cache": self.setup_persistent_cache_jax
}

@property
def jax_cache_folder(self):
return self._jax_cache_folder

@jax_cache_folder.setter
def jax_cache_folder(self, value):
self._jax_cache_folder = value
def __setitem__(self, key, value):
return self.update(key, value)

@property
def econpizza_cache_folder(self):
return self._econpizza_cache_folder

@econpizza_cache_folder.setter
def econpizza_cache_folder(self, value):
self._econpizza_cache_folder = value

def update(self, key, value):
"""Updates the attribute, and if it's related to caching, calls the appropriate setup function."""
if hasattr(self, key):
setattr(self, key, value)
if key in self._setup_persistent_cache_map and value:
self._setup_persistent_cache_map[key]()
else:
raise AttributeError(f"'EconPizzaConfig' object has no attribute '{key}'")

def _create_cache_dir(self, folder_name: str):
cwd = os.getcwd()
folder_path = os.path.join(cwd, folder_name)
os.makedirs(folder_path, exist_ok=True)

return folder_path

def setup_persistent_cache(self):
"""Create folders for JAX and EconPizza cache.
By default, they are created in callee working directory.
"""
if self.enable_persistent_cache == True:
if not os.path.exists(self.econpizza_cache_folder):
folder_path_pizza = self._create_cache_dir(self.econpizza_cache_folder)
self.econpizza_cache_folder = folder_path_pizza
else:
folder_path_pizza = self.econpizza_cache_folder

# Jax cache is enabled by the used via JAX API. In this case we should not set another folder
if jax.config.jax_compilation_cache_dir is None:
folder_path_jax = self._create_cache_dir(self.jax_cache_folder)
jax.config.update("jax_compilation_cache_dir", folder_path_jax)
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
self.jax_cache_folder = folder_path_jax

def __repr__(self):
properties = {
k.lstrip("_"): v for k, v in self.__dict__.items() if k.startswith("_")
}
return f"{properties}"

def __str__(self):
return self.__repr__()

def setup_persistent_cache_jax(self):
"""Setup JAX persistent cache if enabled."""
if jax.config.jax_compilation_cache_dir is None and not os.path.exists(self.jax_cache_folder):
folder_path_jax = self._create_cache_dir(self.jax_cache_folder)
jax.config.update("jax_compilation_cache_dir", folder_path_jax)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
self.jax_cache_folder = folder_path_jax

config = EconPizzaConfig()
76 changes: 33 additions & 43 deletions econpizza/testing/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Tests for the config module. Delete any __econpizza__ or __jax_cache__ folders you might have in the current folder before running"""
import pytest
import jax
from unittest.mock import patch
Expand All @@ -16,20 +17,19 @@ def ep_config_reset():

@pytest.fixture(scope="function", autouse=True)
def os_getcwd_create():
folder_path = "./config_working_dir"
test_cache_folder = os.path.abspath("config_working_dir")

if not os.path.exists(folder_path):
os.makedirs(folder_path)
if not os.path.exists(test_cache_folder):
os.makedirs(test_cache_folder)

with patch("os.getcwd", return_value="./config_working_dir"):
with patch("os.getcwd", return_value=test_cache_folder):
yield

if os.path.exists(folder_path):
shutil.rmtree(folder_path)
if os.path.exists(test_cache_folder):
shutil.rmtree(test_cache_folder)

def test_config_default_values():
assert ep.config.enable_persistent_cache == False
assert ep.config.econpizza_cache_folder == "__econpizza_cache__"
assert ep.config["enable_jax_persistent_cache"] == False
assert ep.config.jax_cache_folder == "__jax_cache__"

def test_config_jax_default_values():
Expand All @@ -39,59 +39,49 @@ def test_config_jax_default_values():

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_enable_persistent_cache(mock_jax_update, mock_makedirs):
ep.config.enable_persistent_cache = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True)
def test_config_enable_jax_persistent_cache(mock_jax_update, mock_makedirs):
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__"))
mock_jax_update.assert_any_call("jax_persistent_cache_min_entry_size_bytes", -1)
mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0)

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_set_econpizza_folder(mock_jax_update, mock_makedirs):
ep.config.econpizza_cache_folder = "test1"
ep.config.enable_persistent_cache = True

mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True)
mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__"))

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_set_jax_folder(mock_jax_update, mock_makedirs):
ep.config.jax_cache_folder = "test1"
ep.config.enable_persistent_cache = True
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True)
mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1"))

@patch("jax.config.update")
def test_config_jax_folder_set_from_outside(mock_jax_update):
mock_jax_update("jax_compilation_cache_dir", "jax_from_outside")
ep.config.enable_persistent_cache = True
ep.config["enable_jax_persistent_cache"] = True
mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside")

@patch("os.path.exists")
@patch("os.makedirs")
@patch("jax.config.update")
def test_econpizza_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs):
ep.config.enable_persistent_cache = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True)
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

ep.config.enable_persistent_cache = True
# only jax config is updated
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True)

mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

@patch("jax.config.update")
def test_config_enable_persistent_cache_called_after_model_load(mock_jax_update):
def test_jax_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs, mock_exists):
# Set to first return False when the folder is not created, then True when the folder is created
mock_exists.side_effect = [False, True]

# When called for the first time, a cache folder should be created(default is __jax_cache__)
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)
assert mock_jax_update.call_count == 2
# Now reset the mock so that the calls are 0 again.
mock_makedirs.reset_mock()
mock_jax_update.reset_mock()
# The second time we should not create a folder
ep.config["enable_jax_persistent_cache"] = True
mock_makedirs.assert_not_called()
assert mock_jax_update.call_count == 0

def test_config_enable_jax_persistent_cache_called_after_model_load():
_ = ep.load(ep.examples.dsge)

assert os.path.exists(ep.config.econpizza_cache_folder) == False
ep.config.enable_persistent_cache = True
assert os.path.exists(ep.config.econpizza_cache_folder) == True




assert os.path.exists(ep.config.jax_cache_folder) == False
ep.config["enable_jax_persistent_cache"] = True
assert os.path.exists(ep.config.jax_cache_folder) == True

0 comments on commit 60130ee

Please sign in to comment.