Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Always reference these instructions first and fallback to search or bash command
- **Bootstrap and install the repository:**

- `cd /home/runner/work/dpdata/dpdata` (or wherever the repo is cloned)
- `uv pip install -e .` -- installs dpdata in development mode with core dependencies (numpy, scipy, h5py, monty, wcmatch)
- `uv pip install -e .` -- installs dpdata in development mode with core dependencies (numpy, scipy, h5py, wcmatch)
Comment thread
njzjz marked this conversation as resolved.
- Test installation: `dpdata --version` -- should show version like "dpdata v0.1.dev2+..."

- **Run tests:**
Expand Down Expand Up @@ -93,7 +93,7 @@ The following are outputs from frequently run commands. Reference them instead o

### Key dependencies

- Core: numpy>=1.14.3, scipy, h5py, monty, wcmatch
- Core: numpy>=1.14.3, scipy, h5py, wcmatch
Comment thread
njzjz marked this conversation as resolved.
- Optional: ase (ASE integration), parmed (AMBER), pymatgen (Materials Project), rdkit (molecular analysis)
- Testing: unittest (built-in), coverage
- Linting: ruff
Expand Down Expand Up @@ -132,7 +132,7 @@ The following are outputs from frequently run commands. Reference them instead o

- **Installation timeouts:** Network timeouts during `uv pip install` are common. If this occurs, try:

- Individual package installation: `uv pip install numpy scipy h5py monty wcmatch`
- Individual package installation: `uv pip install numpy scipy h5py wcmatch`
Comment thread
njzjz marked this conversation as resolved.
- Use `--timeout` option: `uv pip install --timeout 300 -e .`
- Verify existing installation works: `dpdata --version` should work even if reinstall fails

Expand Down
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def setup(app):
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"python": ("https://docs.python.org/", None),
"ase": ("https://wiki.fysik.dtu.dk/ase/", None),
"monty": ("https://guide.materialsvirtuallab.org/monty/", None),
"h5py": ("https://docs.h5py.org/en/stable/", None),
}

Expand Down
1 change: 0 additions & 1 deletion docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ dependencies:
- xeus-python
- numpy
- scipy
- monty
- wcmatch
- pip:
- ..
226 changes: 226 additions & 0 deletions dpdata/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from __future__ import annotations

import bz2
import datetime
import gzip
import importlib
import json
from enum import Enum
from pathlib import Path
from typing import Any, BinaryIO, TextIO, cast
from uuid import UUID

import numpy as np


def _detect_format(filename: str | Path, fmt: str | None = None) -> str:
if fmt is not None:
return fmt
basename = Path(filename).name.lower()
if ".mpk" in basename:
return "mpk"
if ".yaml" in basename or ".yml" in basename:
return "yaml"
return "json"


def _open_text(filename: str | Path, mode: str) -> TextIO:
path = str(filename)
lower_path = path.lower()
if lower_path.endswith((".gz", ".z")):
return cast("TextIO", gzip.open(path, mode, encoding="utf-8"))
if lower_path.endswith(".bz2"):
return cast("TextIO", bz2.open(path, mode, encoding="utf-8"))
return cast("TextIO", open(path, mode, encoding="utf-8"))


def _open_binary(filename: str | Path, mode: str) -> BinaryIO:
path = str(filename)
lower_path = path.lower()
if lower_path.endswith((".gz", ".z")):
return cast("BinaryIO", gzip.open(path, mode))
if lower_path.endswith(".bz2"):
return cast("BinaryIO", bz2.open(path, mode))
return cast("BinaryIO", open(path, mode))


def _yaml_dump(obj: Any, fp, *args: Any, **kwargs: Any) -> None:
try:
yaml = importlib.import_module("yaml")

