Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# PyCharm
/.idea

# IPython
.ipynb_checkpoints/
*.ipynb

# bytecode
__pycache__/

*.ipynb
/poetry.toml

# Cython
build
*.pyd
*.so
71 changes: 71 additions & 0 deletions build-extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

import os
import shutil
import sys

from pathlib import Path

from Cython.Build import cythonize
from setuptools import Distribution
from setuptools import Extension
from setuptools.command.build_ext import build_ext


if sys.platform == "win32":
COMPILE_ARGS = ["/O2", "/fp:fast"]
LINK_ARGS = []
INCLUDE_DIRS = []
LIBRARIES = []
else:
COMPILE_ARGS = ["-march=native", "-O3", "-msse", "-msse2", "-mfma", "-mfpmath=sse"]
LINK_ARGS = []
INCLUDE_DIRS = []
LIBRARIES = ["m"]


def build() -> None:

extensions = []
pyx_files = list(Path("smarttree").glob("*.pyx"))
for pyx_file in pyx_files:
module_name = f"smarttree.{pyx_file.stem}"
extension = Extension(
module_name,
[str(pyx_file)],
extra_compile_args=COMPILE_ARGS,
extra_link_args=LINK_ARGS,
include_dirs=INCLUDE_DIRS,
libraries=LIBRARIES,
)
extensions.append(extension)

ext_modules = cythonize(
extensions,
include_path=INCLUDE_DIRS,
compiler_directives={"binding": True, "language_level": 3},
build_dir="build",
)

distribution = Distribution({
"name": "smarttree",
"ext_modules": ext_modules
})

cmd = build_ext(distribution)
cmd.ensure_finalized()
cmd.run()

# Copy built extensions back to the project
for output in cmd.get_outputs():
output = Path(output)
relative_extension = Path(".") / output.relative_to(cmd.build_lib)

shutil.copyfile(output, relative_extension)
mode = os.stat(relative_extension).st_mode
mode |= (mode & 0o444) >> 2
os.chmod(relative_extension, mode)


if __name__ == "__main__":
build()
263 changes: 168 additions & 95 deletions poetry.lock

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
[build-system]
requires = ["poetry-core", "cython", "setuptools"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "smarttree"
version = "0.1.0"
description = "Custom Decision Tree with bells and whistles"
authors = ["Mikhail Martin <mikhailmartin95@yandex.ru>"]
readme = "README.md"
packages = [{ include = "smarttree" }]
include = [
{ path = "smarttree/**/*.pyd", format = "wheel" },
{ path = "smarttree/**/*.so", format = "wheel" },
{ path = "smarttree/**/*.lib", format = "wheel" },
]
exclude = ["**/*.c"]

[tool.poetry.build]
script = "build-extension.py"

[tool.poetry.dependencies]
python = "^3.11"
pandas = "2.1.4"
scikit-learn = "^1.3.2"
graphviz = "^0.20.1"
poetry-core = "^2.2.0"

[tool.poetry.group.dev.dependencies]
notebook = "^7.0.6"
jupyterlab-execute-time = "^3.1.0"
cython = "^3.1.4"

[tool.poetry.group.test.dependencies]
pytest = "^7.4.3"
Expand All @@ -25,10 +39,6 @@ pandas-stubs = "^2.3.2.250827"
scikit-learn-stubs = "^0.0.3"
ruff = "^0.12.12"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[[tool.mypy.overrides]]
module = "graphviz"
ignore_missing_imports = true
Expand Down
4 changes: 4 additions & 0 deletions smarttree/_cgini_index.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import pandas as pd
from numpy.typing import NDArray

def cgini_index(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: ...
47 changes: 47 additions & 0 deletions smarttree/_cgini_index.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
cimport cython
from libc.stdint cimport int8_t
import numpy as np


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def cgini_index(mask, y, class_names):

cdef int8_t[:] mask_arr = mask.values.astype(np.int8)
cdef object[:] y_arr = y.values
cdef long N = 0
cdef long N_i = 0
cdef double p_i = 0.0
cdef double gini_index = 1.0
cdef int i
cdef int j
cdef int n = len(mask)
cdef int n_classes = len(class_names)
cdef object class_name
cdef object label
cdef int8_t mask_value

for i in range(n):
mask_value = mask_arr[i]
if mask_value:
N += 1

if N == 0:
return 0.0

for j in range(n_classes):
N_i = 0
class_name = class_names[j]

for i in range(n):
mask_value = mask_arr[i]
if mask_value:
label = y_arr[i]
if label == class_name:
N_i += 1

p_i = <double>N_i / <double>N
gini_index -= p_i * p_i

return gini_index
20 changes: 11 additions & 9 deletions smarttree/_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
from numpy.typing import NDArray

from ._cgini_index import cgini_index
from ._dataset import Dataset
from ._tree import TreeNode
from ._types import ClassificationCriterionType, NaModeType
Expand Down Expand Up @@ -132,15 +133,16 @@ def gini_index(self, mask: pd.Series) -> float:
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
N = mask.sum()

gini_index = 1
for label in self.dataset.class_names:
N_i = (mask & (self.dataset.y == label)).sum()
p_i = N_i / N
gini_index -= pow(p_i, 2)

return gini_index
# N = mask.sum()
#
# gini_index = 1
# for label in self.dataset.class_names:
# N_i = (mask & (self.dataset.y == label)).sum()
# p_i = N_i / N
# gini_index -= pow(p_i, 2)
#
# return gini_index
return cgini_index(mask, self.dataset.y, self.dataset.class_names)

def entropy(self, mask: pd.Series) -> float:
r"""
Expand Down
4 changes: 3 additions & 1 deletion smarttree/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass

import numpy as np
import pandas as pd
from numpy.typing import NDArray


@dataclass
Expand All @@ -10,7 +12,7 @@ class Dataset:
y: pd.Series

def __post_init__(self) -> None:
self.class_names = sorted(self.y.unique())
self.class_names: NDArray = np.sort(self.y.unique())
self.mask_na = {column: self.X[column].isna() for column in self.X.columns}

@property
Expand Down
Loading