diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3dde5382..fa7e8245 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,18 @@ jobs: path: ~/.cache/uv key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml') }} + - name: Cache image classification model + id: cache-img-model + uses: actions/cache@v4 + with: + path: models/image_classification/model.onnx + key: img-class-model + + - name: Download ONNX model + if: steps.cache-img-model.outputs.cache-hit != 'true' + run: | + curl -L -o models/image_classification/model.onnx https://huggingface.co/dima806/yoga_pose_image_classification/resolve/main/onnx/model.onnx + - name: Project setup uses: ./.github/actions/project-setup diff --git a/Cargo.lock b/Cargo.lock index 5fe17a05..e19ad472 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,24 @@ dependencies = [ "memchr", ] +[[package]] +name = "aligned" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4508988c62edf04abd8d92897fca0c2995d907ce1dfeaf369dac3716a40685" +dependencies = [ + "as-slice", +] + +[[package]] +name = "aligned-vec" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b" +dependencies = [ + "equator", +] + [[package]] name = "android_system_properties" version = "0.1.5" @@ -105,6 +123,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" + [[package]] name = "arc-swap" version = "1.9.1" @@ -114,6 +138,32 @@ dependencies = [ "rustversion", ] +[[package]] +name = "arg_enum_proc_macro" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "arrayvec" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f02882884d3e1bc524fb12c79f107f6ad0e1cfd498c536ffb494301740995dfe" + +[[package]] +name = "as-slice" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -146,6 +196,49 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "av-scenechange" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f321d77c20e19b92c39e7471cf986812cbb46659d2af674adc4331ef3f18394" +dependencies = [ + "aligned", + "anyhow", + "arg_enum_proc_macro", + "arrayvec", + "log", + "num-rational", + "num-traits", + "pastey", + "rayon", + "thiserror 2.0.18", + "v_frame", + "y4m", +] + +[[package]] +name = "av1-grain" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cfddb07216410377231960af4fcab838eaa12e013417781b78bd95ee22077f8" +dependencies = [ + "anyhow", + "arrayvec", + "log", + "nom 8.0.0", + "num-rational", + "v_frame", +] + +[[package]] +name = "avif-serialize" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7178fe5f7d460b13895ebb9dcb28a3a6216d2df2574a0806cb51b555d297f38" +dependencies = [ + "arrayvec", +] + [[package]] name = "aws-lc-rs" version = "1.16.2" @@ -187,6 +280,7 @@ dependencies = [ "matchit", "memchr", "mime", + "multer", "percent-encoding", "pin-project-lite", "serde_core", @@ -260,12 +354,27 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bit_field" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e4b40c7323adcfc0a41c4b88143ed58346ff65a288fc144329c5c45e05d70c6" + [[package]] name = "bitflags" version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "bitstream-io" +version = "4.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eff00be299a18769011411c9def0d827e8f2d7bf0c3dbf53633147a8867fd1f" +dependencies = [ + "no_std_io2", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -285,6 +394,12 @@ dependencies = [ "serde", ] +[[package]] +name = "built" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0e531d93d39c34eef561e929e8a7f86d77a5af08aac4f6d6e39976c51858e9" + [[package]] name = "bumpalo" version = "3.20.2" @@ -303,11 +418,20 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + [[package]] name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +dependencies = [ + "serde", +] [[package]] name = "castaway" @@ -471,6 +595,12 @@ dependencies = [ "regex-lite", ] +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "colorchoice" version = "1.0.5" @@ -599,6 +729,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -819,6 +955,7 @@ dependencies = [ "anyhow", "axum", "axum-server", + "bytes", "clap", "clap_derive", "codspeed-divan-compat", @@ -827,6 +964,8 @@ dependencies = [ "dotenv", "figment", "flate2", + "image", + "image-ndarray", "mlua", "ndarray", "ndarray-stats", @@ -918,6 +1057,26 @@ dependencies = [ "log", ] +[[package]] +name = "equator" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc" +dependencies = [ + "equator-macro", +] + +[[package]] +name = "equator-macro" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -943,12 +1102,42 @@ dependencies = [ "cc", ] +[[package]] +name = "exr" +version = "1.74.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4300e043a56aa2cb633c01af81ca8f699a321879a7854d3896a0ba89056363be" +dependencies = [ + "bit_field", + "half", + "lebe", + "miniz_oxide", + "rayon-core", + "smallvec 1.15.1", + "zune-inflate", +] + [[package]] name = "fastrand" version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +[[package]] +name = "fax" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf1079563223d5d59d83c85886a56e586cfd5c1a26292e971a0fa266531ac5a" + +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + [[package]] name = "figment" version = "0.10.19" @@ -1186,6 +1375,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "gif" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "glob" version = "0.3.3" @@ -1211,6 +1410,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -1492,6 +1702,59 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "image" +version = "0.25.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "exr", + "gif", + "image-webp", + "moxcms", + "num-traits", + "png", + "qoi", + "ravif", + "rayon", + "rgb", + "serde", + "tiff", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-ndarray" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "366ec4e7613badea5930852b9fc8781fdbb010a59845a3a5c1cf61d0ccc3f133" +dependencies = [ + "image", + "ndarray", + "num-traits", + "thiserror 2.0.18", +] + +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error", +] + +[[package]] +name = "imgref" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89194689a993ab15268672e99e7b0e19da2da3268ac682e8f02d29d4d1434cd7" + [[package]] name = "indexmap" version = "2.14.0" @@ -1532,6 +1795,17 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "interpolate_name" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "ipnet" version = "2.12.0" @@ -1656,12 +1930,28 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" +[[package]] +name = "lebe" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8" + [[package]] name = "libc" version = "0.2.185" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" +[[package]] +name = "libfuzzer-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9fd2f41a1cba099f79a0b6b6c35656cf7c03351a7bae8ff0f28f25270f929d2" +dependencies = [ + "arbitrary", + "cc", +] + [[package]] name = "libredox" version = "0.1.16" @@ -1701,6 +1991,15 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "loop9" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062" +dependencies = [ + "imgref", +] + [[package]] name = "lru-slab" version = "0.1.2" @@ -1767,6 +2066,16 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "maybe-rayon" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" +dependencies = [ + "cfg-if", + "rayon", +] + [[package]] name = "memchr" version = "2.8.0" @@ -1867,6 +2176,33 @@ dependencies = [ "syn", ] +[[package]] +name = "moxcms" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b" +dependencies = [ + "num-traits", + "pxfm", +] + +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "multimap" version = "0.10.1" @@ -1920,6 +2256,12 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "nix" version = "0.31.2" @@ -1932,6 +2274,15 @@ dependencies = [ "libc", ] +[[package]] +name = "no_std_io2" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418abd1b6d34fbf6cae440dc874771b0525a604428704c76e48b29a5e67b8003" +dependencies = [ + "memchr", +] + [[package]] name = "noisy_float" version = "0.2.1" @@ -1951,6 +2302,21 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + +[[package]] +name = "noop_proc_macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1960,6 +2326,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -1969,6 +2345,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -1978,6 +2365,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2216,6 +2614,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pastey" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" + [[package]] name = "pear" version = "0.2.9" @@ -2303,6 +2707,19 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "png" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" +dependencies = [ + "bitflags", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -2377,6 +2794,25 @@ dependencies = [ "yansi", ] +[[package]] +name = "profiling" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d595e54a326bc53c1c197b32d295e14b169e3cfeaa8dc82b529f947fba6bcf5" +dependencies = [ + "profiling-procmacros", +] + +[[package]] +name = "profiling-procmacros" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4488a4a36b9a4ba6b9334a32a39971f77c1436ec82c38707bce707699cc3bbcb" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "prost" version = "0.14.3" @@ -2450,6 +2886,12 @@ dependencies = [ "pulldown-cmark", ] +[[package]] +name = "pxfm" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" + [[package]] name = "pyo3" version = "0.27.2" @@ -2511,6 +2953,21 @@ dependencies = [ "syn", ] +[[package]] +name = "qoi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + [[package]] name = "quinn" version = "0.11.9" @@ -2647,6 +3104,56 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rav1e" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b6dd56e85d9483277cde964fd1bdb0428de4fec5ebba7540995639a21cb32b" +dependencies = [ + "aligned-vec", + "arbitrary", + "arg_enum_proc_macro", + "arrayvec", + "av-scenechange", + "av1-grain", + "bitstream-io", + "built", + "cfg-if", + "interpolate_name", + "itertools 0.14.0", + "libc", + "libfuzzer-sys", + "log", + "maybe-rayon", + "new_debug_unreachable", + "noop_proc_macro", + "num-derive", + "num-traits", + "paste", + "profiling", + "rand 0.9.3", + "rand_chacha 0.9.0", + "simd_helpers", + "thiserror 2.0.18", + "v_frame", + "wasm-bindgen", +] + +[[package]] +name = "ravif" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e52310197d971b0f5be7fe6b57530dcd27beb35c1b013f29d66c1ad73fbbcc45" +dependencies = [ + "avif-serialize", + "imgref", + "loop9", + "quick-error", + "rav1e", + "rayon", + "rgb", +] + [[package]] name = "rawpointer" version = "0.2.1" @@ -2852,6 +3359,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "rgb" +version = "0.8.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b34b781b31e5d73e9fbc8689c70551fd1ade9a19e3e28cfec8580a79290cc4" + [[package]] name = "ring" version = "0.17.14" @@ -3237,6 +3750,15 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" +[[package]] +name = "simd_helpers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6" +dependencies = [ + "quote", +] + [[package]] name = "slab" version = "0.4.12" @@ -3276,6 +3798,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -3283,7 +3811,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" dependencies = [ "base64 0.13.1", - "nom", + "nom 7.1.3", "serde", "unicode-segmentation", ] @@ -3498,6 +4026,20 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "tiff" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b63feaf3343d35b6ca4d50483f94843803b0f51634937cc2ec519fc32232bc52" +dependencies = [ + "fax", + "flate2", + "half", + "quick-error", + "weezl", + "zune-jpeg", +] + [[package]] name = "tinystr" version = "0.8.3" @@ -4071,6 +4613,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "v_frame" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "666b7727c8875d6ab5db9533418d7c764233ac9c0cff1d469aec8fa127597be2" +dependencies = [ + "aligned-vec", + "num-traits", + "wasm-bindgen", +] + [[package]] name = "valuable" version = "0.1.1" @@ -4272,6 +4825,12 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + [[package]] name = "which" version = "8.0.2" @@ -4726,6 +5285,12 @@ dependencies = [ "rustix", ] +[[package]] +name = "y4m" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5a4b21e1a62b67a2970e6831bc091d7b87e119e7f9791aef9702e3bef04448" + [[package]] name = "yansi" version = "1.0.1" @@ -4840,3 +5405,27 @@ name = "zmij" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-inflate" +version = "0.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "zune-jpeg" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296" +dependencies = [ + "zune-core", +] diff --git a/Cargo.toml b/Cargo.toml index 7f7b8059..ecea85cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,15 @@ ndarray = "0.16.1" serde = { version = "1.0.228", features = ["serde_derive"] } tracing = "0.1.41" thiserror = "2.0.17" +image-ndarray = "0.1.5" + +[workspace.dependencies.bytes] +version = "1.11.1" +features = ["serde"] + +[workspace.dependencies.image] +version = "0.25.10" +features = ["serde"] [workspace.dependencies.parking_lot] version = "0.12.5" diff --git a/encoderfile/Cargo.toml b/encoderfile/Cargo.toml index 59f87ea7..8aee66dd 100644 --- a/encoderfile/Cargo.toml +++ b/encoderfile/Cargo.toml @@ -141,9 +141,18 @@ workspace = true [dependencies.serde_json] workspace = true +[dependencies.bytes] +workspace = true + [dependencies.ndarray] workspace = true +[dependencies.image] +workspace = true + +[dependencies.image-ndarray] +workspace = true + [dependencies.figment] version = "0.10.19" features = ["env", "serde_yaml", "yaml"] @@ -213,6 +222,7 @@ optional = true [dependencies.axum] version = "0.8.6" +features = ["multipart"] optional = true [dependencies.axum-server] diff --git a/encoderfile/benches/postprocessing.rs b/encoderfile/benches/postprocessing.rs index dbb66816..43660516 100644 --- a/encoderfile/benches/postprocessing.rs +++ b/encoderfile/benches/postprocessing.rs @@ -16,7 +16,7 @@ fn main() { #[divan::bench(args = [(8, 16, 384), (16, 128, 768), (64, 512, 1024)])] fn embedding_postprocess(b: Bencher, dim: (usize, usize, usize)) { - let tokenizer = &embedding_state().tokenizer; + let tokenizer = &embedding_state().model_input_state.tokenizer; let (batch, tokens, hidden) = dim; // Random embeddings @@ -35,7 +35,7 @@ fn embedding_postprocess(b: Bencher, dim: (usize, usize, usize)) { #[divan::bench(args = [8, 16, 64])] fn sequence_classification_postprocess(b: Bencher, batch: usize) { let state = sequence_classification_state(); - let config = &state.model_config; + let config = &state.task_state; let n_labels = config.id2label.clone().unwrap().len(); let mut rng = rand::rng(); @@ -51,10 +51,10 @@ fn sequence_classification_postprocess(b: Bencher, batch: usize) { #[divan::bench(args = [(8, 16), (16, 128), (64, 512)])] fn token_classification_postprocess(b: Bencher, dim: (usize, usize)) { let state = token_classification_state(); - let config = &state.model_config; + let config = &state.task_state; let n_labels = config.id2label.clone().unwrap().len(); - let tokenizer = &embedding_state().tokenizer; + let tokenizer = &embedding_state().model_input_state.tokenizer; let (batch, tokens) = dim; // Random embeddings diff --git a/encoderfile/build.rs b/encoderfile/build.rs index 7ecdf96c..016ed527 100644 --- a/encoderfile/build.rs +++ b/encoderfile/build.rs @@ -12,14 +12,18 @@ fn main() -> Result<(), Box> { "proto/sequence_classification.proto", "proto/token_classification.proto", "proto/sentence_embedding.proto", + "proto/image_classification.proto", "proto/manifest.proto", + "proto/image_types.proto", ], &[ "proto/embedding", "proto/sequence_classification", "proto/token_classification", "proto/sentence_embedding", + "proto/image_classification", "proto/manifest", + "proto/image_types", ], )?; diff --git a/encoderfile/proto/image_classification.proto b/encoderfile/proto/image_classification.proto new file mode 100644 index 00000000..065a3fa3 --- /dev/null +++ b/encoderfile/proto/image_classification.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package encoderfile.image_classification; + +import "proto/metadata.proto"; +import "proto/image_types.proto"; + +service ImageClassificationInference { + rpc Predict(ImageClassificationRequest) returns (ImageClassificationResponse); + rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse); +} + +message ImageClassificationRequest { + repeated encoderfile.image_types.ImageInput inputs = 1; + map metadata = 11; +} + +message ImageClassificationResponse { + repeated encoderfile.image_types.ImageLabels results = 1; + map metadata = 11; +} diff --git a/encoderfile/proto/image_segmentation.proto b/encoderfile/proto/image_segmentation.proto new file mode 100644 index 00000000..bea6d600 --- /dev/null +++ b/encoderfile/proto/image_segmentation.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package encoderfile.image_segmentation; + +import "proto/token.proto"; +import "proto/metadata.proto"; + +service ImageSegmentation { + rpc Predict(ImageSegmentationRequest) returns (ImageSegmentationResponse); + rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse); +} + +message ImageSegmentationRequest { + repeated encoderfile.image_types.ImageInput images = 1; + map metadata = 11; +} + +message ImageSegment { + encoderfile.image_types.ImageLabelScore label = 1; + bytes mask = 2; +} + +message ImageSegments { + repeated ImageSegment segments = 1; +} + +message ImageSegmentationResponse { + repeated ImageSegments segments_batch = 1; + map metadata = 11; +} diff --git a/encoderfile/proto/image_types.proto b/encoderfile/proto/image_types.proto new file mode 100644 index 00000000..ffcd4c3c --- /dev/null +++ b/encoderfile/proto/image_types.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package encoderfile.image_types; + +message ImageInput { + bytes image = 1; +} + +message ImageLabelScore { + string label = 1; + optional float score = 2; +} + +message ImageLabels { + repeated ImageLabelScore labels = 1; +} diff --git a/encoderfile/proto/manifest.proto b/encoderfile/proto/manifest.proto index 16878fd4..f0ca23ca 100644 --- a/encoderfile/proto/manifest.proto +++ b/encoderfile/proto/manifest.proto @@ -55,6 +55,9 @@ message EncoderfileManifest { // Tokenizer data (vocab, merges, config). // Serialized runtime::tokenizer::TokenizerService optional Artifact tokenizer = 130; + + // Image preprocessor configuration. + optional Artifact image_preprocessor = 140; } message LuaLibs { diff --git a/encoderfile/proto/metadata.proto b/encoderfile/proto/metadata.proto index de67f847..25e14f33 100644 --- a/encoderfile/proto/metadata.proto +++ b/encoderfile/proto/metadata.proto @@ -6,6 +6,7 @@ message GetModelMetadataRequest {} message GetModelMetadataResponse { string model_id = 1; + // TODO decide if we want a model family/area at a higher level ModelType model_type = 2; map id2label = 3; } @@ -16,4 +17,8 @@ enum ModelType { SEQUENCE_CLASSIFICATION = 2; TOKEN_CLASSIFICATION = 3; SENTENCE_EMBEDDING = 4; + + IMAGE_CLASSIFICATION = 21; + // IMAGE_SEGMENTATION = 22; + // OBJECT_DETECTION = 23; } diff --git a/encoderfile/proto/object_detection.proto b/encoderfile/proto/object_detection.proto new file mode 100644 index 00000000..a55b4d85 --- /dev/null +++ b/encoderfile/proto/object_detection.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package encoderfile.object_detection; + +import "proto/token.proto"; +import "proto/metadata.proto"; + +service ObjectDetection { + rpc Predict(ObjectDetectionRequest) returns (ObjectDetectionResponse); + rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse); +} + +message ObjectDetectionRequest { + repeated encoderfile.image_types.ImageInput inputs = 1; + map metadata = 11; +} + +message ImageBoundingBox { + encoderfile.image_types.ImageLabelScore label = 1; + xmin int32 = 2; + xmax int32 = 3; + ymin int32 = 4; + ymax int32 = 5; +} + +message ImageBoundingBoxes { + repeated ImageBoundingBox box = 1; +} + +message ObjectDetectionResponse { + repeated ImageBoundingBoxes boxes = 1; + map metadata = 11; +} diff --git a/encoderfile/proto/sentence_embedding.proto b/encoderfile/proto/sentence_embedding.proto index f7afc989..b14a72a7 100644 --- a/encoderfile/proto/sentence_embedding.proto +++ b/encoderfile/proto/sentence_embedding.proto @@ -2,7 +2,6 @@ syntax = "proto3"; package encoderfile.sentence_embedding; -import "proto/token.proto"; import "proto/metadata.proto"; service SentenceEmbeddingInference { diff --git a/encoderfile/src/builder/builder.rs b/encoderfile/src/builder/builder.rs index 6bdecdda..f5691e31 100644 --- a/encoderfile/src/builder/builder.rs +++ b/encoderfile/src/builder/builder.rs @@ -16,6 +16,7 @@ use crate::{ assets::{AssetKind, AssetPlan, AssetSource, PlannedAsset}, codec::EncoderfileCodec, }, + runtime::Input, }; use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; @@ -26,6 +27,10 @@ pub struct EncoderfileBuilder { pub config: BuildConfig, } +pub fn validate(_input: &Input) -> Result<()> { + Ok(()) +} + impl EncoderfileBuilder { pub fn new(config: BuildConfig) -> EncoderfileBuilder { Self { config } @@ -89,10 +94,22 @@ impl EncoderfileBuilder { } // validate tokenizer - let tokenizer_asset = - crate::builder::tokenizer::validate_tokenizer(&self.config.encoderfile)?; - planned_assets.push(tokenizer_asset); - terminal::success("Tokenizer validated"); + match self.config.encoderfile.model_type.input_type() { + Input::Text => { + let tokenizer_asset = + crate::builder::tokenizer::validate_tokenizer(&self.config.encoderfile)?; + planned_assets.push(tokenizer_asset); + terminal::success("Tokenizer validated"); + } + Input::Image => { + let image_preprocessor_asset = + crate::builder::image_preprocessor::validate_image_preprocessor( + &self.config.encoderfile, + )?; + planned_assets.push(image_preprocessor_asset); + terminal::success("Image preprocessor validated"); + } + } // initialize final binary terminal::info("Writing encoderfile..."); diff --git a/encoderfile/src/builder/config.rs b/encoderfile/src/builder/config.rs index d57c3b68..751fc968 100644 --- a/encoderfile/src/builder/config.rs +++ b/encoderfile/src/builder/config.rs @@ -1,4 +1,4 @@ -use crate::common::{Config as EmbeddedConfig, LuaLibs, ModelConfig, ModelType}; +use crate::common::{Config as EmbeddedConfig, LuaLibs, ModelConfig, model_type::ModelType}; use anyhow::{Context, Result, bail}; use schemars::JsonSchema; use std::string::String; @@ -24,7 +24,7 @@ pub struct BuildConfig { pub encoderfile: EncoderfileConfig, } -pub const DEFAULT_VERSION: &str = "0.1.0"; +pub const DEFAULT_VERSION: &str = "0.2.0"; pub const CONFIG_FILE_NOT_FOUND_MSG: &str = "Encoderfile config not found"; @@ -268,6 +268,7 @@ pub enum ModelPath { model_weights_path: PathBuf, tokenizer_path: PathBuf, tokenizer_config_path: Option, + preprocessor_config_path: Option, }, } @@ -328,6 +329,7 @@ macro_rules! asset_path { impl ModelPath { asset_path!(model_config_path, "config.json", "model config"); asset_path!(tokenizer_path, "tokenizer.json", "tokenizer"); + asset_path!(@Optional preprocessor_config_path, "preprocessor_config.json", "image preprocessing"); asset_path!(model_weights_path, "model.onnx", "model weights"); asset_path!(@Optional tokenizer_config_path, "tokenizer_config.json", "tokenizer config"); } @@ -414,6 +416,23 @@ mod tests { tokenizer_path: base.join("tokenizer.json"), model_weights_path: base.join("model.onnx"), tokenizer_config_path: Some(base.join("tokenizer_config.json")), + preprocessor_config_path: None, + }; + + assert!(mp.model_config_path().is_ok()); + + cleanup(&base); + } + + #[test] + fn test_modelpath_explicit_paths_image() { + let base = create_temp_model_dir(); + let mp = ModelPath::Paths { + model_config_path: base.join("config.json"), + tokenizer_path: PathBuf::new(), // not needed for image model + model_weights_path: base.join("model.onnx"), + tokenizer_config_path: None, + preprocessor_config_path: Some(base.join("preprocessor_config.json")), }; assert!(mp.model_config_path().is_ok()); diff --git a/encoderfile/src/builder/image_preprocessor.rs b/encoderfile/src/builder/image_preprocessor.rs new file mode 100644 index 00000000..80c09a11 --- /dev/null +++ b/encoderfile/src/builder/image_preprocessor.rs @@ -0,0 +1,87 @@ +// IMPORTANT NOTE: +// +// Image preprocessor configuration is NOT a stable, self-contained artifact (see tokenizer situation). +// +// It seems to vary widely between models and is often not even explicitly defined anywhere, so for now we just +// require users to provide the config for the model they are using, and we will deal with new +// models on a case-by-case basis as they come in. + +use crate::format::assets::{AssetKind, AssetSource, PlannedAsset}; +use anyhow::Result; + +use super::config::EncoderfileConfig; +use crate::runtime::ImagePreprocessing; + +pub fn validate_image_preprocessor<'a>( + encoderfile_config: &'a EncoderfileConfig, +) -> Result> { + let config = match encoderfile_config.path.preprocessor_config_path()? { + // if preprocessor_config.json is provided, use that + Some(preprocessor_config_path) => { + // open preprocessor_config + let contents = std::fs::read_to_string(preprocessor_config_path)?; + let preprocessor_config: ImagePreprocessing = serde_json::from_str(contents.as_str())?; + preprocessor_config + } + // some values may be present in config.json + None => { + // from_model_config(&image_preprocessing.config)?; + anyhow::bail!("FATAL: No preprocessor_config.json provided"); + } + }; + let model_config = encoderfile_config.model_config()?; + let serialized = serde_json::to_vec(&config)?; + + // num_channels must be same as len for mean and std + if let Some(num_channels) = model_config.num_channels { + if let Some(image_mean) = config.image_mean.as_ref() + && image_mean.len() != num_channels as usize + { + anyhow::bail!("num_channels must match length of image_mean"); + } + if let Some(image_std) = config.image_std.as_ref() + && image_std.len() != num_channels as usize + { + anyhow::bail!("num_channels must match length of image_std"); + } + } + + PlannedAsset::from_asset_source( + AssetSource::InMemory(std::borrow::Cow::Owned(serialized)), + AssetKind::ImagePreprocessor, + ) +} + +#[cfg(test)] +mod tests { + use crate::builder::config::ModelPath; + use crate::common::model_type::ModelType; + + use super::*; + + #[test] + fn test_validate_preprocessor_config() { + let config = EncoderfileConfig { + name: "my-model".into(), + version: "0.0.1".into(), + path: ModelPath::Directory("../models/image_classification".into()), + model_type: ModelType::Embedding, + output_path: None, + cache_dir: None, + transform: None, + lua_libs: None, + tokenizer: None, + validate_transform: false, + base_binary_path: None, + target: None, + }; + + let preprocessor_config = validate_image_preprocessor(&config) + .expect("Failed to validate image preprocessor config"); + + println!( + "Validated image preprocessor config: {:?}", + preprocessor_config + ); + } +} diff --git a/encoderfile/src/builder/mod.rs b/encoderfile/src/builder/mod.rs index 96afa60b..463a2e67 100644 --- a/encoderfile/src/builder/mod.rs +++ b/encoderfile/src/builder/mod.rs @@ -5,6 +5,7 @@ pub mod builder; pub mod cache; pub mod cli; pub mod config; +pub mod image_preprocessor; pub mod model; pub mod templates; /// Terminal logging utilities. diff --git a/encoderfile/src/builder/model.rs b/encoderfile/src/builder/model.rs index 85c14e84..1a27ea8f 100644 --- a/encoderfile/src/builder/model.rs +++ b/encoderfile/src/builder/model.rs @@ -13,7 +13,7 @@ pub trait ModelTypeExt { fn validate_model<'a>(&self, path: &'a Path) -> Result>; } -impl ModelTypeExt for crate::common::ModelType { +impl ModelTypeExt for crate::common::model_type::ModelType { fn validate_model<'a>(&self, path: &'a Path) -> Result> { let model = ORTSessionBuilder::default().from_file(path)?; @@ -22,6 +22,7 @@ impl ModelTypeExt for crate::common::ModelType { Self::SequenceClassification => validate_sequence_classification_model(model), Self::TokenClassification => validate_token_classification_model(model), Self::SentenceEmbedding => validate_sentence_embedding_model(model), + Self::ImageClassification => validate_image_classification_model(model), }?; PlannedAsset::from_asset_source(AssetSource::File(path), AssetKind::ModelWeights) @@ -68,6 +69,16 @@ fn validate_token_classification_model(model: Session) -> Result<()> { Ok(()) } +fn validate_image_classification_model(model: Session) -> Result<()> { + let shape = get_outp_dim(model.outputs.as_slice(), "logits")?; + + if shape.len() != 2 { + bail!("Model must return tensor of shape [batch_size, n_labels]") + } + + Ok(()) +} + fn get_outp_dim<'a>(outputs: &'a [Output], outp_name: &str) -> Result<&'a Shape> { outputs .iter() diff --git a/encoderfile/src/builder/tokenizer.rs b/encoderfile/src/builder/tokenizer.rs index cc472a7d..a8e32800 100644 --- a/encoderfile/src/builder/tokenizer.rs +++ b/encoderfile/src/builder/tokenizer.rs @@ -348,7 +348,7 @@ impl<'a> TokenizerConfigBuilder<'a> { #[cfg(test)] mod tests { use crate::builder::config::{ModelPath, TokenizerBuildConfig}; - use crate::common::ModelType; + use crate::common::model_type::ModelType; use super::*; @@ -452,6 +452,7 @@ mod tests { model_weights_path: path.model_weights_path().unwrap(), tokenizer_path: path.tokenizer_path().unwrap(), tokenizer_config_path: None, + preprocessor_config_path: None, }; let config = EncoderfileConfig { diff --git a/encoderfile/src/builder/transforms/validation/embedding.rs b/encoderfile/src/builder/transforms/validation/embedding.rs index 20785c07..83a48784 100644 --- a/encoderfile/src/builder/transforms/validation/embedding.rs +++ b/encoderfile/src/builder/transforms/validation/embedding.rs @@ -56,7 +56,7 @@ impl TransformValidatorExt for EmbeddingTransform { #[cfg(test)] mod tests { use crate::builder::config::{EncoderfileConfig, ModelPath}; - use crate::common::ModelType; + use crate::common::model_type::ModelType; use crate::transforms::DEFAULT_LIBS; use super::*; diff --git a/encoderfile/src/builder/transforms/validation/image_classification.rs b/encoderfile/src/builder/transforms/validation/image_classification.rs new file mode 100644 index 00000000..9f5fb81a --- /dev/null +++ b/encoderfile/src/builder/transforms/validation/image_classification.rs @@ -0,0 +1,131 @@ +use super::{ + TransformValidatorExt, + utils::{BATCH_SIZE, random_tensor, validation_err, validation_err_ctx}, +}; +use crate::{ + common::ModelConfig, + transforms::{ImageClassificationTransform, Postprocessor}, +}; +use anyhow::{Context, Result}; + +const TEST_NUM_LABELS: usize = 16; + +impl TransformValidatorExt for ImageClassificationTransform { + fn dry_run(&self, model_config: &ModelConfig) -> Result<()> { + let num_labels = match model_config.num_labels() { + Some(n) => n, + None => validation_err( + "Model config does not have `num_labels`, `id2label`, or `label2id` field. Please make sure you're using an ImageClassification model.", + )?, + }; + + let dummy_logits = random_tensor(&[BATCH_SIZE, TEST_NUM_LABELS], (-1.0, 1.0))?; + let shape = dummy_logits.shape().to_owned(); + + let res = self.postprocess(dummy_logits) + .with_context(|| { + validation_err_ctx( + format!( + "Failed to run postprocessing on dummy logits (randomly generated in range -1.0..1.0) of shape {:?}", + shape.as_slice(), + ) + ) + })?; + + // result must return tensor of rank 2 + if res.ndim() != 2 { + validation_err(format!( + "Transform must return tensor of rank 2. Got tensor of shape {:?}.", + res.shape() + ))? + } + + // result must have same shape as original + if res.shape() != shape { + validation_err(format!( + "Transform must return Tensor of shape [batch_size, num_labels]. Expected shape [{}, {}], got shape {:?}", + BATCH_SIZE, + num_labels, + res.shape() + ))? + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::builder::config::{EncoderfileConfig, ModelPath}; + use crate::common::model_type::ModelType; + use crate::transforms::DEFAULT_LIBS; + + use super::*; + + fn test_encoderfile_config() -> EncoderfileConfig { + EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/image_classification")), + model_type: ModelType::ImageClassification, + cache_dir: None, + output_path: None, + transform: None, + lua_libs: None, + validate_transform: true, + tokenizer: None, + base_binary_path: None, + target: None, + } + } + + fn test_model_config() -> ModelConfig { + let config_json = include_str!("../../../../../models/token_classification/config.json"); + + serde_json::from_str(config_json).unwrap() + } + + #[test] + fn test_identity_validation() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + ImageClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr end".to_string()), + ) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config) + .expect("Failed to validate"); + } + + #[test] + fn test_bad_return_type() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = ImageClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return 1 end".to_string()), + ) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } + + #[test] + fn test_bad_dimensionality() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = ImageClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr:sum_axis(1) end".to_string()), + ) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } +} diff --git a/encoderfile/src/builder/transforms/validation/mod.rs b/encoderfile/src/builder/transforms/validation/mod.rs index 8d1a3783..0b74cb05 100644 --- a/encoderfile/src/builder/transforms/validation/mod.rs +++ b/encoderfile/src/builder/transforms/validation/mod.rs @@ -1,5 +1,5 @@ use crate::{ - common::{ModelConfig, ModelType}, + common::{ModelConfig, model_type::ModelType}, format::assets::{AssetKind, AssetSource, PlannedAsset}, generated::manifest::LuaLibs as ManifestLuaLibs, transforms::{TransformSpec, convert_libs}, @@ -10,6 +10,7 @@ use crate::builder::config::EncoderfileConfig; use prost::Message; mod embedding; +mod image_classification; mod sentence_embedding; mod sequence_classification; mod token_classification; @@ -89,6 +90,12 @@ pub fn validate_transform<'a>( encoderfile_config, model_config ), + ModelType::ImageClassification => validate_transform!( + ImageClassificationTransform, + transform_str, + encoderfile_config, + model_config + ), }?; let lua_libs: Option = encoderfile_config diff --git a/encoderfile/src/builder/transforms/validation/sentence_embedding.rs b/encoderfile/src/builder/transforms/validation/sentence_embedding.rs index e478927d..22beda8b 100644 --- a/encoderfile/src/builder/transforms/validation/sentence_embedding.rs +++ b/encoderfile/src/builder/transforms/validation/sentence_embedding.rs @@ -59,7 +59,7 @@ impl TransformValidatorExt for SentenceEmbeddingTransform { #[cfg(test)] mod tests { use crate::builder::config::{EncoderfileConfig, ModelPath}; - use crate::common::ModelType; + use crate::common::model_type::ModelType; use crate::transforms::DEFAULT_LIBS; use super::*; diff --git a/encoderfile/src/builder/transforms/validation/sequence_classification.rs b/encoderfile/src/builder/transforms/validation/sequence_classification.rs index 6c4879dc..50615834 100644 --- a/encoderfile/src/builder/transforms/validation/sequence_classification.rs +++ b/encoderfile/src/builder/transforms/validation/sequence_classification.rs @@ -55,7 +55,7 @@ impl TransformValidatorExt for SequenceClassificationTransform { #[cfg(test)] mod tests { use crate::builder::config::{EncoderfileConfig, ModelPath}; - use crate::common::ModelType; + use crate::common::model_type::ModelType; use crate::transforms::DEFAULT_LIBS; use super::*; diff --git a/encoderfile/src/builder/transforms/validation/token_classification.rs b/encoderfile/src/builder/transforms/validation/token_classification.rs index 42801a9b..30d2c75e 100644 --- a/encoderfile/src/builder/transforms/validation/token_classification.rs +++ b/encoderfile/src/builder/transforms/validation/token_classification.rs @@ -56,7 +56,7 @@ impl TransformValidatorExt for TokenClassificationTransform { #[cfg(test)] mod tests { use crate::builder::config::{EncoderfileConfig, ModelPath}; - use crate::common::ModelType; + use crate::common::model_type::ModelType; use crate::transforms::DEFAULT_LIBS; use super::*; diff --git a/encoderfile/src/common/image_classification.rs b/encoderfile/src/common/image_classification.rs new file mode 100644 index 00000000..0af0e264 --- /dev/null +++ b/encoderfile/src/common/image_classification.rs @@ -0,0 +1,92 @@ +use crate::common::FromReadInput; +use crate::common::image_types::{ImageInfo, ImageLabelScore}; +use anyhow::Result; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, io::Read}; +use utoipa::ToSchema; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ImageClassificationRequest { + pub images: Vec, + pub metadata: Option>, +} + +impl super::FromCliInput for ImageClassificationRequest { + fn from_cli_input(inputs: Vec) -> Self { + let images = inputs + .into_iter() + .map(|path| { + let image_data = std::fs::read(path).expect("Failed to read image file"); + let format = + image::guess_format(&image_data).expect("Failed to guess image format"); + ImageInfo { + image_bytes: Bytes::from(image_data), + image_format: format, + } + }) + .collect(); + + Self { + images, + metadata: Some(HashMap::default()), + } + } +} + +impl FromReadInput for ImageClassificationRequest { + fn from_read_input(input: Vec<&mut impl Read>) -> Result { + let images = input + .into_iter() + .map(|reader| { + let mut image_data = Vec::new(); + reader + .read_to_end(&mut image_data) + .map_err(|e| anyhow::anyhow!("Failed to read image data: {}", e))?; + let format = image::guess_format(&image_data) + .map_err(|e| anyhow::anyhow!("Failed to guess image format: {}", e))?; + Ok(ImageInfo { + image_bytes: Bytes::from(image_data), + image_format: format, + }) + }) + .collect::>>()?; + + Ok(Self { + images, + metadata: Some(HashMap::default()), + }) + } +} + +#[derive(Debug, Serialize, Deserialize, ToSchema, utoipa::ToResponse)] +pub struct ImageClassificationResponse { + pub results: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)] +pub struct ImageClassificationResult { + pub labels: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use image::ImageFormat; + use std::fs::File; + + #[test] + fn test_image_classification_request_from_read_input() { + let mut file = + File::open("../test-pictures/yoga01.jpg").expect("Failed to open test image"); + let file_vec = vec![&mut file]; + let request = ImageClassificationRequest::from_read_input(file_vec) + .expect("Failed to create request from read input"); + + assert_eq!(request.images.len(), 1); + assert_eq!(request.images[0].image_format, ImageFormat::Jpeg); + assert!(!request.images[0].image_bytes.is_empty()); + } +} diff --git a/encoderfile/src/common/image_types.rs b/encoderfile/src/common/image_types.rs new file mode 100644 index 00000000..7423cbeb --- /dev/null +++ b/encoderfile/src/common/image_types.rs @@ -0,0 +1,21 @@ +use bytes::Bytes; +use image::ImageFormat; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ImageInfo { + pub image_bytes: Bytes, + pub image_format: ImageFormat, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)] +pub struct ImageLabelScore { + pub label: String, + pub score: Option, +} + +#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)] +pub struct ImageLabels { + pub labels: Vec, +} diff --git a/encoderfile/src/common/mod.rs b/encoderfile/src/common/mod.rs index 56549e37..2990ae73 100644 --- a/encoderfile/src/common/mod.rs +++ b/encoderfile/src/common/mod.rs @@ -8,16 +8,31 @@ mod sequence_classification; mod token; mod token_classification; +// CV +mod image_classification; +mod image_types; + pub use config::*; pub use embedding::*; pub use model_config::*; pub use model_metadata::*; -pub use model_type::ModelType; pub use sentence_embedding::*; pub use sequence_classification::*; pub use token::*; pub use token_classification::*; +// CV +use anyhow::Result; +pub use image_classification::*; +pub use image_types::*; +use std::io::Read; + pub trait FromCliInput { fn from_cli_input(inputs: Vec) -> Self; } + +pub trait FromReadInput { + fn from_read_input(input: Vec<&mut impl Read>) -> Result + where + Self: Sized; +} diff --git a/encoderfile/src/common/model_config.rs b/encoderfile/src/common/model_config.rs index 89eb4e2d..16233c62 100644 --- a/encoderfile/src/common/model_config.rs +++ b/encoderfile/src/common/model_config.rs @@ -4,11 +4,17 @@ use std::collections::HashMap; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ModelConfig { pub model_type: String, + // FIXME to be moved to per-task structs pub num_labels: Option, pub id2label: Option>, pub label2id: Option>, + pub height: Option, + pub width: Option, + pub image_size: Option, + pub num_channels: Option, } +// TODO add image handling metadata impl ModelConfig { pub fn id2label(&self, id: u32) -> Option<&str> { self.id2label.as_ref()?.get(&id).map(|s| s.as_str()) @@ -19,8 +25,8 @@ impl ModelConfig { } pub fn num_labels(&self) -> Option { - if self.num_labels.is_some() { - return self.num_labels; + if let Some(num_labels) = self.num_labels { + return Some(num_labels); } if let Some(id2label) = &self.id2label { @@ -33,6 +39,17 @@ impl ModelConfig { None } + pub fn height(&self) -> Option { + self.height.or(self.image_size) + } + + pub fn width(&self) -> Option { + self.width.or(self.image_size) + } + + pub fn num_channels(&self) -> Option { + self.num_channels + } } #[cfg(test)] @@ -58,6 +75,10 @@ mod tests { num_labels: Some(3), id2label: Some(id2label.clone()), label2id: Some(label2id.clone()), + height: None, + width: None, + image_size: None, + num_channels: None, }; assert_eq!(config.num_labels(), Some(3)); @@ -67,6 +88,10 @@ mod tests { num_labels: None, id2label: Some(id2label.clone()), label2id: Some(label2id.clone()), + height: None, + width: None, + image_size: None, + num_channels: None, }; assert_eq!(config.num_labels(), Some(3)); @@ -76,6 +101,10 @@ mod tests { num_labels: None, id2label: None, label2id: Some(label2id.clone()), + height: None, + width: None, + image_size: None, + num_channels: None, }; assert_eq!(config.num_labels(), Some(3)); diff --git a/encoderfile/src/common/model_type.rs b/encoderfile/src/common/model_type.rs index 96cc7d3e..b4ca0d34 100644 --- a/encoderfile/src/common/model_type.rs +++ b/encoderfile/src/common/model_type.rs @@ -38,16 +38,18 @@ macro_rules! model_type { } } )* + } } +pub trait ModelTypeSpec: Send + Sync + Clone + std::fmt::Debug + 'static { + fn enum_val() -> ModelType; +} + model_type![ Embedding, SequenceClassification, TokenClassification, SentenceEmbedding, + ImageClassification ]; - -pub trait ModelTypeSpec: Send + Sync + Clone + std::fmt::Debug + 'static { - fn enum_val() -> ModelType; -} diff --git a/encoderfile/src/dev_utils/mod.rs b/encoderfile/src/dev_utils/mod.rs index 6a4f128a..6f145700 100644 --- a/encoderfile/src/dev_utils/mod.rs +++ b/encoderfile/src/dev_utils/mod.rs @@ -1,20 +1,31 @@ use crate::{ common::{ - Config, ModelConfig, TokenizerConfig, + Config, TokenizerConfig, model_type::{self, ModelTypeSpec}, }, - runtime::{AppState, EncoderfileState, ORTSessionBuilder}, + runtime::{ + AppState, ClassifierState, EncoderfileState, FeatureExtractorState, ImageConfig, + ImageInputState, ImagePreprocessing, ImageSize, InputType, ORTSessionBuilder, TaskType, + TextInputState, + }, }; use ort::session::Session; use parking_lot::Mutex; use std::str::FromStr; -use std::{fs::File, io::BufReader}; +use std::{fmt::Debug, fs::File, io::BufReader}; const EMBEDDING_DIR: &str = "../models/embedding"; const SEQUENCE_CLASSIFICATION_DIR: &str = "../models/sequence_classification"; const TOKEN_CLASSIFICATION_DIR: &str = "../models/token_classification"; +const IMAGE_CLASSIFICATION_DIR: &str = "../models/image_classification"; -pub fn get_state(dir: &str) -> AppState { +pub fn get_state<'a, T: ModelTypeSpec + InputType + TaskType>(dir: &'a str) -> AppState +where + ::State: TryFrom<&'a str>, + <::State as TryFrom<&'a str>>::Error: Debug, + ::State: TryFrom<&'a str>, + <::State as TryFrom<&'a str>>::Error: Debug, +{ let config = Config { name: "my-model".to_string(), version: "0.0.1".to_string(), @@ -23,11 +34,96 @@ pub fn get_state(dir: &str) -> AppState { lua_libs: None, }; - let model_config = get_model_config(dir); - let tokenizer = get_tokenizer(dir); let session = get_model(dir); - EncoderfileState::new(config, session, tokenizer, model_config).into() + let model_input_state = + ::State::try_from(dir).expect("could not load model input state from file"); + let model_task_state = + ::State::try_from(dir).expect("could not load model task state from file"); + + EncoderfileState::new(config, session, model_input_state, model_task_state).into() +} + +pub trait TaskTypeFromFile: TaskType { + fn get_task_state(dir: &str) -> Result; +} + +pub fn get_config_reader(dir: &str) -> BufReader { + let file = File::open(format!("{}/{}", dir, "config.json")).expect("Config not found"); + BufReader::new(file) +} + +pub fn get_preproc_reader(dir: &str) -> BufReader { + let file = File::open(format!("{}/{}", dir, "preprocessor_config.json")) + .expect("Preprocessing config not found"); + BufReader::new(file) +} + +// Input types +fn get_text_input_state(dir: &str) -> Result { + let reader = get_config_reader(dir); + let tokenizer = get_tokenizer(dir); + let model_config = serde_json::from_reader(reader)?; + + Ok(TextInputState { + tokenizer, + model_config, + }) +} + +fn get_image_input_state(dir: &str) -> Result { + let config_reader = get_config_reader(dir); + let preproc_reader = get_preproc_reader(dir); + let config_state: ImageConfig = serde_json::from_reader(config_reader)?; + let preproc_state: ImagePreprocessing = serde_json::from_reader(preproc_reader)?; + Ok(ImageInputState { + config: ImageConfig { + num_channels: config_state.num_channels, + image_size: config_state.image_size, + }, + preprocessing: ImagePreprocessing { + do_normalize: preproc_state.do_normalize, + do_rescale: preproc_state.do_rescale, + do_resize: preproc_state.do_resize, + image_processor_type: preproc_state.image_processor_type, + rescale_factor: preproc_state.rescale_factor, + image_mean: preproc_state.image_mean, + image_std: preproc_state.image_std, + size: preproc_state.size.or(Some(ImageSize { + width: config_state.image_size, + height: config_state.image_size, + shortest_edge: None, + })), + }, + }) +} + +macro_rules! state_impl { + ($input_type:ty, $state_fun:ident) => { + impl TryFrom<&str> for $input_type { + type Error = anyhow::Error; + + fn try_from(dir: &str) -> Result { + $state_fun(dir) + } + } + }; +} + +state_impl!(TextInputState, get_text_input_state); +state_impl!(ImageInputState, get_image_input_state); +state_impl!(ClassifierState, get_class_task_state); +state_impl!(FeatureExtractorState, get_feature_task_state); + +// Task types +fn get_class_task_state(dir: &str) -> Result { + let reader = get_config_reader(dir); + let state: ClassifierState = serde_json::from_reader(reader)?; + Ok(state) +} + +fn get_feature_task_state(_dir: &str) -> Result { + Ok(FeatureExtractorState {}) } pub fn embedding_state() -> AppState { @@ -46,12 +142,8 @@ pub fn token_classification_state() -> AppState get_state(TOKEN_CLASSIFICATION_DIR) } -fn get_model_config(dir: &str) -> ModelConfig { - let file = File::open(format!("{}/{}", dir, "config.json")).expect("Config not found"); - let reader = BufReader::new(file); - - // Deserialize into struct - serde_json::from_reader(reader).expect("Invalid model config") +pub fn image_classification_state() -> AppState { + get_state(IMAGE_CLASSIFICATION_DIR) } fn get_tokenizer(dir: &str) -> crate::runtime::TokenizerService { diff --git a/encoderfile/src/format/assets/kind.rs b/encoderfile/src/format/assets/kind.rs index d7a12f1a..6c5be0ab 100644 --- a/encoderfile/src/format/assets/kind.rs +++ b/encoderfile/src/format/assets/kind.rs @@ -1,4 +1,7 @@ -use crate::common::model_type::ModelTypeSpec; +use crate::{ + common::model_type::ModelTypeSpec, + runtime::{Input, InputType, Task, TaskType}, +}; /// Identifies the semantic role of an embedded artifact. /// @@ -39,6 +42,9 @@ pub enum AssetKind { /// Tokenizer data required for text-based models. Tokenizer, + + /// Optional image preprocessing configuration. + ImagePreprocessor, } impl AssetKind { @@ -47,30 +53,49 @@ impl AssetKind { AssetKind::Transform, AssetKind::ModelConfig, AssetKind::Tokenizer, + AssetKind::ImagePreprocessor, ]; } -pub trait AssetPolicySpec: ModelTypeSpec { - fn required_assets() -> &'static [AssetKind]; - fn optional_assets() -> &'static [AssetKind]; +pub trait AssetPolicySpec: ModelTypeSpec + InputType + TaskType { + fn required_assets() -> &'static [AssetKind] { + match (Self::input_type(), Self::task_type()) { + (Input::Text, Task::Classification) => &[ + AssetKind::ModelWeights, + AssetKind::ModelConfig, + AssetKind::Tokenizer, + ], + (Input::Text, Task::FeatureExtraction) => &[ + AssetKind::ModelWeights, + AssetKind::ModelConfig, + AssetKind::Tokenizer, + ], + (Input::Image, Task::Classification) => &[ + AssetKind::ModelWeights, + AssetKind::ModelConfig, + AssetKind::ImagePreprocessor, + ], + (Input::Image, Task::FeatureExtraction) => &[ + AssetKind::ModelWeights, + AssetKind::ModelConfig, + AssetKind::ImagePreprocessor, + ], + } + } + fn optional_assets() -> &'static [AssetKind] { + match (Self::input_type(), Self::task_type()) { + (Input::Text, Task::Classification) => &[AssetKind::Transform], + (Input::Text, Task::FeatureExtraction) => &[AssetKind::Transform], + (Input::Image, Task::Classification) => &[AssetKind::Transform], + (Input::Image, Task::FeatureExtraction) => &[AssetKind::Transform], + } + } } macro_rules! asset_policy_spec { // Huggingface-style encoders (Encoder, $model_type:ident) => { - impl AssetPolicySpec for crate::common::model_type::$model_type { - fn required_assets() -> &'static [AssetKind] { - &[ - AssetKind::ModelWeights, - AssetKind::ModelConfig, - AssetKind::Tokenizer, - ] - } - - fn optional_assets() -> &'static [AssetKind] { - &[AssetKind::Transform] - } - } + impl AssetPolicySpec for crate::common::model_type::$model_type {} }; } @@ -78,3 +103,4 @@ asset_policy_spec!(Encoder, Embedding); asset_policy_spec!(Encoder, SequenceClassification); asset_policy_spec!(Encoder, TokenClassification); asset_policy_spec!(Encoder, SentenceEmbedding); +asset_policy_spec!(Encoder, ImageClassification); diff --git a/encoderfile/src/format/codec/encoder.rs b/encoderfile/src/format/codec/encoder.rs index 242054b6..c79d9c99 100644 --- a/encoderfile/src/format/codec/encoder.rs +++ b/encoderfile/src/format/codec/encoder.rs @@ -3,7 +3,8 @@ use anyhow::{Result, bail}; use crate::{ common::model_type::{ - Embedding, ModelType, SentenceEmbedding, SequenceClassification, TokenClassification, + Embedding, ImageClassification, ModelType, SentenceEmbedding, SequenceClassification, + TokenClassification, }, format::{ assets::{AssetPlan, AssetPolicySpec}, @@ -85,6 +86,7 @@ impl EncoderfileCodec { } ModelType::TokenClassification => Self::validate_assets::(plan)?, ModelType::SentenceEmbedding => Self::validate_assets::(plan)?, + ModelType::ImageClassification => Self::validate_assets::(plan)?, }; let model_type: crate::generated::metadata::ModelType = model_type.into(); @@ -99,6 +101,7 @@ impl EncoderfileCodec { weights: None, transform: None, tokenizer: None, + image_preprocessor: None, }; // Populate artifacts with length + hash diff --git a/encoderfile/src/format/codec/mod.rs b/encoderfile/src/format/codec/mod.rs index 6878500e..b153dbb8 100644 --- a/encoderfile/src/format/codec/mod.rs +++ b/encoderfile/src/format/codec/mod.rs @@ -41,6 +41,7 @@ impl EncoderfileManifest { AssetKind::ModelConfig => &mut self.model_config, AssetKind::Transform => &mut self.transform, AssetKind::Tokenizer => &mut self.tokenizer, + AssetKind::ImagePreprocessor => &mut self.image_preprocessor, } } @@ -50,6 +51,7 @@ impl EncoderfileManifest { AssetKind::ModelConfig => &self.model_config, AssetKind::Transform => &self.transform, AssetKind::Tokenizer => &self.tokenizer, + AssetKind::ImagePreprocessor => &self.image_preprocessor, } } @@ -74,6 +76,7 @@ mod tests { weights: None, transform: None, tokenizer: None, + image_preprocessor: None, } } diff --git a/encoderfile/src/format/container.rs b/encoderfile/src/format/container.rs index a7216bcc..38ef1af2 100644 --- a/encoderfile/src/format/container.rs +++ b/encoderfile/src/format/container.rs @@ -2,7 +2,7 @@ use anyhow::Result; use std::io::{Read, Seek, SeekFrom}; use crate::{ - common::ModelType, + common::model_type::ModelType, format::{assets::AssetKind, footer::EncoderfileFooter}, generated::manifest::{Artifact, EncoderfileManifest}, }; diff --git a/encoderfile/src/generated/image_classification.rs b/encoderfile/src/generated/image_classification.rs new file mode 100644 index 00000000..0ca34dd6 --- /dev/null +++ b/encoderfile/src/generated/image_classification.rs @@ -0,0 +1,47 @@ +use crate::{common, generated::image_types::ImageLabels}; + +tonic::include_proto!("encoderfile.image_classification"); + +impl From for common::ImageClassificationRequest { + fn from(val: ImageClassificationRequest) -> Self { + let images = val + .inputs + .into_iter() + .map(|input| { + common::ImageInfo { + image_bytes: bytes::Bytes::from(input.image), + image_format: image::ImageFormat::Png, // TODO: detect format properly + } + }) + .collect(); + Self { + images, + metadata: if val.metadata.is_empty() { + None + } else { + Some(val.metadata) + }, + } + } +} + +impl From for ImageClassificationResponse { + fn from(val: common::ImageClassificationResponse) -> Self { + Self { + results: val + .results + .into_iter() + .map(|result| result.into()) + .collect(), + metadata: val.metadata.unwrap_or_default(), + } + } +} + +impl From for ImageLabels { + fn from(val: common::ImageClassificationResult) -> Self { + ImageLabels { + labels: val.labels.into_iter().map(|label| label.into()).collect(), + } + } +} diff --git a/encoderfile/src/generated/image_types.rs b/encoderfile/src/generated/image_types.rs new file mode 100644 index 00000000..5fbdd38e --- /dev/null +++ b/encoderfile/src/generated/image_types.rs @@ -0,0 +1,28 @@ +use crate::common; + +tonic::include_proto!("encoderfile.image_types"); + +impl From for ImageInput { + fn from(val: common::ImageInfo) -> Self { + ImageInput { + image: val.image_bytes.to_vec(), + } + } +} + +impl From for ImageLabelScore { + fn from(val: common::ImageLabelScore) -> Self { + ImageLabelScore { + label: val.label, + score: val.score, + } + } +} + +impl From for ImageLabels { + fn from(val: common::ImageLabels) -> Self { + ImageLabels { + labels: val.labels.into_iter().map(|label| label.into()).collect(), + } + } +} diff --git a/encoderfile/src/generated/metadata.rs b/encoderfile/src/generated/metadata.rs index 33a660a4..d240d1f2 100644 --- a/encoderfile/src/generated/metadata.rs +++ b/encoderfile/src/generated/metadata.rs @@ -12,24 +12,28 @@ impl From for GetModelMetadataResponse { } } -impl From for ModelType { - fn from(val: common::ModelType) -> Self { +impl From for ModelType { + fn from(val: common::model_type::ModelType) -> Self { match val { - common::ModelType::Embedding => Self::Embedding, - common::ModelType::SequenceClassification => Self::SequenceClassification, - common::ModelType::TokenClassification => Self::TokenClassification, - common::ModelType::SentenceEmbedding => Self::SentenceEmbedding, + common::model_type::ModelType::Embedding => Self::Embedding, + common::model_type::ModelType::SequenceClassification => Self::SequenceClassification, + common::model_type::ModelType::TokenClassification => Self::TokenClassification, + common::model_type::ModelType::SentenceEmbedding => Self::SentenceEmbedding, + common::model_type::ModelType::ImageClassification => Self::ImageClassification, } } } -impl From for common::ModelType { +impl From for common::model_type::ModelType { fn from(val: ModelType) -> Self { match val { - ModelType::Embedding => common::ModelType::Embedding, - ModelType::SequenceClassification => common::ModelType::SequenceClassification, - ModelType::TokenClassification => common::ModelType::TokenClassification, - ModelType::SentenceEmbedding => common::ModelType::SentenceEmbedding, + ModelType::Embedding => common::model_type::ModelType::Embedding, + ModelType::SequenceClassification => { + common::model_type::ModelType::SequenceClassification + } + ModelType::TokenClassification => common::model_type::ModelType::TokenClassification, + ModelType::SentenceEmbedding => common::model_type::ModelType::SentenceEmbedding, + ModelType::ImageClassification => common::model_type::ModelType::ImageClassification, ModelType::Unspecified => { unreachable!("Unspecified model type. This should not happen.") } diff --git a/encoderfile/src/generated/mod.rs b/encoderfile/src/generated/mod.rs index 79d1fbe4..4b5a8dbc 100644 --- a/encoderfile/src/generated/mod.rs +++ b/encoderfile/src/generated/mod.rs @@ -1,4 +1,6 @@ pub mod embedding; +pub mod image_classification; +pub mod image_types; pub mod manifest; pub mod metadata; pub mod sentence_embedding; diff --git a/encoderfile/src/inference/embedding.rs b/encoderfile/src/inference/embedding.rs index af125053..9aa062c8 100644 --- a/encoderfile/src/inference/embedding.rs +++ b/encoderfile/src/inference/embedding.rs @@ -13,7 +13,7 @@ pub fn embedding<'a>( transform: &EmbeddingTransform, encodings: Vec, ) -> Result, ApiError> { - let (a_ids, a_mask, a_type_ids) = crate::prepare_inputs!(encodings); + let (a_ids, a_mask, a_type_ids) = crate::prepare_text_inputs!(encodings); let mut outputs = crate::run_model!(session, a_ids, a_mask, a_type_ids)? .get("last_hidden_state") diff --git a/encoderfile/src/inference/image_classification.rs b/encoderfile/src/inference/image_classification.rs new file mode 100644 index 00000000..f29b00c1 --- /dev/null +++ b/encoderfile/src/inference/image_classification.rs @@ -0,0 +1,60 @@ +use ndarray::{Array2, Array4, Axis, Ix2}; + +use crate::error::ApiError; + +use crate::common::ImageLabelScore; + +/* +fn logit_to_prob(logit: f32) -> f32 { + 1.0 / (1.0 + (-logit).exp()) +} +*/ + +#[tracing::instrument(skip_all)] +pub fn image_classification<'a>( + mut session: crate::runtime::Model<'a>, + // CHECK if this is a vec of flattened rgb images with num_channels X height X width + images: Array4, + classes: Vec, +) -> Result>, ApiError> { + let grouped_images = ort::value::TensorRef::from_array_view(&images) + .unwrap() + .to_owned(); + let raw_outputs = crate::run_cv_model!(session, grouped_images)?; + let /*mut*/ outputs = raw_outputs + .get("logits") + .ok_or(ApiError::InternalError("Model does not return logits"))? + .try_extract_array::() + .map_err(|_| ApiError::InternalError("Model does not return tensor extractable to f32"))? + .into_dimensionality::() + .map_err(|_| { + ApiError::InternalError("Model does not return tensor of shape [n_batch, n_classes]") + })? + .into_owned(); + // outputs.mapv_inplace(logit_to_prob); + + Ok(postprocess(outputs, classes)) +} + +#[tracing::instrument(skip_all)] +pub fn postprocess(outputs: Array2, classes: Vec) -> Vec> { + outputs + .axis_iter(Axis(0)) + .map(|logs| { + logs.iter() + .enumerate() + .map(|(idx, score)| { + ImageLabelScore { + label: classes[idx].to_string(), // TODO: get label from config + score: Some(*score), + } + }) + .collect() + }) + .collect() +} + +#[cfg(test)] +mod tests { + // Add your test cases here +} diff --git a/encoderfile/src/inference/mod.rs b/encoderfile/src/inference/mod.rs index 09803536..e9b82a92 100644 --- a/encoderfile/src/inference/mod.rs +++ b/encoderfile/src/inference/mod.rs @@ -1,5 +1,8 @@ +// text pub mod embedding; pub mod sentence_embedding; pub mod sequence_classification; pub mod token_classification; +// cv +pub mod image_classification; pub mod utils; diff --git a/encoderfile/src/inference/sentence_embedding.rs b/encoderfile/src/inference/sentence_embedding.rs index ea0e3051..d2a876af 100644 --- a/encoderfile/src/inference/sentence_embedding.rs +++ b/encoderfile/src/inference/sentence_embedding.rs @@ -13,7 +13,7 @@ pub fn sentence_embedding<'a>( transform: &SentenceEmbeddingTransform, encodings: Vec, ) -> Result, ApiError> { - let (a_ids, a_mask, a_type_ids) = crate::prepare_inputs!(encodings); + let (a_ids, a_mask, a_type_ids) = crate::prepare_text_inputs!(encodings); let a_mask_arr = a_mask .try_extract_array::() diff --git a/encoderfile/src/inference/sequence_classification.rs b/encoderfile/src/inference/sequence_classification.rs index f003a07e..73941211 100644 --- a/encoderfile/src/inference/sequence_classification.rs +++ b/encoderfile/src/inference/sequence_classification.rs @@ -1,6 +1,7 @@ use crate::{ - common::{ModelConfig, SequenceClassificationResult}, + common::SequenceClassificationResult, error::ApiError, + runtime::ClassifierState, transforms::{Postprocessor, SequenceClassificationTransform}, }; use ndarray::{Array2, Axis, Ix2}; @@ -11,10 +12,10 @@ use tokenizers::Encoding; pub fn sequence_classification<'a>( mut session: crate::runtime::Model<'a>, transform: &SequenceClassificationTransform, - config: &ModelConfig, + config: &ClassifierState, encodings: Vec, ) -> Result, ApiError> { - let (a_ids, a_mask, a_type_ids) = crate::prepare_inputs!(encodings); + let (a_ids, a_mask, a_type_ids) = crate::prepare_text_inputs!(encodings); let mut outputs = crate::run_model!(session, a_ids, a_mask, a_type_ids)? .get("logits") @@ -35,7 +36,7 @@ pub fn sequence_classification<'a>( #[tracing::instrument(skip_all)] pub fn postprocess( outputs: Array2, - config: &ModelConfig, + config: &ClassifierState, ) -> Vec { outputs .axis_iter(Axis(0)) diff --git a/encoderfile/src/inference/token_classification.rs b/encoderfile/src/inference/token_classification.rs index 732073f0..8f0899c7 100644 --- a/encoderfile/src/inference/token_classification.rs +++ b/encoderfile/src/inference/token_classification.rs @@ -1,6 +1,7 @@ use crate::{ - common::{ModelConfig, TokenClassification, TokenClassificationResult, TokenInfo}, + common::{TokenClassification, TokenClassificationResult, TokenInfo}, error::ApiError, + runtime::ClassifierState, transforms::{Postprocessor, TokenClassificationTransform}, }; use ndarray::{Array3, Axis, Ix3}; @@ -11,10 +12,10 @@ use tokenizers::Encoding; pub fn token_classification<'a>( mut session: crate::runtime::Model<'a>, transform: &TokenClassificationTransform, - config: &ModelConfig, + config: &ClassifierState, encodings: Vec, ) -> Result, ApiError> { - let (a_ids, a_mask, a_type_ids) = crate::prepare_inputs!(encodings); + let (a_ids, a_mask, a_type_ids) = crate::prepare_text_inputs!(encodings); let mut outputs = crate::run_model!(session, a_ids, a_mask, a_type_ids)? .get("logits") @@ -36,7 +37,7 @@ pub fn token_classification<'a>( pub fn postprocess( outputs: Array3, encodings: Vec, - config: &ModelConfig, + config: &ClassifierState, ) -> Vec { let mut predictions = Vec::new(); diff --git a/encoderfile/src/inference/utils.rs b/encoderfile/src/inference/utils.rs index a59f1f62..6d4d1329 100644 --- a/encoderfile/src/inference/utils.rs +++ b/encoderfile/src/inference/utils.rs @@ -3,7 +3,7 @@ use ort::session::Session; use parking_lot::MutexGuard; #[macro_export] -macro_rules! prepare_inputs { +macro_rules! prepare_text_inputs { ($encodings:ident) => {{ let padded_token_length = $encodings[0].len(); @@ -75,3 +75,13 @@ macro_rules! run_model { }) }}; } + +#[macro_export] +macro_rules! run_cv_model { + ($session:expr, $image_bytes:expr) => {{ + $session.run(ort::inputs!($image_bytes)).map_err(|e| { + tracing::error!("Error running model: {:?}", e); + $crate::error::ApiError::InternalError("Error running model") + }) + }}; +} diff --git a/encoderfile/src/runtime/loader.rs b/encoderfile/src/runtime/loader.rs index 38a9d480..213350ca 100644 --- a/encoderfile/src/runtime/loader.rs +++ b/encoderfile/src/runtime/loader.rs @@ -5,10 +5,10 @@ use std::io::{Read, Seek}; use ort::session::{Session, builder::GraphOptimizationLevel}; use crate::{ - common::{Config, LuaLibs, ModelConfig, ModelType}, + common::{Config, LuaLibs, ModelConfig, model_type::ModelType}, format::{assets::AssetKind, codec::EncoderfileCodec, container::Encoderfile}, generated::manifest::{self, TransformType}, - runtime::{ORTExecutionProvider, ORTSessionBuilder, TokenizerService}, + runtime::{ImagePreprocessing, ORTExecutionProvider, ORTSessionBuilder, TokenizerService}, }; pub struct EncoderfileLoader<'a, R: Read + Seek> { @@ -129,6 +129,21 @@ impl<'a, R: Read + Seek> EncoderfileLoader<'a, R> { Err(e) => bail!("Error loading model config: {e:?}"), } } + + pub fn image_preprocessor_config(&mut self) -> Result { + match self + .encoderfile + .open_required(self.reader, AssetKind::ImagePreprocessor) + { + Ok(mut r) => { + let mut buf = vec![0u8; r.len() as usize]; + r.read_exact(&mut buf)?; + + Ok(serde_json::from_slice(buf.as_slice())?) + } + Err(e) => bail!("Error loading image preprocessor config: {e:?}"), + } + } } pub fn load_assets<'a, R: Read + Seek>(file: &'a mut R) -> Result> { diff --git a/encoderfile/src/runtime/mod.rs b/encoderfile/src/runtime/mod.rs index be3bf6e7..886b4356 100644 --- a/encoderfile/src/runtime/mod.rs +++ b/encoderfile/src/runtime/mod.rs @@ -8,7 +8,11 @@ mod tokenizer; pub use loader::{EncoderfileLoader, load_assets}; pub use session::{ORTExecutionProvider, ORTSessionBuilder}; -pub use state::{AppState, EncoderfileState}; +pub use state::{ + AppState, ClassifierState, EncoderfileState, FeatureExtractorState, ImageConfig, + ImageInputState, ImagePreprocessing, ImageSize, Input, InputType, Task, TaskType, + TextInputState, +}; pub use tokenizer::TokenizerService; pub type Model<'a> = MutexGuard<'a, Session>; diff --git a/encoderfile/src/runtime/state.rs b/encoderfile/src/runtime/state.rs index 5690d99e..d3d65787 100644 --- a/encoderfile/src/runtime/state.rs +++ b/encoderfile/src/runtime/state.rs @@ -1,32 +1,384 @@ -use std::{marker::PhantomData, sync::Arc}; +use mlua::prelude::*; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::Debug, + io::{Read, Seek}, + marker::PhantomData, + sync::Arc, +}; use ort::session::Session; use parking_lot::Mutex; use crate::{ - common::{Config, ModelConfig, ModelType, model_type::ModelTypeSpec}, + common::{ + Config, ModelConfig, + model_type::{self, ModelType, ModelTypeSpec}, + }, runtime::TokenizerService, + runtime::loader::EncoderfileLoader, transforms::DEFAULT_LIBS, }; pub type AppState = Arc>; +#[derive(PartialEq)] +pub enum Task { + Classification, + FeatureExtraction, +} + +#[derive(PartialEq)] +pub enum Input { + Text, + Image, +} + +pub trait TaskType { + const TASK: Task; + fn task_type_val(&self) -> Task { + Self::task_type() + } + fn task_type() -> Task { + Self::TASK + } + type State: Debug; +} + +pub trait InputType { + const INPUT: Input; + fn input_type_val(&self) -> Input { + Self::input_type() + } + fn input_type() -> Input { + Self::INPUT + } + type State: Debug; +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TextInputState { + // TODO check Clone impl + pub tokenizer: TokenizerService, + pub model_config: ModelConfig, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageInputState { + pub config: ImageConfig, + pub preprocessing: ImagePreprocessing, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageConfig { + pub num_channels: u32, + pub image_size: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImagePreprocessing { + pub rescale_factor: Option, + pub image_mean: Option>, + pub image_std: Option>, + pub do_normalize: Option, + pub do_rescale: Option, + pub do_resize: Option, + pub image_processor_type: Option, + pub size: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageSize { + pub height: Option, + pub width: Option, + pub shortest_edge: Option, +} + +impl LuaUserData for ImageInputState { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("num_channels", |_, this| Ok(this.config.num_channels)); + fields.add_field_method_get("image_size", |_, this| Ok(this.config.image_size)); + fields.add_field_method_get("rescale_factor", |_, this| { + Ok(this.preprocessing.rescale_factor) + }); + fields.add_field_method_get("image_mean", |_, this| { + Ok(this.preprocessing.image_mean.clone()) + }); + fields.add_field_method_get("image_std", |_, this| { + Ok(this.preprocessing.image_std.clone()) + }); + fields.add_field_method_get("do_normalize", |_, this| { + Ok(this.preprocessing.do_normalize) + }); + fields.add_field_method_get("do_rescale", |_, this| Ok(this.preprocessing.do_rescale)); + fields.add_field_method_get("do_resize", |_, this| Ok(this.preprocessing.do_resize)); + fields.add_field_method_get("size_height", |_, this| { + Ok(this.preprocessing.size.as_ref().and_then(|s| s.height)) + }); + fields.add_field_method_get("size_width", |_, this| { + Ok(this.preprocessing.size.as_ref().and_then(|s| s.width)) + }); + fields.add_field_method_get("size_shortest_edge", |_, this| { + Ok(this + .preprocessing + .size + .as_ref() + .and_then(|s| s.shortest_edge)) + }); + } +} + +impl LuaUserData for TextInputState { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("model_type", |_, this| { + Ok(this.model_config.model_type.clone()) + }); + fields.add_field_method_get("num_labels", |_, this| Ok(this.model_config.num_labels())); + fields.add_field_method_get("id2label", |_, this| Ok(this.model_config.id2label.clone())); + fields.add_field_method_get("label2id", |_, this| Ok(this.model_config.label2id.clone())); + } +} + +impl LuaUserData for ClassifierState { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("num_labels", |_, this| Ok(this.num_labels())); + fields.add_field_method_get("id2label", |_, this| Ok(this.id2label.clone())); + fields.add_field_method_get("label2id", |_, this| Ok(this.label2id.clone())); + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ClassifierState { + pub id2label: Option>, + pub label2id: Option>, + pub num_labels: Option, +} +impl ClassifierState { + pub fn id2label(&self, id: u32) -> Option<&str> { + self.id2label.as_ref()?.get(&id).map(|s| s.as_str()) + } + + pub fn label2id(&self, label: &str) -> Option { + self.label2id.as_ref()?.get(label).copied() + } + + pub fn num_labels(&self) -> Option { + if self.num_labels.is_some() { + return self.num_labels; + } + + if let Some(id2label) = &self.id2label { + return Some(id2label.len()); + } + + if let Some(label2id) = &self.label2id { + return Some(label2id.len()); + } + + None + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FeatureExtractorState {} + +fn text_input_state_try_from_loader<'a, R>( + loader: &mut EncoderfileLoader<'a, R>, +) -> Result +where + R: Read + Seek, +{ + let tokenizer = loader.tokenizer()?; + let model_config = loader.model_config()?; + Ok(TextInputState { + tokenizer, + model_config, + }) +} + +fn image_input_state_try_from_loader<'a, R>( + loader: &mut EncoderfileLoader<'a, R>, +) -> Result +where + R: Read + Seek, +{ + let model_config = loader.model_config()?; + let preprocessor_config = loader.image_preprocessor_config()?; + Ok(ImageInputState { + config: ImageConfig { + num_channels: model_config + .num_channels + .ok_or_else(|| anyhow::anyhow!("num_channels is required for image models"))?, + image_size: model_config.image_size, + }, + preprocessing: ImagePreprocessing { + rescale_factor: preprocessor_config.rescale_factor, + image_mean: preprocessor_config.image_mean, + image_std: preprocessor_config.image_std, + do_normalize: preprocessor_config.do_normalize, + do_rescale: preprocessor_config.do_rescale, + do_resize: preprocessor_config.do_resize, + image_processor_type: preprocessor_config.image_processor_type, + size: preprocessor_config.size, + }, + }) +} + +fn classifier_state_try_from_loader<'a, R>( + loader: &mut EncoderfileLoader<'a, R>, +) -> Result +where + R: Read + Seek, +{ + let model_config = loader.model_config()?.clone(); + Ok(ClassifierState { + id2label: model_config.id2label.clone(), + label2id: model_config.label2id.clone(), + num_labels: model_config.num_labels(), + }) +} + +fn feature_extractor_state_try_from_loader<'a, R>( + _loader: &mut EncoderfileLoader<'a, R>, +) -> Result +where + R: Read + Seek, +{ + Ok(FeatureExtractorState {}) +} + +macro_rules! state_from_source_impl { + ($base_type:tt, $state_type:ty, $state_fun:ident) => { + impl<'a, 'borrow, R> TryFrom<&'borrow mut EncoderfileLoader<'a, R>> for $state_type + where + R: Read + Seek, + { + type Error = anyhow::Error; + + fn try_from( + loader: &'borrow mut EncoderfileLoader<'a, R>, + ) -> Result { + $state_fun::(loader) + } + } + }; +} + +state_from_source_impl!(InputType, TextInputState, text_input_state_try_from_loader); +state_from_source_impl!( + InputType, + ImageInputState, + image_input_state_try_from_loader +); +state_from_source_impl!(TaskType, ClassifierState, classifier_state_try_from_loader); +state_from_source_impl!( + TaskType, + FeatureExtractorState, + feature_extractor_state_try_from_loader +); + +macro_rules! input_state_impl { + ($model_type:ty, $state_type:ty, $input:expr) => { + impl InputType for $model_type { + const INPUT: Input = $input; + type State = $state_type; + } + }; +} + +input_state_impl!(model_type::Embedding, TextInputState, Input::Text); +input_state_impl!(model_type::SentenceEmbedding, TextInputState, Input::Text); +input_state_impl!( + model_type::SequenceClassification, + TextInputState, + Input::Text +); +input_state_impl!(model_type::TokenClassification, TextInputState, Input::Text); +input_state_impl!( + model_type::ImageClassification, + ImageInputState, + Input::Image +); + +macro_rules! task_state_impl { + ($model_type:ty, $state_type:ty, $task:expr) => { + impl TaskType for $model_type { + const TASK: Task = $task; + type State = $state_type; + } + }; +} + +task_state_impl!( + model_type::SequenceClassification, + ClassifierState, + Task::Classification +); +task_state_impl!( + model_type::TokenClassification, + ClassifierState, + Task::Classification +); +task_state_impl!( + model_type::ImageClassification, + ClassifierState, + Task::Classification +); +task_state_impl!( + model_type::Embedding, + FeatureExtractorState, + Task::FeatureExtraction +); +task_state_impl!( + model_type::SentenceEmbedding, + FeatureExtractorState, + Task::FeatureExtraction +); + +macro_rules! input_type_impl { + [ $( $x:ident ),* $(,)? ] => { + impl ModelType { + pub fn input_type(&self) -> crate::runtime::Input { + match self { + $( + ModelType::$x => model_type::$x::input_type(), + )* + } + } + pub fn task_type(&self) -> crate::runtime::Task { + match self { + $( + ModelType::$x => model_type::$x::task_type(), + )* + } + } + } + } +} +input_type_impl![ + Embedding, + SequenceClassification, + TokenClassification, + SentenceEmbedding, + ImageClassification +]; + #[derive(Debug)] -pub struct EncoderfileState { +pub struct EncoderfileState { pub config: Config, pub session: Mutex, - pub tokenizer: TokenizerService, - pub model_config: ModelConfig, + pub model_input_state: ::State, + pub task_state: ::State, pub lua_libs: Vec, _marker: PhantomData, } -impl EncoderfileState { +impl EncoderfileState { pub fn new( config: Config, session: Mutex, - tokenizer: TokenizerService, - model_config: ModelConfig, + model_input_state: ::State, + task_state: ::State, ) -> EncoderfileState { let lua_libs = match config.lua_libs { Some(ref libs) => Vec::::from(libs), @@ -35,8 +387,8 @@ impl EncoderfileState { EncoderfileState { config, session, - tokenizer, - model_config, + model_input_state, + task_state, lua_libs, _marker: PhantomData, } diff --git a/encoderfile/src/services/embedding.rs b/encoderfile/src/services/embedding.rs index 53e89f7b..cfc846f7 100644 --- a/encoderfile/src/services/embedding.rs +++ b/encoderfile/src/services/embedding.rs @@ -15,7 +15,10 @@ impl Inference for AppState { fn inference(&self, request: impl Into) -> Result { let request = request.into(); - let encodings = self.tokenizer.encode_text(request.inputs)?; + let encodings = self + .model_input_state + .tokenizer + .encode_text(request.inputs)?; let transform = EmbeddingTransform::new(self.lua_libs.clone(), self.transform_str())?; diff --git a/encoderfile/src/services/image_classification.rs b/encoderfile/src/services/image_classification.rs new file mode 100644 index 00000000..c8969955 --- /dev/null +++ b/encoderfile/src/services/image_classification.rs @@ -0,0 +1,256 @@ +use crate::{ + common::{ + ImageClassificationRequest, ImageClassificationResponse, ImageClassificationResult, + model_type, + }, + error::ApiError, + runtime::AppState, + transforms::{DEFAULT_LIBS, Image, ImageClassificationTransform, Preprocessor}, +}; +use ndarray::{ArrayD, Axis, Ix4, Zip}; + +use super::inference::Inference; +use crate::inference::image_classification::image_classification; + +// No service impl yet + +impl Inference for AppState { + type Input = ImageClassificationRequest; + type Output = ImageClassificationResponse; + + fn inference(&self, request: impl Into) -> Result { + // let transform = ImageClassificationTransform::new(self.lua_libs.clone(), self.transform_str())?; + + let request = request.into(); + if request.images.is_empty() { + return Err(ApiError::InputError("Cannot classify empty image list")); + } + + let postprocess_code = r##" + function Preprocess(img) + return img:resize(224,224):to_array(3) + end + "## + .to_string(); + + let engine = + ImageClassificationTransform::new(DEFAULT_LIBS.to_vec(), Some(postprocess_code)) + .expect("Failed to create engine"); + + let num_channels = self.model_input_state.config.num_channels as usize; + let rescale_factor = self + .model_input_state + .preprocessing + .rescale_factor + .ok_or(ApiError::InternalError("missing rescale factor"))?; + let image_mean = self + .model_input_state + .preprocessing + .image_mean + .as_ref() + .ok_or(ApiError::InternalError("missing image mean"))?; + let image_std = self + .model_input_state + .preprocessing + .image_std + .as_ref() + .ok_or(ApiError::InternalError("missing image std"))?; + + let images: Vec> = request + .images + .iter() + .map(|image_info| { + let img = image::load_from_memory(&image_info.image_bytes) + .expect("Failed to load image from bytes"); + let mut res = engine + .preprocess((Image(img), self.model_input_state.clone())) + .expect("Failed") + .into_inner(); + let mean_arr = + ndarray::Array::from_shape_vec((num_channels, 1, 1), image_mean.to_vec()) + .expect("mean shape mismatch"); + let std_arr = + ndarray::Array::from_shape_vec((num_channels, 1, 1), image_std.to_vec()) + .expect("std shape mismatch"); + Zip::from(&mut res) + .and_broadcast(&mean_arr) + .and_broadcast(&std_arr) + .for_each(|x, &m, &s| *x = (*x * rescale_factor - m) / s); + res + }) + .collect(); + + let images_array = ndarray::stack( + Axis(0), + &images.iter().map(|x| x.view()).collect::>(), + ) + .unwrap() + .into_dimensionality::() + .unwrap(); + + // TODO overlap preprocessing and inference, but for now just do it sequentially + // Since we are adding gpu providers now, preprocessing could run in cpu while inference + // is running. Using some sort of task queue will pave the way for more efficient batch + // processing. However, it will not be implemented right now. + + let label_map = self.task_state.id2label.clone().unwrap(); + let mut entries: Vec<_> = label_map.iter().collect(); + entries.sort_by(|x, y| x.0.cmp(y.0)); + let classes: Vec = entries + .into_iter() + .map(|(_, label)| label.clone()) + .collect(); + + let labels_batch = image_classification(self.session.lock(), images_array, classes)?; + + Ok(ImageClassificationResponse { + results: labels_batch + .iter() + .map(|labels| ImageClassificationResult { + labels: labels.clone(), + }) + .collect(), + metadata: request.metadata, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::FromReadInput; + use crate::common::ImageClassificationRequest; + use crate::common::model_type::ImageClassification; + use crate::dev_utils; + use std::fs::File; + use std::sync::{Arc, Once}; + + fn init_tracing() { + static TRACING: Once = Once::new(); + + TRACING.call_once(|| { + let _ = tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("debug,ort=warn")), + ) + .with_test_writer() + .try_init(); + }); + } + + #[test] + fn test_image_classification_request_from_file() { + init_tracing(); + + let state = dev_utils::get_state::("../models/image_classification"); + let mut file = + File::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + let file_vec = vec![&mut file]; + let request = ImageClassificationRequest::from_read_input(file_vec) + .expect("Failed to create request from read input"); + let response = state.inference(request).expect("Inference failed"); + println!("Inference response: {:?}", response); + assert_eq!(response.results.len(), 1); + assert_eq!(response.results[0].labels.len(), 9); + assert!( + response.results[0] + .labels + .iter() + .enumerate() + .max_by(|a, b| a.1.score.partial_cmp(&b.1.score).unwrap()) + .unwrap() + .1 + .label + == "Downward-Dog" + ); // top label should be "yoga mat" + } + + #[test] + fn test_image_classification_empty() { + init_tracing(); + + let state = dev_utils::get_state::("../models/image_classification"); + let request = ImageClassificationRequest { + images: vec![], + metadata: Default::default(), + }; + let response = state.inference(request); + assert!(response.is_err()); + } + + #[test] + fn test_image_classification_missing_rescale_factor() { + init_tracing(); + + let mut state = + dev_utils::get_state::("../models/image_classification"); + Arc::get_mut(&mut state) + .expect("state should not be shared") + .model_input_state + .preprocessing + .rescale_factor = None; + + let mut file = + File::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + let file_vec = vec![&mut file]; + let request = ImageClassificationRequest::from_read_input(file_vec) + .expect("Failed to create request from read input"); + + let response = state.inference(request); + assert!(matches!( + response, + Err(ApiError::InternalError("missing rescale factor")) + )); + } + + #[test] + fn test_image_classification_missing_image_mean() { + init_tracing(); + + let mut state = + dev_utils::get_state::("../models/image_classification"); + Arc::get_mut(&mut state) + .expect("state should not be shared") + .model_input_state + .preprocessing + .image_mean = None; + + let mut file = + File::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + let file_vec = vec![&mut file]; + let request = ImageClassificationRequest::from_read_input(file_vec) + .expect("Failed to create request from read input"); + + let response = state.inference(request); + assert!(matches!( + response, + Err(ApiError::InternalError("missing image mean")) + )); + } + + #[test] + fn test_image_classification_missing_image_std() { + init_tracing(); + + let mut state = + dev_utils::get_state::("../models/image_classification"); + Arc::get_mut(&mut state) + .expect("state should not be shared") + .model_input_state + .preprocessing + .image_std = None; + + let mut file = + File::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + let file_vec = vec![&mut file]; + let request = ImageClassificationRequest::from_read_input(file_vec) + .expect("Failed to create request from read input"); + + let response = state.inference(request); + assert!(matches!( + response, + Err(ApiError::InternalError("missing image std")) + )); + } +} diff --git a/encoderfile/src/services/inference.rs b/encoderfile/src/services/inference.rs index 5e55b15b..d1832431 100644 --- a/encoderfile/src/services/inference.rs +++ b/encoderfile/src/services/inference.rs @@ -1,8 +1,9 @@ use crate::{common::FromCliInput, error::ApiError, services::Metadata}; +// FIXME enforce the openapi schema later on pub trait Inference: Metadata { - type Input: FromCliInput + serde::de::DeserializeOwned + Sync + Send + utoipa::ToSchema; - type Output: serde::Serialize + Sync + Send + utoipa::ToSchema; + type Input: FromCliInput + serde::de::DeserializeOwned + Sync + Send; + type Output: serde::Serialize + Sync + Send; fn inference(&self, request: impl Into) -> Result; } diff --git a/encoderfile/src/services/mod.rs b/encoderfile/src/services/mod.rs index 43720d50..4e7db263 100644 --- a/encoderfile/src/services/mod.rs +++ b/encoderfile/src/services/mod.rs @@ -1,4 +1,5 @@ mod embedding; +mod image_classification; mod inference; mod model_metadata; mod sentence_embedding; diff --git a/encoderfile/src/services/model_metadata.rs b/encoderfile/src/services/model_metadata.rs index d43cc28d..8d4637d3 100644 --- a/encoderfile/src/services/model_metadata.rs +++ b/encoderfile/src/services/model_metadata.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; use crate::{ - common::{GetModelMetadataResponse, ModelType, model_type::ModelTypeSpec}, - runtime::AppState, + common::{ + GetModelMetadataResponse, + model_type::{ModelType, ModelTypeSpec}, + }, + runtime::{AppState, ClassifierState, FeatureExtractorState, InputType, TaskType}, }; pub trait Metadata { @@ -21,7 +24,27 @@ pub trait Metadata { fn id2label(&self) -> Option>; } -impl Metadata for AppState { +trait TaskStateMetadata { + fn id2label(&self) -> Option>; +} + +impl TaskStateMetadata for ClassifierState { + fn id2label(&self) -> Option> { + println!("ClassifierState: {:?}", self); + self.id2label.clone() + } +} + +impl TaskStateMetadata for FeatureExtractorState { + fn id2label(&self) -> Option> { + None + } +} + +impl Metadata for AppState +where + ::State: TaskStateMetadata, +{ fn model_id(&self) -> String { self.config.name.clone() } @@ -31,6 +54,6 @@ impl Metadata for AppState { } fn id2label(&self) -> Option> { - self.model_config.id2label.clone() + self.task_state.id2label() } } diff --git a/encoderfile/src/services/sentence_embedding.rs b/encoderfile/src/services/sentence_embedding.rs index 115c6322..465e3424 100644 --- a/encoderfile/src/services/sentence_embedding.rs +++ b/encoderfile/src/services/sentence_embedding.rs @@ -15,7 +15,10 @@ impl Inference for AppState { fn inference(&self, request: impl Into) -> Result { let request = request.into(); - let encodings = self.tokenizer.encode_text(request.inputs)?; + let encodings = self + .model_input_state + .tokenizer + .encode_text(request.inputs)?; let transform = SentenceEmbeddingTransform::new(self.lua_libs.clone(), self.transform_str())?; diff --git a/encoderfile/src/services/sequence_classification.rs b/encoderfile/src/services/sequence_classification.rs index 52af7313..f354cb21 100644 --- a/encoderfile/src/services/sequence_classification.rs +++ b/encoderfile/src/services/sequence_classification.rs @@ -15,7 +15,10 @@ impl Inference for AppState { fn inference(&self, request: impl Into) -> Result { let request = request.into(); - let encodings = self.tokenizer.encode_text(request.inputs)?; + let encodings = self + .model_input_state + .tokenizer + .encode_text(request.inputs)?; let transform = SequenceClassificationTransform::new(self.lua_libs.clone(), self.transform_str())?; @@ -23,7 +26,7 @@ impl Inference for AppState { let results = inference::sequence_classification::sequence_classification( self.session.lock(), &transform, - &self.model_config, + &self.task_state, encodings, )?; diff --git a/encoderfile/src/services/token_classification.rs b/encoderfile/src/services/token_classification.rs index 2fd12329..fe7e3402 100644 --- a/encoderfile/src/services/token_classification.rs +++ b/encoderfile/src/services/token_classification.rs @@ -17,7 +17,10 @@ impl Inference for AppState { let session = self.session.lock(); - let encodings = self.tokenizer.encode_text(request.inputs)?; + let encodings = self + .model_input_state + .tokenizer + .encode_text(request.inputs)?; let transform = TokenClassificationTransform::new(self.lua_libs.clone(), self.transform_str())?; @@ -25,7 +28,7 @@ impl Inference for AppState { let results = inference::token_classification::token_classification( session, &transform, - &self.model_config, + &self.task_state, encodings, )?; diff --git a/encoderfile/src/transforms/engine/image_classification.rs b/encoderfile/src/transforms/engine/image_classification.rs new file mode 100644 index 00000000..be9fd8a6 --- /dev/null +++ b/encoderfile/src/transforms/engine/image_classification.rs @@ -0,0 +1,206 @@ +use crate::{common::model_type, error::ApiError, runtime::ImageInputState}; + +use super::{super::image::Image, super::tensor::Tensor, Postprocessor, Preprocessor, Transform}; +use ndarray::{Array2, Ix2}; + +impl Postprocessor for Transform { + type Input = Array2; + type Output = Array2; + + fn postprocess(&self, data: Self::Input) -> Result { + let func = match self.postprocessor() { + Some(p) => p, + None => return Ok(data), + }; + + let expected_shape = data.shape().to_owned(); + + let tensor = Tensor(data.into_dyn()); + + let result = func + .call::(tensor) + .map_err(|e| ApiError::LuaError(e.to_string()))? + .into_inner() + .into_dimensionality::().map_err(|e| { + tracing::error!("Failed to cast array into Ix2: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, num_classes]"); + ApiError::LuaError("Error postprocessing image classifications".to_string()) + })?; + + let result_shape = result.shape(); + + if expected_shape.as_slice() != result_shape { + tracing::error!( + "Transform error: expected tensor of shape {:?}, got tensor of shape {:?}", + expected_shape.as_slice(), + result_shape + ); + + return Err(ApiError::LuaError( + "Error postprocessing image classifications".to_string(), + )); + } + + Ok(result) + } +} + +impl Preprocessor for Transform { + type Input = (Image, ImageInputState); + type Output = Tensor; + + fn preprocess(&self, (image, config): Self::Input) -> Result { + let func = match self.preprocessor() { + Some(p) => p, + None => { + return Err(ApiError::InternalError( + "No preprocessor defined for this model", + )); + } + }; + + self.lua + .globals() + .set("input_config", config) + .map_err(|e| ApiError::LuaError(e.to_string()))?; + + func.call::(image) + .map_err(|e| ApiError::LuaError(e.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transforms::DEFAULT_LIBS; + + #[test] + fn test_image_cls_no_transform() { + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some("".to_string()), + ) + .expect("Failed to create Transform"); + + let arr = ndarray::Array2::::from_elem((32, 16), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_image_cls_identity_transform() { + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" + function Postprocess(arr) + return arr + end + "## + .to_string(), + ), + ) + .expect("Failed to create engine"); + + let arr = ndarray::Array2::::from_elem((16, 32), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_image_cls_transform_bad_fn() { + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" + function Postprocess(arr) + return 1 + end + "## + .to_string(), + ), + ) + .expect("Failed to create engine"); + + let arr = ndarray::Array2::::from_elem((16, 32), 2.0); + + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()) + } + + #[test] + fn test_bad_dimensionality_transform_postprocessing() { + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" + function Postprocess(x) + return x:sum_axis(1) + end + "## + .to_string(), + ), + ) + .unwrap(); + + let arr = ndarray::Array2::::from_elem((3, 3), 2.0); + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()); + + if let Err(e) = result { + match e { + ApiError::LuaError(s) => { + assert!(s.contains("Error postprocessing image classifications")) + } + _ => panic!("Didn't return lua error"), + } + } + } + + #[test] + fn test_image_preprocess() { + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" + function Preprocess(img) + return img:resize(input_config.size_height, input_config.size_width):to_array(input_config.num_channels) + end + "## + .to_string(), + ), + ) + .expect("Failed to create engine"); + + let img = image::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + + let config = ImageInputState { + config: crate::runtime::ImageConfig { + num_channels: 3, + image_size: Some(224), + }, + preprocessing: crate::runtime::ImagePreprocessing { + rescale_factor: None, + image_mean: None, + image_std: None, + do_normalize: None, + do_rescale: None, + do_resize: None, + image_processor_type: None, + size: Some(crate::runtime::ImageSize { + height: Some(224), + width: Some(224), + shortest_edge: None, + }), + }, + }; + let result = engine.preprocess((Image(img), config)).expect("Failed"); + + assert!(result.into_inner().shape() == [3, 224, 224]); + } +} diff --git a/encoderfile/src/transforms/engine/mod.rs b/encoderfile/src/transforms/engine/mod.rs index 73cee6e2..bc6e9712 100644 --- a/encoderfile/src/transforms/engine/mod.rs +++ b/encoderfile/src/transforms/engine/mod.rs @@ -13,6 +13,7 @@ use super::tensor::Tensor; use mlua::prelude::*; mod embedding; +mod image_classification; mod sentence_embedding; mod sequence_classification; mod token_classification; @@ -86,6 +87,12 @@ transform!(EmbeddingTransform, Embedding); transform!(SequenceClassificationTransform, SequenceClassification); transform!(TokenClassificationTransform, TokenClassification); transform!(SentenceEmbeddingTransform, SentenceEmbedding); +transform!(ImageClassificationTransform, ImageClassification); + +pub trait TransformSpec { + fn has_postprocessor(&self) -> bool; + fn has_preprocessor(&self) -> bool; +} pub trait Postprocessor: TransformSpec { type Input; @@ -94,14 +101,18 @@ pub trait Postprocessor: TransformSpec { fn postprocess(&self, data: Self::Input) -> Result; } -pub trait TransformSpec { - fn has_postprocessor(&self) -> bool; +pub trait Preprocessor: TransformSpec { + type Input; + type Output; + + fn preprocess(&self, data: Self::Input) -> Result; } #[derive(Debug)] pub struct Transform { #[allow(dead_code)] lua: Lua, + preprocessor: Option, postprocessor: Option, _marker: PhantomData, } @@ -111,6 +122,10 @@ impl Transform { &self.postprocessor } + fn preprocessor(&self) -> &Option { + &self.preprocessor + } + #[tracing::instrument(name = "new_transform", skip_all)] pub fn new(libs: Vec, transform: Option) -> Result { let lua = new_lua(libs)?; @@ -124,8 +139,14 @@ impl Transform { .get::>("Postprocess") .map_err(|e| ApiError::LuaError(e.to_string()))?; + let preprocessor = lua + .globals() + .get::>("Preprocess") + .map_err(|e| ApiError::LuaError(e.to_string()))?; + Ok(Self { lua, + preprocessor, postprocessor, _marker: PhantomData, }) @@ -136,6 +157,10 @@ impl TransformSpec for Transform { fn has_postprocessor(&self) -> bool { self.postprocessor.is_some() } + + fn has_preprocessor(&self) -> bool { + self.preprocessor.is_some() + } } fn new_lua(libs: Vec) -> Result { diff --git a/encoderfile/src/transforms/image/mod.rs b/encoderfile/src/transforms/image/mod.rs new file mode 100644 index 00000000..a01573a3 --- /dev/null +++ b/encoderfile/src/transforms/image/mod.rs @@ -0,0 +1,88 @@ +use super::Tensor; +use image::{DynamicImage, GenericImageView}; +use mlua::prelude::*; +use ndarray::Array3; + +const DEFAULT_FILTER_TYPE: image::imageops::FilterType = image::imageops::FilterType::Triangle; + +#[derive(Debug, Clone)] +pub struct Image(pub DynamicImage); + +impl Image { + pub fn into_inner(&self) -> &DynamicImage { + &self.0 + } +} + +impl FromLua for Image { + fn from_lua(value: LuaValue, _lua: &Lua) -> Result { + match value { + LuaValue::UserData(data) => data.borrow::().map(|i| i.to_owned()), + _ => Err(LuaError::external( + format!("Unknown type: {}", value.type_name()).as_str(), + )), + } + } +} + +fn dyn_image_to_array3(image: &DynamicImage, num_channels: u32) -> Array3 { + // TODO num_channels is tied to the format we convert to + let raw = image.to_rgb8().into_raw(); + let (h_us, w_us) = image.dimensions(); + let h_us: usize = h_us as usize; + let w_us: usize = w_us as usize; + let nc_us: usize = num_channels as usize; + + // Build CHW array directly from raw HWC bytes, avoiding an intermediate array and transpose. + Array3::from_shape_fn((nc_us, h_us, w_us), |(c, y, x)| { + raw[y * w_us * nc_us + x * nc_us + c] as f32 + }) +} + +fn resize_image(image: &DynamicImage, height: u32, width: u32) -> DynamicImage { + image.resize_exact(width, height, DEFAULT_FILTER_TYPE) +} + +impl LuaUserData for Image { + fn add_methods>(methods: &mut M) { + // tensor ops + methods.add_method("to_array", |_, this, num_channels| { + Ok(Tensor( + dyn_image_to_array3(this.into_inner(), num_channels).into_dyn(), + )) + }); + methods.add_method("resize", |_, this, (height, width)| { + Ok(Image(resize_image(this.into_inner(), height, width))) + }); + } +} + +#[cfg(test)] +fn load_env() -> Lua { + Lua::new() +} + +#[test] +fn test_resize_image() { + use image::GenericImageView; + let img = image::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + assert_ne!(img.dimensions(), (224, 224)); + let lua = load_env(); + let img_val = Image(img); + lua.globals().set("img", img_val).unwrap(); + let resized: Image = lua.load("return img:resize(224, 224)").eval().unwrap(); + assert_eq!(resized.into_inner().dimensions(), (224, 224)); +} + +#[test] +fn test_image_to_array() { + let img = image::open("../test-pictures/yoga02.jpg").expect("Failed to open test image"); + let lua = load_env(); + let img_val = Image(img); + lua.globals().set("img", img_val).unwrap(); + let array: Tensor = lua + .load("return img:resize(224,224):to_array(3)") + .eval() + .unwrap(); + assert_eq!(array.into_inner().shape(), &[3, 224, 224]); +} diff --git a/encoderfile/src/transforms/mod.rs b/encoderfile/src/transforms/mod.rs index 6f00903a..e1489a7c 100644 --- a/encoderfile/src/transforms/mod.rs +++ b/encoderfile/src/transforms/mod.rs @@ -1,8 +1,10 @@ mod engine; +mod image; mod tensor; mod utils; pub use engine::*; +pub use image::Image; pub use tensor::Tensor; pub const DEFAULT_LIBS: [mlua::StdLib; 3] = [ diff --git a/encoderfile/src/transport/cli.rs b/encoderfile/src/transport/cli.rs index 8a934fb4..c15f3681 100644 --- a/encoderfile/src/transport/cli.rs +++ b/encoderfile/src/transport/cli.rs @@ -1,9 +1,9 @@ use crate::{ common::{ - FromCliInput, ModelType, - model_type::{self, ModelTypeSpec}, + FromCliInput, + model_type::{self, ModelType, ModelTypeSpec}, }, - runtime::{EncoderfileLoader, EncoderfileState, ORTExecutionProvider}, + runtime::{EncoderfileLoader, EncoderfileState, InputType, ORTExecutionProvider, TaskType}, services::{Inference, Metadata}, transport::{ grpc::GrpcRouter, @@ -18,7 +18,7 @@ use opentelemetry::trace::TracerProvider as _; use opentelemetry_otlp::{Protocol, WithExportConfig}; use opentelemetry_sdk::trace::SdkTracerProvider; use std::{ - fmt::Display, + fmt::{Debug, Display}, io::{Read, Seek, Write}, sync::Arc, }; @@ -110,9 +110,9 @@ pub enum Commands { } impl Commands { - pub async fn execute<'a, R: Read + Seek>( + pub async fn execute<'loader, R: Read + Seek>( self, - loader: &mut EncoderfileLoader<'a, R>, + loader: &mut EncoderfileLoader<'loader, R>, ) -> Result<()> { match loader.model_type() { ModelType::Embedding => { @@ -131,14 +131,28 @@ impl Commands { self.execute_from_loader::(loader) .await } + ModelType::ImageClassification => { + self.execute_from_loader::(loader) + .await + } } } - pub async fn execute_from_loader<'a, R: Read + Seek, T: ModelTypeSpec>( + pub async fn execute_from_loader< + 'loader, + R: Read + Seek, + T: ModelTypeSpec + InputType + TaskType, + >( self, - loader: &mut EncoderfileLoader<'a, R>, + loader: &mut EncoderfileLoader<'loader, R>, ) -> Result<()> where Arc>: Inference + GrpcRouter + HttpRouter + McpRouter + CliRoute, + ::State: Debug, + ::State: Debug, + for<'b> ::State: + TryFrom<&'b mut EncoderfileLoader<'loader, R>, Error = anyhow::Error>, + for<'b> ::State: + TryFrom<&'b mut EncoderfileLoader<'loader, R>, Error = anyhow::Error>, { match self { Commands::Serve { @@ -161,15 +175,15 @@ impl Commands { onnx_args.graph_optimization_level(), )? .into(); - let model_config = loader.model_config()?; - let tokenizer = loader.tokenizer()?; let config = loader.encoderfile_config()?; let state = Arc::new(EncoderfileState::::new( config, session, - tokenizer, - model_config, + ::State::try_from(loader) + .expect("could not load model input state from file"), + ::State::try_from(loader) + .expect("could not load model task state from file"), )); let banner = crate::get_banner(state.model_id().as_str()); @@ -225,15 +239,15 @@ impl Commands { )? .into(); - let model_config = loader.model_config()?; - let tokenizer = loader.tokenizer()?; let config = loader.encoderfile_config()?; let state = Arc::new(EncoderfileState::::new( config, session, - tokenizer, - model_config, + ::State::try_from(loader) + .expect("could not load model input state from file"), + ::State::try_from(loader) + .expect("could not load model task state from file"), )); setup_tracing(None)?; @@ -255,15 +269,15 @@ impl Commands { )? .into(); - let model_config = loader.model_config()?; - let tokenizer = loader.tokenizer()?; let config = loader.encoderfile_config()?; let state = Arc::new(EncoderfileState::::new( config, session, - tokenizer, - model_config, + ::State::try_from(loader) + .expect("could not load model input state from file"), + ::State::try_from(loader) + .expect("could not load model input state from file"), )); let banner = crate::get_banner(state.model_id().as_str()); diff --git a/encoderfile/src/transport/grpc/mod.rs b/encoderfile/src/transport/grpc/mod.rs index 162de39a..9cc81ee2 100644 --- a/encoderfile/src/transport/grpc/mod.rs +++ b/encoderfile/src/transport/grpc/mod.rs @@ -1,6 +1,9 @@ use crate::{ common::model_type, - generated::{embedding, sentence_embedding, sequence_classification, token_classification}, + generated::{ + embedding, image_classification, sentence_embedding, sequence_classification, + token_classification, + }, runtime::AppState, services::{Inference, Metadata}, }; @@ -71,6 +74,7 @@ macro_rules! generate_grpc_server { tonic::Response<$crate::generated::metadata::GetModelMetadataResponse>, tonic::Status, > { + println!("And the metadata is...: {:?}", self.state.metadata()); Ok(tonic::Response::new(self.state.metadata().into())) } } @@ -116,3 +120,13 @@ generate_grpc_server!( SentenceEmbeddingInference, SentenceEmbeddingInferenceServer ); + +generate_grpc_server!( + ImageClassification, + image_classification, + image_classification_inference_server, + ImageClassificationRequest, + ImageClassificationResponse, + ImageClassificationInference, + ImageClassificationInferenceServer +); diff --git a/encoderfile/src/transport/http/example.md b/encoderfile/src/transport/http/example.md new file mode 100644 index 00000000..3eb1bfc9 --- /dev/null +++ b/encoderfile/src/transport/http/example.md @@ -0,0 +1,316 @@ +# Multipart OpenAPI Service Example + +This document provides examples of how to interact with the multipart file upload and prediction endpoint. + +## Endpoint Overview + +- **POST /predict/multipart** - Submit a JSON payload with binary file attachments +- **GET /predict/multipart/openapi.json** - Retrieve the OpenAPI specification + +## Example 1: cURL with Two Image Files + +```bash +curl -X POST http://localhost:8080/predict/multipart \ + -F "payload={\"model_version\": \"1.0\", \"threshold\": 0.8}" \ + -F "files=@/path/to/image1.png" \ + -F "files=@/path/to/image2.jpg" +``` + +### Request Body (multipart/form-data) + +``` +--boundary_123abc456def +Content-Disposition: form-data; name="payload" +Content-Type: application/json + +{"model_version": "1.0", "threshold": 0.8} +--boundary_123abc456def +Content-Disposition: form-data; name="files"; filename="image1.png" +Content-Type: image/png + + +--boundary_123abc456def +Content-Disposition: form-data; name="files"; filename="image2.jpg" +Content-Type: image/jpeg + + +--boundary_123abc456def-- +``` + +### Response + +```json +{ + "payload": { + "model_version": "1.0", + "threshold": 0.8 + }, + "attachment_count": 2, + "attachments": [ + { + "file_name": "image1.png", + "content_type": "image/png", + "size_bytes": 45230 + }, + { + "file_name": "image2.jpg", + "content_type": "image/jpeg", + "size_bytes": 52104 + } + ] +} +``` + +## Example 2: Python Requests Library + +```python +import requests +import json + +url = "http://localhost:8080/predict/multipart" + +# Prepare the payload +payloaquest Body (multipart/form-data) + +``` +--boundary_xyz789pqr012 +Content-Disposition: form-data; name="payload"; filename="payload.json" +Content-Type: application/json + +{"model_version": "1.0", "threshold": 0.8, "batch_id": "batch_12345"} +--boundary_xyz789pqr012 +Content-Disposition: form-data; name="files"; filename="image1.png" +Content-Type: image/png + + +--boundary_xyz789pqr012 +Content-Disposition: form-data; name="files"; filename="image2.jpg" +Content-Type: image/jpeg + + +--boundary_xyz789pqr012 +Content-Disposition: form-data; name="files"; filename="document.pdf" +Content-Type: application/pdf + + +--boundary_xyz789pqr012-- +``` + +### Red = { + "model_version": "1.0", + "threshold": 0.8, + "batch_id": "batch_12345" +} + +# Prepare files +files = [ + ("payload", ("payload.json", json.dumps(payload), "application/json")), + ("files", ("image1.png", open("image1.png", "rb"), "image/png")), + ("files", ("image2.jpg", open("image2.jpg", "rb"), "image/jpeg")), + ("files", ("document.pdf", open("document.pdf", "rb"), "application/pdf")), +] + +# Send the request +response = requests.post(url, files=files) + +print("Status Code:", response.status_code) +print("Response:", response.json()) +``` + +### Response + +```json +{ + "payload": { + "model_version": "1.0", + "threshold": 0.8, + "batch_id": "batch_12345" + }, + "attachment_count": 3, + "attachments": [ + { + "file_name": "image1.png", + "content_type": "image/png", + quest Body (multipart/form-data) + +``` +--boundary_webkit_abc123 +Content-Disposition: form-data; name="payload" + +{"model_version":"1.0","threshold":0.8,"inference_id":"inf_abc123"} +--boundary_webkit_abc123 +Content-Disposition: form-data; name="files"; filename="photo1.jpg" +Content-Type: image/jpeg + + +--boundary_webkit_abc123 +Content-Disposition: form-data; name="files"; filename="photo2.jpg" +Content-Type: image/jpeg + + +--boundary_webkit_abc123-- +``` + +### Re"size_bytes": 45230 + }, + { + "file_name": "image2.jpg", + "content_type": "image/jpeg", + "size_bytes": 52104 + }, + { + "file_name": "document.pdf", + "content_type": "application/pdf", + "size_bytes": 128512 + } + ] +} +``` + +## Example 3: JavaScript Fetch API + +```javascript +const payload = { + model_version: "1.0", + threshold: 0.8, + inference_id: "inf_abc123" +}; + +const formData = new FormData(); + +// Add the JSON payload as a form field +formData.append("payload", JSON.stringify(payload)); + +// Add multiple binary files +const imageFile1 = document.getElementById("imageInput1").files[0]; +const imageFile2 = document.getElementById("imageInput2").files[0]; + +formData.append("files", imageFile1); +formData.append("files", imageFile2); + +// Make the request +const response = await fetch("http://localhost:8080/predict/multipart", { + method: "POST", + body: formData +}); + +const result = await response.json(); +console.log("Success:", result); +``` + +### Response + +```json +{ + "payload": { + "model_version": "1.0", + "threshold": 0.8, + "inference_id": "inf_abc123" + }, + "attachment_count": 2, + "attachments": [ + { + "file_name": "photo1.jpg", + "content_type": "image/jpeg", + "size_bytes": 245120 + }, + { + "file_name": "photo2.jpg", + "content_type": "image/jpeg", + "size_bytes": 187904 + } + ] +} +``` + +## Example 4: Error Handling + +### Missing Payload + +If the request is sent without a `payload` form field: + +```bash +curl -X POST http://localhost:8080/predict/multipart \ + -F "files=@/path/to/image.png" +``` + +**Response (422 Unprocessable Entity):** + +``` +missing required multipart field 'payload' +``` + +### Invalid JSON in Payload + +If the payload field contains invalid JSON: + +```bash +curl -X POST http://localhost:8080/predict/multipart \ + -F "payload=not valid json" \ + -F "files=@/path/to/image.png" +``` + +**Response (422 Unprocessable Entity):** + +``` +invalid json in 'payload' field +``` + +### Malformed Multipart Body + +If the multipart encoding is corrupted: + +**Response (400 Bad Request):** + +``` +multipart parse error: [error details] +``` + +## Request Parts Specification + +### Required: `payload` Part + +- **Name**: `payload` (exactly one) +- **Content-Type**: `application/json` (recommended) +- **Content**: Valid JSON object or array + +### Optional: `files` Parts + +- **Name**: `files` (zero or more) +- **Content-Type**: Any MIME type (e.g., `image/png`, `application/pdf`) +- **Content**: Binary data +- **Filename**: Optional but recommended (used in response metadata) + +## Response Structure + +```json +{ + "payload": "...", // Echo of the submitted JSON payload + "attachment_count": 3, // Number of files attached + "attachments": [ // Metadata for each file + { + "file_name": "...", // Original filename if provided, null otherwise + "content_type": "...", // MIME type if provided, null otherwise + "size_bytes": 12345 // File size in bytes + } + ] +} +``` + +## HTTP Status Codes + +| Status | Meaning | Condition | +|--------|---------|-----------| +| 200 | OK | Request processed successfully | +| 400 | Bad Request | Malformed multipart body | +| 422 | Unprocessable Entity | Missing `payload` or invalid JSON | + +## OpenAPI Specification + +To retrieve the OpenAPI specification for this endpoint: + +```bash +curl -X GET http://localhost:8080/predict/multipart/openapi.json +``` + +This returns a machine-readable OpenAPI 3.0 document describing the endpoint. diff --git a/encoderfile/src/transport/http/mod.rs b/encoderfile/src/transport/http/mod.rs index f5b5ffd1..cc526318 100644 --- a/encoderfile/src/transport/http/mod.rs +++ b/encoderfile/src/transport/http/mod.rs @@ -1,5 +1,6 @@ mod base; mod error; +pub mod multipart_openapi; pub trait HttpRouter where diff --git a/encoderfile/src/transport/http/multipart_openapi.rs b/encoderfile/src/transport/http/multipart_openapi.rs new file mode 100644 index 00000000..a07a793d --- /dev/null +++ b/encoderfile/src/transport/http/multipart_openapi.rs @@ -0,0 +1,315 @@ +use crate::common::model_type::ImageClassification; +use crate::common::{ImageClassificationRequest, ImageClassificationResponse, ImageInfo}; +use crate::runtime::AppState; +use crate::services::Inference; +use axum::{ + Json, + extract::{Multipart, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use utoipa::OpenApi; + +pub const MULTIPART_PREDICT_ENDPOINT: &str = "/predict/multipart"; +pub const MULTIPART_OPENAPI_ENDPOINT: &str = "/predict/multipart/openapi.json"; + +#[derive(Debug, Serialize, Deserialize, utoipa::ToSchema)] +pub struct MultipartPredictBody { + /// Arbitrary JSON payload sent in the multipart part named `payload`. + pub payload: serde_json::Value, + + /// Binary attachments sent as repeated `files` multipart parts. + #[schema(value_type = Vec)] + pub files: Vec, +} + +#[derive(Debug, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ParsedAttachment { + pub file_name: Option, + pub content_type: Option, + pub size_bytes: usize, +} + +#[derive(Debug, thiserror::Error)] +pub enum MultipartApiError { + #[error("missing required multipart field 'payload'")] + MissingPayload, + #[error("invalid json in 'payload' field")] + InvalidPayload, + #[error("multipart parse error: {0}")] + Multipart(String), + #[error("failed to construct request from multipart: {0}")] + RequestConstruction(String), +} + +impl IntoResponse for MultipartApiError { + fn into_response(self) -> Response { + let status = match self { + Self::MissingPayload | Self::InvalidPayload => StatusCode::UNPROCESSABLE_ENTITY, + Self::RequestConstruction(_) => StatusCode::UNPROCESSABLE_ENTITY, + Self::Multipart(_) => StatusCode::BAD_REQUEST, + }; + + (status, self.to_string()).into_response() + } +} + +/// Trait for converting multipart payload and attachments into a typed request. +pub trait FromMultipart: Sized { + /// Construct an instance from a JSON payload and list of attachment bytes. + fn from_multipart( + payload: serde_json::Value, + attachments: Vec<(Option, Option, bytes::Bytes)>, + ) -> Result; +} + +impl FromMultipart for ImageClassificationRequest { + fn from_multipart( + payload: serde_json::Value, + attachments: Vec<(Option, Option, bytes::Bytes)>, + ) -> Result { + let images = attachments + .into_iter() + .map(|(_file_name, _content_type, image_bytes)| { + let format = image::guess_format(&image_bytes).map_err(|e| { + MultipartApiError::RequestConstruction(format!( + "Failed to detect image format: {}", + e + )) + })?; + Ok(ImageInfo { + image_bytes, + image_format: format, + }) + }) + .collect::, _>>()?; + + let metadata = if payload.is_null() || payload == serde_json::json!({}) { + Some(HashMap::default()) + } else { + serde_json::from_value(payload) + .ok() + .or(Some(HashMap::default())) + }; + + Ok(Self { images, metadata }) + } +} + +#[derive(Debug, utoipa::OpenApi)] +#[openapi( + paths(post_multipart), + components(schemas(MultipartPredictBody, ImageClassificationResponse, ParsedAttachment)) +)] +pub struct MultipartApiDoc; + +#[utoipa::path( + get, + path = MULTIPART_OPENAPI_ENDPOINT, + responses( + (status = 200, description = "Successful") + ) +)] +pub async fn openapi() -> impl IntoResponse { + Json(MultipartApiDoc::openapi()) +} + +#[utoipa::path( + post, + path = MULTIPART_PREDICT_ENDPOINT, + request_body( + content = MultipartPredictBody, + content_type = "multipart/form-data", + description = "Multipart payload with a JSON part named 'payload' and 0..N binary parts named 'files'" + ), + responses( + (status = 200, body = ImageClassificationResponse), + (status = 422, description = "Missing or invalid payload JSON"), + (status = 400, description = "Invalid multipart body") + ) +)] +pub async fn post_multipart( + state: State>, + mut multipart: Multipart, +) -> Result, MultipartApiError> { + parse_multipart(state, &mut multipart).await +} + +/// Generic multipart parser that extracts payload and attachments. +pub async fn parse_multipart( + State(state): State>, + multipart: &mut Multipart, +) -> Result, MultipartApiError> { + let mut payload: Option = None; + let mut attachments = Vec::new(); + let mut attachment_metadata = Vec::new(); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| MultipartApiError::Multipart(e.to_string()))? + { + let name = field.name().map(ToOwned::to_owned); + let file_name = field.file_name().map(ToOwned::to_owned); + let content_type = field.content_type().map(ToOwned::to_owned); + let bytes = field + .bytes() + .await + .map_err(|e| MultipartApiError::Multipart(e.to_string()))?; + + match name.as_deref() { + Some("payload") => { + payload = Some( + serde_json::from_slice(&bytes) + .map_err(|_| MultipartApiError::InvalidPayload)?, + ); + } + Some("files") => { + attachment_metadata.push(ParsedAttachment { + file_name: file_name.clone(), + content_type: content_type.clone(), + size_bytes: bytes.len(), + }); + attachments.push((file_name, content_type, bytes)); + } + _ => {} + } + } + + let payload = payload.ok_or(MultipartApiError::MissingPayload)?; + + // Convert to typed request + let request = ImageClassificationRequest::from_multipart(payload.clone(), attachments)?; + let result = state + .inference(request) + .map(Json) + .map_err(|e| MultipartApiError::RequestConstruction(format!("Inference error: {}", e)))?; + + Ok(result) +} + +/// Generic handler that converts multipart request into typed request. +pub async fn post_multipart_typed( + State(state): State>, + mut multipart: Multipart, +) -> Result, MultipartApiError> { + let mut payload: Option = None; + let mut attachments = Vec::new(); + let mut attachment_metadata = Vec::new(); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| MultipartApiError::Multipart(e.to_string()))? + { + let name = field.name().map(ToOwned::to_owned); + let file_name = field.file_name().map(ToOwned::to_owned); + let content_type = field.content_type().map(ToOwned::to_owned); + let bytes = field + .bytes() + .await + .map_err(|e| MultipartApiError::Multipart(e.to_string()))?; + + match name.as_deref() { + Some("payload") => { + payload = Some( + serde_json::from_slice(&bytes) + .map_err(|_| MultipartApiError::InvalidPayload)?, + ); + } + Some("files") => { + attachment_metadata.push(ParsedAttachment { + file_name: file_name.clone(), + content_type: content_type.clone(), + size_bytes: bytes.len(), + }); + attachments.push((file_name, content_type, bytes)); + } + _ => {} + } + } + + let payload = payload.ok_or(MultipartApiError::MissingPayload)?; + + // Convert to typed request + let request = ImageClassificationRequest::from_multipart(payload.clone(), attachments)?; + let result = state + .inference(request) + .map(Json) + .map_err(|e| MultipartApiError::RequestConstruction(format!("Inference error: {}", e)))?; + + Ok(result) +} + +/// HttpRouter implementation for ImageClassification model type. +/// Combines standard model serving endpoints with multipart file upload capability. +impl super::HttpRouter for crate::runtime::AppState { + fn http_router(self) -> axum::Router { + axum::Router::new() + .route("/health", axum::routing::get(super::base::health)) + .route( + "/model", + axum::routing::get(super::base::get_model_metadata::), + ) + .route("/predict", axum::routing::post(predict_handler)) + .route("/openapi.json", axum::routing::get(standard_openapi)) + .route( + MULTIPART_PREDICT_ENDPOINT, + axum::routing::post(post_multipart_image_classification), + ) + .route(MULTIPART_OPENAPI_ENDPOINT, axum::routing::get(openapi)) + .with_state(self) + } +} + +/// Multipart handler specialized for ImageClassificationRequest. +async fn post_multipart_image_classification( + state: State>, + multipart: Multipart, +) -> Result, MultipartApiError> { + post_multipart_typed::(state, multipart).await +} + +/// Standard predict endpoint for ImageClassification. +async fn predict_handler( + State(state): State>, + Json(req): Json< as crate::services::Inference>::Input>, +) -> impl IntoResponse { + super::base::predict(State(state), Json(req)).await +} + +/// Standard OpenAPI endpoint for ImageClassification model service (without multipart). +async fn standard_openapi() -> impl IntoResponse { + Json(serde_json::json!({ + "openapi": "3.0.0", + "info": { + "title": "ImageClassification Model API", + "version": "1.0.0" + }, + "paths": { + "/health": { + "get": { + "responses": { + "200": { "description": "Successful" } + } + } + }, + "/model": { + "get": { + "responses": { + "200": { "description": "Successful" } + } + } + }, + "/predict": { + "post": { + "responses": { + "200": { "description": "Successful" } + } + } + } + } + })) +} diff --git a/encoderfile/src/transport/mcp/mod.rs b/encoderfile/src/transport/mcp/mod.rs index 4512c711..b9fb2241 100644 --- a/encoderfile/src/transport/mcp/mod.rs +++ b/encoderfile/src/transport/mcp/mod.rs @@ -3,6 +3,9 @@ use rmcp::transport::streamable_http_server::{ StreamableHttpService, session::local::LocalSessionManager, }; +use crate::common::model_type; +use crate::runtime::AppState; + mod error; pub trait McpRouter @@ -13,7 +16,7 @@ where const NEW_TOOL: fn(Self) -> Self::Tool; // TODO figure out the lifetimes of a state so a ref can be safely passed - fn mcp_router(self) -> axum::Router + fn mcp_router(self) -> Result where ::Tool: rmcp::ServerHandler, { @@ -23,7 +26,42 @@ where Default::default(), ); - axum::Router::new().nest_service("/mcp", service) + Ok(axum::Router::new().nest_service("/mcp", service)) + } +} + +pub struct DummyTool {} + +impl ServerHandler for DummyTool { + fn get_info(&self) -> rmcp::model::ServerInfo { + rmcp::model::ServerInfo { + protocol_version: rmcp::model::ProtocolVersion::V_2025_06_18, + capabilities: rmcp::model::ServerCapabilities::default(), + server_info: rmcp::model::Implementation::default(), + instructions: None, + } + } + + async fn initialize( + &self, + _request: rmcp::model::InitializeRequestParam, + _context: rmcp::service::RequestContext, + ) -> Result { + Err(rmcp::ErrorData { + code: rmcp::model::ErrorCode::INTERNAL_ERROR, + message: std::borrow::Cow::Borrowed("This is a dummy tool with no functionality."), + data: None, + }) + } +} + +impl McpRouter for AppState { + type Tool = DummyTool; + const NEW_TOOL: fn(Self) -> Self::Tool = |_state| Self::Tool {}; + fn mcp_router(self) -> Result { + Err(crate::error::ApiError::InternalError( + "MCP not implemented for ImageClassification model type", + )) } } @@ -151,3 +189,16 @@ generate_mcp!( "Performs sentence embedding of input text sequences.", "This tool will embed a sequence of texts." ); + +// Doesn't use a json schema, see how we can go around this limitation +/* +generate_mcp!( + ImageClassification, + ImageClassificationTool, + image_classification, + ImageClassificationRequest, + ImageClassificationResponse, + "Performs image classification of input images.", + "This tool will classify input images." +); +*/ diff --git a/encoderfile/src/transport/server.rs b/encoderfile/src/transport/server.rs index 6f6cda1f..15f7c13d 100644 --- a/encoderfile/src/transport/server.rs +++ b/encoderfile/src/transport/server.rs @@ -23,14 +23,14 @@ pub async fn run_grpc( "gRPC", state, |state| { - state + Ok(state .clone() .grpc_router() .layer( tower_http::trace::TraceLayer::new_for_grpc() .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), ) - .into_make_service_with_connect_info::() + .into_make_service_with_connect_info::()) }, ) .await @@ -51,14 +51,14 @@ pub async fn run_http( "HTTP", state, |state| { - state + Ok(state .clone() .http_router() .layer( tower_http::trace::TraceLayer::new_for_http() .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), ) - .into_make_service_with_connect_info::() + .into_make_service_with_connect_info::()) }, ) .await @@ -79,16 +79,15 @@ pub async fn run_mcp( "MCP", state, |state| { - state - .clone() - .mcp_router() + let mcp_router = state.clone().mcp_router()?; + Ok(mcp_router .layer( tower_http::trace::TraceLayer::new_for_http() // TODO check if otel is enabled // .make_span_with(crate::middleware::format_span) .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), ) - .into_make_service_with_connect_info::() + .into_make_service_with_connect_info::()) }, ) .await @@ -101,11 +100,16 @@ async fn serve_with_optional_tls( maybe_key_file: Option, server_type_str: &str, state: S, - into_service_fn: impl Fn(&S) -> IntoMakeServiceWithConnectInfo, + into_service_fn: impl Fn( + &S, + ) -> Result< + IntoMakeServiceWithConnectInfo, + crate::error::ApiError, + >, ) -> Result<()> { let addr = format!("{}:{}", &hostname, &port); - let router = into_service_fn(&state); + let router = into_service_fn(&state)?; let model_type = state.model_type(); diff --git a/encoderfile/tests/test_grpc.rs b/encoderfile/tests/test_grpc.rs index 5cd95276..5927bb46 100644 --- a/encoderfile/tests/test_grpc.rs +++ b/encoderfile/tests/test_grpc.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +use std::fs::File; +use std::io::Read; use encoderfile::{ dev_utils::*, @@ -6,6 +8,11 @@ use encoderfile::{ embedding::{ EmbeddingRequest, EmbeddingResponse, embedding_inference_server::EmbeddingInference, }, + image_classification::{ + ImageClassificationRequest, ImageClassificationResponse, + image_classification_inference_server::ImageClassificationInference, + }, + image_types::ImageInput, metadata::{GetModelMetadataRequest, GetModelMetadataResponse}, sentence_embedding::{ SentenceEmbeddingRequest, SentenceEmbeddingResponse, @@ -46,8 +53,6 @@ macro_rules! test_grpc_service { .unwrap() .into_inner(); - println!("Model metadata: {:?}", response); - if $has_labels { assert!(!response.id2label.is_empty(), "id2label is an empty dict") } else { @@ -140,3 +145,29 @@ test_grpc_service!( }, SentenceEmbeddingResponse ); + +const TEST_IMAGE_PATH: &str = "../test-pictures/yoga01.jpg"; + +fn get_file_bytes(filename: &str) -> Vec { + let mut file = File::open(filename).expect("Failed to open test image"); + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer) + .expect("Failed to read test image"); + buffer +} + +test_grpc_service!( + image_classification_tests, + { GrpcService::new(image_classification_state()) }, + true, + ImageClassificationRequest { + inputs: [TEST_IMAGE_PATH, TEST_IMAGE_PATH] + .iter() + .map(|s| ImageInput { + image: get_file_bytes(s) + }) + .collect(), + metadata: HashMap::new(), + }, + ImageClassificationResponse +); diff --git a/encoderfile/tests/test_http.rs b/encoderfile/tests/test_http.rs index 441d001b..77d575ae 100644 --- a/encoderfile/tests/test_http.rs +++ b/encoderfile/tests/test_http.rs @@ -130,3 +130,81 @@ test_router_mod!( metadata: None, } ); + +mod image_classification_tests { + use axum::http::{Request, StatusCode}; + use encoderfile::{dev_utils, transport::http::HttpRouter}; + use tower::ServiceExt; + + fn router() -> axum::Router { + let state = dev_utils::image_classification_state(); + state.http_router() + } + + #[tokio::test] + async fn test_predict_route() { + let router = router(); + let img_loc1 = "../test-pictures/yoga01.jpg"; + let img_loc2 = "../test-pictures/yoga02.jpg"; + let img_bytes1 = std::fs::read(img_loc1).unwrap(); + let img_bytes2 = std::fs::read(img_loc2).unwrap(); + let payload = serde_json::json!({ + "inputs": ["yoga01.jpg", "yoga02.jpg"], + "metadata": {} + }); + + let boundary = "----encoderfile-boundary"; + let mut multipart_body = Vec::new(); + + multipart_body.extend_from_slice( + format!( + "--{boundary}\r\nContent-Disposition: form-data; name=\"payload\"\r\nContent-Type: application/json\r\n\r\n{}\r\n", + payload + ) + .as_bytes(), + ); + + multipart_body.extend_from_slice( + format!( + "--{boundary}\r\nContent-Disposition: form-data; name=\"files\"; filename=\"yoga01.jpg\"\r\nContent-Type: image/jpeg\r\n\r\n" + ) + .as_bytes(), + ); + multipart_body.extend_from_slice(&img_bytes1); + multipart_body.extend_from_slice(b"\r\n"); + + multipart_body.extend_from_slice( + format!( + "--{boundary}\r\nContent-Disposition: form-data; name=\"files\"; filename=\"yoga02.jpg\"\r\nContent-Type: image/jpeg\r\n\r\n" + ) + .as_bytes(), + ); + multipart_body.extend_from_slice(&img_bytes2); + multipart_body.extend_from_slice(b"\r\n"); + + multipart_body.extend_from_slice(format!("--{boundary}--\r\n").as_bytes()); + + let request = Request::post("/predict/multipart") + .header( + "Content-Type", + format!("multipart/form-data; boundary={boundary}"), + ) + .body(axum::body::Body::from(multipart_body)) + .unwrap(); + + let resp = router.oneshot(request).await.unwrap(); + + if resp.status() != StatusCode::OK { + panic!("{} {:#?}", resp.status(), resp.body()) + } + + assert_eq!(resp.status(), StatusCode::OK); + + // gather the body into a single bytes object and convert it into a string for easier debugging if the test fails + let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_string = String::from_utf8(body_bytes.to_vec()).unwrap(); + println!("Response body: {}", body_string); + } +} diff --git a/encoderfile/tests/test_mcp.rs b/encoderfile/tests/test_mcp.rs index 55ac6ab0..864181f1 100644 --- a/encoderfile/tests/test_mcp.rs +++ b/encoderfile/tests/test_mcp.rs @@ -1,12 +1,13 @@ use anyhow::Result; use encoderfile::AppState; use encoderfile::common::model_type::ModelTypeSpec; +use encoderfile::runtime::{InputType, TaskType}; use encoderfile::transport::mcp::McpRouter; use tokio::net::TcpListener; use tokio::sync::oneshot; use tower_http::trace::DefaultOnResponse; -async fn run_mcp( +async fn run_mcp( addr: String, state: AppState, shutdown_receiver: oneshot::Receiver<()>, @@ -15,7 +16,7 @@ async fn run_mcp( where AppState: McpRouter, { - let router = state.mcp_router().layer( + let router = state.mcp_router()?.layer( tower_http::trace::TraceLayer::new_for_http() // TODO check if otel is enabled // .make_span_with(crate::middleware::format_span) diff --git a/encoderfile/tests/test_model_validation.rs b/encoderfile/tests/test_model_validation.rs index 25d971b4..edb9aa64 100644 --- a/encoderfile/tests/test_model_validation.rs +++ b/encoderfile/tests/test_model_validation.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use encoderfile::{builder::model::ModelTypeExt as _, common::ModelType}; +use encoderfile::{builder::model::ModelTypeExt as _, common::model_type::ModelType}; #[test] pub fn test_embedding() { @@ -45,3 +45,15 @@ pub fn test_sequence_classification() { .is_ok() ); } + +#[test] +pub fn test_image_classification() { + let path = PathBuf::from("../models/image_classification/model.onnx"); + + assert!(ModelType::ImageClassification.validate_model(&path).is_ok()); + assert!( + ModelType::TokenClassification + .validate_model(&path) + .is_err() + ); +} diff --git a/encoderfile/tests/test_models.rs b/encoderfile/tests/test_models.rs index 44718dc0..fed03a82 100644 --- a/encoderfile/tests/test_models.rs +++ b/encoderfile/tests/test_models.rs @@ -10,6 +10,7 @@ fn test_embedding_model() { let state = embedding_state(); let encodings = state + .model_input_state .tokenizer .encode_text(vec![ "hello world".to_string(), @@ -34,6 +35,7 @@ fn test_embedding_inference_with_bad_model() { let state = token_classification_state(); let encodings = state + .model_input_state .tokenizer .encode_text(vec![ "hello world".to_string(), @@ -54,6 +56,7 @@ fn test_sequence_classification_model() { let state = sequence_classification_state(); let encodings = state + .model_input_state .tokenizer .encode_text(vec![ "hello world".to_string(), @@ -69,7 +72,7 @@ fn test_sequence_classification_model() { let results = sequence_classification( session_lock, &transform, - &state.model_config, + &state.task_state, encodings.clone(), ) .expect("Failed to compute results"); @@ -77,12 +80,15 @@ fn test_sequence_classification_model() { assert!(results.len() == encodings.len()); } +// FIXME doesn't compile +/* #[test] #[should_panic] fn test_sequence_classification_inference_with_bad_model() { let state = embedding_state(); let encodings = state + .per_model_input_state .tokenizer .encode_text(vec![ "hello world".to_string(), @@ -98,17 +104,19 @@ fn test_sequence_classification_inference_with_bad_model() { sequence_classification( session_lock, &transform, - &state.model_config, + &state.per_task_state, encodings.clone(), ) .expect("Failed to compute results"); } +*/ #[test] fn test_token_classification_model() { let state = token_classification_state(); let encodings = state + .model_input_state .tokenizer .encode_text(vec![ "hello world".to_string(), @@ -124,7 +132,7 @@ fn test_token_classification_model() { let results = token_classification( session_lock, &transform, - &state.model_config, + &state.task_state, encodings.clone(), ) .expect("Failed to compute results"); @@ -138,6 +146,7 @@ fn test_token_classification_inference_with_bad_model() { let state = sequence_classification_state(); let encodings = state + .model_input_state .tokenizer .encode_text(vec![ "hello world".to_string(), @@ -153,8 +162,35 @@ fn test_token_classification_inference_with_bad_model() { token_classification( session_lock, &transform, - &state.model_config, + &state.task_state, encodings.clone(), ) .expect("Failed to compute results"); } + +#[test] +fn test_image_classification_model() { + // TODO + /* + let state = embedding_state(); + + let encodings = state + .per_model_input_state + .tokenizer + .encode_text(vec![ + "hello world".to_string(), + "the quick brown fox jumps over the lazy dog".to_string(), + ]) + .expect("Failed to encode text"); + + let session_lock = state.session.lock(); + + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); + + let results = + embedding(session_lock, &transform, encodings.clone()).expect("Failed to compute results"); + + assert!(results.len() == encodings.len()); + */ +} diff --git a/models/image_classification/config.json b/models/image_classification/config.json new file mode 100644 index 00000000..e2309f2a --- /dev/null +++ b/models/image_classification/config.json @@ -0,0 +1,45 @@ +{ + "_name_or_path": "dima806/yoga_pose_image_classification", + "architectures": [ + "ViTForImageClassification" + ], + "attention_probs_dropout_prob": 0.0, + "encoder_stride": 16, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 768, + "id2label": { + "0": "Bridge", + "1": "Child", + "2": "Cobra", + "3": "Downward-Dog", + "4": "Pigeon", + "5": "Standing-Mountain", + "6": "Tree", + "7": "Triangle", + "8": "Warrior" + }, + "image_size": 224, + "initializer_range": 0.02, + "intermediate_size": 3072, + "label2id": { + "Bridge": 0, + "Child": 1, + "Cobra": 2, + "Downward-Dog": 3, + "Pigeon": 4, + "Standing-Mountain": 5, + "Tree": 6, + "Triangle": 7, + "Warrior": 8 + }, + "layer_norm_eps": 1e-12, + "model_type": "vit", + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "patch_size": 16, + "problem_type": "single_label_classification", + "qkv_bias": true, + "transformers_version": "4.37.2" +} diff --git a/models/image_classification/preprocessor_config.json b/models/image_classification/preprocessor_config.json new file mode 100644 index 00000000..02018dec --- /dev/null +++ b/models/image_classification/preprocessor_config.json @@ -0,0 +1,22 @@ +{ + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_processor_type": "ViTFeatureExtractor", + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "resample": 2, + "rescale_factor": 0.00392156862745098, + "size": { + "height": 224, + "width": 224 + } +} diff --git a/test-pictures/yoga01.jpg b/test-pictures/yoga01.jpg new file mode 100644 index 00000000..ecbec5f7 Binary files /dev/null and b/test-pictures/yoga01.jpg differ diff --git a/test-pictures/yoga02.jpg b/test-pictures/yoga02.jpg new file mode 100644 index 00000000..bcf54057 Binary files /dev/null and b/test-pictures/yoga02.jpg differ diff --git a/test_img_class_config.yml b/test_img_class_config.yml new file mode 100644 index 00000000..8653088f --- /dev/null +++ b/test_img_class_config.yml @@ -0,0 +1,8 @@ +# optimum-cli export onnx --model dima806/yoga_pose_image_classification --task image-classification ./yoga +encoderfile: + name: test-img-class + path: models/image_classification + model_type: image_classification + output_path: ./test-img-class.encoderfile + base_binary_path: ./target/debug/encoderfile-runtime + no_download: true