Skip to content

Commit

Permalink
Merge pull request #2 from shtopane/kk/add-cache-config
Browse files Browse the repository at this point in the history
Add configuration for cache
  • Loading branch information
gboehl authored Sep 4, 2024
2 parents d7ae0eb + cb2cb27 commit a276464
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 1 deletion.
2 changes: 1 addition & 1 deletion econpizza/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .solvers.solve_linear_state_space import solve_linear_state_space, find_path_linear_state_space
from .solvers.shooting import find_path_shooting
from .parser import parse, load

from .config import config

# set number of cores for XLA
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
Expand Down
80 changes: 80 additions & 0 deletions econpizza/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Configuration object"""

import os
import jax

class EconPizzaConfig(dict):
def __init__(self, *args, **kwargs):
super(EconPizzaConfig, self).__init__(*args, **kwargs)
self._enable_persistent_cache = False
self._econpizza_cache_folder = "__econpizza_cache__"
self._jax_cache_folder = "__jax_cache__"

@property
def enable_persistent_cache(self):
return self._enable_persistent_cache

@enable_persistent_cache.setter
def enable_persistent_cache(self, value):
self._enable_persistent_cache = value
self.setup_persistent_cache()

@property
def jax_cache_folder(self):
return self._jax_cache_folder

@jax_cache_folder.setter
def jax_cache_folder(self, value):
self._jax_cache_folder = value

@property
def econpizza_cache_folder(self):
return self._econpizza_cache_folder

@econpizza_cache_folder.setter
def econpizza_cache_folder(self, value):
self._econpizza_cache_folder = value

def update(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
else:
raise AttributeError(f"'EconPizzaConfig' object has no attribute '{key}'")

def _create_cache_dir(self, folder_name: str):
cwd = os.getcwd()
folder_path = os.path.join(cwd, folder_name)
os.makedirs(folder_path, exist_ok=True)

return folder_path

def setup_persistent_cache(self):
"""Create folders for JAX and EconPizza cache.
By default, they are created in callee working directory.
"""
if self.enable_persistent_cache == True:
if not os.path.exists(self.econpizza_cache_folder):
folder_path_pizza = self._create_cache_dir(self.econpizza_cache_folder)
self.econpizza_cache_folder = folder_path_pizza
else:
folder_path_pizza = self.econpizza_cache_folder

# Jax cache is enabled by the used via JAX API. In this case we should not set another folder
if jax.config.jax_compilation_cache_dir is None:
folder_path_jax = self._create_cache_dir(self.jax_cache_folder)
jax.config.update("jax_compilation_cache_dir", folder_path_jax)
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
self.jax_cache_folder = folder_path_jax

def __repr__(self):
properties = {
k.lstrip("_"): v for k, v in self.__dict__.items() if k.startswith("_")
}
return f"{properties}"

def __str__(self):
return self.__repr__()


config = EconPizzaConfig()
97 changes: 97 additions & 0 deletions econpizza/testing/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest
import jax
from unittest.mock import patch
import shutil
import os
import sys
# autopep8: off
sys.path.insert(0, os.path.abspath("."))
import econpizza as ep
from econpizza.config import EconPizzaConfig
# autopep8: on

@pytest.fixture(scope="function", autouse=True)
def ep_config_reset():
ep.config = EconPizzaConfig()

@pytest.fixture(scope="function", autouse=True)
def os_getcwd_create():
folder_path = "./config_working_dir"

if not os.path.exists(folder_path):
os.makedirs(folder_path)

with patch("os.getcwd", return_value="./config_working_dir"):
yield

if os.path.exists(folder_path):
shutil.rmtree(folder_path)

def test_config_default_values():
assert ep.config.enable_persistent_cache == False
assert ep.config.econpizza_cache_folder == "__econpizza_cache__"
assert ep.config.jax_cache_folder == "__jax_cache__"

def test_config_jax_default_values():
assert jax.config.values["jax_compilation_cache_dir"] is None
assert jax.config.values["jax_persistent_cache_min_entry_size_bytes"] == .0
assert jax.config.values["jax_persistent_cache_min_compile_time_secs"] == 1.0

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_enable_persistent_cache(mock_jax_update, mock_makedirs):
ep.config.enable_persistent_cache = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True)
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__"))
mock_jax_update.assert_any_call("jax_persistent_cache_min_entry_size_bytes", -1)
mock_jax_update.assert_any_call("jax_persistent_cache_min_compile_time_secs", 0)

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_set_econpizza_folder(mock_jax_update, mock_makedirs):
ep.config.econpizza_cache_folder = "test1"
ep.config.enable_persistent_cache = True

mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True)
mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "__jax_cache__"))

@patch("os.makedirs")
@patch("jax.config.update")
def test_config_set_jax_folder(mock_jax_update, mock_makedirs):
ep.config.jax_cache_folder = "test1"
ep.config.enable_persistent_cache = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "test1"), exist_ok=True)
mock_jax_update.assert_any_call("jax_compilation_cache_dir", os.path.join(os.getcwd(), "test1"))

@patch("jax.config.update")
def test_config_jax_folder_set_from_outside(mock_jax_update):
mock_jax_update("jax_compilation_cache_dir", "jax_from_outside")
ep.config.enable_persistent_cache = True
mock_jax_update.assert_any_call("jax_compilation_cache_dir", "jax_from_outside")

@patch("os.makedirs")
@patch("jax.config.update")
def test_econpizza_cache_folder_not_created_second_time(mock_jax_update, mock_makedirs):
ep.config.enable_persistent_cache = True
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True)
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

ep.config.enable_persistent_cache = True
# only jax config is updated
mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__econpizza_cache__"), exist_ok=True)

mock_makedirs.assert_any_call(os.path.join(os.getcwd(), "__jax_cache__"), exist_ok=True)

@patch("jax.config.update")
def test_config_enable_persistent_cache_called_after_model_load(mock_jax_update):
_ = ep.load(ep.examples.dsge)

assert os.path.exists(ep.config.econpizza_cache_folder) == False
ep.config.enable_persistent_cache = True
assert os.path.exists(ep.config.econpizza_cache_folder) == True




0 comments on commit a276464

Please sign in to comment.