Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metatomic-core/include/metatomic.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ enum mta_status_t mta_load_plugin(const char *path);
* status code if an error occurs. You can get more details about the
* error with `mta_last_error`.
*/
enum mta_status_t mta_load_model(const char *plugin_name,
const char *load_from,
enum mta_status_t mta_load_model(const char *load_from,
const char *options_json,
const char *plugin_name,
struct mta_model_t *model);

#ifdef __cplusplus
Expand Down
44 changes: 44 additions & 0 deletions metatomic-core/include/metatomic/plugin.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,51 @@
#pragma once

#include <string>

#include <metatomic.h>
#include <metatomic/errors.hpp>

namespace metatomic {
/// Load the shared library at `path` and register the plugin contained
/// within. The library must export the symbols generated by the
/// `MTA_REGISTER_PLUGIN` macro.
///
/// @param path path to the plugin shared library
inline void load_plugin(const std::string& path) {
auto status = mta_load_plugin(path.c_str());
details::check_status(status);
}

/// Load a model from `load_from` with the given options.
///
/// If `plugin_name` is empty, metatomic will try to determine the correct
/// plugin to use by checking the `load_from` parameter. If we can not
/// determine the correct plugin, we then try to load the model with each
/// registered plugin until one succeeds.
///
/// If `plugin_name` is given, then we only try to load the model with the
/// specified plugin, and return an error if the plugin can not load the
/// model.
///
/// @param load_from where to load the model from (e.g. a file path, a
/// model name, etc.)
/// @param plugin_name optional name of the plugin to use for loading the
/// model, or empty to let metatomic search
/// @param options_json optional JSON object containing string keys and
/// string values for loading the model
/// @return the loaded model
inline mta_model_t load_model(
const std::string& load_from,
const std::string& options_json = "",
const std::string& plugin_name = ""
) {
Comment thread
RMeli marked this conversation as resolved.
mta_model_t model;
const char* plugin_name_ptr = plugin_name.empty() ? nullptr : plugin_name.c_str();
const char* options_json_ptr = options_json.empty() ? nullptr : options_json.c_str();

auto status = mta_load_model(load_from.c_str(), options_json_ptr, plugin_name_ptr, &model);
details::check_status(status);

return model;
}
} // namespace metatomic
4 changes: 2 additions & 2 deletions metatomic-core/src/c_api/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ pub unsafe extern "C" fn mta_load_plugin(path: *const c_char) -> mta_status_t {
/// error with `mta_last_error`.
#[no_mangle]
pub unsafe extern "C" fn mta_load_model(
plugin_name: *const c_char,
load_from: *const c_char,
options_json: *const c_char,
plugin_name: *const c_char,
model: *mut mta_model_t,
) -> mta_status_t {
let unwind_wrapper = std::panic::AssertUnwindSafe(model);
Expand Down Expand Up @@ -157,7 +157,7 @@ pub unsafe extern "C" fn mta_load_model(
}
}

let loaded = crate::plugin::load_model(plugin_name, CStr::from_ptr(load_from), options_json)?;
let loaded = crate::plugin::load_model(CStr::from_ptr(load_from), options_json, plugin_name)?;

let _ = &unwind_wrapper;
*unwind_wrapper.0 = loaded.into_raw();
Expand Down
14 changes: 12 additions & 2 deletions metatomic-core/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,26 @@ pub fn load_plugin(path: &str) -> Result<(), Error> {

/// Load a model from `load_from`, using the given options.
pub fn load_model(
plugin_name: Option<&str>,
load_from: &CStr,
options_json: &CStr,
plugin_name: Option<&str>,
) -> Result<Model, Error> {
let plugins = PLUGINS.lock().expect("plugin registry mutex was poisoned");

if let Some(plugin_name) = plugin_name {
for plugin in plugins.iter() {
if plugin.name() == plugin_name {
return plugin.load_model(load_from, options_json);
return plugin.load_model(load_from, options_json).map_err(|e| {
if let Error::CallbackError(mta_status_t::MTA_MODEL_NOT_SUPPORTED_ERROR) = e {
Error::InvalidParameter(format!(
"failed to load model from '{}': plugin '{}' could not load the model",
load_from.to_string_lossy(),
plugin_name
))
} else {
e
}
});
}
}

Expand Down
2 changes: 1 addition & 1 deletion metatomic-core/tests/c-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ TEST_CASE("simple C model can be registered and loaded through the C API") {
mta_register_plugin(PLUGIN);

auto model = mta_model_t{};
auto status = mta_load_model("test-c-plugin", "test-c-model", nullptr, &model);
auto status = mta_load_model("test-c-model", "{}", "test-c-plugin", &model);
REQUIRE(status == MTA_SUCCESS);

CHECK(model.data != nullptr);
Expand Down
12 changes: 5 additions & 7 deletions metatomic-core/tests/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,11 @@ TEST_CASE("unit conversion factor") {
factor = metatomic::unit_conversion_factor("kJ/mol", "eV");
CHECK(factor == Approx(0.010364269656262174).epsilon(1e-15));

// dimension mismatch -> error
try{
factor = metatomic::unit_conversion_factor("m", "kg");
}
catch(metatomic::Error& e){
CHECK(std::string(e.what()) == "invalid parameter: dimension mismatch in unit conversion: 'm' has dimension [L] but 'kg' has dimension [M]");
}
REQUIRE_THROWS_WITH(
metatomic::unit_conversion_factor("m", "kg"),
"invalid parameter: dimension mismatch in unit conversion: "
"'m' has dimension [L] but 'kg' has dimension [M]"
);
}
}

Expand Down
84 changes: 55 additions & 29 deletions metatomic-core/tests/plugins.cpp
Original file line number Diff line number Diff line change
@@ -1,45 +1,71 @@
#include <catch.hpp>

#include "metatomic.h"
#include "metatomic.hpp"


TEST_CASE("Load plugins") {
auto status = mta_load_plugin(PLUGIN_DIR "/test-c-plugin.so");
CHECK(status == MTA_SUCCESS);
SECTION("C API") {
auto status = mta_load_plugin(PLUGIN_DIR "/test-c-plugin.so");
CHECK(status == MTA_SUCCESS);

// try to load the model with an explicit plugin name
struct mta_model_t model;
status = mta_load_model("test-c-plugin", "some_model", "{}", &model);
CHECK(status == MTA_MODEL_NOT_SUPPORTED_ERROR);
const char* error_message;
const char* error_origin;

// load the plugin without specifying the plugin name
status = mta_load_model(nullptr, "some_model", "{}", &model);
CHECK(status == MTA_INVALID_PARAMETER_ERROR);
struct mta_model_t model;
status = mta_load_model("some_model", "{}", "test-c-plugin", &model);
CHECK(status == MTA_INVALID_PARAMETER_ERROR);

const char* error_message;
const char* error_origin;
status = mta_last_error(&error_message, &error_origin, nullptr);
REQUIRE(status == MTA_SUCCESS);

status = mta_last_error(&error_message, &error_origin, nullptr);
REQUIRE(status == MTA_SUCCESS);
CHECK(std::string(error_origin) == "metatomic-core");
CHECK(std::string(error_message) == (
"invalid parameter: failed to load model from 'some_model': plugin 'test-c-plugin' could not load the model"
));

CHECK(std::string(error_origin) == "metatomic-core");
const char* expected_message = (
"invalid parameter: failed to load model from 'some_model': tried the "
"following plugins, but none could load the model: test-c-plugin"
);
CHECK(std::string(error_message) == expected_message);
status = mta_load_model("some_model", "{}", nullptr, &model);
CHECK(status == MTA_INVALID_PARAMETER_ERROR);

status = mta_last_error(&error_message, &error_origin, nullptr);
REQUIRE(status == MTA_SUCCESS);

status = mta_load_plugin(PLUGIN_DIR "/bad-abi-plugin.so");
CHECK(status == MTA_INVALID_PARAMETER_ERROR);
CHECK(std::string(error_origin) == "metatomic-core");
CHECK(std::string(error_message) == (
"invalid parameter: failed to load model from 'some_model': tried the "
"following plugins, but none could load the model: test-c-plugin"
));

status = mta_last_error(&error_message, &error_origin, nullptr);
REQUIRE(status == MTA_SUCCESS);

CHECK(std::string(error_origin) == "metatomic-core");
expected_message = (
"invalid parameter: can not register plugin 'bad-abi-plugin': "
"plugin ABI version is 2, but metatomic expects 1"
);
CHECK(std::string(error_message) == expected_message);
status = mta_load_plugin(PLUGIN_DIR "/bad-abi-plugin.so");
CHECK(status == MTA_INVALID_PARAMETER_ERROR);

status = mta_last_error(&error_message, &error_origin, nullptr);
REQUIRE(status == MTA_SUCCESS);

CHECK(std::string(error_origin) == "metatomic-core");
CHECK(std::string(error_message) == (
"invalid parameter: can not register plugin 'bad-abi-plugin': "
"plugin ABI version is 2, but metatomic expects 1"
));
}

SECTION("C++ API") {
REQUIRE_THROWS_WITH(
metatomic::load_model("some_model", "{}", "test-c-plugin"),
"invalid parameter: failed to load model from 'some_model': plugin 'test-c-plugin' could not load the model"
);

REQUIRE_THROWS_WITH(
metatomic::load_model("some_model"),
"invalid parameter: failed to load model from 'some_model': tried the "
"following plugins, but none could load the model: test-c-plugin"
);

REQUIRE_THROWS_WITH(
metatomic::load_plugin(PLUGIN_DIR "/bad-abi-plugin.so"),
"invalid parameter: can not register plugin 'bad-abi-plugin': "
"plugin ABI version is 2, but metatomic expects 1"
);
}
}
Loading