diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 2b0b22f..e245624 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -20,6 +20,7 @@ class ColumnSplitResult(NamedTuple): + information_gain: float feature_values: list[list] child_masks: list[pd.Series] @@ -61,11 +62,77 @@ def __init__( def split(self, *args, **kwargs) -> ColumnSplitResult: raise NotImplementedError + def foo( + self, + parent_mask: pd.Series, + split_feature: str, + child_masks: list[pd.Series], + ) -> tuple[float, list[pd.Series], int]: + + if self.dataset.has_na[split_feature]: + mask_na = parent_mask & self.dataset.mask_na[split_feature] + na_mode = self.feature_na_mode[split_feature] + if na_mode == "include_all": + return self.include_all_split(parent_mask, mask_na, child_masks) + elif na_mode == "include_best": + return self.include_best_split(parent_mask, mask_na, child_masks) + else: + assert False + else: + information_gain = self.information_gain(parent_mask, child_masks) + return information_gain, child_masks, -1 + + def include_all_split( + self, + parent_mask: pd.Series, + mask_na: pd.Series, + child_masks: list[pd.Series], + ) -> tuple[float, list[pd.Series], int]: + + for i, child_mask in enumerate(child_masks): + child_masks[i] = child_mask | (parent_mask & mask_na) + if child_masks[i].sum() < self.min_samples_leaf: + return NO_INFORMATION_GAIN, [], -1 + + information_gain = self.information_gain(parent_mask, child_masks, normalize=True) + + return information_gain, child_masks, -1 + + def include_best_split( + self, + parent_mask: pd.Series, + mask_na: pd.Series, + child_masks: list[pd.Series], + ) -> tuple[float, list[pd.Series], int]: + + candidates = [] + origin_child_masks = child_masks + for i, child_mask in enumerate(origin_child_masks): + child_masks = deepcopy(origin_child_masks) + child_masks[i] = child_mask | (parent_mask & mask_na) + for child_mask in child_masks: + if child_mask.sum() < self.min_samples_leaf: + break + else: + candidates.append(child_masks) + + best_information_gain = NO_INFORMATION_GAIN + best_child_masks = [] + best_child_na_index = -1 + for child_na_index, child_masks in enumerate(candidates): + information_gain = self.information_gain(parent_mask, child_masks) + if best_information_gain < information_gain: + best_information_gain = information_gain + best_child_masks = child_masks + best_child_na_index = child_na_index + + return best_information_gain, best_child_masks, best_child_na_index + def information_gain( self, parent_mask: pd.Series, child_masks: list[pd.Series], - na_mode: NaModeType | None = None, + normalize: bool = False, ) -> float: r""" Calculates information gain of the split. @@ -75,8 +142,9 @@ def information_gain( boolean mask of parent node. child_masks: pd.Series list of boolean masks of child nodes. - na_mode: {"include_all", ...}, default=None - If "include_all" use normalization. + normalize: bool, default=False + if True, normalizes information gain by split factor to handle + unbalanced splits. Uses child node counts for normalization. Returns: float: information gain. @@ -113,7 +181,7 @@ def information_gain( impurity_child_i = self.impurity(child_mask_i) weighted_impurity_childs += (N_child_i / N_parent) * impurity_child_i - if na_mode == "include_all": + if normalize: norm_coef = N_parent / N_childs weighted_impurity_childs *= norm_coef @@ -133,15 +201,6 @@ def gini_index(self, mask: pd.Series) -> float: C - total number of classes; p_i - the probability of choosing a sample with class i. """ - # N = mask.sum() - # - # gini_index = 1 - # for label in self.dataset.class_names: - # N_i = (mask & (self.dataset.y == label)).sum() - # p_i = N_i / N - # gini_index -= pow(p_i, 2) - # - # return gini_index return cgini_index(mask, self.dataset.y, self.dataset.class_names) def entropy(self, mask: pd.Series) -> float: @@ -217,51 +276,17 @@ def __moving_average(array: NDArray, window: int = 2) -> NDArray: return np.convolve(array, np.ones(window), mode="valid") / window def __num_split( - self, parent_mask: pd.Series, + self, + parent_mask: pd.Series, split_feature: str, threshold: float, ) -> tuple[float, list[pd.Series], int]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - mask_less = parent_mask & (self.dataset.X[split_feature] <= threshold) mask_more = parent_mask & (self.dataset.X[split_feature] > threshold) child_masks = [mask_less, mask_more] - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) # update - if child_masks[i].sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [], -1 - - elif na_mode == "include_best": - candidates = [] - origin_child_masks = child_masks - for i, child_mask in enumerate(origin_child_masks): - child_masks = deepcopy(origin_child_masks) - child_masks[i] = child_mask | (parent_mask & mask_na) # update - for child_mask in child_masks: - if child_mask.sum() < self.min_samples_leaf: - break - else: - candidates.append(child_masks) - - best_information_gain = NO_INFORMATION_GAIN - best_child_masks = [] - best_child_na_index = -1 - for i, child_masks in enumerate(candidates): - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - if best_information_gain < information_gain: - best_information_gain = information_gain - best_child_masks = child_masks - best_child_na_index = i - - return best_information_gain, best_child_masks, best_child_na_index - - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - - return information_gain, child_masks, -1 + return self.foo(parent_mask, split_feature, child_masks) class CatColumnSplitter(BaseColumnSplitter): @@ -329,48 +354,13 @@ def __cat_split( feature_values: list[list], ) -> tuple[float, list[pd.Series], int]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - child_masks = [] for partition in feature_values: partition_mask = self.dataset.X[split_feature].isin(partition) child_mask = parent_mask & partition_mask child_masks.append(child_mask) - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) # update - if child_masks[i].sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [], -1 - - elif na_mode == "include_best": - candidates = [] - origin_child_masks = child_masks - for i, child_mask in enumerate(origin_child_masks): - child_masks = deepcopy(origin_child_masks) - child_masks[i] = child_mask | (parent_mask & mask_na) # update - for child_mask in child_masks: - if child_mask.sum() < self.min_samples_leaf: - break - else: - candidates.append(child_masks) - - best_information_gain = NO_INFORMATION_GAIN - best_child_masks = [] - best_child_na_index = -1 - for i, child_masks in enumerate(candidates): - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - if best_information_gain < information_gain: - best_information_gain = information_gain - best_child_masks = child_masks - best_child_na_index = i - - return best_information_gain, best_child_masks, best_child_na_index - - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - - return information_gain, child_masks, -1 + return self.foo(parent_mask, split_feature, child_masks) def __cat_partitions( self, @@ -417,12 +407,12 @@ def split(self, node: TreeNode, split_feature: str) -> ColumnSplitResult: best_split_result = ColumnSplitResult.no_split() for feature_values in self.__rank_partitions(available_feature_values): - information_gain, child_masks = self.__rank_split( + information_gain, child_masks, child_na_index = self.__rank_split( node.mask, split_feature, feature_values ) if best_split_result.information_gain < information_gain: best_split_result = ColumnSplitResult( - information_gain, list(feature_values), child_masks + information_gain, list(feature_values), child_masks, child_na_index ) return best_split_result @@ -432,7 +422,7 @@ def __rank_split( parent_mask: pd.Series, split_feature: str, feature_values: tuple[list, list], - ) -> tuple[float, list[pd.Series]]: + ) -> tuple[float, list[pd.Series], int]: feature_values_left, feature_values_right = feature_values @@ -440,13 +430,7 @@ def __rank_split( mask_right = parent_mask & self.dataset.X[split_feature].isin(feature_values_right) child_masks = [mask_left, mask_right] - for child_mask in child_masks: - if child_mask.sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [] - - information_gain = self.information_gain(parent_mask, child_masks) - - return information_gain, child_masks + return self.foo(parent_mask, split_feature, child_masks) @staticmethod def __rank_partitions(collection: list) -> Generator[tuple[list, list], None, None]: diff --git a/smarttree/_dataset.py b/smarttree/_dataset.py index fed54eb..90ee9b0 100644 --- a/smarttree/_dataset.py +++ b/smarttree/_dataset.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import pandas as pd @@ -10,10 +10,20 @@ class Dataset: X: pd.DataFrame y: pd.Series + class_names: NDArray = field(init=False) + has_na: dict[str, bool] = field(init=False) + mask_na: dict[str, pd.Series] = field(init=False) def __post_init__(self) -> None: - self.class_names: NDArray = np.sort(self.y.unique()) - self.mask_na = {column: self.X[column].isna() for column in self.X.columns} + self.class_names = np.sort(self.y.unique()) + self.has_na = dict() + self.mask_na = dict() + for column in self.X.columns: + mask_na = self.X[column].isna() + has_na = mask_na.any() + self.has_na[column] = has_na + if has_na: + self.mask_na[column] = mask_na @property def size(self) -> int: diff --git a/smarttree/_node_splitter.py b/smarttree/_node_splitter.py index fca0af4..c4fec9b 100644 --- a/smarttree/_node_splitter.py +++ b/smarttree/_node_splitter.py @@ -9,6 +9,7 @@ class NodeSplitResult(NamedTuple): + information_gain: float split_type: str split_feature: str diff --git a/tests/conftest.py b/tests/conftest.py index 89da3ef..482b5e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from itertools import chain import numpy as np import pandas as pd @@ -151,12 +152,8 @@ def feature_na_mode( ) -> dict[str, NaModeType | None]: result = dict() - for num_feature in num_features: - result[num_feature] = "min" - for cat_feature in cat_features: - result[cat_feature] = "as_category" - for rank_feature in rank_features: - result[rank_feature] = None + for feature in chain(num_features, cat_features, rank_features): + result[feature] = "include_best" return result diff --git a/tests/test__node_splitter.py b/tests/test__node_splitter.py index e816f04..010af8b 100644 --- a/tests/test__node_splitter.py +++ b/tests/test__node_splitter.py @@ -4,7 +4,7 @@ @pytest.fixture(scope="module") -def concrete_node_splitter( +def node_splitter( X, y, num_features, cat_features, rank_features, feature_na_mode ) -> NodeSplitter: return NodeSplitter( @@ -24,10 +24,10 @@ def concrete_node_splitter( ) -def test__find_best_split(concrete_node_splitter, root_node): - concrete_node_splitter.find_best_split_for(root_node, leaf_counter=0) +def test__find_best_split(node_splitter, root_node): + node_splitter.find_best_split_for(root_node, leaf_counter=0) -def test__is_splittable(concrete_node_splitter, root_node): - concrete_node_splitter.find_best_split_for(root_node, leaf_counter=0) - concrete_node_splitter.is_splittable(root_node, leaf_counter=0) +def test__is_splittable(node_splitter, root_node): + node_splitter.find_best_split_for(root_node, leaf_counter=0) + node_splitter.is_splittable(root_node, leaf_counter=0)