Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies = [
"pandas>=2.2,<3.0",
"pint>=0.23.0",
"argon2_cffi>=23.1.0",
"oso>=0.27.3,<0.28",
"alembic>=1.8.0,<2.0",
"click>=8.1.3,<9.0",
"celery>=5.3.1,<6.0",
Expand Down
1 change: 0 additions & 1 deletion requirements/install-min.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ psycopg==3.1.10
sqlalchemy==2.0.8
pandas==2.2
argon2_cffi==23.1.0
oso==0.27.3
alembic==1.8.0
click==8.1.3
celery==5.3.1
Expand Down
6 changes: 1 addition & 5 deletions requirements/install-min.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ celery==5.3.1
certifi==2024.12.14
# via requests
cffi==1.17.1
# via
# argon2-cffi-bindings
# oso
# via argon2-cffi-bindings
charset-normalizer==3.4.1
# via requests
click==8.1.3
Expand Down Expand Up @@ -53,8 +51,6 @@ markupsafe==3.0.2
# via mako
numpy==1.26.4
# via pandas
oso==0.27.3
# via -r requirements/install-min.in
packaging==24.2
# via redis
pandas==2.2.0
Expand Down
6 changes: 1 addition & 5 deletions requirements/install.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ celery==5.4.0
certifi==2024.12.14
# via requests
cffi==1.17.1
# via
# argon2-cffi-bindings
# oso
# via argon2-cffi-bindings
charset-normalizer==3.4.1
# via requests
click==8.1.8
Expand Down Expand Up @@ -55,8 +53,6 @@ markupsafe==3.0.2
# via mako
numpy==2.0.2
# via pandas
oso==0.27.3
# via bemserver-core (pyproject.toml)
pandas==2.2.3
# via bemserver-core (pyproject.toml)
pint==0.24.4
Expand Down
14 changes: 0 additions & 14 deletions src/bemserver_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import pandas as pd

from bemserver_core import (
authorization,
common,
database,
input_output, # noqa
model,
plugins,
settings,
tasks, # noqa
Expand All @@ -26,12 +24,6 @@

class BEMServerCore:
def __init__(self):
self.auth_model_classes = model.AUTH_MODEL_CLASSES
self.auth_polar_files = [
authorization.AUTH_POLAR_FILE,
model.AUTH_POLAR_FILE,
]

# Load config
self.config = settings.DEFAULT_CONFIG.copy()
file_path = os.environ.get("BEMSERVER_CORE_SETTINGS_FILE")
Expand All @@ -48,12 +40,6 @@ def __init__(self):
# Set db URL
database.db.set_db_url(self.config["SQLALCHEMY_DATABASE_URI"])

# Init auth
authorization.auth.init_authorization(
self.auth_model_classes,
self.auth_polar_files,
)

# Load unit definition files
for file_path in self.config["UNIT_DEFINITION_FILES"]:
common.ureg.load_definitions(file_path)
Expand Down
18 changes: 0 additions & 18 deletions src/bemserver_core/authorization.polar

This file was deleted.

149 changes: 92 additions & 57 deletions src/bemserver_core/authorization.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
"""Authorization"""

import functools
import typing
import warnings
from contextvars import ContextVar
from pathlib import Path

from oso import Oso, OsoError, Relation # noqa
from polar.data.adapter.sqlalchemy_adapter import SqlAlchemyAdapter

from bemserver_core.database import db
from bemserver_core.exceptions import BEMServerAuthorizationError
from bemserver_core.exceptions import (
BEMServerAuthorizationError,
BEMServerAuthorizationUndefinedActionError,
)
from bemserver_core.utils import make_context_var_manager

if typing.TYPE_CHECKING:
from bemserver_core.model import User

CURRENT_USER = ContextVar("current_user", default=None)
OPEN_BAR = ContextVar("open_bar", default=False)

CurrentUser = make_context_var_manager(CURRENT_USER)
OpenBar = functools.partial(make_context_var_manager(OPEN_BAR), True)


AUTH_POLAR_FILE = Path(__file__).parent / "authorization.polar"


def get_current_user():
current_user = CURRENT_USER.get()
if current_user is None or not current_user.is_active:
Expand All @@ -35,66 +36,80 @@ def get():
return OPEN_BAR.get()


class OsoProxy:
"""Oso proxy class

Provides lazy loading of classes and authorization rules
"""

def __init__(self, *args, **kwargs):
self.oso = None
self.oso_args = args
self.oso_kwargs = kwargs

def __getattr__(self, attr):
return getattr(self.oso, attr)

def init_authorization(self, model_classes, polar_files):
"""Register model classes and load rules

Must be done after model classes are imported
"""
self.oso = Oso(*self.oso_args, **self.oso_kwargs)
self.set_data_filtering_adapter(SqlAlchemyAdapter(db.session))

# Register classes
self.register_class(OpenBarPolarClass)
AuthMixin.register_class(name="Base")
for cls in model_classes:
cls.register_class()

# Load authorization policy
self.load_files(polar_files)

class AuthorizationsManager:
def __init__(self) -> None:
self._rules: dict = {}

def add_rule(self, action: str) -> typing.Callable:
def decorator(func: typing.Callable):
if action in self._rules:
warnings.warn(
f"Redefining authorization rule for {action}",
RuntimeWarning,
stacklevel=1,
)
self._rules[action] = func
return func

return decorator

def eval_rule(self, action: str, actor: "User", item: any):
try:
rule = self._rules[action]
except KeyError as exc:
raise BEMServerAuthorizationUndefinedActionError(
f"Undefined action: {action}"
) from exc
return rule(actor, item)

def authorize(self, action: str, item: any) -> bool:
actor = get_current_user()
if not (
OPEN_BAR.get() or actor.is_admin or self.eval_rule(action, actor, item)
):
raise BEMServerAuthorizationError

def authorize_query(self, model_cls, query):
actor = get_current_user()
if not (OPEN_BAR.get() or actor.is_admin):
query = model_cls.authorize_query(actor, query)
return query

auth = OsoProxy(
forbidden_error=BEMServerAuthorizationError,
not_found_error=BEMServerAuthorizationError,
)

auth_mgr: AuthorizationsManager = AuthorizationsManager()

class AuthMixin:
@classmethod
def register_class(cls, *args, **kwargs):
auth.register_class(cls, *args, **kwargs)

class AuthMgrMixin:
@classmethod
def _query(cls, **kwargs):
user = get_current_user()
# TODO: Workaround for https://github.com/osohq/oso/issues/1536
if OPEN_BAR.get() or user.is_admin:
query = db.session.query(cls)
else:
query = auth.authorized_query(user, "read", cls)
query = db.session.query(cls)
query = auth_mgr.authorize_query(cls, query)
for key, val in kwargs.items():
query = query.filter(getattr(cls, key) == val)
return query

@classmethod
def authorize_query(cls, actor: "User", query):
"""Override in model class to add custom rules"""
return query

def authorize_create(self, actor):
return False

def authorize_read(self, actor):
return False

def authorize_update(self, actor):
return False

def authorize_delete(self, actor):
return False

@classmethod
def new(cls, **kwargs):
# Override Base.new to avoid adding to the session if auth failed
item = cls(**kwargs)
auth.authorize(get_current_user(), "create", item)
auth_mgr.authorize("create", item)
db.session.add(item)
return item

Expand All @@ -103,13 +118,33 @@ def get_by_id(cls, item_id, **kwargs):
item = super().get_by_id(item_id)
if item is None:
return None
auth.authorize(get_current_user(), "read", item)
auth_mgr.authorize("read", item)
return item

def update(self, **kwargs):
auth.authorize(get_current_user(), "update", self)
auth_mgr.authorize("update", self)
super().update(**kwargs)

def delete(self):
auth.authorize(get_current_user(), "delete", self)
auth_mgr.authorize("delete", self)
super().delete()


@auth_mgr.add_rule("create")
def authorize_create(actor, item):
return item.authorize_create(actor)


@auth_mgr.add_rule("read")
def authorize_read(actor, item):
return item.authorize_read(actor)


@auth_mgr.add_rule("update")
def authorize_update(actor, item):
return item.authorize_update(actor)


@auth_mgr.add_rule("delete")
def authorize_delete(actor, item):
return item.authorize_delete(actor)
4 changes: 4 additions & 0 deletions src/bemserver_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class BEMServerAuthorizationError(BEMServerCoreIOError):
"""Operation not autorized to current user"""


class BEMServerAuthorizationUndefinedActionError(BEMServerCoreIOError):
"""Action undefined"""


class PropertyTypeInvalidError(BEMServerCoreError):
"""Invalid property value type: cast error"""

Expand Down
16 changes: 8 additions & 8 deletions src/bemserver_core/input_output/timeseries_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd

from bemserver_core.authorization import auth, get_current_user
from bemserver_core.authorization import auth_mgr
from bemserver_core.common import ureg
from bemserver_core.database import db
from bemserver_core.exceptions import (
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_last(
"""
# Check permissions
for ts in timeseries:
auth.authorize(get_current_user(), "read_data", ts)
auth_mgr.authorize("read_ts_data", ts)

params = {
"timeseries_ids": [ts.id for ts in timeseries],
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_timeseries_stats(
"""
# Check permissions
for ts in timeseries:
auth.authorize(get_current_user(), "read_data", ts)
auth_mgr.authorize("read_ts_data", ts)

params = {
"timeseries_ids": [ts.id for ts in timeseries],
Expand Down Expand Up @@ -276,7 +276,7 @@ def set_timeseries_data(

# Check permissions
for ts in timeseries:
auth.authorize(get_current_user(), "write_data", ts)
auth_mgr.authorize("write_ts_data", ts)

if convert_from:
cls._convert_from(
Expand Down Expand Up @@ -355,7 +355,7 @@ def get_timeseries_data(
"""
# Check permissions
for ts in timeseries:
auth.authorize(get_current_user(), "read_data", ts)
auth_mgr.authorize("read_ts_data", ts)

# Get timeseries data
stmt = (
Expand Down Expand Up @@ -455,7 +455,7 @@ def get_timeseries_buckets_data(

# Check permissions
for ts in timeseries:
auth.authorize(get_current_user(), "read_data", ts)
auth_mgr.authorize("read_ts_data", ts)

fill_value = 0 if aggregation == "count" else np.nan
dtype = int if aggregation == "count" else float
Expand Down Expand Up @@ -574,7 +574,7 @@ def get_timeseries_aggregate_data(
Returns a dataframe.
"""
for ts in timeseries:
auth.authorize(get_current_user(), "read_data", ts)
auth_mgr.authorize("read_ts_data", ts)

if agg == "avg":
agg_func = sqla.func.avg(TimeseriesData.value)
Expand Down Expand Up @@ -633,7 +633,7 @@ def delete(cls, start_dt, end_dt, timeseries, data_state):
"""
# Check permissions
for ts in timeseries:
auth.authorize(get_current_user(), "write_data", ts)
auth_mgr.authorize("write_ts_data", ts)

# Delete timeseries data
(
Expand Down
Loading
Loading