Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class SomeOtherClass:
│ │ └── io_wrapper.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── adj_mat.py
│ │ ├── entities.py
│ │ ├── optimisation.py
│ │ ├── parsing.py
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ build-backend = "uv_build"

[dependency-groups]
dev = [
"hypothesis>=6.140.2",
"ipykernel>=6.29.5",
"pre-commit>=4.1.0",
"pytest>=8.3.4",
Expand Down
6 changes: 3 additions & 3 deletions src/spaghettree/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from spaghettree import Result
from spaghettree.adapters.io_wrapper import IOProtocol, IOWrapper
from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.optimisation import (
AdjMat,
get_dwm,
get_top_suggested_merges,
yellow,
Expand Down Expand Up @@ -66,7 +66,7 @@ def run_process(
else:
# remove any new_root so that it doesn't try to use ruff on the json
new_root = ""
adj_mat = AdjMat.from_call_tree_no_optimisation(call_tree).unwrap()
adj_mat = AdjMat.from_call_tree(call_tree, optimise=optimise_src_code).unwrap()
print( # noqa: T201
yellow(
f"Current Directed Weighted Modularity (DWM): {get_dwm(adj_mat.mat, adj_mat.communities): .5f}"
Expand All @@ -79,7 +79,7 @@ def run_process(

res = {Path(call_tree_save_path).absolute(): json.dumps(call_tree, indent=4)}

return io.write_files(res, ruff_root=new_root, format_code=optimise_src_code)
return io.write_files(res, ruff_root=new_root, format_bulk=optimise_src_code)


if __name__ == "__main__":
Expand Down
13 changes: 7 additions & 6 deletions src/spaghettree/adapters/io_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def read_files(self, root: str | Path) -> Result: ...
def write(self, modified_code: str, filepath: str, *, format_code: bool = True) -> None: ...

def write_files(
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_bulk: bool = True
) -> Result: ...


Expand Down Expand Up @@ -74,12 +74,12 @@ def write(self, modified_code: str, filepath: str, *, format_code: bool = True)
self._run_ruff(filepath)

def write_files(
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_bulk: bool = True
) -> Result:
results, fails = {}, {}

for filepath, modified_code in src_code.items():
if not ruff_root or not format_code:
if not ruff_root or format_bulk:
# format all at the end instead
res = self.write(modified_code, filepath, format_code=False)
else:
Expand All @@ -90,9 +90,10 @@ def write_files(
print(yellow(f"File written to `{filepath}`")) # noqa: T201
results[filepath] = res.inner
else:
logger.error(yellow(f"failed to write {filepath = } {res.err_msg = }"))
fails[filepath] = res

if ruff_root:
if ruff_root and format_bulk:
self._run_ruff(ruff_root)
if fails:
return Err(fails)
Expand Down Expand Up @@ -146,13 +147,13 @@ def write(self, modified_code: str, filepath: str, *, format_code: bool = True)
self.files[filepath] = format_code_str(modified_code) if format_code else modified_code

def write_files(
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_bulk: bool = True
) -> Result:
results, fails = {}, {}

for filepath, modified_code in src_code.items():
if ruff_root is not None:
res = self.write(modified_code, filepath, format_code=format_code)
res = self.write(modified_code, filepath, format_code=format_bulk)

if res.is_ok():
results[filepath] = res.inner
Expand Down
54 changes: 0 additions & 54 deletions src/spaghettree/domain/adj_mat.py

This file was deleted.

56 changes: 53 additions & 3 deletions src/spaghettree/domain/optimisation.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,76 @@
from collections import defaultdict
from copy import deepcopy
from typing import Self

import attrs
import numpy as np

from spaghettree import safe
from spaghettree.domain.adj_mat import AdjMat
from spaghettree.logger import logger


@attrs.define
class AdjMat:
mat: np.ndarray = attrs.field()
node_map: dict[int, str] = attrs.field()
communities: list[int] = attrs.field()
comm_map: dict[int, str] = attrs.field(factory=dict)

@classmethod
@safe
def from_call_tree(cls, call_tree: dict[str, list[str]], *, optimise: bool = True) -> Self:
logger.debug(f"{call_tree = }")
ent_idx, node_map, modules, mod_map = AdjMat._get_components(call_tree)
communities = [mod_map[name] for name in modules]

mat = AdjMat._create_adj_map(call_tree, ent_idx)

starting_dwm = get_dwm(mat, communities)
print(cyan("Starting DWM:"), yellow(starting_dwm)) # noqa: T201
logger.debug(f"{starting_dwm = }")

communities = list(node_map.keys()) if optimise else communities

return cls(mat, node_map, communities, comm_map={v: k for k, v in mod_map.items()})

@staticmethod
def _create_adj_map(call_tree: dict[str, list[str]], ent_idx: dict[str, int]) -> np.ndarray:
n = len(ent_idx)
adj_mat = np.zeros((n, n), dtype=int)

for caller, called in call_tree.items():
for call in called:
src_idx = ent_idx[caller]
dst_idx = ent_idx[call]
adj_mat[src_idx, dst_idx] += 1
return adj_mat

@staticmethod
def _get_components(
call_tree: dict[str, list[str]],
) -> tuple[dict[str, int], dict[int, str], list[str], dict[str, int]]:
ent_idx: dict[str, int] = {node: i for i, node in enumerate(call_tree)}
node_map: dict[int, str] = {idx: ent_name for ent_name, idx in ent_idx.items()}
modules: list[str] = [".".join(k.split(".")[:-1]) for k in call_tree]
unique_mods = list(dict.fromkeys(modules))
return ent_idx, node_map, modules, {name: idx for idx, name in enumerate(unique_mods)}


@safe
def optimise_communities(adj_mat: AdjMat) -> AdjMat:
valid_merges = get_merge_pairs(adj_mat)
logger.debug(f"{get_dwm(adj_mat.mat, adj_mat.communities) = }")

while valid_merges:
to_merge = remove_overlapping_pairs(valid_merges)
adj_mat.communities = apply_merges(adj_mat.communities, to_merge)
valid_merges = get_merge_pairs(adj_mat)

logger.debug(f"{get_dwm(adj_mat.mat, adj_mat.communities) = }")
opt_dwm = get_dwm(adj_mat.mat, adj_mat.communities)

print(cyan("Optimised DWM:"), yellow(opt_dwm)) # noqa: T201
logger.debug(f"{opt_dwm = }")
logger.debug(f"{adj_mat.communities = }")

return adj_mat


Expand Down
2 changes: 1 addition & 1 deletion src/spaghettree/domain/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from tqdm import tqdm

from spaghettree import safe
from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.entities import EntityCST
from spaghettree.domain.optimisation import AdjMat
from spaghettree.domain.visitors import EntityLocation, OnePassVisitor
from spaghettree.logger import logger

Expand Down
9 changes: 3 additions & 6 deletions src/spaghettree/domain/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from functools import partial

from spaghettree import Result, safe
from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.entities import EntityCST, ImportCST, ImportType
from spaghettree.domain.optimisation import (
AdjMat,
merge_single_entity_communities_if_no_gain_penalty,
optimise_communities,
)
Expand All @@ -26,7 +26,7 @@ def optimise_entity_positions(
new_root: str,
) -> Result:
return (
AdjMat.from_call_tree(call_tree)
AdjMat.from_call_tree(call_tree, optimise=True)
.and_then(pair_exclusive_calls)
.and_then(optimise_communities)
.and_then(merge_single_entity_communities_if_no_gain_penalty)
Expand Down Expand Up @@ -106,11 +106,8 @@ def rename_mod_name(name: str, renamed_modules: list[str]) -> str:
dirname_counts = Counter(dirnames)

logger.debug(f"{dirname_counts = }")
top_level_init = 2

if len(name_parts) == top_level_init and basename == "__init__":
pass
elif (basename in ("__all__", "logger") and dirname.endswith(".__init__")) or (
if (basename in ("__all__", "logger") and dirname.endswith(".__init__")) or (
dirname not in renamed_modules and dirname_counts.get(dirname, 0) <= 1
):
name = dirname
Expand Down
86 changes: 83 additions & 3 deletions tests/domain/test_optimisation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import string

import hypothesis.extra.numpy as hnp
import numpy as np
import pytest
from hypothesis import given
from hypothesis import strategies as st

from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.optimisation import SuggestedMerge, get_top_suggested_merges
from spaghettree.domain.optimisation import (
AdjMat,
SuggestedMerge,
get_dwm,
get_top_suggested_merges,
optimise_communities,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -78,7 +89,76 @@
],
)
def test_get_top_suggested_merges(call_tree, expected_result):
adj_mat = AdjMat.from_call_tree_no_optimisation(call_tree).unwrap()
adj_mat = AdjMat.from_call_tree(call_tree, optimise=False).unwrap()
res = get_top_suggested_merges(adj_mat)
assert res.is_ok()
assert res.unwrap() == expected_result


@st.composite
def st_adj_mat_and_comms(draw, max_n: int = 20, max_val: int = 20) -> tuple[np.ndarray, list[int]]:
n = draw(st.integers(min_value=1, max_value=max_n))

adj_mat = draw(
hnp.arrays(
dtype=np.int64,
shape=(n, n),
elements=st.integers(min_value=0, max_value=max_val),
)
)

comms = draw(
st.lists(
st.integers(min_value=1, max_value=max_val),
min_size=n,
max_size=n,
)
)

return adj_mat, comms


@given(st_adj_mat_and_comms())
def test_get_dwm_is_within_bounds(data):
mat, comms = data
dwm = get_dwm(mat, comms)
assert -0.5 <= dwm <= 1.0


def st_ent_path(
root: str = "package",
min_modules: int = 1,
max_modules: int = 3,
id_min_size: int = 3,
id_max_size: int = 3,
):
alphabet = string.ascii_lowercase
identifier = st.text(alphabet=alphabet, min_size=id_min_size, max_size=id_max_size)

return st.builds(
lambda modules, leaf: ".".join([root, *modules, leaf]),
st.lists(identifier, min_size=min_modules, max_size=max_modules),
identifier,
)


def st_call_tree(keys_count: int = 20):
keys = st.lists(st_ent_path(), min_size=2, max_size=keys_count, unique=True)
return keys.flatmap(
lambda k: st.fixed_dictionaries(
{key: st.lists(st.sampled_from(k), min_size=0, max_size=5) for key in k}
)
)


@given(st_call_tree())
def test_does_not_produce_worse_dwm(tree):
adj_mat = AdjMat.from_call_tree(tree, optimise=True).unwrap()
starting_dwm = get_dwm(adj_mat.mat, adj_mat.communities)
res = optimise_communities(adj_mat)

assert res.is_ok()

res_adj_mat = res.unwrap()
final_dwm = get_dwm(res_adj_mat.mat, res_adj_mat.communities)
assert starting_dwm <= final_dwm
4 changes: 2 additions & 2 deletions tests/domain/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest

from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.entities import ClassCST, FuncCST, GlobalCST, ImportCST, ImportType
from spaghettree.domain.optimisation import AdjMat
from spaghettree.domain.parsing import str_to_cst
from spaghettree.domain.processing import (
add_empty_inits_if_needed,
Expand Down Expand Up @@ -125,7 +125,7 @@
],
)
def test_second_half_of_processing(call_tree, entities, location_map, src_root, expected_result):
adj_mat = AdjMat.from_call_tree(call_tree).inner
adj_mat = AdjMat.from_call_tree(call_tree, optimise=True).inner
adj_mat.communities = [0, 2, 2, 4, 4, 0]
res = (
create_new_module_map(adj_mat, entities=entities)
Expand Down
Loading