From e84be0130b46ac63e6e633931d04804f89fea665 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 20 Jun 2026 01:40:29 +0800 Subject: [PATCH] fix(serialization): remove monty dependency --- AGENTS.md | 6 +- docs/conf.py | 1 - docs/environment.yml | 1 - dpdata/serialization.py | 226 ++++++++++++++++++++++++++++++++ dpdata/system.py | 10 +- pyproject.toml | 2 - tests/test_json.py | 18 +++ tests/test_to_pymatgen_entry.py | 3 +- 8 files changed, 253 insertions(+), 14 deletions(-) create mode 100644 dpdata/serialization.py diff --git a/AGENTS.md b/AGENTS.md index 19d633b99..feb6116cd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,7 +9,7 @@ Always reference these instructions first and fallback to search or bash command - **Bootstrap and install the repository:** - `cd /home/runner/work/dpdata/dpdata` (or wherever the repo is cloned) - - `uv pip install -e .` -- installs dpdata in development mode with core dependencies (numpy, scipy, h5py, monty, wcmatch) + - `uv pip install -e .` -- installs dpdata in development mode with core dependencies (numpy, scipy, h5py, wcmatch) - Test installation: `dpdata --version` -- should show version like "dpdata v0.1.dev2+..." - **Run tests:** @@ -93,7 +93,7 @@ The following are outputs from frequently run commands. Reference them instead o ### Key dependencies -- Core: numpy>=1.14.3, scipy, h5py, monty, wcmatch +- Core: numpy>=1.14.3, scipy, h5py, wcmatch - Optional: ase (ASE integration), parmed (AMBER), pymatgen (Materials Project), rdkit (molecular analysis) - Testing: unittest (built-in), coverage - Linting: ruff @@ -132,7 +132,7 @@ The following are outputs from frequently run commands. Reference them instead o - **Installation timeouts:** Network timeouts during `uv pip install` are common. If this occurs, try: - - Individual package installation: `uv pip install numpy scipy h5py monty wcmatch` + - Individual package installation: `uv pip install numpy scipy h5py wcmatch` - Use `--timeout` option: `uv pip install --timeout 300 -e .` - Verify existing installation works: `dpdata --version` should work even if reinstall fails diff --git a/docs/conf.py b/docs/conf.py index 263cb5507..ca38237ed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -202,7 +202,6 @@ def setup(app): "numpy": ("https://docs.scipy.org/doc/numpy/", None), "python": ("https://docs.python.org/", None), "ase": ("https://wiki.fysik.dtu.dk/ase/", None), - "monty": ("https://guide.materialsvirtuallab.org/monty/", None), "h5py": ("https://docs.h5py.org/en/stable/", None), } diff --git a/docs/environment.yml b/docs/environment.yml index 89d2e5cad..14d5a7ee4 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -6,7 +6,6 @@ dependencies: - xeus-python - numpy - scipy - - monty - wcmatch - pip: - .. diff --git a/dpdata/serialization.py b/dpdata/serialization.py new file mode 100644 index 000000000..44b985671 --- /dev/null +++ b/dpdata/serialization.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import bz2 +import datetime +import gzip +import importlib +import json +from enum import Enum +from pathlib import Path +from typing import Any, BinaryIO, TextIO, cast +from uuid import UUID + +import numpy as np + + +def _detect_format(filename: str | Path, fmt: str | None = None) -> str: + if fmt is not None: + return fmt + basename = Path(filename).name.lower() + if ".mpk" in basename: + return "mpk" + if ".yaml" in basename or ".yml" in basename: + return "yaml" + return "json" + + +def _open_text(filename: str | Path, mode: str) -> TextIO: + path = str(filename) + lower_path = path.lower() + if lower_path.endswith((".gz", ".z")): + return cast("TextIO", gzip.open(path, mode, encoding="utf-8")) + if lower_path.endswith(".bz2"): + return cast("TextIO", bz2.open(path, mode, encoding="utf-8")) + return cast("TextIO", open(path, mode, encoding="utf-8")) + + +def _open_binary(filename: str | Path, mode: str) -> BinaryIO: + path = str(filename) + lower_path = path.lower() + if lower_path.endswith((".gz", ".z")): + return cast("BinaryIO", gzip.open(path, mode)) + if lower_path.endswith(".bz2"): + return cast("BinaryIO", bz2.open(path, mode)) + return cast("BinaryIO", open(path, mode)) + + +def _yaml_dump(obj: Any, fp, *args: Any, **kwargs: Any) -> None: + try: + yaml = importlib.import_module("yaml") + + if "sort_keys" not in kwargs: + kwargs["sort_keys"] = False + getattr(yaml, "safe_dump")(obj, fp, *args, **kwargs) + except ModuleNotFoundError: + try: + ruamel_yaml = importlib.import_module("ruamel.yaml") + except ModuleNotFoundError as e: + raise RuntimeError( + "Dumping YAML files requires PyYAML or ruamel.yaml." + ) from e + yaml = getattr(ruamel_yaml, "YAML")() + if "indent" in kwargs: + indent = kwargs.pop("indent") + yaml.indent(mapping=indent, sequence=indent, offset=2) + yaml.dump(obj, fp, *args, **kwargs) + + +def _yaml_load(fp, *args: Any, **kwargs: Any) -> Any: + try: + yaml = importlib.import_module("yaml") + + return getattr(yaml, "safe_load")(fp, *args, **kwargs) + except ModuleNotFoundError: + try: + ruamel_yaml = importlib.import_module("ruamel.yaml") + except ModuleNotFoundError as e: + raise RuntimeError( + "Loading YAML files requires PyYAML or ruamel.yaml." + ) from e + yaml = getattr(ruamel_yaml, "YAML")(typ="safe") + return yaml.load(fp, *args, **kwargs) + + +def _encode_ndarray(obj: np.ndarray) -> dict[str, Any]: + if str(obj.dtype).startswith("complex"): + data = [obj.real.tolist(), obj.imag.tolist()] + else: + data = obj.tolist() + return { + "@module": "numpy", + "@class": "array", + "dtype": str(obj.dtype), + "data": data, + } + + +def to_serializable(obj: Any) -> Any: + """Convert common dpdata objects to monty-compatible plain data.""" + if isinstance(obj, dict): + return {to_serializable(k): to_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [to_serializable(v) for v in obj] + if isinstance(obj, np.ndarray): + return _encode_ndarray(obj) + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, datetime.datetime): + return { + "@module": "datetime", + "@class": "datetime", + "string": str(obj), + } + if isinstance(obj, UUID): + return {"@module": "uuid", "@class": "UUID", "string": str(obj)} + if isinstance(obj, Path): + return {"@module": "pathlib", "@class": "Path", "string": str(obj)} + if isinstance(obj, Enum): + return { + "@module": obj.__class__.__module__, + "@class": obj.__class__.__name__, + "value": to_serializable(obj.value), + } + if hasattr(obj, "as_dict"): + data = obj.as_dict() + if "@module" not in data: + data["@module"] = obj.__class__.__module__ + if "@class" not in data: + data["@class"] = obj.__class__.__name__ + return to_serializable(data) + return obj + + +def _decode_ndarray(data: dict[str, Any]) -> np.ndarray: + dtype = data["dtype"] + if dtype.startswith("complex"): + real, imag = data["data"] + return np.array(real, dtype=dtype) + np.array(imag, dtype=dtype) * 1j + return np.array(data["data"], dtype=dtype) + + +def process_decoded(obj: Any) -> Any: + """Decode monty-style dictionaries used by existing dpdata JSON files.""" + if isinstance(obj, dict): + if "@module" in obj and "@class" in obj: + module_name = obj["@module"] + class_name = obj["@class"] + if module_name == "numpy" and class_name == "array": + return _decode_ndarray(obj) + if module_name == "datetime" and class_name == "datetime": + value = obj["string"].split("+")[0] + try: + return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") + except ValueError: + return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S") + if module_name == "uuid" and class_name == "UUID": + return UUID(obj["string"]) + if module_name == "pathlib" and class_name == "Path": + return Path(obj["string"]) + try: + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + except (AttributeError, ImportError, ModuleNotFoundError): + cls = None + if cls is not None: + data = {k: v for k, v in obj.items() if not k.startswith("@")} + if hasattr(cls, "from_dict"): + return cls.from_dict(data) + if isinstance(cls, type) and issubclass(cls, Enum): + return cls(process_decoded(data["value"])) + return {process_decoded(k): process_decoded(v) for k, v in obj.items()} + if isinstance(obj, list): + return [process_decoded(v) for v in obj] + return obj + + +def dumpfn( + obj: Any, + filename: str | Path, + *args: Any, + fmt: str | None = None, + **kwargs: Any, +) -> None: + """Dump an object to JSON, YAML, or msgpack without requiring monty.""" + fmt = _detect_format(filename, fmt) + obj = to_serializable(obj) + if fmt == "json": + with _open_text(filename, "wt") as fp: + json.dump(obj, fp, *args, **kwargs) + return + if fmt == "yaml": + with _open_text(filename, "wt") as fp: + _yaml_dump(obj, fp, *args, **kwargs) + return + if fmt == "mpk": + try: + import msgpack + except ModuleNotFoundError as e: + raise RuntimeError("Dumping msgpack files requires msgpack.") from e + with _open_binary(filename, "wb") as fp: + msgpack.dump(obj, fp, *args, **kwargs) + return + raise TypeError(f"Invalid format: {fmt}") + + +def loadfn( + filename: str | Path, + *args: Any, + fmt: str | None = None, + **kwargs: Any, +) -> Any: + """Load JSON, YAML, or msgpack data and decode monty-style objects.""" + fmt = _detect_format(filename, fmt) + if fmt == "json": + with _open_text(filename, "rt") as fp: + return process_decoded(json.load(fp, *args, **kwargs)) + if fmt == "yaml": + with _open_text(filename, "rt") as fp: + return process_decoded(_yaml_load(fp, *args, **kwargs)) + if fmt == "mpk": + try: + import msgpack + except ModuleNotFoundError as e: + raise RuntimeError("Loading msgpack files requires msgpack.") from e + with _open_binary(filename, "rb") as fp: + return process_decoded(msgpack.load(fp, *args, **kwargs)) + raise TypeError(f"Invalid format: {fmt}") diff --git a/dpdata/system.py b/dpdata/system.py index 4150abc89..1ffc4f244 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -331,7 +331,7 @@ def __add__(self, others): def dump(self, filename: str, indent: int = 4): """Dump .json or .yaml file.""" - from monty.serialization import dumpfn + from dpdata.serialization import dumpfn dumpfn(self.as_dict(), filename, indent=indent) @@ -378,19 +378,17 @@ def map_atom_types( @staticmethod def load(filename: str): """Rebuild System obj. from .json or .yaml file.""" - from monty.serialization import loadfn + from dpdata.serialization import loadfn return loadfn(filename) @classmethod def from_dict(cls, data: dict): """Construct a System instance from a data dict.""" - from monty.serialization import MontyDecoder # type: ignore + from dpdata.serialization import process_decoded decoded = { - k: MontyDecoder().process_decoded(v) - for k, v in data.items() - if not k.startswith("@") + k: process_decoded(v) for k, v in data.items() if not k.startswith("@") } return cls(**decoded) diff --git a/pyproject.toml b/pyproject.toml index b94fce2e9..b88e305a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ classifiers = [ ] dependencies = [ 'numpy>=1.14.3', - 'monty', 'scipy', 'h5py', 'wcmatch', @@ -121,7 +120,6 @@ banned-module-level-imports = [ "deepmd", "h5py", "wcmatch", - "monty", "scipy", ] diff --git a/tests/test_json.py b/tests/test_json.py index 0b6f1b9dd..c98e21db2 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +import tempfile import unittest from comp_sys import CompLabeledSys, IsPBC @@ -26,5 +28,21 @@ def setUp(self): self.v_places = 4 +class TestJsonDumpLoad(unittest.TestCase, CompLabeledSys, IsPBC): + def setUp(self): + self.system_1 = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") + self.tmpdir = tempfile.TemporaryDirectory() + self.filename = os.path.join(self.tmpdir.name, "h2o.md.json") + self.system_1.dump(self.filename) + self.system_2 = dpdata.LabeledSystem.load(self.filename) + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 4 + + def tearDown(self): + self.tmpdir.cleanup() + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_to_pymatgen_entry.py b/tests/test_to_pymatgen_entry.py index dfdeb4680..512b86b8b 100644 --- a/tests/test_to_pymatgen_entry.py +++ b/tests/test_to_pymatgen_entry.py @@ -4,7 +4,8 @@ import unittest from context import dpdata -from monty.serialization import loadfn # noqa: TID253 + +from dpdata.serialization import loadfn try: from pymatgen.entries.computed_entries import ComputedStructureEntry # noqa: F401