Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cmds/core/ban.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, bot: Bot):
@slash_command(guild_ids=settings.guild_ids, description="Ban a user from the server permanently.")
@has_any_role(*settings.role_groups.get("ALL_ADMINS"), *settings.role_groups.get("ALL_SR_MODS"))
async def ban(
self, ctx: ApplicationContext, user: discord.Member, reason: str, evidence: str = None
self, ctx: ApplicationContext, user: discord.Member, reason: str, evidence: str
) -> Interaction | WebhookMessage:
"""Ban a user from the server permanently."""
await ctx.defer(ephemeral=False)
Expand All @@ -48,7 +48,7 @@ async def ban(
*settings.role_groups.get("ALL_HTB_STAFF")
)
async def tempban(
self, ctx: ApplicationContext, user: discord.Member, duration: str, reason: str, evidence: str = None
self, ctx: ApplicationContext, user: discord.Member, duration: str, reason: str, evidence: str
) -> Interaction | WebhookMessage:
"""Ban a user from the server temporarily."""
await ctx.defer(ephemeral=False)
Expand Down
7 changes: 5 additions & 2 deletions src/cmds/core/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from src.core import settings
from src.database.models import HtbDiscordLink
from src.database.session import AsyncSessionLocal
from src.helpers.ban import add_infraction
from src.helpers.ban import add_infraction, validate_evidence
from src.helpers.checks import member_is_staff

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,7 +72,7 @@ async def bad_name(self, ctx: ApplicationContext, user: Member) -> Interaction |

