Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions dpdata/formats/deepmd/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def to_system_data(
data["atom_types"] = g["type.raw"][:]
ntypes = np.max(data["atom_types"]) + 1
natoms = data["atom_types"].size
data["atom_numbs"] = []
for ii in range(ntypes):
data["atom_numbs"].append(np.count_nonzero(data["atom_types"] == ii))
data["atom_names"] = []
# if find type_map.raw, use it
if "type_map.raw" in g.keys():
my_type_map = list(np.char.decode(g["type_map.raw"][:]))
Expand All @@ -60,9 +56,11 @@ def to_system_data(
my_type_map = []
for ii in range(ntypes):
my_type_map.append("Type_%d" % ii) # noqa: UP031
assert len(my_type_map) >= len(data["atom_numbs"])
for ii in range(len(data["atom_numbs"])):
data["atom_names"].append(my_type_map[ii])
assert len(my_type_map) >= ntypes
data["atom_names"] = my_type_map
data["atom_numbs"] = []
for ii, _ in enumerate(data["atom_names"]):
data["atom_numbs"].append(np.count_nonzero(data["atom_types"] == ii))

data["orig"] = np.zeros([3])
if "nopbc" in g.keys():
Expand All @@ -81,7 +79,6 @@ def to_system_data(
"atom_names",
"atom_types",
"orig",
"real_atom_types",
"real_atom_names",
"nopbc",
):
Expand Down Expand Up @@ -184,7 +181,6 @@ def dump(
"atom_names",
"atom_types",
"orig",
"real_atom_types",
"real_atom_names",
"nopbc",
):
Expand Down
103 changes: 98 additions & 5 deletions dpdata/formats/deepmd/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,34 @@ def _strip_virtual_atoms(atom_types_row, coords, extra_data, dtypes):
return atom_types, coords, stripped


def to_system_data(folder, type_map=None, labels=True):
data = comp_to_system_data(folder, type_map, labels)
# data is empty
def _to_system_data(data, type_map=None, labels=True):
"""Split one mixed-type data dict into regular System data dicts.

Mixed DeePMD data stores all atoms as one placeholder atom type and keeps
the original atom type of every frame in ``real_atom_types``. This helper
groups frames with the same ``real_atom_types`` row, restores the original
``atom_types`` and ``atom_numbs``, and strips virtual atoms introduced by
``atom_numb_pad``.

Parameters
----------
data : dict
Mixed-type data loaded by a backend reader. The dict must contain
``real_atom_types`` and the usual System/LabeledSystem frame data.
type_map : list[str], optional
Type map used to remap stored atom types while loading. Virtual atoms
marked by ``-1`` are preserved during remapping.
labels : bool, default=True
Whether the data should be interpreted with
:class:`dpdata.LabeledSystem` data types. Set to ``False`` for
unlabeled System data.

Returns
-------
list[dict]
Regular System/LabeledSystem data dicts, one for each unique real atom
type layout found in the mixed input.
"""
old_type_map = data["atom_names"].copy()
if type_map is not None:
assert isinstance(type_map, list)
Expand Down Expand Up @@ -220,7 +245,73 @@ def to_system_data(folder, type_map=None, labels=True):
return data_list


def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
def to_system_data(folder, type_map=None, labels=True, load_func=None):
"""Load mixed-type DeePMD data and split it into regular systems.

By default this function reads the ``deepmd/npy/mixed`` directory layout
through :mod:`dpdata.formats.deepmd.comp`. Other storage backends can pass
``load_func`` to reuse the same mixed-type reconstruction logic. The loader
must return the same data dict shape as ``deepmd/npy`` and include
``real_atom_types``.

Parameters
----------
folder
Backend-specific location to load. For the default npy backend this is
a directory; HDF5 callers pass an HDF5 group.
type_map : list[str], optional
Type map used to remap atom types while loading.
labels : bool, default=True
Whether labeled data such as energies and forces should be loaded.
load_func : callable, optional
Backend reader with signature ``load_func(folder, type_map, labels)``.

Returns
-------
list[dict]
Regular System/LabeledSystem data dicts split out of the mixed input.
"""
if load_func is None:
load_func = comp_to_system_data
data = load_func(folder, type_map=type_map, labels=labels)
return _to_system_data(data, type_map=type_map, labels=labels)


def dump(
folder,
data,
set_size=2000,
comp_prec=np.float32,
remove_sets=True,
dump_func=None,
):
"""Dump one System data dict in mixed-type DeePMD layout.

If ``data`` has not already been converted to mixed type, it is copied and
converted first. The converted data stores the original element names in
``real_atom_names`` and the per-frame real atom type table in
``real_atom_types``; the backend writer receives the converted data with
``real_atom_names`` exposed as ``atom_names`` so it is written to
``type_map.raw``.

Parameters
----------
folder
Backend-specific destination. For the default npy backend this is a
directory; HDF5 callers pass an HDF5 group.
data : dict
System or LabeledSystem data dict to dump.
set_size : int, default=2000
Maximum number of frames per ``set.*`` chunk.
comp_prec : numpy.dtype, default=numpy.float32
Floating point precision used by the backend writer.
remove_sets : bool, default=True
Whether existing npy ``set.*`` directories should be removed before
dumping. Backends that do not use directories may ignore this argument.
dump_func : callable, optional
Backend writer with signature
``dump_func(folder, data, set_size, comp_prec, remove_sets)``.
"""
# if not converted to mixed
if "real_atom_types" not in data:
from dpdata import LabeledSystem, System
Expand All @@ -236,7 +327,9 @@ def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):

data = data.copy()
data["atom_names"] = data.pop("real_atom_names")
comp_dump(folder, data, set_size, comp_prec, remove_sets)
if dump_func is None:
dump_func = comp_dump
dump_func(folder, data, set_size, comp_prec, remove_sets)


def mix_system(*system, type_map, atom_numb_pad=None, **kwargs):
Expand Down
Loading
Loading