From 3661768fb1db5189919537d0234090efeff5a55d Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 17 Sep 2025 02:08:54 +0300 Subject: [PATCH] added _check__ambiguous() --- smarttree/_check.py | 23 +++++++++ .../decision_tree/base/test__check_params.py | 51 +++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/smarttree/_check.py b/smarttree/_check.py index cdb9e40..883a965 100644 --- a/smarttree/_check.py +++ b/smarttree/_check.py @@ -58,6 +58,18 @@ def check__params( if rank_features is not None: _check__rank_features(rank_features) + if num_features is not None and cat_features is not None: + param_name1, param_name2 = "num_features", "cat_features" + _check__ambiguous(num_features, cat_features, param_name1, param_name2) + + if num_features is not None and rank_features is not None: + param_name1, param_name2 = "num_features", "rank_features" + _check__ambiguous(num_features, list(rank_features.keys()), param_name1, param_name2) + + if cat_features is not None and rank_features is not None: + param_name1, param_name2 = "cat_features", "rank_features" + _check__ambiguous(cat_features, list(rank_features.keys()), param_name1, param_name2) + if hierarchy is not None: _check__hierarchy(hierarchy) @@ -221,6 +233,17 @@ def _check__features_contain_duplicates(param_name, features): raise ValueError(f"`{param_name}` contains duplicates.") +def _check__ambiguous(features1, features2, param_name1, param_name2): + set_features1 = {features1} if isinstance(features1, str) else set(features1) + set_features2 = {features2} if isinstance(features2, str) else set(features2) + intersection = set_features1 & set_features2 + if intersection: + raise ValueError( + "Following feature names are ambiguous, they are defined in both" + f" '{param_name1}' and '{param_name2}': {intersection}." + ) + + def _check__hierarchy(hierarchy): common_message = ( "`hierarchy` must be a dictionary" diff --git a/tests/decision_tree/base/test__check_params.py b/tests/decision_tree/base/test__check_params.py index 9043388..2bcd540 100644 --- a/tests/decision_tree/base/test__check_params.py +++ b/tests/decision_tree/base/test__check_params.py @@ -442,6 +442,57 @@ def test__check_params__features_contain_duplicates(params_to_set, expected_cont SmartDecisionTreeClassifier(**params_to_set) +@pytest.mark.parametrize( + ("n_features", "c_features", "r_features", "expected_context"), + [ + ("f1", "f2", None, does_not_raise()), + ( + "f1", + "f1", + None, + pytest.raises( + ValueError, + match=( + "Following feature names are ambiguous, they are defined in" + " both 'num_features' and 'cat_features': {'f1'}." + ), + ), + ), + ("f1", None, {"f2": ["v2"]}, does_not_raise()), + ( + "f1", + None, + {"f1": ["v1"]}, + pytest.raises( + ValueError, + match=( + "Following feature names are ambiguous, they are defined in" + " both 'num_features' and 'rank_features': {'f1'}." + ), + ), + ), + (None, "f1", {"f2": ["v2"]}, does_not_raise()), + ( + None, + "f1", + {"f1": ["v1"]}, + pytest.raises( + ValueError, + match=( + "Following feature names are ambiguous, they are defined in" + " both 'cat_features' and 'rank_features': {'f1'}." + ), + ), + ), + ], +) +def test__check_params__ambiguous(n_features, c_features, r_features, expected_context): + with expected_context: + SmartDecisionTreeClassifier( + num_features=n_features, cat_features=c_features, rank_features=r_features + ) + + @pytest.mark.parametrize( ("hierarchy", "expected_context"), [