From 25aa9a70fe5aca93327cae25d7ff752f6d32dbfd Mon Sep 17 00:00:00 2001 From: Antonio Stanziola Date: Tue, 17 Sep 2024 11:19:38 +0100 Subject: [PATCH 1/3] fix: bug in get_implemented in new plum --- jaxdf/util.py | 122 +++++++++++++++++++++++++------------------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/jaxdf/util.py b/jaxdf/util.py index 2d05d66..e7e9697 100644 --- a/jaxdf/util.py +++ b/jaxdf/util.py @@ -1,61 +1,61 @@ -import warnings - -from jax.numpy import expand_dims, ndarray - - -def append_dimension(x: ndarray): - return expand_dims(x, -1) - - -def update_dictionary(old: dict, new_entries: dict): - r"""Update a dictionary with new entries. - - Args: - old (dict): The dictionary to update - new_entries (dict): The new entries to add to the dictionary - - Returns: - dict: The updated dictionary - """ - for key, val in zip(new_entries.keys(), new_entries.values()): - old[key] = val - return old - - -def _get_implemented(f): - warnings.warn( - "jaxdf.util._get_implemented is deprecated. Use jaxdf.util.get_implemented instead.", - DeprecationWarning, - ) - return get_implemented(f) - - -def get_implemented(f): - r"""Prints the implemented methods of an operator - - Arguments: - f (Callable): The operator to get the implemented methods of. - - Returns: - None - - """ - - # TODO: Why there are more instances for the same types? - - print(f.__name__ + ":") - instances = [] - a = f.methods - for f_instance in a: - # Get types - types = f_instance.types - - # Change each type with its classname - types = tuple(map(lambda x: x.__name__, types)) - - # Append - instances.append(str(types)) - - instances = set(instances) - for instance in instances: - print(" ─ " + instance) +import warnings + +from jax.numpy import expand_dims, ndarray + + +def append_dimension(x: ndarray): + return expand_dims(x, -1) + + +def update_dictionary(old: dict, new_entries: dict): + r"""Update a dictionary with new entries. + + Args: + old (dict): The dictionary to update + new_entries (dict): The new entries to add to the dictionary + + Returns: + dict: The updated dictionary + """ + for key, val in zip(new_entries.keys(), new_entries.values()): + old[key] = val + return old + + +def _get_implemented(f): + warnings.warn( + "jaxdf.util._get_implemented is deprecated. Use jaxdf.util.get_implemented instead.", + DeprecationWarning, + ) + return get_implemented(f) + + +def get_implemented(f): + r"""Prints the implemented methods of an operator + + Arguments: + f (Callable): The operator to get the implemented methods of. + + Returns: + None + + """ + + # TODO: Why there are more instances for the same types? + + print(f.__name__ + ":") + instances = [] + a = f.methods + for f_instance in a: + # Get types + types = f_instance.signature.types + + # Change each type with its classname + types = tuple(map(lambda x: x.__name__, types)) + + # Append + instances.append(str(types)) + + instances = set(instances) + for instance in instances: + print(" ─ " + instance) From 2a907cc03c6c5175b1f0587ebb11910b3dc521ca Mon Sep 17 00:00:00 2001 From: Antonio Stanziola Date: Tue, 17 Sep 2024 11:24:25 +0100 Subject: [PATCH 2/3] updated plum dependency --- CHANGELOG.md | 3 +- pyproject.toml | 240 ++++++++++++++++++++++++------------------------- 2 files changed, 122 insertions(+), 121 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d80607..0b9158a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed +- Fixed `util.get_implemented` bug that was happening with the new version of `plum` ## [0.2.7] - 2023-11-24 ### Changed @@ -66,4 +68,3 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [0.2.7]: https://github.com/ucl-bug/jaxdf/compare/0.2.6...0.2.7 [0.2.6]: https://github.com/ucl-bug/jaxdf/compare/0.2.5...0.2.6 [0.2.5]: https://github.com/ucl-bug/jaxdf/tree/0.2.5 - diff --git a/pyproject.toml b/pyproject.toml index fd26eb8..42344aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,120 +1,120 @@ -[tool.poetry] -name = "jaxdf" -version = "0.2.7" -description = "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations" -authors = [ - "Antonio Stanziola ", - "Simon Arridge", - "Ben T. Cox", - "Bradley E. Treeby", -] -readme = "README.md" -keywords = [ - "jax", - "pde", - "discretization", - "differential equations", - "simulation", - "differentiable programming", -] -license = "LGPL-3.0-only" -classifiers=[ - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Environment :: GPU", - "Environment :: GPU :: NVIDIA CUDA", - "Environment :: GPU :: NVIDIA CUDA :: 11.6", - "Environment :: GPU :: NVIDIA CUDA :: 11.7", - "Environment :: GPU :: NVIDIA CUDA :: 11.8", - "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0", - "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Physics", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", -] - -packages = [ - { include="jaxdf", from="." } -] - -[tool.poetry.urls] -"Homepage" = "https://ucl-bug.github.io/jaxdf" -"Repository" = "https://github.com/ucl-bug/jaxdf" -"Bug Tracker" = "https://github.com/ucl-bug/jaxdf/issues" -"Support" = "https://discord.gg/VtUb4fFznt" - -[tool.poetry.dependencies] -python = "^3.9" -plum-dispatch = "^2.2.2" -jax = "^0.4.20" -equinox = "^0.11.2" - -[tool.poetry.group.dev.dependencies] -coverage = "^7.3.2" -mypy = "^1.4.0" -pre-commit = "^3.3.3" -mkdocs-material-extensions = "^1.3.1" -mkdocs-material = "^9.4.12" -mkdocs-jupyter = "^0.24.6" -mkdocs-autorefs = "^0.5.0" -mkdocs-mermaid2-plugin = "^0.6.0" -mkdocstrings-python = "^1.7.5" -isort = "^5.12.0" -pycln = "^2.4.0" -python-kacl = "^0.4.6" -mkdocs-macros-plugin = "^1.0.5" -pymdown-extensions = "^10.4" -pytest = "^7.4.0" -plumkdocs = "^0.0.5" -jupyterlab = "^4.0.9" - -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" - -[tools.isort] -src_paths = ["jaxdf", "tests"] -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true - -[tool.pycln] -all = true - -[tool.mypy] -disallow_any_unimported = true -disallow_untyped_defs = true -no_implicit_optional = true -strict_equality = true -warn_unused_ignores = true -warn_redundant_casts = true -warn_return_any = true -check_untyped_defs = true -show_error_codes = true -ignore_missing_imports = true -allow_redefinition = true -exclude = ['jaxdf/operators/'] - -[tool.yapf] -based_on_style = "pep8" -spaces_before_comment = 4 -split_before_logical_operator = true -indent_width = 2 - -[tool.pytest.ini_options] -addopts = """\ - --doctest-modules \ -""" - -[tool.coverage.report] -exclude_lines = [ - 'if TYPE_CHECKING:', - 'pragma: no cover' -] +[tool.poetry] +name = "jaxdf" +version = "0.2.7" +description = "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations" +authors = [ + "Antonio Stanziola ", + "Simon Arridge", + "Ben T. Cox", + "Bradley E. Treeby", +] +readme = "README.md" +keywords = [ + "jax", + "pde", + "discretization", + "differential equations", + "simulation", + "differentiable programming", +] +license = "LGPL-3.0-only" +classifiers=[ + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Environment :: GPU", + "Environment :: GPU :: NVIDIA CUDA", + "Environment :: GPU :: NVIDIA CUDA :: 11.6", + "Environment :: GPU :: NVIDIA CUDA :: 11.7", + "Environment :: GPU :: NVIDIA CUDA :: 11.8", + "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0", + "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Physics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +packages = [ + { include="jaxdf", from="." } +] + +[tool.poetry.urls] +"Homepage" = "https://ucl-bug.github.io/jaxdf" +"Repository" = "https://github.com/ucl-bug/jaxdf" +"Bug Tracker" = "https://github.com/ucl-bug/jaxdf/issues" +"Support" = "https://discord.gg/VtUb4fFznt" + +[tool.poetry.dependencies] +python = "^3.9" +plum-dispatch = "^2.5.2" +jax = "^0.4.20" +equinox = "^0.11.2" + +[tool.poetry.group.dev.dependencies] +coverage = "^7.3.2" +mypy = "^1.4.0" +pre-commit = "^3.3.3" +mkdocs-material-extensions = "^1.3.1" +mkdocs-material = "^9.4.12" +mkdocs-jupyter = "^0.24.6" +mkdocs-autorefs = "^0.5.0" +mkdocs-mermaid2-plugin = "^0.6.0" +mkdocstrings-python = "^1.7.5" +isort = "^5.12.0" +pycln = "^2.4.0" +python-kacl = "^0.4.6" +mkdocs-macros-plugin = "^1.0.5" +pymdown-extensions = "^10.4" +pytest = "^7.4.0" +plumkdocs = "^0.0.5" +jupyterlab = "^4.0.9" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tools.isort] +src_paths = ["jaxdf", "tests"] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.pycln] +all = true + +[tool.mypy] +disallow_any_unimported = true +disallow_untyped_defs = true +no_implicit_optional = true +strict_equality = true +warn_unused_ignores = true +warn_redundant_casts = true +warn_return_any = true +check_untyped_defs = true +show_error_codes = true +ignore_missing_imports = true +allow_redefinition = true +exclude = ['jaxdf/operators/'] + +[tool.yapf] +based_on_style = "pep8" +spaces_before_comment = 4 +split_before_logical_operator = true +indent_width = 2 + +[tool.pytest.ini_options] +addopts = """\ + --doctest-modules \ +""" + +[tool.coverage.report] +exclude_lines = [ + 'if TYPE_CHECKING:', + 'pragma: no cover' +] From 9c1629ffd2b1da93e2ab7e85b81498486801a168 Mon Sep 17 00:00:00 2001 From: Antonio Stanziola Date: Tue, 17 Sep 2024 11:30:50 +0100 Subject: [PATCH 3/3] removed util._get_implemented --- CHANGELOG.md | 3 +++ jaxdf/util.py | 10 ---------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b9158a..f89ecd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Fixed - Fixed `util.get_implemented` bug that was happening with the new version of `plum` +### Removed +- Removed the deprecated `util._get_implemented` function + ## [0.2.7] - 2023-11-24 ### Changed - The Quickstart tutorial has been upgdated. diff --git a/jaxdf/util.py b/jaxdf/util.py index e7e9697..57d315d 100644 --- a/jaxdf/util.py +++ b/jaxdf/util.py @@ -1,5 +1,3 @@ -import warnings - from jax.numpy import expand_dims, ndarray @@ -22,14 +20,6 @@ def update_dictionary(old: dict, new_entries: dict): return old -def _get_implemented(f): - warnings.warn( - "jaxdf.util._get_implemented is deprecated. Use jaxdf.util.get_implemented instead.", - DeprecationWarning, - ) - return get_implemented(f) - - def get_implemented(f): r"""Prints the implemented methods of an operator