diff --git a/admin/admin_theme.py b/admin/admin_theme.py
index 74cd5691f..fa2181085 100644
--- a/admin/admin_theme.py
+++ b/admin/admin_theme.py
@@ -12,7 +12,7 @@
from core.models import Config
from core.template import (
AdminTemplates, TEMPLATES, TemplateService, UserTemplates,
- get_current_theme, get_theme_list, get_theme_info, register_theme_statics,
+ get_current_theme, get_theme_list, get_theme_info,
)
from lib.dependency.dependencies import validate_super_admin, validate_theme
@@ -110,7 +110,7 @@ async def theme_update(
os.unlink(file_path)
# 테마 관련 정적 파일을 등록합니다.
- register_theme_statics(app)
+ TemplateService.register_statics(app)
# 이전 테마 경로를 제거 후 새로운 테마 경로를 추가합니다.
user_template = UserTemplates()
diff --git a/admin/admin_visit.py b/admin/admin_visit.py
index 24f87c368..f142a123a 100644
--- a/admin/admin_visit.py
+++ b/admin/admin_visit.py
@@ -424,11 +424,10 @@ async def visit_hour(
# 합계
total_count = db.scalar(query.add_columns(func.count(Visit.vi_id)))
# 시간별 접속자집계
- # TODO: postgresql는 테스트가 안되어 있음
if dialect == 'mysql':
query = query.add_columns(func.hour(Visit.vi_time).label('hour'))
elif dialect == 'postgresql':
- query = query.add_columns(func.to_char(Visit.vi_time, 'HH24').label('hour'))
+ query = query.add_columns(extract('hour', Visit.vi_time).label('hour'))
elif dialect == 'sqlite':
query = query.add_columns(func.strftime('%H', Visit.vi_time).label('hour'))
query_result = db.execute(
@@ -472,11 +471,10 @@ async def visit_weekday(
# 합계
total_count = db.scalar(query.add_columns(func.count(Visit.vi_id)))
# 요일별 접속자집계
- # TODO: postgresql는 테스트가 안되어 있음
if dialect == 'mysql':
query = query.add_columns(func.dayofweek(Visit.vi_date).label('dow'))
elif dialect == 'postgresql':
- query = query.add_columns(func.to_char(Visit.vi_date, 'D').label('dow'))
+ query = query.add_columns(extract('dow', Visit.vi_date).label('dow'))
elif dialect == 'sqlite':
query = query.add_columns(func.strftime('%w', Visit.vi_date).label('dow'))
query_result = db.execute(
diff --git a/bbs/board.py b/bbs/board.py
index 6c622c67c..1385c4f76 100644
--- a/bbs/board.py
+++ b/bbs/board.py
@@ -4,6 +4,7 @@
from typing_extensions import Annotated, List
from fastapi import APIRouter, Depends, Request, Form, Path, Query, File, UploadFile
+from typing import Union
from fastapi.responses import FileResponse, RedirectResponse
from core.database import db_session
@@ -183,7 +184,15 @@ async def write_form_add(
else:
service.validate_write_level()
- # TODO: 포인트 검증
+ # 포인트 검증
+ required_point = (
+ board.bo_comment_point if parent_write else board.bo_write_point
+ )
+ service.point_service.validate_enough_point(
+ service.member.mb_id,
+ required_point,
+ "답변 작성" if parent_write else "게시글 작성",
+ )
# 게시판 제목 설정
board.subject = service.subject
@@ -298,7 +307,7 @@ async def create_post(
form_data: Annotated[WriteForm, Depends()],
service: Annotated[CreatePostService, Depends(CreatePostService.async_init)],
file_service: Annotated[BoardFileService, Depends()],
- parent_id: int = Form(None),
+ parent_id: Union[int, None, str] = Form(None),
notice: bool = Form(False),
secret: str = Form(""),
html: str = Form(""),
@@ -310,6 +319,10 @@ async def create_post(
recaptcha_response: str = Form("", alias="g-recaptcha-response"),
):
"""게시글을 작성한다."""
+ if parent_id in ("", None):
+ parent_id = None
+ else:
+ parent_id = int(parent_id)
await service.validate_captcha(recaptcha_response)
service.validate_write_delay()
service.validate_write_level()
@@ -466,6 +479,12 @@ async def write_comment_update(
form: WriteCommentForm = Depends(),
recaptcha_response: str = Form("", alias="g-recaptcha-response"),
):
+ # 여기서 if 문을 사용해야 함!
+ if form.comment_id in ("", None):
+ comment_id = None
+ else:
+ comment_id = int(form.comment_id)
+
"""
댓글 등록/수정
"""
@@ -489,7 +508,8 @@ async def write_comment_update(
elif form.w == "cu":
# 댓글 수정
write_model = service.write_model
- comment = service.db.get(write_model, form.comment_id)
+ # comment = service.db.get(write_model, form.comment_id)
+ comment = service.db.get(write_model, comment_id)
if not comment:
raise AlertException(f"{form.comment_id} : 존재하지 않는 댓글입니다.", 404)
diff --git a/core/exception.py b/core/exception.py
index 0f38ba06f..0845d66bf 100644
--- a/core/exception.py
+++ b/core/exception.py
@@ -3,7 +3,6 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
-from fastapi.templating import Jinja2Templates
from starlette.templating import _TemplateResponse
from slowapi.errors import RateLimitExceeded
@@ -135,13 +134,10 @@ def template_response(
Returns:
_TemplateResponse: 템플릿 응답 객체
"""
- from core.template import TemplateService, theme_asset
+ from core.template import TemplateService
- # 새로운 템플릿 응답 객체를 생성합니다.
- # - UserTemplates, AdminTemplates 클래스는 기본 컨텍스트 설정 시 DB를 조회하는데,
- # 처음 설치 시에는 DB가 없으므로 새로운 템플릿 응답 객체를 생성합니다.
- template = Jinja2Templates(directory=TemplateService.get_templates_dir())
- template.env.globals["theme_asset"] = theme_asset
+ # UserTemplates/AdminTemplates는 DB 조회가 필요하므로 사용하지 않는다.
+ template = TemplateService.get_templates()
return template.TemplateResponse(
name=template_html,
context=context,
diff --git a/core/formclass.py b/core/formclass.py
index 185d7c9b4..083aaee2a 100644
--- a/core/formclass.py
+++ b/core/formclass.py
@@ -2,7 +2,7 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
-
+from typing import Union
from fastapi import Form
from core.exception import AlertException
@@ -456,7 +456,9 @@ class WriteCommentForm:
wr_name: str = Form(None)
wr_password: str = Form(None)
wr_secret: str = Form(None)
- comment_id: int = Form(None)
+ # comment_id: int = Form(None)
+ comment_id: Union[int, str, None] = Form(None)
+
@dataclass
diff --git a/core/template.py b/core/template.py
index 6459e5df7..7a79e7ddc 100644
--- a/core/template.py
+++ b/core/template.py
@@ -92,21 +92,75 @@ def get_admin_theme_path() -> str:
ADMIN_TEMPLATES_DIR = get_admin_theme_path() # 관리자 템플릿 경로
class TemplateService():
- """템플릿 서비스 클래스
- - TODO: 이외의 다른 부분도 클래스화 해야한다.
+ """템플릿 서비스 클래스.
+
+ 템플릿 경로와 정적 파일, 렌더링 옵션을 관리한다.
"""
- _templates_dir: str = None # 사용자 템플릿 경로
+
+ _templates_dir: typing.Optional[str] = None # 사용자 템플릿 경로
+ _templates: typing.Optional[Jinja2Templates] = None
+ _env_options: dict = {}
@classmethod
def get_templates_dir(cls) -> str:
+ """현재 테마의 템플릿 디렉터리 경로를 반환한다."""
if cls._templates_dir is None:
cls.set_templates_dir()
return cls._templates_dir
@classmethod
- def set_templates_dir(cls) -> None:
- cls._templates_dir = get_theme_path()
+ def set_templates_dir(cls, template_dir: typing.Optional[str] = None) -> None:
+ """템플릿 디렉터리를 설정한다."""
+ cls._templates_dir = template_dir or get_theme_path()
+ cls._templates = None
+
+ @classmethod
+ def set_env_options(cls, **env_options) -> None:
+ """Jinja2 환경 설정을 갱신한다."""
+ cls._env_options.update(env_options)
+ cls._templates = None
+
+ @classmethod
+ def get_templates(cls, **env_options) -> Jinja2Templates:
+ """Jinja2Templates 객체를 반환한다.
+
+ Args:
+ **env_options: Environment 옵션
+ """
+ if env_options:
+ options = {**cls._env_options, **env_options}
+ templates = Jinja2Templates(
+ directory=cls.get_templates_dir(),
+ **options
+ )
+ templates.env.globals["theme_asset"] = theme_asset
+ return templates
+
+ if cls._templates is None:
+ cls._templates = Jinja2Templates(
+ directory=cls.get_templates_dir(),
+ **cls._env_options
+ )
+ cls._templates.env.globals["theme_asset"] = theme_asset
+
+ return cls._templates
+
+ @classmethod
+ def register_statics(cls, app: FastAPI) -> None:
+ """현재 테마의 static 디렉터리를 FastAPI에 등록한다."""
+ theme = get_current_theme()
+ directories = ["/mobile", ""]
+ for directory in directories:
+ static_directory = f"{TEMPLATES}/{theme}{directory}/static"
+
+ if not os.path.isdir(static_directory):
+ continue
+
+ url = f"/theme_static/{theme}{directory}"
+ path = StaticFiles(directory=static_directory)
+ static_device = directory.replace("/", "_")
+ app.mount(url, path, name=f"static_{theme}{static_device}")
class UserTemplates(Jinja2Templates):
@@ -285,29 +339,8 @@ def theme_asset(request: Request, asset_path: str) -> str:
def register_theme_statics(app: FastAPI) -> None:
- """
- 현재 테마의 static 경로를 가상의 경로로 등록하는 함수
- - ex) PC: /{theme}/basic/static/css -> /theme_static/basic/css
- - ex) Mobile: /{theme}/basic/mobile/static/css -> /theme_static/basic/mobile/css
-
- Args:
- app (FastAPI): FastAPI 객체
- """
- theme = get_current_theme()
- directories = ["/mobile", ""]
- for directory in directories:
- static_directory = f"{TEMPLATES}/{theme}{directory}/static"
-
- if not os.path.isdir(static_directory):
- # logger = logging.getLogger("uvicorn.error")
- # logger.warning("theme has not static directory : ",
- # static_directory)
- continue
-
- url = f"/theme_static/{theme}{directory}"
- path = StaticFiles(directory=static_directory)
- static_device = directory.replace("/", "_")
- app.mount(url, path, name=f"static_{theme}{static_device}") # tag
+ """Backward compatible wrapper for :meth:`TemplateService.register_statics`."""
+ TemplateService.register_statics(app)
def get_theme_list():
diff --git a/lib/board_lib.py b/lib/board_lib.py
index 70eddbc58..59ee7b213 100644
--- a/lib/board_lib.py
+++ b/lib/board_lib.py
@@ -6,10 +6,10 @@
import bleach
from typing import List
from fastapi import Request
-from fastapi.templating import Jinja2Templates
from sqlalchemy import and_, asc, desc, func, insert, or_, select
from sqlalchemy.sql.expression import Select
from sqlalchemy.orm import Session
+from cachetools import TTLCache
from core.database import DBConnect
from core.exception import AlertException
@@ -24,6 +24,53 @@
from service.board_file_service import BoardFileService as FileService
+# Caches for reducing DB and disk I/O
+FILE_META_CACHE = TTLCache(maxsize=1024, ttl=300)
+WRITE_CACHE = TTLCache(maxsize=1024, ttl=300)
+FILE_EXIST_CACHE = TTLCache(maxsize=1024, ttl=300)
+
+
+def _cache_key(bo_table: str, wr_id: int) -> str:
+ """Create a unified cache key."""
+ return f"{bo_table}:{wr_id}"
+
+
+def get_write_cached(bo_table: str, wr_id: int) -> WriteBaseModel | None:
+ """Return write object from cache or database."""
+ key = _cache_key(bo_table, wr_id)
+ write = WRITE_CACHE.get(key)
+ if write is None:
+ with DBConnect().sessionLocal() as db:
+ write_model = dynamic_create_write_table(bo_table)
+ write = db.get(write_model, wr_id)
+ if write:
+ WRITE_CACHE[key] = write
+ return write
+
+
+def get_board_files_cached(request: Request, bo_table: str, wr_id: int):
+ """Return board files grouped by type using cache."""
+ key = _cache_key(bo_table, wr_id)
+ result = FILE_META_CACHE.get(key)
+ if result is None:
+ with DBConnect().sessionLocal() as db:
+ service = FileService(request, db)
+ result = service.get_board_files_by_type(bo_table, wr_id)
+ FILE_META_CACHE[key] = result
+ return result
+
+
+def is_file_exist_cached(request: Request, db: Session, bo_table: str, wr_id: int) -> bool:
+ """Check file existence using cache."""
+ key = _cache_key(bo_table, wr_id)
+ exist = FILE_EXIST_CACHE.get(key)
+ if exist is None:
+ service = FileService(request, db)
+ exist = service.is_exist(bo_table, wr_id)
+ FILE_EXIST_CACHE[key] = exist
+ return exist
+
+
class BoardConfig():
"""게시판 설정 정보를 담는 클래스."""
@@ -560,7 +607,6 @@ def get_list(request: Request, db: Session, write: WriteBaseModel, board_config:
Returns:
WriteBaseModel: 게시글 목록.
"""
- file_service = FileService(request, db)
write.subject = board_config.cut_write_subject(write.wr_subject, subject_len)
write.name = cut_name(request, write.wr_name)
write.email = StringEncrypt().encrypt(write.wr_email)
@@ -570,7 +616,7 @@ def get_list(request: Request, db: Session, write: WriteBaseModel, board_config:
write.icon_secret = "secret" in write.wr_option
write.icon_hot = board_config.is_icon_hot(write.wr_hit)
write.icon_new = board_config.is_icon_new(write.wr_datetime)
- write.icon_file = file_service.is_exist(board_config.board.bo_table, write.wr_id)
+ write.icon_file = is_file_exist_cached(request, db, board_config.board.bo_table, write.wr_id)
write.icon_link = write.wr_link1 or write.wr_link2
write.icon_reply = write.wr_reply
@@ -674,8 +720,7 @@ def send_write_mail(request: Request, board: Board, write: WriteBaseModel, origi
"""
with DBConnect().sessionLocal() as db:
config = request.state.config
- templates = Jinja2Templates(
- directory=TemplateService.get_templates_dir())
+ templates = TemplateService.get_templates()
def _add_admin_email(admin_id: str):
admin = db.scalar(select(Member).filter_by(mb_id=admin_id))
@@ -738,9 +783,7 @@ def get_list_thumbnail(request: Request, board: Board, write: WriteBaseModel, th
thumb_height (int, optional): _description_. Defaults to 0.
"""
config = request.state.config
- with DBConnect().sessionLocal() as db:
- service = FileService(request, db)
- images, files = service.get_board_files_by_type(board.bo_table, write.wr_id)
+ images, files = get_board_files_cached(request, board.bo_table, write.wr_id)
source_file = None
result = {"src": "", "alt": "", "noimg":""}
diff --git a/lib/common.py b/lib/common.py
index d81c816b5..69a3d031e 100644
--- a/lib/common.py
+++ b/lib/common.py
@@ -671,12 +671,29 @@ def delete_old_records():
# 탈퇴회원 자동 삭제
if config.cf_leave_day > 0:
- # TODO: 회원삭제 처리 추가
- # query = update(Member).where(Member.mb_leave_date < datetime.now() - timedelta(days=config.cf_leave_day))
- # data = {}
- # result = db.execute(query, data)
- # print("회원 삭제 기준일 : ", datetime.now() - timedelta(days=config.cf_leave_day), f"{result}건 삭제")
- pass
+ base_datetime = today - timedelta(days=config.cf_leave_day)
+ engine_name = db.bind.dialect.name
+ if engine_name == "sqlite":
+ cutoff = base_datetime.strftime("%Y%m%d")
+ member_ids = db.scalars(
+ select(Member.mb_id)
+ .where(Member.mb_leave_date != "")
+ .where(Member.mb_leave_date <= cutoff)
+ ).all()
+ else:
+ member_ids = db.scalars(
+ select(Member.mb_id)
+ .where(Member.mb_leave_date != "")
+ .where(func.cast(Member.mb_leave_date, DateTime) <= base_datetime)
+ ).all()
+
+ if member_ids:
+ db.execute(delete(Member).where(Member.mb_id.in_(member_ids)))
+ logging.info(
+ "Deleted leave members older than %s: %s",
+ base_datetime.strftime("%Y-%m-%d"),
+ ", ".join(member_ids),
+ )
db.commit()
except Exception as e:
print(e)
diff --git a/lib/mail.py b/lib/mail.py
index 9f4dc0498..d8df8dc54 100644
--- a/lib/mail.py
+++ b/lib/mail.py
@@ -10,7 +10,6 @@
from email.utils import formataddr
from fastapi import Request
-from fastapi.templating import Jinja2Templates
from core.database import DBConnect
from core.models import Config, Member, PollEtc, QaConfig, QaContent
@@ -88,8 +87,7 @@ async def send_password_reset_mail(request: Request, member: Member) -> None:
request.state.config = config = db.query(Config).first()
try:
- templates = Jinja2Templates(
- directory=TemplateService.get_templates_dir())
+ templates = TemplateService.get_templates()
subject = f"[{config.cf_title}] 요청하신 비밀번호 찾기 메일입니다."
body = templates.TemplateResponse(
@@ -117,7 +115,7 @@ async def send_register_mail(request: Request, member: Member) -> None:
request.state.config = config = db.query(Config).first()
try:
- templates = Jinja2Templates(directory=TemplateService.get_templates_dir())
+ templates = TemplateService.get_templates()
from_email = get_admin_email(request)
from_name = get_admin_email_name(request)
context = {"request": request, "member": member}
@@ -154,8 +152,7 @@ async def send_register_admin_mail(request: Request, member: Member) -> None:
request.state.config = config = db.query(Config).first()
try:
- templates = Jinja2Templates(
- directory=TemplateService.get_templates_dir())
+ templates = TemplateService.get_templates()
from_email = get_admin_email(request)
from_name = get_admin_email_name(request)
context = {"request": request, "member": member}
@@ -185,8 +182,7 @@ async def send_poll_etc_mail(request: Request, poll_etc: PollEtc) -> None:
try:
if config.cf_email_po_super_admin and config.cf_admin_email:
- templates = Jinja2Templates(
- directory=TemplateService.get_templates_dir())
+ templates = TemplateService.get_templates()
email = get_admin_email(request)
from_name = get_admin_email_name(request)
subject = f"[{config.cf_title}] 설문조사 - 기타의견 메일"
@@ -223,8 +219,7 @@ async def send_qa_mail(request: Request, qa: QaContent) -> None:
from_name = get_admin_email_name(request)
subject = f"[{config.cf_title}] {qa_config.qa_title} 질문 알림 메일"
content = qa.qa_subject + "
" + qa.qa_content
- templates = Jinja2Templates(
- directory=TemplateService.get_templates_dir())
+ templates = TemplateService.get_templates()
if qa.qa_parent:
question = db.get(QaContent, qa.qa_parent)
diff --git a/lib/member.py b/lib/member.py
index c5e5d1038..1ca4353b0 100644
--- a/lib/member.py
+++ b/lib/member.py
@@ -1,5 +1,6 @@
"""회원 관련 기능을 제공하는 모듈입니다."""
import math
+from dataclasses import dataclass
from datetime import date, datetime, timedelta
from typing import Tuple, Union
@@ -8,51 +9,48 @@
from core.models import Board, Config, Group, Member
+@dataclass(slots=True)
class MemberDetails:
- mb_no: int = 0
- mb_id: str = None
- mb_name: str = None
- mb_nick: str = None
- mb_email: str = None
- mb_homepage: str = None
- mb_level: int = 1
- mb_tel: str = None
- mb_hp: str = None
- mb_certify: str = None
- mb_adult: int = 0
- mb_signature: str = None
- mb_point: int = 0
- mb_today_login: datetime = None
- mb_login_ip: str = None
- mb_datetime: datetime = None
- mb_ip: str = None
- mb_leave_date: str = None
- mb_intercept_date: str = None
- mb_mailling: int = 0
- mb_sms: int = 0
- mb_profile: int = 0
-
- _admin_type: str = None
-
- def __init__(
- self,
- request: Request,
- member: Member,
- board: Board = None,
- group: Group = None
- ):
- # TODO: 반복적으로 호출되는 문제 해결해야함.
- # print("__init__", member)
- super().__init__()
+ """서비스 전반에서 사용되는 경량 회원 정보"""
+
+ request: Request
+ member: Member | None
+ config: Config
+ level: int
+ admin_type: Union[str, None]
+
+ def __init__(self, request: Request, member: Member, board: Board = None,
+ group: Group = None):
+ # 캐시를 이용해 중복 초기화를 방지한다.
+ cache_key = (
+ getattr(member, "mb_no", 0),
+ getattr(board, "bo_table", None),
+ getattr(group, "gr_id", None),
+ )
+ cache = getattr(request.state, "_member_details_cache", None)
+ if cache is None:
+ cache = {}
+ setattr(request.state, "_member_details_cache", cache)
+ if cache_key in cache:
+ cached = cache[cache_key]
+ self.request = request
+ self.member = cached.member
+ self.config = request.state.config
+ self.level = cached.level
+ self.admin_type = cached.admin_type
+ return
self.request = request
+ self.member = member
self.config = request.state.config
- # member의 속성을 class 속성에 복사
- if member:
- for key, value in member.__dict__.items():
- setattr(self, key, value)
- self.level: int = self.mb_level
- self.admin_type: Union[str, None] = self.get_admin_type(group, board)
+ self.level = int(member.mb_level) if member else 1
+ self.admin_type = self.get_admin_type(group, board)
+ cache[cache_key] = self
+
+ def __getattr__(self, item):
+ if self.member:
+ return getattr(self.member, item, None)
+ return None
def get_admin_type(self, group: Group = None, board: Board = None) -> Union[str, None]:
"""게시판 관리자 여부 확인 후 관리자 타입 반환
diff --git a/main.py b/main.py
index 553c2a55b..43e97b7d9 100644
--- a/main.py
+++ b/main.py
@@ -21,7 +21,7 @@
)
from core.routers import router as template_router
from core.settings import ENV_PATH, settings
-from core.template import register_theme_statics
+from core.template import TemplateService
from lib.common import (
get_client_ip, is_intercept_ip, is_possible_ip, session_member_key
)
@@ -66,7 +66,7 @@ async def lifespan(app: FastAPI):
os.mkdir("data")
# 각 경로에 있는 파일들을 정적 파일로 등록합니다.
-register_theme_statics(app)
+TemplateService.register_statics(app)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/data", StaticFiles(directory="data"), name="data")
diff --git a/service/point_service.py b/service/point_service.py
index f7d05db9e..e4df74503 100644
--- a/service/point_service.py
+++ b/service/point_service.py
@@ -3,6 +3,7 @@
from datetime import datetime, timedelta
from typing import List
from typing_extensions import Annotated
+from lib.template_filters import number_format
from fastapi import Depends, Request
from sqlalchemy import delete, func, select, update
@@ -143,6 +144,21 @@ def get_total_point(self, mb_id: str) -> int:
return int(point_sum) if point_sum else 0
+ def validate_enough_point(self, mb_id: str, required_point: int, action: str) -> None:
+ """필요한 포인트가 있는지 검증한다."""
+ if not self.use_point or required_point >= 0:
+ return
+
+ current_point = self.get_total_point(mb_id) if mb_id else 0
+ if (current_point + required_point) >= 0:
+ return
+
+ point_str = number_format(abs(required_point))
+ message = f"{action}에 필요한 포인트({point_str})가 부족합니다."
+ if not mb_id:
+ message += " 로그인 후 다시 시도해주세요."
+ self.raise_exception(status_code=403, detail=message)
+
def insert_use_point(self, mb_id: str, point: int, po_id: int = None) -> None:
"""
사용한 포인트 내역 입력&업데이트
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_delete_old_records.py b/tests/test_delete_old_records.py
new file mode 100644
index 000000000..f00c413da
--- /dev/null
+++ b/tests/test_delete_old_records.py
@@ -0,0 +1,47 @@
+import logging
+import sys
+from pathlib import Path
+from datetime import datetime, timedelta
+
+sys.path.append(str(Path(__file__).resolve().parents[1]))
+
+from sqlalchemy import select
+
+from lib.common import delete_old_records
+from core.models import Base, Config, Member
+from core.database import DBConnect
+
+
+def setup_module(module):
+ # initialize in-memory database
+ engine = DBConnect().engine
+ Base.metadata.create_all(bind=engine)
+
+
+def get_session():
+ return DBConnect().sessionLocal()
+
+
+def test_leave_member_deletion(caplog):
+ db = get_session()
+ # insert config
+ config = Config(cf_id=1, cf_leave_day=7)
+ db.add(config)
+ old_date = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
+ new_date = (datetime.now() - timedelta(days=3)).strftime('%Y%m%d')
+ db.add_all([
+ Member(mb_id='olduser', mb_leave_date=old_date),
+ Member(mb_id='newuser', mb_leave_date=new_date),
+ ])
+ db.commit()
+ db.close()
+
+ with caplog.at_level(logging.INFO):
+ delete_old_records()
+
+ db = get_session()
+ remaining = {m.mb_id for m in db.scalars(select(Member)).all()}
+ assert 'olduser' not in remaining
+ assert 'newuser' in remaining
+ assert 'olduser' in caplog.text
+ db.close()
diff --git a/tests/test_visit_queries.py b/tests/test_visit_queries.py
new file mode 100644
index 000000000..65e12b527
--- /dev/null
+++ b/tests/test_visit_queries.py
@@ -0,0 +1,57 @@
+import pytest
+from sqlalchemy import select, func, extract
+from sqlalchemy.dialects import mysql, postgresql, sqlite
+from core.models import Visit
+
+
+def build_hour_query(dialect_name):
+ query = select()
+ if dialect_name == 'mysql':
+ query = query.add_columns(func.hour(Visit.vi_time).label('hour'))
+ elif dialect_name == 'postgresql':
+ query = query.add_columns(extract('hour', Visit.vi_time).label('hour'))
+ elif dialect_name == 'sqlite':
+ query = query.add_columns(func.strftime('%H', Visit.vi_time).label('hour'))
+ return query.add_columns(func.count().label('hour_count')).group_by('hour')
+
+
+def build_weekday_query(dialect_name):
+ query = select()
+ if dialect_name == 'mysql':
+ query = query.add_columns(func.dayofweek(Visit.vi_date).label('dow'))
+ elif dialect_name == 'postgresql':
+ query = query.add_columns(extract('dow', Visit.vi_date).label('dow'))
+ elif dialect_name == 'sqlite':
+ query = query.add_columns(func.strftime('%w', Visit.vi_date).label('dow'))
+ return query.add_columns(func.count().label('dow_count')).group_by('dow')
+
+
+def test_hour_query_mysql():
+ sql = str(build_hour_query('mysql').compile(dialect=mysql.dialect()))
+ assert 'hour(' in sql.lower()
+
+
+def test_hour_query_postgresql():
+ sql = str(build_hour_query('postgresql').compile(dialect=postgresql.dialect()))
+ assert 'extract(hour' in sql.lower()
+
+
+def test_hour_query_sqlite():
+ sql = str(build_hour_query('sqlite').compile(dialect=sqlite.dialect()))
+ assert 'strftime' in sql
+
+
+def test_weekday_query_mysql():
+ sql = str(build_weekday_query('mysql').compile(dialect=mysql.dialect()))
+ assert 'dayofweek' in sql.lower()
+
+
+def test_weekday_query_postgresql():
+ sql = str(build_weekday_query('postgresql').compile(dialect=postgresql.dialect()))
+ assert 'extract(dow' in sql.lower()
+
+
+def test_weekday_query_sqlite():
+ sql = str(build_weekday_query('sqlite').compile(dialect=sqlite.dialect()))
+ assert 'strftime' in sql
+