diff --git a/.gitignore b/.gitignore index 3eae1d6..f2ca833 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,4 @@ dmypy.json test.ipynb dummy.py .DS_Store +.cursor/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2dc8f14..de39ddf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,4 +75,4 @@ line-length = 110 atomic = true profile = "black" skip_gitignore = true -known_first_party = ["black", "blib2to3", "blackd", "_black_version"] +known_first_party = ["wnb"] diff --git a/wnb/_utils.py b/wnb/_utils.py index cc9d534..2aa9f81 100644 --- a/wnb/_utils.py +++ b/wnb/_utils.py @@ -54,7 +54,6 @@ def _check_feature_names(estimator, X, *, reset): except ImportError: def _fit_context(*, prefer_skip_nested_validation: bool): - def decorator(fit_method): @functools.wraps(fit_method) def wrapper(estimator, *args: Any, **kwargs: Any): diff --git a/wnb/gnb.py b/wnb/gnb.py index dcd6c98..dbc397a 100644 --- a/wnb/gnb.py +++ b/wnb/gnb.py @@ -114,7 +114,7 @@ class GeneralNB(_BaseNB): >>> X = np.array([[-1, 1], [-2, 1], [-3, 2], [1, 1], [2, 1], [3, 2]]) >>> Y = np.array([1, 1, 1, 2, 2, 2]) >>> from wnb import GeneralNB, Distribution as D - >>> clf = GeneralNB(distributions=[D.NORMAL, D.POISSON]) + >>> clf = GeneralNB([D.NORMAL, D.POISSON]) >>> clf.fit(X, Y) GeneralNB(distributions=[, ]) @@ -200,7 +200,9 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: @staticmethod def _find_dist( - feature_idx: int, feature_name: str | None, dist_mapping: dict[DistributionLike, ColumnKey] + feature_idx: int, + feature_name: str | None, + dist_mapping: dict[DistributionLike, ColumnKey], ) -> DistributionLike: for dist, cols in dist_mapping.items(): cols_ = ( diff --git a/wnb/gwnb.py b/wnb/gwnb.py index 219bdb6..9f52ebf 100644 --- a/wnb/gwnb.py +++ b/wnb/gwnb.py @@ -250,8 +250,7 @@ def _init_parameters(self) -> None: # Ensure the size of error weights matrix matches number of classes if error_weights.shape != (self.n_classes_, self.n_classes_): raise ValueError( - "The shape of error weights matrix does not match the number of classes, " - "must be (n_classes, n_classes)." + "The shape of error weights matrix does not match the number of classes, must be (n_classes, n_classes)." ) self.error_weights_ = error_weights