From 548c80d72b281d4477ec56bd1e0804f0501a4add Mon Sep 17 00:00:00 2001 From: ed cuss Date: Thu, 25 Sep 2025 20:01:26 +0100 Subject: [PATCH 1/4] feat: use hypothesis for the optimisation --- pyproject.toml | 1 + tests/domain/test_optimisation.py | 82 ++++++++++++++++++++++++++++++- uv.lock | 24 +++++++++ 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2e73e85..e8a26c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ build-backend = "uv_build" [dependency-groups] dev = [ + "hypothesis>=6.140.2", "ipykernel>=6.29.5", "pre-commit>=4.1.0", "pytest>=8.3.4", diff --git a/tests/domain/test_optimisation.py b/tests/domain/test_optimisation.py index 8c4aa43..597d296 100644 --- a/tests/domain/test_optimisation.py +++ b/tests/domain/test_optimisation.py @@ -1,7 +1,18 @@ +import string + +import hypothesis.extra.numpy as hnp +import numpy as np import pytest +from hypothesis import given +from hypothesis import strategies as st from spaghettree.domain.adj_mat import AdjMat -from spaghettree.domain.optimisation import SuggestedMerge, get_top_suggested_merges +from spaghettree.domain.optimisation import ( + SuggestedMerge, + get_dwm, + get_top_suggested_merges, + optimise_communities, +) @pytest.mark.parametrize( @@ -82,3 +93,72 @@ def test_get_top_suggested_merges(call_tree, expected_result): res = get_top_suggested_merges(adj_mat) assert res.is_ok() assert res.unwrap() == expected_result + + +@st.composite +def st_adj_mat_and_comms(draw, max_n: int = 20, max_val: int = 20) -> tuple[np.ndarray, list[int]]: + n = draw(st.integers(min_value=1, max_value=max_n)) + + adj_mat = draw( + hnp.arrays( + dtype=np.int64, + shape=(n, n), + elements=st.integers(min_value=0, max_value=max_val), + ) + ) + + comms = draw( + st.lists( + st.integers(min_value=1, max_value=max_val), + min_size=n, + max_size=n, + ) + ) + + return adj_mat, comms + + +@given(st_adj_mat_and_comms()) +def test_get_dwm_is_within_bounds(data): + mat, comms = data + dwm = get_dwm(mat, comms) + assert -0.5 <= dwm <= 1.0 + + +def st_ent_path( + root: str = "package", + min_modules: int = 1, + max_modules: int = 3, + id_min_size: int = 3, + id_max_size: int = 3, +): + alphabet = string.ascii_lowercase + identifier = st.text(alphabet=alphabet, min_size=id_min_size, max_size=id_max_size) + + return st.builds( + lambda modules, leaf: ".".join([root, *modules, leaf]), + st.lists(identifier, min_size=min_modules, max_size=max_modules), + identifier, + ) + + +def st_call_tree(keys_count: int = 10): + keys = st.lists(st_ent_path(), min_size=2, max_size=keys_count, unique=True) + return keys.flatmap( + lambda k: st.fixed_dictionaries( + {key: st.lists(st.sampled_from(k), min_size=0, max_size=5) for key in k} + ) + ) + + +@given(st_call_tree()) +def test_does_not_produce_worse_dwm(tree): + adj_mat = AdjMat.from_call_tree(tree).unwrap() + starting_dwm = get_dwm(adj_mat.mat, adj_mat.communities) + res = optimise_communities(adj_mat) + + assert res.is_ok() + + res_adj_mat = res.unwrap() + final_dwm = get_dwm(res_adj_mat.mat, res_adj_mat.communities) + assert starting_dwm <= final_dwm diff --git a/uv.lock b/uv.lock index 7475db3..e01b801 100644 --- a/uv.lock +++ b/uv.lock @@ -1087,6 +1087,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "hypothesis" +version = "6.140.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/4a/3c340178b986b44b4f71ddb04625c8fb8bf815e7c7e23a6aabb2ce17e849/hypothesis-6.140.2.tar.gz", hash = "sha256:b3b4a162134eeef8a992621de6c43d80e03d44704a3c3bfb5b9d0661b375b0d2", size = 466699, upload-time = "2025-09-23T00:07:21.087Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/7d/7dd3684f9cb707b6b1e808c7f23dd0fa4a96fe106b6accd9b757c9985c50/hypothesis-6.140.2-py3-none-any.whl", hash = "sha256:4524cb84be90961563ef15634e2efe96150bbcce47621a13cff3c1b03a326663", size = 534388, upload-time = "2025-09-23T00:07:16.555Z" }, +] + [[package]] name = "icecream" version = "2.1.4" @@ -2996,6 +3009,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "spaghettree" version = "0.2.1" @@ -3013,6 +3035,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "hypothesis" }, { name = "ipykernel" }, { name = "pre-commit" }, { name = "pytest" }, @@ -3068,6 +3091,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "hypothesis", specifier = ">=6.140.2" }, { name = "ipykernel", specifier = ">=6.29.5" }, { name = "pre-commit", specifier = ">=4.1.0" }, { name = "pytest", specifier = ">=8.3.4" }, From 189a2c0da59ad9f6b5e9b7e01e1f7dfbac4bbe75 Mon Sep 17 00:00:00 2001 From: ed cuss Date: Thu, 25 Sep 2025 20:05:26 +0100 Subject: [PATCH 2/4] feat: use hypothesis for the optimisation --- tests/domain/test_optimisation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/domain/test_optimisation.py b/tests/domain/test_optimisation.py index 597d296..84cc236 100644 --- a/tests/domain/test_optimisation.py +++ b/tests/domain/test_optimisation.py @@ -142,7 +142,7 @@ def st_ent_path( ) -def st_call_tree(keys_count: int = 10): +def st_call_tree(keys_count: int = 20): keys = st.lists(st_ent_path(), min_size=2, max_size=keys_count, unique=True) return keys.flatmap( lambda k: st.fixed_dictionaries( From 1685a1a53be57c06ffe0ce63b22d60feeae06f66 Mon Sep 17 00:00:00 2001 From: ed cuss Date: Thu, 25 Sep 2025 20:20:08 +0100 Subject: [PATCH 3/4] test: catch unwrap on err result + add logging --- src/spaghettree/adapters/io_wrapper.py | 1 + src/spaghettree/domain/processing.py | 5 +---- tests/test_result.py | 5 +++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spaghettree/adapters/io_wrapper.py b/src/spaghettree/adapters/io_wrapper.py index 1204181..8800a00 100644 --- a/src/spaghettree/adapters/io_wrapper.py +++ b/src/spaghettree/adapters/io_wrapper.py @@ -90,6 +90,7 @@ def write_files( print(yellow(f"File written to `{filepath}`")) # noqa: T201 results[filepath] = res.inner else: + logger.debug(f"failed to write {filepath = } {res.err_msg = }") fails[filepath] = res if ruff_root: diff --git a/src/spaghettree/domain/processing.py b/src/spaghettree/domain/processing.py index bcd480f..3dd646a 100644 --- a/src/spaghettree/domain/processing.py +++ b/src/spaghettree/domain/processing.py @@ -106,11 +106,8 @@ def rename_mod_name(name: str, renamed_modules: list[str]) -> str: dirname_counts = Counter(dirnames) logger.debug(f"{dirname_counts = }") - top_level_init = 2 - if len(name_parts) == top_level_init and basename == "__init__": - pass - elif (basename in ("__all__", "logger") and dirname.endswith(".__init__")) or ( + if (basename in ("__all__", "logger") and dirname.endswith(".__init__")) or ( dirname not in renamed_modules and dirname_counts.get(dirname, 0) <= 1 ): name = dirname diff --git a/tests/test_result.py b/tests/test_result.py index 11987fd..7838ddd 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -1,3 +1,5 @@ +import pytest + from spaghettree import safe @@ -11,3 +13,6 @@ def raises(): assert list(res.details[0].keys()) == ["file", "func", "line_no", "locals"] # make sure Err.and_then => Err assert res.and_then(lambda x: x) == res + + with pytest.raises(ValueError): + res.unwrap() From 25b670cf3f25c6966f4fd2b02402965b30935b18 Mon Sep 17 00:00:00 2001 From: ed cuss Date: Thu, 25 Sep 2025 21:13:46 +0100 Subject: [PATCH 4/4] refactor: fold from call tree methods --- README.md | 1 - src/spaghettree/__main__.py | 6 +-- src/spaghettree/adapters/io_wrapper.py | 14 +++---- src/spaghettree/domain/adj_mat.py | 54 ------------------------- src/spaghettree/domain/optimisation.py | 56 ++++++++++++++++++++++++-- src/spaghettree/domain/parsing.py | 2 +- src/spaghettree/domain/processing.py | 4 +- tests/domain/test_optimisation.py | 6 +-- tests/domain/test_processing.py | 4 +- 9 files changed, 71 insertions(+), 76 deletions(-) delete mode 100644 src/spaghettree/domain/adj_mat.py diff --git a/README.md b/README.md index 48c5360..4193b8d 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,6 @@ class SomeOtherClass: │ │ └── io_wrapper.py │ ├── domain │ │ ├── __init__.py -│ │ ├── adj_mat.py │ │ ├── entities.py │ │ ├── optimisation.py │ │ ├── parsing.py diff --git a/src/spaghettree/__main__.py b/src/spaghettree/__main__.py index f1a0cea..378ffff 100644 --- a/src/spaghettree/__main__.py +++ b/src/spaghettree/__main__.py @@ -4,8 +4,8 @@ from spaghettree import Result from spaghettree.adapters.io_wrapper import IOProtocol, IOWrapper -from spaghettree.domain.adj_mat import AdjMat from spaghettree.domain.optimisation import ( + AdjMat, get_dwm, get_top_suggested_merges, yellow, @@ -66,7 +66,7 @@ def run_process( else: # remove any new_root so that it doesn't try to use ruff on the json new_root = "" - adj_mat = AdjMat.from_call_tree_no_optimisation(call_tree).unwrap() + adj_mat = AdjMat.from_call_tree(call_tree, optimise=optimise_src_code).unwrap() print( # noqa: T201 yellow( f"Current Directed Weighted Modularity (DWM): {get_dwm(adj_mat.mat, adj_mat.communities): .5f}" @@ -79,7 +79,7 @@ def run_process( res = {Path(call_tree_save_path).absolute(): json.dumps(call_tree, indent=4)} - return io.write_files(res, ruff_root=new_root, format_code=optimise_src_code) + return io.write_files(res, ruff_root=new_root, format_bulk=optimise_src_code) if __name__ == "__main__": diff --git a/src/spaghettree/adapters/io_wrapper.py b/src/spaghettree/adapters/io_wrapper.py index 8800a00..d906367 100644 --- a/src/spaghettree/adapters/io_wrapper.py +++ b/src/spaghettree/adapters/io_wrapper.py @@ -31,7 +31,7 @@ def read_files(self, root: str | Path) -> Result: ... def write(self, modified_code: str, filepath: str, *, format_code: bool = True) -> None: ... def write_files( - self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True + self, src_code: dict[str, str], ruff_root: str | None = None, *, format_bulk: bool = True ) -> Result: ... @@ -74,12 +74,12 @@ def write(self, modified_code: str, filepath: str, *, format_code: bool = True) self._run_ruff(filepath) def write_files( - self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True + self, src_code: dict[str, str], ruff_root: str | None = None, *, format_bulk: bool = True ) -> Result: results, fails = {}, {} for filepath, modified_code in src_code.items(): - if not ruff_root or not format_code: + if not ruff_root or format_bulk: # format all at the end instead res = self.write(modified_code, filepath, format_code=False) else: @@ -90,10 +90,10 @@ def write_files( print(yellow(f"File written to `{filepath}`")) # noqa: T201 results[filepath] = res.inner else: - logger.debug(f"failed to write {filepath = } {res.err_msg = }") + logger.error(yellow(f"failed to write {filepath = } {res.err_msg = }")) fails[filepath] = res - if ruff_root: + if ruff_root and format_bulk: self._run_ruff(ruff_root) if fails: return Err(fails) @@ -147,13 +147,13 @@ def write(self, modified_code: str, filepath: str, *, format_code: bool = True) self.files[filepath] = format_code_str(modified_code) if format_code else modified_code def write_files( - self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True + self, src_code: dict[str, str], ruff_root: str | None = None, *, format_bulk: bool = True ) -> Result: results, fails = {}, {} for filepath, modified_code in src_code.items(): if ruff_root is not None: - res = self.write(modified_code, filepath, format_code=format_code) + res = self.write(modified_code, filepath, format_code=format_bulk) if res.is_ok(): results[filepath] = res.inner diff --git a/src/spaghettree/domain/adj_mat.py b/src/spaghettree/domain/adj_mat.py deleted file mode 100644 index 4c36703..0000000 --- a/src/spaghettree/domain/adj_mat.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Self - -import attrs -import numpy as np - -from spaghettree import safe -from spaghettree.logger import logger - - -@attrs.define -class AdjMat: - mat: np.ndarray = attrs.field() - node_map: dict[int, str] = attrs.field() - communities: list[int] = attrs.field() - comm_map: dict[int, str] = attrs.field(factory=dict) - - @classmethod - @safe - def from_call_tree(cls, call_tree: dict[str, list[str]]) -> Self: - logger.debug(f"{call_tree = }") - ent_idx: dict[str, int] = {node: i for i, node in enumerate(call_tree)} - node_map: dict[int, str] = {idx: ent_name for ent_name, idx in ent_idx.items()} - - adj_mat = AdjMat._create_adj_map(call_tree, ent_idx) - - return cls(adj_mat, node_map, list(node_map.keys())) - - @classmethod - @safe - def from_call_tree_no_optimisation(cls, call_tree: dict[str, list[str]]) -> Self: - logger.debug(f"{call_tree = }") - ent_idx: dict[str, int] = {node: i for i, node in enumerate(call_tree)} - node_map: dict[int, str] = {idx: ent_name for ent_name, idx in ent_idx.items()} - - modules = [".".join(k.split(".")[:-1]) for k in call_tree] - unique_mods = list(dict.fromkeys(modules)) - mod_map = {name: idx for idx, name in enumerate(unique_mods)} - communities = [mod_map[name] for name in modules] - - adj_mat = AdjMat._create_adj_map(call_tree, ent_idx) - - return cls(adj_mat, node_map, communities, comm_map={v: k for k, v in mod_map.items()}) - - @staticmethod - def _create_adj_map(call_tree: dict[str, list[str]], ent_idx: dict[str, int]) -> np.ndarray: - n = len(ent_idx) - adj_mat = np.zeros((n, n), dtype=int) - - for caller, called in call_tree.items(): - for call in called: - src_idx = ent_idx[caller] - dst_idx = ent_idx[call] - adj_mat[src_idx, dst_idx] += 1 - return adj_mat diff --git a/src/spaghettree/domain/optimisation.py b/src/spaghettree/domain/optimisation.py index 847202e..9986061 100644 --- a/src/spaghettree/domain/optimisation.py +++ b/src/spaghettree/domain/optimisation.py @@ -1,26 +1,76 @@ from collections import defaultdict from copy import deepcopy +from typing import Self import attrs import numpy as np from spaghettree import safe -from spaghettree.domain.adj_mat import AdjMat from spaghettree.logger import logger +@attrs.define +class AdjMat: + mat: np.ndarray = attrs.field() + node_map: dict[int, str] = attrs.field() + communities: list[int] = attrs.field() + comm_map: dict[int, str] = attrs.field(factory=dict) + + @classmethod + @safe + def from_call_tree(cls, call_tree: dict[str, list[str]], *, optimise: bool = True) -> Self: + logger.debug(f"{call_tree = }") + ent_idx, node_map, modules, mod_map = AdjMat._get_components(call_tree) + communities = [mod_map[name] for name in modules] + + mat = AdjMat._create_adj_map(call_tree, ent_idx) + + starting_dwm = get_dwm(mat, communities) + print(cyan("Starting DWM:"), yellow(starting_dwm)) # noqa: T201 + logger.debug(f"{starting_dwm = }") + + communities = list(node_map.keys()) if optimise else communities + + return cls(mat, node_map, communities, comm_map={v: k for k, v in mod_map.items()}) + + @staticmethod + def _create_adj_map(call_tree: dict[str, list[str]], ent_idx: dict[str, int]) -> np.ndarray: + n = len(ent_idx) + adj_mat = np.zeros((n, n), dtype=int) + + for caller, called in call_tree.items(): + for call in called: + src_idx = ent_idx[caller] + dst_idx = ent_idx[call] + adj_mat[src_idx, dst_idx] += 1 + return adj_mat + + @staticmethod + def _get_components( + call_tree: dict[str, list[str]], + ) -> tuple[dict[str, int], dict[int, str], list[str], dict[str, int]]: + ent_idx: dict[str, int] = {node: i for i, node in enumerate(call_tree)} + node_map: dict[int, str] = {idx: ent_name for ent_name, idx in ent_idx.items()} + modules: list[str] = [".".join(k.split(".")[:-1]) for k in call_tree] + unique_mods = list(dict.fromkeys(modules)) + return ent_idx, node_map, modules, {name: idx for idx, name in enumerate(unique_mods)} + + @safe def optimise_communities(adj_mat: AdjMat) -> AdjMat: valid_merges = get_merge_pairs(adj_mat) - logger.debug(f"{get_dwm(adj_mat.mat, adj_mat.communities) = }") while valid_merges: to_merge = remove_overlapping_pairs(valid_merges) adj_mat.communities = apply_merges(adj_mat.communities, to_merge) valid_merges = get_merge_pairs(adj_mat) - logger.debug(f"{get_dwm(adj_mat.mat, adj_mat.communities) = }") + opt_dwm = get_dwm(adj_mat.mat, adj_mat.communities) + + print(cyan("Optimised DWM:"), yellow(opt_dwm)) # noqa: T201 + logger.debug(f"{opt_dwm = }") logger.debug(f"{adj_mat.communities = }") + return adj_mat diff --git a/src/spaghettree/domain/parsing.py b/src/spaghettree/domain/parsing.py index 32c281d..9a796cb 100644 --- a/src/spaghettree/domain/parsing.py +++ b/src/spaghettree/domain/parsing.py @@ -9,8 +9,8 @@ from tqdm import tqdm from spaghettree import safe -from spaghettree.domain.adj_mat import AdjMat from spaghettree.domain.entities import EntityCST +from spaghettree.domain.optimisation import AdjMat from spaghettree.domain.visitors import EntityLocation, OnePassVisitor from spaghettree.logger import logger diff --git a/src/spaghettree/domain/processing.py b/src/spaghettree/domain/processing.py index 3dd646a..9b720cc 100644 --- a/src/spaghettree/domain/processing.py +++ b/src/spaghettree/domain/processing.py @@ -4,9 +4,9 @@ from functools import partial from spaghettree import Result, safe -from spaghettree.domain.adj_mat import AdjMat from spaghettree.domain.entities import EntityCST, ImportCST, ImportType from spaghettree.domain.optimisation import ( + AdjMat, merge_single_entity_communities_if_no_gain_penalty, optimise_communities, ) @@ -26,7 +26,7 @@ def optimise_entity_positions( new_root: str, ) -> Result: return ( - AdjMat.from_call_tree(call_tree) + AdjMat.from_call_tree(call_tree, optimise=True) .and_then(pair_exclusive_calls) .and_then(optimise_communities) .and_then(merge_single_entity_communities_if_no_gain_penalty) diff --git a/tests/domain/test_optimisation.py b/tests/domain/test_optimisation.py index 84cc236..b517e48 100644 --- a/tests/domain/test_optimisation.py +++ b/tests/domain/test_optimisation.py @@ -6,8 +6,8 @@ from hypothesis import given from hypothesis import strategies as st -from spaghettree.domain.adj_mat import AdjMat from spaghettree.domain.optimisation import ( + AdjMat, SuggestedMerge, get_dwm, get_top_suggested_merges, @@ -89,7 +89,7 @@ ], ) def test_get_top_suggested_merges(call_tree, expected_result): - adj_mat = AdjMat.from_call_tree_no_optimisation(call_tree).unwrap() + adj_mat = AdjMat.from_call_tree(call_tree, optimise=False).unwrap() res = get_top_suggested_merges(adj_mat) assert res.is_ok() assert res.unwrap() == expected_result @@ -153,7 +153,7 @@ def st_call_tree(keys_count: int = 20): @given(st_call_tree()) def test_does_not_produce_worse_dwm(tree): - adj_mat = AdjMat.from_call_tree(tree).unwrap() + adj_mat = AdjMat.from_call_tree(tree, optimise=True).unwrap() starting_dwm = get_dwm(adj_mat.mat, adj_mat.communities) res = optimise_communities(adj_mat) diff --git a/tests/domain/test_processing.py b/tests/domain/test_processing.py index 7f11bcf..147c0be 100644 --- a/tests/domain/test_processing.py +++ b/tests/domain/test_processing.py @@ -2,8 +2,8 @@ import pytest -from spaghettree.domain.adj_mat import AdjMat from spaghettree.domain.entities import ClassCST, FuncCST, GlobalCST, ImportCST, ImportType +from spaghettree.domain.optimisation import AdjMat from spaghettree.domain.parsing import str_to_cst from spaghettree.domain.processing import ( add_empty_inits_if_needed, @@ -125,7 +125,7 @@ ], ) def test_second_half_of_processing(call_tree, entities, location_map, src_root, expected_result): - adj_mat = AdjMat.from_call_tree(call_tree).inner + adj_mat = AdjMat.from_call_tree(call_tree, optimise=True).inner adj_mat.communities = [0, 2, 2, 4, 4, 0] res = ( create_new_module_map(adj_mat, entities=entities)