@slash_command(guild_ids=settings.guild_ids, description="Kick a user from the server.")
@has_any_role(*settings.role_groups.get("ALL_ADMINS"), *settings.role_groups.get("ALL_MODS"))
async def kick(self, ctx: ApplicationContext, user: Member, reason: str, evidence: str = None) \
async def kick(self, ctx: ApplicationContext, user: Member, reason: str, evidence: str) \
-> Interaction | WebhookMessage:
"""Kick a user from the server."""
await ctx.defer(ephemeral=False)
Expand All @@ -88,6 +88,9 @@ async def kick(self, ctx: ApplicationContext, user: Member, reason: str, evidenc
if ctx.user.id == member.id:
return await ctx.followup.send("You cannot kick yourself.")

if evidence_error := validate_evidence(evidence):
return await ctx.followup.send(evidence_error.message, delete_after=evidence_error.delete_after)

if len(reason) == 0:
reason = "No reason given..."

Expand Down
14 changes: 12 additions & 2 deletions src/helpers/ban.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
logger = logging.getLogger(__name__)


EVIDENCE_REQUIRED_MESSAGE = "Evidence is required."


def validate_evidence(evidence: str | None) -> SimpleResponse | None:
"""Return an error response when evidence is missing or blank."""
if not isinstance(evidence, str) or not evidence.strip():
return SimpleResponse(message=EVIDENCE_REQUIRED_MESSAGE, delete_after=15)
return None


class BanCodes(Enum):
SUCCESS = "SUCCESS"
ALREADY_EXISTS = "ALREADY_EXISTS"
Expand Down Expand Up @@ -291,8 +301,8 @@ async def ban_member_with_epoch(
if len(reason) == 0:
reason = "No reason given ..."

if not evidence:
evidence = "none provided"
if evidence_error := validate_evidence(evidence):
return evidence_error

# Validate epoch time is in the future
current_time = datetime.now(tz=timezone.utc).timestamp()
Expand Down
4 changes: 2 additions & 2 deletions src/webhooks/handlers/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def _handle_account_banned(self, body: WebhookBody, bot: Bot) -> dict:
discord_id, account_id = self.validate_common_properties(body)
expires_at = self.validate_property(self.get_property_or_trait(body, "expires_at"), "expires_at")
reason = body.properties.get("reason")
notes = body.properties.get("notes")
notes = self.validate_non_empty_string(body.properties.get("notes"), "notes")
created_by = body.properties.get("created_by")

expires_ts = int(datetime.fromisoformat(expires_at).timestamp()) # type: ignore
Expand All @@ -114,7 +114,7 @@ async def _handle_account_banned(self, body: WebhookBody, bot: Bot) -> dict:
member=member,
expires_timestamp=expires_ts,
reason=f"Platform Ban - {reason}",
evidence=notes or "N/A",
evidence=notes,
author_name=created_by or "System",
expires_at_str=expires_at, # type: ignore
log_channel_id=settings.channels.BOT_LOGS,
Expand Down
21 changes: 21 additions & 0 deletions src/webhooks/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,27 @@ def validate_property(self, property: T | None, name: str) -> T:

return property

def validate_non_empty_string(self, value: str | None, name: str) -> str:
"""
Validates a string property is present and not blank.

Args:
value (str | None): The string to validate.
name (str): The name of the property.

Returns:
str: The validated string.

Raises:
HTTPException: If the value is None or blank (400)
"""
if not isinstance(value, str) or not value.strip():
msg = f"Invalid {name}"
self.logger.debug(msg)
raise HTTPException(status_code=400, detail=msg)

return value.strip()

def validate_discord_id(self, discord_id: str | int | None) -> int | str:
"""
Validates the Discord ID. See validate_property function.
Expand Down
34 changes: 25 additions & 9 deletions tests/src/cmds/core/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ async def test_kick_success(self, ctx, guild, bot, session):
patch('src.cmds.core.user.member_is_staff', return_value=False)
):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, user_to_kick, "Violation of rules")
await cog.kick.callback(cog, ctx, user_to_kick, "Violation of rules", "Some evidence")

reason = "Violation of rules"
add_infraction_mock.assert_called_once_with(
ctx.guild, user_to_kick, 0, f"Previously kicked for: {reason} - Evidence: None", ctx.user
ctx.guild, user_to_kick, 0, f"Previously kicked for: {reason} - Evidence: Some evidence", ctx.user
)

# Assertions
Expand All @@ -59,7 +59,7 @@ async def test_kick_fail_user_left(self, ctx, guild, bot, session):

with patch('src.cmds.core.user.member_is_staff', return_value=False):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, user_to_kick, "Violation of rules")
await cog.kick.callback(cog, ctx, user_to_kick, "Violation of rules", "Some evidence")

bot.get_member_or_user.assert_called_once_with(ctx.guild, user_to_kick.id)
ctx.guild.kick.assert_not_called()
Expand All @@ -76,7 +76,7 @@ async def test_kick_fail_user_not_found(self, ctx, guild, bot, session):

with patch('src.cmds.core.user.member_is_staff', return_value=False):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, user_to_kick, "Violation of rules")
await cog.kick.callback(cog, ctx, user_to_kick, "Violation of rules", "Some evidence")

bot.get_member_or_user.assert_called_once_with(ctx.guild, user_to_kick.id)
ctx.guild.kick.assert_not_called()
Expand All @@ -92,7 +92,7 @@ async def test_kick_fail_staff_member(self, ctx, guild, bot):

with patch('src.cmds.core.user.member_is_staff', return_value=True):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, member, "Violation of rules")
await cog.kick.callback(cog, ctx, member, "Violation of rules", "Some evidence")

ctx.defer.assert_awaited_once_with(ephemeral=False)
ctx.followup.send.assert_called_once_with("You cannot kick another staff member.")
Expand All @@ -106,7 +106,7 @@ async def test_kick_fail_bot_member(self, ctx, guild, bot):

with patch('src.cmds.core.user.member_is_staff', return_value=False):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, member, "Violation of rules")
await cog.kick.callback(cog, ctx, member, "Violation of rules", "Some evidence")

ctx.defer.assert_awaited_once_with(ephemeral=False)
ctx.followup.send.assert_called_once_with("You cannot kick a bot.")
Expand All @@ -120,11 +120,27 @@ async def test_kick_fail_self_kick(self, ctx, guild, bot):

with patch('src.cmds.core.user.member_is_staff', return_value=False):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, member, "Violation of rules")
await cog.kick.callback(cog, ctx, member, "Violation of rules", "Some evidence")

ctx.defer.assert_awaited_once_with(ephemeral=False)
ctx.followup.send.assert_called_once_with("You cannot kick yourself.")

@pytest.mark.asyncio
async def test_kick_fail_missing_evidence(self, ctx, guild, bot):
ctx.user = helpers.MockMember(id=1, name="Test Moderator")
member = helpers.MockMember(id=2, name="User to Kick", bot=False)
ctx.guild = guild
ctx.guild.kick = AsyncMock()
bot.get_member_or_user = AsyncMock(return_value=member)

with patch('src.cmds.core.user.member_is_staff', return_value=False):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, member, "Violation of rules", " ")

ctx.defer.assert_awaited_once_with(ephemeral=False)
ctx.guild.kick.assert_not_called()
ctx.followup.send.assert_called_once_with("Evidence is required.", delete_after=15)

@pytest.mark.asyncio
async def test_kick_http_exception_returns_error(self, ctx, guild, bot):
ctx.user = helpers.MockMember(id=1, name="Test Moderator")
Expand All @@ -139,7 +155,7 @@ async def test_kick_http_exception_returns_error(self, ctx, guild, bot):
patch('src.cmds.core.user.member_is_staff', return_value=False)
):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, member, "Violation of rules")
await cog.kick.callback(cog, ctx, member, "Violation of rules", "Some evidence")

ctx.defer.assert_awaited_once_with(ephemeral=False)
ctx.guild.kick.assert_not_called()
Expand All @@ -161,7 +177,7 @@ async def test_kick_forbidden_dm_sends_notice_and_continues(self, ctx, guild, bo
patch('src.cmds.core.user.member_is_staff', return_value=False)
):
cog = user.UserCog(bot)
await cog.kick.callback(cog, ctx, member, "Violation of rules")
await cog.kick.callback(cog, ctx, member, "Violation of rules", "Some evidence")

ctx.defer.assert_awaited_once_with(ephemeral=False)
assert ctx.followup.send.await_count == 2
Expand Down
13 changes: 12 additions & 1 deletion tests/src/helpers/test_ban.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from discord import Forbidden, HTTPException

from src.helpers.ban import _check_member, _dm_banned_member, ban_member
from src.helpers.ban import EVIDENCE_REQUIRED_MESSAGE, _check_member, _dm_banned_member, ban_member
from src.helpers.responses import SimpleResponse
from tests import helpers

Expand Down Expand Up @@ -150,6 +150,17 @@ async def test_ban_member_invalid_duration(self, bot, guild, member, author):
assert isinstance(result, SimpleResponse)
assert result.message == "Invalid duration: could not parse."

@pytest.mark.asyncio
async def test_ban_member_missing_evidence(self, bot, guild, member, author):
duration = "1d"
reason = "xf reason"
member.display_name = "Banned Member"

with mock.patch("src.helpers.ban._check_member", return_value=None):
result = await ban_member(bot, guild, member, duration, reason, " ")
assert isinstance(result, SimpleResponse)
assert result.message == EVIDENCE_REQUIRED_MESSAGE

@pytest.mark.asyncio
async def test_ban_member_permanently_success(self, bot, guild, member, author):
duration = "500w"
Expand Down
29 changes: 29 additions & 0 deletions tests/src/webhooks/handlers/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ async def test_handle_account_banned_member_not_found(self, bot):
"discord_id": discord_id,
"account_id": account_id,
"expires_at": expires_at,
"notes": "Repeated violations",
},
traits={},
)
Expand All @@ -508,3 +509,31 @@ async def test_handle_account_banned_member_not_found(self, bot):
result = await handler._handle_account_banned(body, bot)
mock_log.assert_called()
assert result == handler.fail()

@pytest.mark.asyncio
async def test_handle_account_banned_missing_notes(self, bot):
"""Test account banned event rejects missing notes."""
handler = AccountHandler()
discord_id = 123456789
account_id = 987654321
expires_at = "2024-12-31T23:59:59"
body = WebhookBody(
platform=Platform.ACCOUNT,
event=WebhookEvent.ACCOUNT_BANNED,
properties={
"discord_id": discord_id,
"account_id": account_id,
"expires_at": expires_at,
},
traits={},
)
with (
patch.object(handler, "validate_discord_id", return_value=discord_id),
patch.object(handler, "validate_account_id", return_value=account_id),
patch.object(handler, "validate_property", return_value=expires_at),
):
with pytest.raises(HTTPException) as exc_info:
await handler._handle_account_banned(body, bot)

assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid notes"
17 changes: 17 additions & 0 deletions tests/src/webhooks/handlers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,23 @@ def test_validate_property_none(self):
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid test_property"

def test_validate_non_empty_string_success(self):
"""Test successful non-empty string validation."""
handler = ConcreteHandler()

assert handler.validate_non_empty_string("valid value", "notes") == "valid value"

def test_validate_non_empty_string_invalid(self):
"""Test non-empty string validation rejects blank values."""
handler = ConcreteHandler()

for invalid_value in (None, "", " "):
with pytest.raises(HTTPException) as exc_info:
handler.validate_non_empty_string(invalid_value, "notes")

assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid notes"

def test_validate_discord_id_success(self):
"""Test successful Discord ID validation."""
handler = ConcreteHandler()
Expand Down