Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,4 @@ dmypy.json
test.ipynb
dummy.py
.DS_Store
.cursor/
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ line-length = 110
atomic = true
profile = "black"
skip_gitignore = true
known_first_party = ["black", "blib2to3", "blackd", "_black_version"]
known_first_party = ["wnb"]
1 change: 0 additions & 1 deletion wnb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def _check_feature_names(estimator, X, *, reset):
except ImportError:

def _fit_context(*, prefer_skip_nested_validation: bool):

def decorator(fit_method):
@functools.wraps(fit_method)
def wrapper(estimator, *args: Any, **kwargs: Any):
Expand Down
6 changes: 4 additions & 2 deletions wnb/gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class GeneralNB(_BaseNB):
>>> X = np.array([[-1, 1], [-2, 1], [-3, 2], [1, 1], [2, 1], [3, 2]])
>>> Y = np.array([1, 1, 1, 2, 2, 2])
>>> from wnb import GeneralNB, Distribution as D
>>> clf = GeneralNB(distributions=[D.NORMAL, D.POISSON])
>>> clf = GeneralNB([D.NORMAL, D.POISSON])
>>> clf.fit(X, Y)
GeneralNB(distributions=[<Distribution.NORMAL: 'Normal'>,
<Distribution.POISSON: 'Poisson'>])
Expand Down Expand Up @@ -200,7 +200,9 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:

@staticmethod
def _find_dist(
feature_idx: int, feature_name: str | None, dist_mapping: dict[DistributionLike, ColumnKey]
feature_idx: int,
feature_name: str | None,
dist_mapping: dict[DistributionLike, ColumnKey],
) -> DistributionLike:
for dist, cols in dist_mapping.items():
cols_ = (
Expand Down
3 changes: 1 addition & 2 deletions wnb/gwnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ def _init_parameters(self) -> None:
# Ensure the size of error weights matrix matches number of classes
if error_weights.shape != (self.n_classes_, self.n_classes_):
raise ValueError(
"The shape of error weights matrix does not match the number of classes, "
"must be (n_classes, n_classes)."
"The shape of error weights matrix does not match the number of classes, must be (n_classes, n_classes)."
)

self.error_weights_ = error_weights
Expand Down