if "sort_keys" not in kwargs:
kwargs["sort_keys"] = False
getattr(yaml, "safe_dump")(obj, fp, *args, **kwargs)
except ModuleNotFoundError:
try:
ruamel_yaml = importlib.import_module("ruamel.yaml")
except ModuleNotFoundError as e:
raise RuntimeError(
"Dumping YAML files requires PyYAML or ruamel.yaml."
) from e
yaml = getattr(ruamel_yaml, "YAML")()
if "indent" in kwargs:
indent = kwargs.pop("indent")
yaml.indent(mapping=indent, sequence=indent, offset=2)
yaml.dump(obj, fp, *args, **kwargs)


def _yaml_load(fp, *args: Any, **kwargs: Any) -> Any:
try:
yaml = importlib.import_module("yaml")

return getattr(yaml, "safe_load")(fp, *args, **kwargs)
except ModuleNotFoundError:
try:
ruamel_yaml = importlib.import_module("ruamel.yaml")
except ModuleNotFoundError as e:
raise RuntimeError(
"Loading YAML files requires PyYAML or ruamel.yaml."
) from e
yaml = getattr(ruamel_yaml, "YAML")(typ="safe")
return yaml.load(fp, *args, **kwargs)


def _encode_ndarray(obj: np.ndarray) -> dict[str, Any]:
if str(obj.dtype).startswith("complex"):
data = [obj.real.tolist(), obj.imag.tolist()]
else:
data = obj.tolist()
return {
"@module": "numpy",
"@class": "array",
"dtype": str(obj.dtype),
"data": data,
}


def to_serializable(obj: Any) -> Any:
"""Convert common dpdata objects to monty-compatible plain data."""
if isinstance(obj, dict):
return {to_serializable(k): to_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [to_serializable(v) for v in obj]
if isinstance(obj, np.ndarray):
return _encode_ndarray(obj)
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, datetime.datetime):
return {
"@module": "datetime",
"@class": "datetime",
"string": str(obj),
}
if isinstance(obj, UUID):
return {"@module": "uuid", "@class": "UUID", "string": str(obj)}
if isinstance(obj, Path):
return {"@module": "pathlib", "@class": "Path", "string": str(obj)}
if isinstance(obj, Enum):
return {
"@module": obj.__class__.__module__,
"@class": obj.__class__.__name__,
"value": to_serializable(obj.value),
}
if hasattr(obj, "as_dict"):
data = obj.as_dict()
if "@module" not in data:
data["@module"] = obj.__class__.__module__
if "@class" not in data:
data["@class"] = obj.__class__.__name__
return to_serializable(data)
return obj


def _decode_ndarray(data: dict[str, Any]) -> np.ndarray:
dtype = data["dtype"]
if dtype.startswith("complex"):
real, imag = data["data"]
return np.array(real, dtype=dtype) + np.array(imag, dtype=dtype) * 1j
return np.array(data["data"], dtype=dtype)


def process_decoded(obj: Any) -> Any:
"""Decode monty-style dictionaries used by existing dpdata JSON files."""
if isinstance(obj, dict):
if "@module" in obj and "@class" in obj:
module_name = obj["@module"]
class_name = obj["@class"]
if module_name == "numpy" and class_name == "array":
return _decode_ndarray(obj)
if module_name == "datetime" and class_name == "datetime":
value = obj["string"].split("+")[0]
try:
return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
except ValueError:
return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
Comment thread
njzjz marked this conversation as resolved.
if module_name == "uuid" and class_name == "UUID":
return UUID(obj["string"])
if module_name == "pathlib" and class_name == "Path":
return Path(obj["string"])
try:
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
except (AttributeError, ImportError, ModuleNotFoundError):
cls = None
if cls is not None:
data = {k: v for k, v in obj.items() if not k.startswith("@")}
if hasattr(cls, "from_dict"):
return cls.from_dict(data)
if isinstance(cls, type) and issubclass(cls, Enum):
return cls(process_decoded(data["value"]))
return {process_decoded(k): process_decoded(v) for k, v in obj.items()}
if isinstance(obj, list):
return [process_decoded(v) for v in obj]
return obj


