diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index e7de4118..b1899af0 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -572,9 +572,9 @@ enum mta_status_t mta_load_plugin(const char *path); * status code if an error occurs. You can get more details about the * error with `mta_last_error`. */ -enum mta_status_t mta_load_model(const char *plugin_name, - const char *load_from, +enum mta_status_t mta_load_model(const char *load_from, const char *options_json, + const char *plugin_name, struct mta_model_t *model); #ifdef __cplusplus diff --git a/metatomic-core/include/metatomic/plugin.hpp b/metatomic-core/include/metatomic/plugin.hpp index 1cae91bd..42300984 100644 --- a/metatomic-core/include/metatomic/plugin.hpp +++ b/metatomic-core/include/metatomic/plugin.hpp @@ -1,7 +1,51 @@ #pragma once +#include + #include +#include namespace metatomic { + /// Load the shared library at `path` and register the plugin contained + /// within. The library must export the symbols generated by the + /// `MTA_REGISTER_PLUGIN` macro. + /// + /// @param path path to the plugin shared library + inline void load_plugin(const std::string& path) { + auto status = mta_load_plugin(path.c_str()); + details::check_status(status); + } + + /// Load a model from `load_from` with the given options. + /// + /// If `plugin_name` is empty, metatomic will try to determine the correct + /// plugin to use by checking the `load_from` parameter. If we can not + /// determine the correct plugin, we then try to load the model with each + /// registered plugin until one succeeds. + /// + /// If `plugin_name` is given, then we only try to load the model with the + /// specified plugin, and return an error if the plugin can not load the + /// model. + /// + /// @param load_from where to load the model from (e.g. a file path, a + /// model name, etc.) + /// @param plugin_name optional name of the plugin to use for loading the + /// model, or empty to let metatomic search + /// @param options_json optional JSON object containing string keys and + /// string values for loading the model + /// @return the loaded model + inline mta_model_t load_model( + const std::string& load_from, + const std::string& options_json = "", + const std::string& plugin_name = "" + ) { + mta_model_t model; + const char* plugin_name_ptr = plugin_name.empty() ? nullptr : plugin_name.c_str(); + const char* options_json_ptr = options_json.empty() ? nullptr : options_json.c_str(); + + auto status = mta_load_model(load_from.c_str(), options_json_ptr, plugin_name_ptr, &model); + details::check_status(status); + return model; + } } // namespace metatomic diff --git a/metatomic-core/src/c_api/plugin.rs b/metatomic-core/src/c_api/plugin.rs index 76b9989d..42ee45de 100644 --- a/metatomic-core/src/c_api/plugin.rs +++ b/metatomic-core/src/c_api/plugin.rs @@ -112,9 +112,9 @@ pub unsafe extern "C" fn mta_load_plugin(path: *const c_char) -> mta_status_t { /// error with `mta_last_error`. #[no_mangle] pub unsafe extern "C" fn mta_load_model( - plugin_name: *const c_char, load_from: *const c_char, options_json: *const c_char, + plugin_name: *const c_char, model: *mut mta_model_t, ) -> mta_status_t { let unwind_wrapper = std::panic::AssertUnwindSafe(model); @@ -157,7 +157,7 @@ pub unsafe extern "C" fn mta_load_model( } } - let loaded = crate::plugin::load_model(plugin_name, CStr::from_ptr(load_from), options_json)?; + let loaded = crate::plugin::load_model(CStr::from_ptr(load_from), options_json, plugin_name)?; let _ = &unwind_wrapper; *unwind_wrapper.0 = loaded.into_raw(); diff --git a/metatomic-core/src/plugin.rs b/metatomic-core/src/plugin.rs index 4a9da5ab..cd087965 100644 --- a/metatomic-core/src/plugin.rs +++ b/metatomic-core/src/plugin.rs @@ -141,16 +141,26 @@ pub fn load_plugin(path: &str) -> Result<(), Error> { /// Load a model from `load_from`, using the given options. pub fn load_model( - plugin_name: Option<&str>, load_from: &CStr, options_json: &CStr, + plugin_name: Option<&str>, ) -> Result { let plugins = PLUGINS.lock().expect("plugin registry mutex was poisoned"); if let Some(plugin_name) = plugin_name { for plugin in plugins.iter() { if plugin.name() == plugin_name { - return plugin.load_model(load_from, options_json); + return plugin.load_model(load_from, options_json).map_err(|e| { + if let Error::CallbackError(mta_status_t::MTA_MODEL_NOT_SUPPORTED_ERROR) = e { + Error::InvalidParameter(format!( + "failed to load model from '{}': plugin '{}' could not load the model", + load_from.to_string_lossy(), + plugin_name + )) + } else { + e + } + }); } } diff --git a/metatomic-core/tests/c-model.cpp b/metatomic-core/tests/c-model.cpp index ec1da92f..500e2a95 100644 --- a/metatomic-core/tests/c-model.cpp +++ b/metatomic-core/tests/c-model.cpp @@ -162,7 +162,7 @@ TEST_CASE("simple C model can be registered and loaded through the C API") { mta_register_plugin(PLUGIN); auto model = mta_model_t{}; - auto status = mta_load_model("test-c-plugin", "test-c-model", nullptr, &model); + auto status = mta_load_model("test-c-model", "{}", "test-c-plugin", &model); REQUIRE(status == MTA_SUCCESS); CHECK(model.data != nullptr); diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp index 41bbe5b2..db9f21d4 100644 --- a/metatomic-core/tests/misc.cpp +++ b/metatomic-core/tests/misc.cpp @@ -83,13 +83,11 @@ TEST_CASE("unit conversion factor") { factor = metatomic::unit_conversion_factor("kJ/mol", "eV"); CHECK(factor == Approx(0.010364269656262174).epsilon(1e-15)); - // dimension mismatch -> error - try{ - factor = metatomic::unit_conversion_factor("m", "kg"); - } - catch(metatomic::Error& e){ - CHECK(std::string(e.what()) == "invalid parameter: dimension mismatch in unit conversion: 'm' has dimension [L] but 'kg' has dimension [M]"); - } + REQUIRE_THROWS_WITH( + metatomic::unit_conversion_factor("m", "kg"), + "invalid parameter: dimension mismatch in unit conversion: " + "'m' has dimension [L] but 'kg' has dimension [M]" + ); } } diff --git a/metatomic-core/tests/plugins.cpp b/metatomic-core/tests/plugins.cpp index 79b53947..2652d534 100644 --- a/metatomic-core/tests/plugins.cpp +++ b/metatomic-core/tests/plugins.cpp @@ -1,45 +1,71 @@ #include #include "metatomic.h" +#include "metatomic.hpp" TEST_CASE("Load plugins") { - auto status = mta_load_plugin(PLUGIN_DIR "/test-c-plugin.so"); - CHECK(status == MTA_SUCCESS); + SECTION("C API") { + auto status = mta_load_plugin(PLUGIN_DIR "/test-c-plugin.so"); + CHECK(status == MTA_SUCCESS); - // try to load the model with an explicit plugin name - struct mta_model_t model; - status = mta_load_model("test-c-plugin", "some_model", "{}", &model); - CHECK(status == MTA_MODEL_NOT_SUPPORTED_ERROR); + const char* error_message; + const char* error_origin; - // load the plugin without specifying the plugin name - status = mta_load_model(nullptr, "some_model", "{}", &model); - CHECK(status == MTA_INVALID_PARAMETER_ERROR); + struct mta_model_t model; + status = mta_load_model("some_model", "{}", "test-c-plugin", &model); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); - const char* error_message; - const char* error_origin; + status = mta_last_error(&error_message, &error_origin, nullptr); + REQUIRE(status == MTA_SUCCESS); - status = mta_last_error(&error_message, &error_origin, nullptr); - REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(error_origin) == "metatomic-core"); + CHECK(std::string(error_message) == ( + "invalid parameter: failed to load model from 'some_model': plugin 'test-c-plugin' could not load the model" + )); - CHECK(std::string(error_origin) == "metatomic-core"); - const char* expected_message = ( - "invalid parameter: failed to load model from 'some_model': tried the " - "following plugins, but none could load the model: test-c-plugin" - ); - CHECK(std::string(error_message) == expected_message); + status = mta_load_model("some_model", "{}", nullptr, &model); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + status = mta_last_error(&error_message, &error_origin, nullptr); + REQUIRE(status == MTA_SUCCESS); - status = mta_load_plugin(PLUGIN_DIR "/bad-abi-plugin.so"); - CHECK(status == MTA_INVALID_PARAMETER_ERROR); + CHECK(std::string(error_origin) == "metatomic-core"); + CHECK(std::string(error_message) == ( + "invalid parameter: failed to load model from 'some_model': tried the " + "following plugins, but none could load the model: test-c-plugin" + )); - status = mta_last_error(&error_message, &error_origin, nullptr); - REQUIRE(status == MTA_SUCCESS); - CHECK(std::string(error_origin) == "metatomic-core"); - expected_message = ( - "invalid parameter: can not register plugin 'bad-abi-plugin': " - "plugin ABI version is 2, but metatomic expects 1" - ); - CHECK(std::string(error_message) == expected_message); + status = mta_load_plugin(PLUGIN_DIR "/bad-abi-plugin.so"); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + status = mta_last_error(&error_message, &error_origin, nullptr); + REQUIRE(status == MTA_SUCCESS); + + CHECK(std::string(error_origin) == "metatomic-core"); + CHECK(std::string(error_message) == ( + "invalid parameter: can not register plugin 'bad-abi-plugin': " + "plugin ABI version is 2, but metatomic expects 1" + )); + } + + SECTION("C++ API") { + REQUIRE_THROWS_WITH( + metatomic::load_model("some_model", "{}", "test-c-plugin"), + "invalid parameter: failed to load model from 'some_model': plugin 'test-c-plugin' could not load the model" + ); + + REQUIRE_THROWS_WITH( + metatomic::load_model("some_model"), + "invalid parameter: failed to load model from 'some_model': tried the " + "following plugins, but none could load the model: test-c-plugin" + ); + + REQUIRE_THROWS_WITH( + metatomic::load_plugin(PLUGIN_DIR "/bad-abi-plugin.so"), + "invalid parameter: can not register plugin 'bad-abi-plugin': " + "plugin ABI version is 2, but metatomic expects 1" + ); + } }