diff --git a/tests/conftest.py b/tests/conftest.py index 7b38747..8e04955 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -217,3 +217,35 @@ def render( @pytest.fixture def concrete_smart_tree(): return ConcreteSmartTree() + + +def pytest_collection_modifyitems(session, config, items): + + module_order = [ + "tests/test__criterion.py", + "tests/test__repr_tree_node.py", + "tests/column_splitter/test__base_column_splitter.py", + "tests/column_splitter/test__num_column_splitter.py", + "tests/column_splitter/test__cat_column_splitter.py", + "tests/column_splitter/test__rank_column_splitter.py", + "tests/test__node_splitter.py", + "tests/decision_tree/base/test__check_params.py", + "tests/decision_tree/base/test__get_set_params.py", + "tests/decision_tree/base/test__check_data.py", + "tests/decision_tree/base/test__base_not_fitted.py", + "tests/decision_tree/classifier/test__repr_tree.py", + "tests/decision_tree/classifier/test__classifier_not_fitted.py", + "tests/decision_tree/classifier/test__fit_predict.py", + "tests/test__renderer.py", + ] + + def get_priority(item): + + filepath = item.fspath.relto(session.fspath).replace("\\", "/") + + if filepath in module_order: + return module_order.index(filepath) + else: + return len(module_order) + + items.sort(key=get_priority)