def dumpfn(
obj: Any,
filename: str | Path,
*args: Any,
fmt: str | None = None,
**kwargs: Any,
) -> None:
"""Dump an object to JSON, YAML, or msgpack without requiring monty."""
fmt = _detect_format(filename, fmt)
obj = to_serializable(obj)
if fmt == "json":
with _open_text(filename, "wt") as fp:
json.dump(obj, fp, *args, **kwargs)
return
if fmt == "yaml":
with _open_text(filename, "wt") as fp:
_yaml_dump(obj, fp, *args, **kwargs)
return
if fmt == "mpk":
try:
import msgpack
except ModuleNotFoundError as e:
raise RuntimeError("Dumping msgpack files requires msgpack.") from e
with _open_binary(filename, "wb") as fp:
msgpack.dump(obj, fp, *args, **kwargs)
return
raise TypeError(f"Invalid format: {fmt}")


def loadfn(
filename: str | Path,
*args: Any,
fmt: str | None = None,
**kwargs: Any,
) -> Any:
"""Load JSON, YAML, or msgpack data and decode monty-style objects."""
fmt = _detect_format(filename, fmt)
if fmt == "json":
with _open_text(filename, "rt") as fp:
return process_decoded(json.load(fp, *args, **kwargs))
if fmt == "yaml":
with _open_text(filename, "rt") as fp:
return process_decoded(_yaml_load(fp, *args, **kwargs))
if fmt == "mpk":
try:
import msgpack
except ModuleNotFoundError as e:
raise RuntimeError("Loading msgpack files requires msgpack.") from e
with _open_binary(filename, "rb") as fp:
return process_decoded(msgpack.load(fp, *args, **kwargs))
raise TypeError(f"Invalid format: {fmt}")
10 changes: 4 additions & 6 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __add__(self, others):

def dump(self, filename: str, indent: int = 4):
"""Dump .json or .yaml file."""
from monty.serialization import dumpfn
from dpdata.serialization import dumpfn

dumpfn(self.as_dict(), filename, indent=indent)

Expand Down Expand Up @@ -378,19 +378,17 @@ def map_atom_types(
@staticmethod
def load(filename: str):
"""Rebuild System obj. from .json or .yaml file."""
from monty.serialization import loadfn
from dpdata.serialization import loadfn

return loadfn(filename)

@classmethod
def from_dict(cls, data: dict):
"""Construct a System instance from a data dict."""
from monty.serialization import MontyDecoder # type: ignore
from dpdata.serialization import process_decoded

decoded = {
k: MontyDecoder().process_decoded(v)
for k, v in data.items()
if not k.startswith("@")
k: process_decoded(v) for k, v in data.items() if not k.startswith("@")
}
return cls(**decoded)

Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ classifiers = [
]
dependencies = [
'numpy>=1.14.3',
'monty',
'scipy',
'h5py',
'wcmatch',
Expand Down Expand Up @@ -121,7 +120,6 @@ banned-module-level-imports = [
"deepmd",
"h5py",
"wcmatch",
"monty",
"scipy",
]

Expand Down
18 changes: 18 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os
import tempfile
import unittest

from comp_sys import CompLabeledSys, IsPBC
Expand All @@ -26,5 +28,21 @@ def setUp(self):
self.v_places = 4


class TestJsonDumpLoad(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
self.tmpdir = tempfile.TemporaryDirectory()
self.filename = os.path.join(self.tmpdir.name, "h2o.md.json")
self.system_1.dump(self.filename)
self.system_2 = dpdata.LabeledSystem.load(self.filename)
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4

def tearDown(self):
self.tmpdir.cleanup()


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion tests/test_to_pymatgen_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import unittest

from context import dpdata
from monty.serialization import loadfn # noqa: TID253

from dpdata.serialization import loadfn

try:
from pymatgen.entries.computed_entries import ComputedStructureEntry # noqa: F401
Expand Down
Loading