diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5a7eaa8ecde..2766bb4ddb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,17 +38,17 @@ jobs: uses: actions-rs/cargo@v1 with: command: check - args: --all --verbose --locked + args: --workspace --verbose --locked - name: Build uses: actions-rs/cargo@v1 with: command: build - args: --all --verbose --locked + args: --workspace --verbose --locked - name: Test uses: actions-rs/cargo@v1 with: command: test - args: --all --verbose --locked + args: --workspace --verbose --locked compile-to-wasm: name: Compile to WebAssembly @@ -77,15 +77,18 @@ jobs: target: wasm32-unknown-unknown - name: Compile to WebAssembly uses: actions-rs/cargo@v1 - continue-on-error: true with: command: xtask - args: dist --out-dir target/proc-blocks + args: dist --keep-going env: RUST_LOG: xtask=debug + - name: Check all proc-blocks can be loaded + uses: actions-rs/cargo@v1 + with: + command: xtask + args: doc target/proc-blocks/*.wasm - name: Save Compiled proc-blocks uses: actions/upload-artifact@v2 - continue-on-error: true with: name: compiled-proc-blocks path: target/proc-blocks diff --git a/.rustfmt.toml b/.rustfmt.toml index eaf410e5b1e..6e88a104687 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -11,4 +11,4 @@ use_try_shorthand = true normalize_doc_attributes = true report_todo = "Always" report_fixme = "Always" -edition = "2018" +edition = "2021" diff --git a/Cargo.lock b/Cargo.lock index 52615f63865..6af8e03dfb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,7 +6,6 @@ version = 3 name = "accuracy" version = "0.12.2" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", @@ -28,12 +27,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] -name = "aho-corasick" -version = "0.7.18" +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + +[[package]] +name = "ahash" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "memchr", + "getrandom", + "once_cell", + "version_check", ] [[package]] @@ -47,9 +54,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4361135be9122e0870de935d7c439aef945b9f9ddd4199a553b5270b49c82a27" +checksum = "08f9b8508dccb7687a1d6c4ce66b2b0ecef467c94667de27d8d7fe1f8d2a9cdc" [[package]] name = "apodize" @@ -78,9 +85,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.53" +version = "0.1.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed6aa3524a2dfcf9fe180c51eae2b58738348d819517ceadf95789c51fff7600" +checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716" dependencies = [ "proc-macro2", "quote", @@ -115,9 +122,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.64" +version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f" +checksum = "11a17d453482a265fd5f8479f2a3f405566e6ca627837aaddb85af8b1ab8ef61" dependencies = [ "addr2line", "cc", @@ -128,12 +135,6 @@ dependencies = [ "rustc-demangle", ] -[[package]] -name = "base64" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" - [[package]] name = "binary_classification" version = "0.12.0" @@ -142,15 +143,6 @@ dependencies = [ "wit-bindgen-rust", ] -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - [[package]] name = "bitflags" version = "1.3.2" @@ -158,19 +150,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] -name = "block-buffer" -version = "0.9.0" +name = "bumpalo" +version = "3.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" + +[[package]] +name = "bytecheck" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a31f923c2db9513e4298b72df143e6e655a759b3d6a0966df18f81223fff54f" +dependencies = [ + "bytecheck_derive", + "ptr_meta", +] + +[[package]] +name = "bytecheck_derive" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +checksum = "edb17c862a905d912174daa27ae002326fff56dc8b8ada50a0a5f0976cb174f0" dependencies = [ - "generic-array", + "proc-macro2", + "quote", + "syn", ] +[[package]] +name = "bytemuck" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdead85bdec19c194affaeeb670c0e41fe23de31459efd1c174d049269cf02cc" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "camino" -version = "1.0.7" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f3132262930b0522068049f5870a856ab8affc80c70d08b6ecb785771a6fc23" +checksum = "869119e97797867fd90f5e22af7d0bd274bd4635ebb9eb68c04f3f513ae6c412" dependencies = [ "serde", ] @@ -202,9 +224,6 @@ name = "cc" version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" -dependencies = [ - "jobserver", -] [[package]] name = "cfg-if" @@ -228,37 +247,38 @@ dependencies = [ ] [[package]] -name = "cpp_demangle" -version = "0.3.5" +name = "color_quant" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeaa953eaad386a53111e47172c2fedba671e5684c8dd601a5f474f4f118710f" -dependencies = [ - "cfg-if", -] +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] -name = "cpufeatures" -version = "0.2.2" +name = "corosensei" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +checksum = "ab4b310cff9117ec16d05970743c20df3eaddafd461829f2758e76a8de2863a9" dependencies = [ + "autocfg", + "cfg-if", "libc", + "scopeguard", + "windows-sys", ] [[package]] name = "cranelift-bforest" -version = "0.81.2" +version = "0.82.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eba0f73ab0da95f5d3bd5161da14edc586a88aeae1d09e4a0924f7a141a0093" +checksum = "38faa2a16616c8e78a18d37b4726b98bfd2de192f2fdc8a39ddf568a408a0f75" dependencies = [ "cranelift-entity", ] [[package]] name = "cranelift-codegen" -version = "0.81.2" +version = "0.82.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9cff8758662518d743460f32c3ca6f32d726070af612c19ba92d01ea727e6d9" +checksum = "26f192472a3ba23860afd07d2b0217dc628f21fcc72617aa1336d98e1671f33b" dependencies = [ "cranelift-bforest", "cranelift-codegen-meta", @@ -273,33 +293,30 @@ dependencies = [ [[package]] name = "cranelift-codegen-meta" -version = "0.81.2" +version = "0.82.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfc82fef9d470dd617c4d2537d8f4146d82526bb3bc3ef35b599a3978dad8c81" +checksum = "0f32ddb89e9b89d3d9b36a5b7d7ea3261c98235a76ac95ba46826b8ec40b1a24" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.81.2" +version = "0.82.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06f531b6173eb2fd92d9a9b2a0dbb2450079f913040bdc323ec43ec752b7e44" +checksum = "01fd0d9f288cc1b42d9333b7a776b17e278fc888c28e6a0f09b5573d45a150bc" [[package]] name = "cranelift-entity" -version = "0.81.2" +version = "0.82.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d84f8e8a408071d67f479a00c6d3da965b1f9b4b240b7e7e27edb1a34401b3cd" -dependencies = [ - "serde", -] +checksum = "9e3bfe172b83167604601faf9dc60453e0d0a93415b57a9c4d1a7ae6849185cf" [[package]] name = "cranelift-frontend" -version = "0.81.2" +version = "0.82.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72cc22592c10f1fa6664a55e34ec52593125a94176856d3ec2f7af5664374da1" +checksum = "a006e3e32d80ce0e4ba7f1f9ddf66066d052a8c884a110b91d05404d6ce26dce" dependencies = [ "cranelift-codegen", "log", @@ -307,33 +324,6 @@ dependencies = [ "target-lexicon", ] -[[package]] -name = "cranelift-native" -version = "0.81.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3da723ebbee69f348feb49acc9f6f5b7ad668c04a145abbc7a75b669f9b0afd" -dependencies = [ - "cranelift-codegen", - "libc", - "target-lexicon", -] - -[[package]] -name = "cranelift-wasm" -version = "0.81.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "642c30e1600295e9c58fc349376187831dce1df6822ece7e8ab880010d6e4be2" -dependencies = [ - "cranelift-codegen", - "cranelift-entity", - "cranelift-frontend", - "itertools", - "log", - "smallvec", - "wasmparser 0.82.0", - "wasmtime-types", -] - [[package]] name = "crc32fast" version = "1.3.2" @@ -399,41 +389,54 @@ dependencies = [ ] [[package]] -name = "diff" -version = "0.1.12" +name = "darling" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" +checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" +dependencies = [ + "darling_core", + "darling_macro", +] [[package]] -name = "digest" -version = "0.9.0" +name = "darling_core" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" dependencies = [ - "generic-array", + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "directories-next" -version = "2.0.0" +name = "darling_macro" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "339ee130d97a610ea5a5872d2bbb130fdf68884ff09d3028b81bec8a1ac23bbc" +checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" dependencies = [ - "cfg-if", - "dirs-sys-next", + "darling_core", + "quote", + "syn", ] [[package]] -name = "dirs-sys-next" -version = "0.1.2" +name = "deflate" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +checksum = "c86f7e25f518f4b81808a2cf1c50996a61f5c2eb394b2393bd87f2a4780a432f" dependencies = [ - "libc", - "redox_users", - "winapi", + "adler32", ] +[[package]] +name = "diff" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" + [[package]] name = "either" version = "1.6.1" @@ -444,51 +447,56 @@ checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" name = "elastic_net" version = "0.12.1" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", ] [[package]] -name = "env_logger" -version = "0.9.0" +name = "enum-iterator" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b2cf0344971ee6c64c31be0d530793fba457d322dfec2810c453d0ef228f9c3" +checksum = "4eeac5c5edb79e4e39fe8439ef35207780a11f69c52cbe424ce3dfad4cb78de6" dependencies = [ - "atty", - "humantime", - "log", - "regex", - "termcolor", + "enum-iterator-derive", ] [[package]] -name = "errno" -version = "0.2.8" +name = "enum-iterator-derive" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "c134c37760b27a871ba422106eedbb8247da973a09e82558bf26d619c882b159" dependencies = [ - "errno-dragonfly", - "libc", - "winapi", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "enumset" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "4799cdb24d48f1f8a7a98d06b7fde65a85a2d1e42b25a889f5406aa1fbefe074" dependencies = [ - "cc", - "libc", + "enumset_derive", +] + +[[package]] +name = "enumset_derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea83a3fbdc1d999ccfbcbee717eab36f8edf2d71693a23ce0d7cca19e085304c" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", ] [[package]] name = "f1-score" version = "0.12.0" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", @@ -514,47 +522,50 @@ name = "fft" version = "0.12.0" dependencies = [ "hotg-rune-proc-blocks", - "hound", - "libm", "mel", "nalgebra", - "normalize", "pretty_assertions", "sonogram", "wit-bindgen-rust", ] [[package]] -name = "file-per-thread-logger" -version = "0.1.5" +name = "flate2" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e16290574b39ee41c71aeb90ae960c504ebaf1e2a1c87bd52aa56ed6e1a02f" +checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6" dependencies = [ - "env_logger", - "log", + "crc32fast", + "miniz_oxide", ] [[package]] -name = "generic-array" -version = "0.14.5" +name = "fnv" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" -dependencies = [ - "typenum", - "version_check", -] +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "getrandom" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9be70c98951c83b8d2f8f60d7065fa6d5146873094452a1008da8c2f1e4205ad" +checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" dependencies = [ "cfg-if", "libc", "wasi", ] +[[package]] +name = "gif" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a7187e78088aead22ceedeee99779455b23fc231fe13ec443f99bb71694e5b" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "gimli" version = "0.26.1" @@ -571,6 +582,18 @@ name = "hashbrown" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db0d4cf898abf0081f964436dc980e96670a0f36863e4b83aaacdb65c9d7ccc3" +dependencies = [ + "ahash", +] [[package]] name = "heck" @@ -608,30 +631,45 @@ dependencies = [ name = "hotg-rune-proc-blocks" version = "0.1.0" dependencies = [ + "bytemuck", "getrandom", "ndarray", - "once_cell", - "rand", + "thiserror", + "tracing", + "tracing-subscriber", "wit-bindgen-rust", ] [[package]] -name = "hound" -version = "3.4.0" +name = "id-arena" +version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a164bb2ceaeff4f42542bdb847c41517c78a60f5649671b2a07312b6e117549" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" [[package]] -name = "humantime" -version = "2.1.0" +name = "ident_case" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] -name = "id-arena" -version = "2.2.1" +name = "image" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" +checksum = "28edd9d7bc256be2502e325ac0628bde30b7001b9b52e0abe31a1a9dc2701212" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "gif", + "jpeg-decoder", + "num-iter", + "num-rational", + "num-traits", + "png", + "scoped_threadpool", + "tiff", +] [[package]] name = "image-normalization" @@ -643,21 +681,23 @@ dependencies = [ ] [[package]] -name = "image_input" +name = "image_decode" version = "0.12.0" dependencies = [ "hotg-rune-proc-blocks", + "image", + "strum", "wit-bindgen-rust", ] [[package]] name = "indexmap" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f647032dfaa1f8b6dc29bd3edb7bbef4861b8b8007ebb118d6db284fd59f6ee" +checksum = "e6012d540c5baa3589337a98ce73408de9b5a25ec9fc2c6fd6be8f0d39e0ca5a" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.11.2", "serde", ] @@ -670,12 +710,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "io-lifetimes" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec58677acfea8a15352d42fc87d11d63596ade9239e0a7c9352914417515dbe6" - [[package]] name = "itertools" version = "0.10.3" @@ -687,17 +721,26 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" +checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" [[package]] -name = "jobserver" -version = "0.1.24" +name = "jpeg-decoder" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af25a77299a7f711a01975c35a6a424eb6862092cc2d6c72c4ed6cbc56dfc1fa" +checksum = "9478aa10f73e7528198d75109c8be5cd7d15fb530238040148d5f9a22d4c5b3b" dependencies = [ - "libc", + "rayon", +] + +[[package]] +name = "js-sys" +version = "0.3.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3fac17f7123a73ca62df411b1bf727ccc805daa070338fda671c86dac1bdc27" +dependencies = [ + "wasm-bindgen", ] [[package]] @@ -726,9 +769,19 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" [[package]] name = "libc" -version = "0.2.121" +version = "0.2.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" + +[[package]] +name = "libloading" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efaa7b300f3b5fe8eb6bf21ce3895e1751d9665086af2d64b42f19701015ff4f" +checksum = "efbc0f03f9a775e9f6aed295c6a1ba2253c5757a9e03d55c6caa46a681abcddd" +dependencies = [ + "cfg-if", + "winapi", +] [[package]] name = "libm" @@ -746,23 +799,16 @@ checksum = "806de604a37f6d73a83c850af7b3ba33a44f330d12fd5e4ac216645b54da912b" name = "linear_regression" version = "0.12.1" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", ] -[[package]] -name = "linux-raw-sys" -version = "0.0.42" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5284f00d480e1c39af34e72f8ad60b94f47007e3481cd3b731c1d67190ddc7b7" - [[package]] name = "log" -version = "0.4.16" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6389c490849ff5bc16be905ae24bc913a9c8892e19b2341dbc175e14c341c2b8" +checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ "cfg-if", ] @@ -771,12 +817,32 @@ dependencies = [ name = "logistic_regression" version = "0.12.5" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", ] +[[package]] +name = "loupe" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6a72dfa44fe15b5e76b94307eeb2ff995a8c5b283b55008940c02e0c5b634d" +dependencies = [ + "indexmap", + "loupe-derive", + "rustversion", +] + +[[package]] +name = "loupe-derive" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fbfc88337168279f2e9ae06e157cfed4efd3316e14dc96ed074d4f2e6c5952" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "mach" version = "0.3.2" @@ -817,9 +883,18 @@ dependencies = [ [[package]] name = "memchr" -version = "2.4.1" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "memmap2" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +checksum = "d5172b50c23043ff43dd53e51392f36519d9b35a8f3a410d30ece5d1aedd58ae" +dependencies = [ + "libc", +] [[package]] name = "memoffset" @@ -832,12 +907,11 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.4.4" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +checksum = "6f5c75688da582b8ffc1f1799e9db273f32133c49e048f614d22ec3256773ccc" dependencies = [ "adler", - "autocfg", ] [[package]] @@ -936,18 +1010,18 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +checksum = "97fbc387afefefd5e9e39493299f3069e14a140dd34dc19b4c1c1a8fddb6a790" dependencies = [ "num-traits", ] [[package]] name = "num-integer" -version = "0.1.44" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" dependencies = [ "autocfg", "num-traits", @@ -955,9 +1029,9 @@ dependencies = [ [[package]] name = "num-iter" -version = "0.1.42" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" dependencies = [ "autocfg", "num-integer", @@ -978,9 +1052,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", "libm", @@ -998,11 +1072,12 @@ dependencies = [ [[package]] name = "object" -version = "0.27.1" +version = "0.28.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9" +checksum = "e42c982f2d955fac81dd7e1d0e1426a7d702acd9c98d19ab01083a6a0328c424" dependencies = [ "crc32fast", + "hashbrown 0.11.2", "indexmap", "memchr", ] @@ -1022,12 +1097,6 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7709cef83f0c1f58f666e746a08b21e0085f7440fa6a29cc194d68aac97a4225" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "output_vt100" version = "0.1.3" @@ -1037,14 +1106,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "parse" -version = "0.12.0" -dependencies = [ - "hotg-rune-proc-blocks", - "wit-bindgen-rust", -] - [[package]] name = "password_strength" version = "0.12.0" @@ -1061,9 +1122,21 @@ checksum = "0c520e05135d6e763148b6426a837e239041653ba7becd2e538c076c738025fc" [[package]] name = "pin-project-lite" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" + +[[package]] +name = "png" +version = "0.17.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc38c0ad57efb786dd57b9864e5b18bae478c00c824dc55a38bbc9da95dde3ba" +dependencies = [ + "bitflags", + "crc32fast", + "deflate", + "miniz_oxide", +] [[package]] name = "ppv-lite86" @@ -1075,7 +1148,6 @@ checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" name = "prediction_errors" version = "0.12.0" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", @@ -1119,20 +1191,31 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.37" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec757218438d5fda206afc041538b2f6d889286160d649a86a24d37e1235afd1" +checksum = "c54b25569025b7fc9651de43004ae593a75ad88543b17178aa5e1b9c4f15f56f" dependencies = [ - "unicode-xid", + "unicode-ident", ] [[package]] -name = "psm" -version = "0.1.18" +name = "ptr_meta" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "871372391786ccec00d3c5d3d6608905b3d4db263639cfe075d3b60a736d115a" +checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" dependencies = [ - "cc", + "ptr_meta_derive", +] + +[[package]] +name = "ptr_meta_derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -1148,9 +1231,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632d02bff7f874a36f33ea8bb416cd484b90cc66c1194b1a1110d067a7013f58" +checksum = "a1feb54ed693b93a84e14094943b84b7c4eae204c512b7ccb95ab0c66d278ad1" dependencies = [ "proc-macro2", ] @@ -1203,9 +1286,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.5.1" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" +checksum = "bd99e5772ead8baa5215278c9b15bf92087709e9c1b2d1f97cdb5a183c933a7d" dependencies = [ "autocfg", "crossbeam-deque", @@ -1215,14 +1298,13 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.9.1" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" +checksum = "258bcdb5ac6dad48491bb2992db6b7cf74878b0384908af124823d118c99683f" dependencies = [ "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "lazy_static", "num_cpus", ] @@ -1235,17 +1317,6 @@ dependencies = [ "bitflags", ] -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall", - "thiserror", -] - [[package]] name = "regalloc" version = "0.0.34" @@ -1259,12 +1330,10 @@ dependencies = [ [[package]] name = "regex" -version = "1.5.5" +version = "1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a11647b6b25ff05a515cb92c365cec08801e83423a235b51e231e1808747286" +checksum = "d83f127d94bdbcda4c8cc2e50f6f84f4b611f69c902699ca385a39c3a75f9ff1" dependencies = [ - "aho-corasick", - "memchr", "regex-syntax", ] @@ -1279,15 +1348,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.25" +version = "0.6.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +checksum = "49b3de9ec5dc0a3417da371aab17d729997c15010e7fd24ff707773a33bddb64" [[package]] name = "region" -version = "2.2.0" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877e54ea2adcd70d80e9179344c97f93ef0dffd6b03e1f4529e6e83ab2fa9ae0" +checksum = "76e189c2369884dce920945e2ddf79b3dff49e071a167dd1817fa9c4c00d512e" dependencies = [ "bitflags", "libc", @@ -1305,10 +1374,44 @@ dependencies = [ ] [[package]] -name = "rustc-demangle" -version = "0.1.21" +name = "rend" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" +checksum = "79af64b4b6362ffba04eef3a4e10829718a4896dac19daa741851c86781edf95" +dependencies = [ + "bytecheck", +] + +[[package]] +name = "rkyv" +version = "0.7.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "517a3034eb2b1499714e9d1e49b2367ad567e07639b69776d35e259d9c27cca6" +dependencies = [ + "bytecheck", + "hashbrown 0.12.1", + "ptr_meta", + "rend", + "rkyv_derive", + "seahash", +] + +[[package]] +name = "rkyv_derive" +version = "0.7.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "505c209ee04111a006431abf39696e640838364d67a107c559ababaf6fd8c9dd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" [[package]] name = "rustc-hash" @@ -1317,24 +1420,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] -name = "rustix" -version = "0.33.5" +name = "rustversion" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03627528abcc4a365554d32a9f3bbf67f7694c102cfeda792dc86a2d6057cc85" -dependencies = [ - "bitflags", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys", - "winapi", -] +checksum = "f2cc38e8fa666e2de3c4aba7edeb5ffc5246c1c2ed0e3d17e560aeeba736b23f" [[package]] name = "ryu" -version = "1.0.9" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" + +[[package]] +name = "scoped_threadpool" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" +checksum = "1d51f5df5af43ab3f1360b429fa5e0152ac5ce8c0bd6485cae490332e96846a8" [[package]] name = "scopeguard" @@ -1342,6 +1443,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "seahash" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" + [[package]] name = "segment_output" version = "0.12.0" @@ -1352,27 +1459,36 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.7" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d65bd28f48be7196d222d95b9243287f48d27aca604e08497513019ff0502cc4" +checksum = "a41d061efea015927ac527063765e73601444cdc344ba855bc7bd44578b25e1c" dependencies = [ "serde", ] [[package]] name = "serde" -version = "1.0.136" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789" +checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "212e73464ebcde48d723aa02eb270ba62eff38a9b732df31f33f1b4e145f3a54" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" -version = "1.0.136" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9" +checksum = "1f26faba0c3959972377d3b2d306ee9f71faee9714294e41bb777f83f88578be" dependencies = [ "proc-macro2", "quote", @@ -1381,28 +1497,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.79" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95" +checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c" dependencies = [ "itoa", "ryu", "serde", ] -[[package]] -name = "sha2" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" -dependencies = [ - "block-buffer", - "cfg-if", - "cpufeatures", - "digest", - "opaque-debug", -] - [[package]] name = "sharded-slab" version = "0.1.4" @@ -1501,11 +1604,32 @@ dependencies = [ "syn", ] +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6878079b17446e4d3eba6192bb0a2950d5b14f0ed8424b852310e5a94345d0ef" +dependencies = [ + "heck 0.4.0", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "support_vector_classifier" version = "0.12.3" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", @@ -1515,7 +1639,6 @@ dependencies = [ name = "support_vector_regression" version = "0.12.3" dependencies = [ - "getrandom", "hotg-rune-proc-blocks", "smartcore", "wit-bindgen-rust", @@ -1523,20 +1646,20 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.91" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b683b2b825c8eef438b77c36a06dc262294da3d5a5813fac20da149241dcd44d" +checksum = "0748dd251e24453cb8717f0354206b91557e4ec8703673a4b30208f2abaf1ebf" dependencies = [ "proc-macro2", "quote", - "unicode-xid", + "unicode-ident", ] [[package]] name = "target-lexicon" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fa7e55043acb85fca6b3c01485a2eeb6b69c5d21002e273c79e465f43b7ac1" +checksum = "c02424087780c9b71cc96799eaeddff35af2bc513278cda5c99fc1f5d026d3c1" [[package]] name = "tempfile" @@ -1552,23 +1675,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "tensor_input" -version = "0.12.0" -dependencies = [ - "hotg-rune-proc-blocks", - "wit-bindgen-rust", -] - -[[package]] -name = "termcolor" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" -dependencies = [ - "winapi-util", -] - [[package]] name = "text_extractor" version = "0.12.0" @@ -1588,18 +1694,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" +checksum = "bd829fe32373d27f76265620b5309d0340cb8550f523c1dda251d6298069069a" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" +checksum = "0396bc89e626244658bef819e22d0cc459e795a5ebe878e6ec336d1674a8d79a" dependencies = [ "proc-macro2", "quote", @@ -1615,11 +1721,22 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tiff" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cfada0986f446a770eca461e8c6566cb879682f7d687c8348aa0c857bd52286" +dependencies = [ + "flate2", + "jpeg-decoder", + "weezl", +] + [[package]] name = "tinyvec" -version = "1.5.1" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c1c1d5a42b6245520c249549ec267180beaffcc0615401ac8e31853d4b6d8d2" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" dependencies = [ "tinyvec_macros", ] @@ -1641,22 +1758,14 @@ dependencies = [ "wit-bindgen-rust", ] -[[package]] -name = "toml" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" -dependencies = [ - "serde", -] - [[package]] name = "tracing" -version = "0.1.32" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1bdf54a7c28a2bbf701e1d2233f6c77f473486b94bee4f9678da5a148dca7f" +checksum = "a400e31aa60b9d44a52a8ee0343b5b18566b03a8321e0d321f695cf56e940160" dependencies = [ "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -1664,9 +1773,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.20" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e65ce065b4b5c53e73bb28912318cb8c9e9ad3921f1d669eb0e68b4c8143a2b" +checksum = "cc6b8ad3567499f98a1db7a752b07a7c8c7c7c34c332ec00effb2b0027974b7c" dependencies = [ "proc-macro2", "quote", @@ -1675,19 +1784,19 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.24" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90442985ee2f57c9e1b548ee72ae842f4a9a20e3f417cc38dbc5dc684d9bb4ee" +checksum = "7709595b8878a4965ce5e87ebf880a7d39c9afc6837721b21a5a816a8117d921" dependencies = [ - "lazy_static", + "once_cell", "valuable", ] [[package]] name = "tracing-log" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" dependencies = [ "lazy_static", "log", @@ -1696,9 +1805,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9df98b037d039d03400d9dd06b0f8ce05486b5f25e9a2d7d36196e142ebbc52" +checksum = "4bc28f93baff38037f64e6f43d34cfa1605f27a49c34e8a04c5e78b0babf2596" dependencies = [ "ansi_term", "lazy_static", @@ -1736,6 +1845,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicode-ident" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bd2fe26506023ed7b5e1e315add59d6f584c621d037f9368fea9cfb988f368c" + [[package]] name = "unicode-normalization" version = "0.1.19" @@ -1759,17 +1874,9 @@ checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" [[package]] name = "unicode-xid" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" - -[[package]] -name = "utf8_decode" -version = "0.12.0" -dependencies = [ - "hotg-rune-proc-blocks", - "wit-bindgen-rust", -] +checksum = "957e51f3646910546462e67d5f7599b9e4fb8acdd304b087a6494730f9eebf04" [[package]] name = "valuable" @@ -1817,212 +1924,354 @@ dependencies = [ [[package]] name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] -name = "wasmparser" -version = "0.77.0" +name = "wasm-bindgen" +version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b35c86d22e720a07d954ebbed772d01180501afe7d03d464f413bb5f8914a8d6" +checksum = "7c53b543413a17a202f4be280a7e5c62a1c69345f5de525ee64f8cfdbc954994" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] [[package]] -name = "wasmparser" -version = "0.82.0" +name = "wasm-bindgen-backend" +version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0559cc0f1779240d6f894933498877ea94f693d84f3ee39c9a9932c6c312bd70" +checksum = "5491a68ab4500fa6b4d726bd67408630c3dbe9c4fe7bda16d5c82a1fd8c7340a" +dependencies = [ + "bumpalo", + "lazy_static", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] [[package]] -name = "wasmtime" -version = "0.34.2" +name = "wasm-bindgen-macro" +version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8463ad287e1d87d9a141a010cbe4b3f8227ade85cc8ac64f2bef3219b66f94" +checksum = "c441e177922bc58f1e12c022624b6216378e5febc2f0533e41ba443d505b80aa" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d94ac45fcf608c1f45ef53e748d35660f168490c10b23704c7779ab8f5c3048" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a89911bd99e5f3659ec4acf9c4d93b0a90fe4a2a11f15328472058edc5261be" + +[[package]] +name = "wasm-encoder" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f0c17267a5ffd6ae3d897589460e21db1673c84fb7016b909c9691369a75ea" +dependencies = [ + "leb128", +] + +[[package]] +name = "wasmer" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea8d8361c9d006ea3d7797de7bd6b1492ffd0f91a22430cfda6c1658ad57bedf" dependencies = [ - "anyhow", - "async-trait", - "backtrace", - "bincode", "cfg-if", "indexmap", - "lazy_static", - "libc", - "log", - "object", - "once_cell", - "paste", - "psm", - "rayon", - "region", - "serde", + "js-sys", + "loupe", + "more-asserts", "target-lexicon", - "wasmparser 0.82.0", - "wasmtime-cache", - "wasmtime-cranelift", - "wasmtime-environ", - "wasmtime-fiber", - "wasmtime-jit", - "wasmtime-runtime", + "thiserror", + "wasm-bindgen", + "wasmer-artifact", + "wasmer-compiler", + "wasmer-compiler-cranelift", + "wasmer-derive", + "wasmer-engine", + "wasmer-engine-dylib", + "wasmer-engine-universal", + "wasmer-types", + "wasmer-vm", "wat", "winapi", ] [[package]] -name = "wasmtime-cache" -version = "0.34.2" +name = "wasmer-artifact" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b066cd527050ed06eba8f4eb8948d833f033401f09313a5e5231ebe3e316bb9d" +checksum = "7aaf9428c29c1d8ad2ac0e45889ba8a568a835e33fd058964e5e500f2f7ce325" dependencies = [ - "anyhow", - "base64", - "bincode", - "directories-next", - "file-per-thread-logger", - "log", - "rustix", + "enumset", + "loupe", + "thiserror", + "wasmer-compiler", + "wasmer-types", +] + +[[package]] +name = "wasmer-compiler" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67a6cd866aed456656db2cfea96c18baabbd33f676578482b85c51e1ee19d2c" +dependencies = [ + "enumset", + "loupe", + "rkyv", "serde", - "sha2", - "toml", - "winapi", - "zstd", + "serde_bytes", + "smallvec", + "target-lexicon", + "thiserror", + "wasmer-types", + "wasmparser 0.83.0", ] [[package]] -name = "wasmtime-cranelift" -version = "0.34.2" +name = "wasmer-compiler-cranelift" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "381b034926e26980a0aed3f26ec4ba2ff3be9763f386bfb18b7bf2a3fbc1a284" +checksum = "48be2f9f6495f08649e4f8b946a2cbbe119faf5a654aa1457f9504a99d23dae0" dependencies = [ - "anyhow", "cranelift-codegen", "cranelift-entity", "cranelift-frontend", - "cranelift-native", - "cranelift-wasm", "gimli", - "log", + "loupe", "more-asserts", - "object", + "rayon", + "smallvec", "target-lexicon", - "thiserror", - "wasmparser 0.82.0", - "wasmtime-environ", + "tracing", + "wasmer-compiler", + "wasmer-types", ] [[package]] -name = "wasmtime-environ" -version = "0.34.2" +name = "wasmer-derive" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877230e7f92f8b5509845e804bb27c7c993197339a7cf0de4a2af411ee6ea75b" +checksum = "00e50405cc2a2f74ff574584710a5f2c1d5c93744acce2ca0866084739284b51" dependencies = [ - "anyhow", - "cranelift-entity", - "gimli", - "indexmap", - "log", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "wasmer-engine" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f98f010978c244db431b392aeab0661df7ea0822343334f8f2a920763548e45" +dependencies = [ + "backtrace", + "enumset", + "lazy_static", + "loupe", + "memmap2", "more-asserts", - "object", + "rustc-demangle", "serde", + "serde_bytes", "target-lexicon", "thiserror", - "wasmparser 0.82.0", - "wasmtime-types", + "wasmer-artifact", + "wasmer-compiler", + "wasmer-types", + "wasmer-vm", ] [[package]] -name = "wasmtime-fiber" -version = "0.34.2" +name = "wasmer-engine-dylib" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dffb509e67c6c2ea49f38bd5db3712476fcc94c4776521012e5f69ae4bb27b4a" +checksum = "ad0358af9c154724587731175553805648d9acb8f6657880d165e378672b7e53" dependencies = [ - "cc", - "rustix", - "winapi", + "cfg-if", + "enum-iterator", + "enumset", + "leb128", + "libloading", + "loupe", + "object", + "rkyv", + "serde", + "tempfile", + "tracing", + "wasmer-artifact", + "wasmer-compiler", + "wasmer-engine", + "wasmer-object", + "wasmer-types", + "wasmer-vm", + "which", ] [[package]] -name = "wasmtime-jit" -version = "0.34.2" +name = "wasmer-engine-universal" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee2da33bb337fbdfb6e031d485bf2a39d51f37f48e79c6327228d3fc68ec531" +checksum = "440dc3d93c9ca47865a4f4edd037ea81bf983b5796b59b3d712d844b32dbef15" dependencies = [ - "addr2line", - "anyhow", - "bincode", "cfg-if", - "cpp_demangle", - "gimli", - "log", - "object", + "enumset", + "leb128", + "loupe", "region", - "rustc-demangle", - "rustix", + "rkyv", + "wasmer-compiler", + "wasmer-engine", + "wasmer-engine-universal-artifact", + "wasmer-types", + "wasmer-vm", + "winapi", +] + +[[package]] +name = "wasmer-engine-universal-artifact" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f1db3f54152657eb6e86c44b66525ff7801dad8328fe677da48dd06af9ad41" +dependencies = [ + "enum-iterator", + "enumset", + "loupe", + "rkyv", + "thiserror", + "wasmer-artifact", + "wasmer-compiler", + "wasmer-types", +] + +[[package]] +name = "wasmer-object" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d831335ff3a44ecf451303f6f891175c642488036b92ceceb24ac8623a8fa8b" +dependencies = [ + "object", + "thiserror", + "wasmer-compiler", + "wasmer-types", +] + +[[package]] +name = "wasmer-types" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39df01ea05dc0a9bab67e054c7cb01521e53b35a7bb90bd02eca564ed0b2667f" +dependencies = [ + "backtrace", + "enum-iterator", + "indexmap", + "loupe", + "more-asserts", + "rkyv", "serde", - "target-lexicon", "thiserror", - "wasmtime-environ", - "wasmtime-runtime", - "winapi", ] [[package]] -name = "wasmtime-runtime" -version = "0.34.2" +name = "wasmer-vm" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcb5bd981c971c398dac645874748f261084dc907a98b3ee70fa41e005a2b365" +checksum = "30d965fa61f4dc4cdb35a54daaf7ecec3563fbb94154a6c35433f879466247dd" dependencies = [ - "anyhow", "backtrace", "cc", "cfg-if", + "corosensei", + "enum-iterator", "indexmap", "lazy_static", "libc", - "log", + "loupe", "mach", "memoffset", "more-asserts", - "rand", "region", - "rustix", + "rkyv", + "scopeguard", + "serde", "thiserror", - "wasmtime-environ", - "wasmtime-fiber", + "wasmer-artifact", + "wasmer-types", "winapi", ] [[package]] -name = "wasmtime-types" -version = "0.34.2" +name = "wasmparser" +version = "0.77.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73696a97fb815c2944896ae9e4fc49182fd7ec0b58088f9ad9768459a521e347" -dependencies = [ - "cranelift-entity", - "serde", - "thiserror", - "wasmparser 0.82.0", -] +checksum = "b35c86d22e720a07d954ebbed772d01180501afe7d03d464f413bb5f8914a8d6" + +[[package]] +name = "wasmparser" +version = "0.83.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "718ed7c55c2add6548cca3ddd6383d738cd73b892df400e96b9aa876f0141d7a" [[package]] name = "wast" -version = "40.0.0" +version = "42.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bb4f48a8b083dbc50e291e430afb8f524092bb00428957bcc63f49f856c64ac" +checksum = "badcb03f976f983ff0daf294da9697be659442f61e6b0942bb37a2b6cbfe9dd4" dependencies = [ "leb128", "memchr", "unicode-width", + "wasm-encoder", ] [[package]] name = "wat" -version = "1.0.42" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0401b6395ce0db91629a75b29597ccb66ea29950af9fc859f1bb3a736609c76e" +checksum = "b92f20b742ac527066c8414bc0637352661b68cab07ef42586cefaba71c965cf" dependencies = [ "wast", ] +[[package]] +name = "weezl" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c97e489d8f836838d497091de568cf16b117486d529ec5579233521065bd5e4" + +[[package]] +name = "which" +version = "4.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c4fb54e6113b6a8772ee41c3404fb0301ac79604489467e0a9ce1f3e97c24ae" +dependencies = [ + "either", + "lazy_static", + "libc", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2039,15 +2288,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" -[[package]] -name = "winapi-util" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" -dependencies = [ - "winapi", -] - [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2055,65 +2295,90 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "wit-bindgen-gen-core" -version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b#e165c0226ad0f00a840ef786079b58d0910dbb6b" +name = "windows-sys" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43dbb096663629518eb1dfa72d80243ca5a6aca764cae62a2df70af760a9be75" dependencies = [ - "anyhow", - "wit-parser 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b)", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", ] [[package]] -name = "wit-bindgen-gen-core" -version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen#e165c0226ad0f00a840ef786079b58d0910dbb6b" -dependencies = [ - "anyhow", - "wit-parser 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen)", -] +name = "windows_aarch64_msvc" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd761fd3eb9ab8cc1ed81e56e567f02dd82c4c837e48ac3b2181b9ffc5060807" [[package]] -name = "wit-bindgen-gen-rust" +name = "windows_i686_gnu" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab0cf703a96bab2dc0c02c0fa748491294bf9b7feb27e1f4f96340f208ada0e" + +[[package]] +name = "windows_i686_msvc" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cfdbe89cc9ad7ce618ba34abc34bbb6c36d99e96cae2245b7943cd75ee773d0" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4dd9b0c0e9ece7bb22e84d70d01b71c6d6248b81a3c60d11869451b4cb24784" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff1e4aa646495048ec7f3ffddc411e1d829c026a2ec62b39da15c1055e406eaa" + +[[package]] +name = "wit-bindgen-gen-core" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ - "heck 0.3.3", - "wit-bindgen-gen-core 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b)", + "anyhow", + "wit-parser", ] [[package]] name = "wit-bindgen-gen-rust" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "heck 0.3.3", - "wit-bindgen-gen-core 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen)", + "wit-bindgen-gen-core", ] [[package]] name = "wit-bindgen-gen-rust-wasm" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "heck 0.3.3", - "wit-bindgen-gen-core 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen)", - "wit-bindgen-gen-rust 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen)", + "wit-bindgen-gen-core", + "wit-bindgen-gen-rust", ] [[package]] -name = "wit-bindgen-gen-wasmtime" +name = "wit-bindgen-gen-wasmer" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "heck 0.3.3", - "wit-bindgen-gen-core 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b)", - "wit-bindgen-gen-rust 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b)", + "wit-bindgen-gen-core", + "wit-bindgen-gen-rust", ] [[package]] name = "wit-bindgen-rust" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "async-trait", "bitflags", @@ -2123,53 +2388,42 @@ dependencies = [ [[package]] name = "wit-bindgen-rust-impl" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "proc-macro2", "syn", - "wit-bindgen-gen-core 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen)", + "wit-bindgen-gen-core", "wit-bindgen-gen-rust-wasm", ] [[package]] -name = "wit-bindgen-wasmtime" +name = "wit-bindgen-wasmer" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "anyhow", "bitflags", "thiserror", - "wasmtime", - "wit-bindgen-wasmtime-impl", + "tracing", + "wasmer", + "wit-bindgen-wasmer-impl", ] [[package]] -name = "wit-bindgen-wasmtime-impl" +name = "wit-bindgen-wasmer-impl" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "proc-macro2", "syn", - "wit-bindgen-gen-core 0.1.0 (git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b)", - "wit-bindgen-gen-wasmtime", -] - -[[package]] -name = "wit-parser" -version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen?rev=e165c0226ad0f00a840ef786079b58d0910dbb6b#e165c0226ad0f00a840ef786079b58d0910dbb6b" -dependencies = [ - "anyhow", - "id-arena", - "pulldown-cmark", - "unicode-normalization", - "unicode-xid", + "wit-bindgen-gen-core", + "wit-bindgen-gen-wasmer", ] [[package]] name = "wit-parser" version = "0.1.0" -source = "git+https://github.com/bytecodealliance/wit-bindgen#e165c0226ad0f00a840ef786079b58d0910dbb6b" +source = "git+https://github.com/wasmerio/wit-bindgen?branch=wasmer#8772835affd9aa1cff1831e2a5f822d832006001" dependencies = [ "anyhow", "id-arena", @@ -2187,42 +2441,15 @@ dependencies = [ "heck 0.4.0", "itertools", "once_cell", + "rand", "serde", "serde_json", "structopt", "tempfile", + "thiserror", "tracing", "tracing-subscriber", "walrus", - "wasmtime", - "wit-bindgen-wasmtime", -] - -[[package]] -name = "zstd" -version = "0.10.0+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b1365becbe415f3f0fcd024e2f7b45bacfb5bdd055f0dc113571394114e7bdd" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "4.1.4+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f7cd17c9af1a4d6c24beb1cc54b17e2ef7b593dc92f19e9d9acad8b182bbaee" -dependencies = [ - "libc", - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "1.6.3+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc49afa5c8d634e75761feda8c592051e7eeb4683ba827211eb0d731d3402ea8" -dependencies = [ - "cc", - "libc", + "wasmer", + "wit-bindgen-wasmer", ] diff --git a/Cargo.toml b/Cargo.toml index 7efb0fc4a96..06d3b336739 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,42 +1,38 @@ [workspace] members = [ + "accuracy", "argmax", "audio_float_conversion", "binary_classification", + "elastic_net", + "f1-score", "fft", + "image_decode", "image-normalization", - "image_input", "label", + "linear_regression", + "logistic_regression", "modulo", "most_confident_indices", "noise-filtering", "normalize", "object_filter", - "parse", + "password_strength", + "prediction_errors", "segment_output", "softmax", + "support_vector_classifier", + "support_vector_regression", "support", - "tensor_input", "text_extractor", "tokenizers", - "utf8_decode", - "xtask", - "password_strength", - "logistic_regression", - "linear_regression", - "elastic_net", - "support_vector_classifier", - "support_vector_regression", "train_test_split", - "accuracy", - "f1-score", - "prediction_errors" + "xtask", ] # Uncomment these lines if you need to override wit-bindgen with hotg's fork. -#[patch.'https://github.com/bytecodealliance/wit-bindgen'] -#wit-bindgen-rust = { git = "https://github.com/hotg-ai/wit-bindgen", branch = "variant-should-be-clone" } -#wit-bindgen-wasmtime = { git = "https://github.com/hotg-ai/wit-bindgen", branch = "variant-should-be-clone" } +[patch.'https://github.com/bytecodealliance/wit-bindgen'] +wit-bindgen-rust = { git = "https://github.com/wasmerio/wit-bindgen", branch = "wasmer" } [profile.dev] opt-level = 1 diff --git a/accuracy/Cargo.toml b/accuracy/Cargo.toml index 9b3f99ec628..47b17dd1e0c 100644 --- a/accuracy/Cargo.toml +++ b/accuracy/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "accuracy" version = "0.12.2" -edition = "2018" +edition = "2021" description = "calculates accuracy of predicted labels when compared to true labels" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,7 +10,6 @@ description = "calculates accuracy of predicted labels when compared to true lab hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] diff --git a/accuracy/src/lib.rs b/accuracy/src/lib.rs index 3fa856ff330..e2cfaf24d35 100644 --- a/accuracy/src/lib.rs +++ b/accuracy/src/lib.rs @@ -1,155 +1,92 @@ -// use linfa_logistic::LogisticRegression; -use smartcore::metrics::*; - -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, +use hotg_rune_proc_blocks::{ + guest::{ + Argument, ArgumentHint, ArgumentMetadata, Dimensions, + ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::ArrayView1, }; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -/// A proc block which can perform linear regression -struct ProcBlockV1; +use smartcore::metrics::*; -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Accuracy", env!("CARGO_PKG_VERSION")); - metadata.set_description( - "calculates accuracy of predicted labels when compared to true labels", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("metric"); - metadata.add_tag("analytics"); +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Accuracy, +} - let element_type = ArgumentMetadata::new("element_type"); - element_type - .set_description("The type of tensor this proc-block will accept"); - element_type.set_default_value("f64"); - element_type.add_hint(&interpret_as_string_in_enum(&[ +fn metadata() -> Metadata { + Metadata::new("Accuracy", env!("CARGO_PKG_VERSION")) + .with_description("calculates accuracy of predicted labels when compared to true labels") + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("metric") + .with_tag("analytics") + .with_argument(ArgumentMetadata::new("element_type") + .with_description("The type of tensor this proc-block will accept") + .with_default_value("f64") + .with_hint(ArgumentHint::one_of([ "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ])); - metadata.add_argument(&element_type); - - let y_true = TensorMetadata::new("y_true"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_true.add_hint(&hint); - metadata.add_input(&y_true); - - let y_pred = TensorMetadata::new("y_pred"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_pred.add_hint(&hint); - metadata.add_input(&y_pred); - - let accuracy = TensorMetadata::new("accuracy"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[1])); - accuracy.add_hint(&hint); - metadata.add_output(&accuracy); + ])) + ) + .with_input(TensorMetadata::new("y_true")) + .with_input(TensorMetadata::new("y_pred")) + .with_output(TensorMetadata::new("accuracy")) +} - register_node(&metadata); +/// A proc block which can perform linear regression +struct Accuracy; + +impl ProcBlock for Accuracy { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::numeric("y_true", vec![0]), + TensorConstraint::numeric("y_pred", vec![0]), + ], + outputs: vec![TensorConstraint { + name: "accuracy".to_string(), + dimensions: Dimensions::Fixed(vec![1]), + element_type: ElementTypeConstraint::F64, + }], + } } - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "y_true", - element_type, - DimensionsParam::Fixed(&[0]), - ); + fn run(&self, inputs: Vec) -> Result, RunError> { + let y_true = Tensor::get_named(&inputs, "y_true")?.view_1d()?; + let y_pred = Tensor::get_named(&inputs, "y_pred")?.view_1d()?; - ctx.add_input_tensor( - "y_pred", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_output_tensor( - "accuracy", - element_type, - DimensionsParam::Fixed(&[1]), - ); + let accuracy = transform(y_true, y_pred); - Ok(()) + Ok(vec![Tensor::new_1d("accuracy", &[accuracy])]) } +} - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let y_true = ctx.get_input_tensor("y_true").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_true".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let y_pred = ctx.get_input_tensor("y_pred").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_pred".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let accuracy = transform( - y_true.buffer.elements().to_vec(), - y_pred.buffer.elements().to_vec(), - ); - - let output = vec![accuracy]; - - ctx.set_output_tensor( - "accuracy", - TensorParam { - element_type: ElementType::F64, - dimensions: &[1 as u32], - buffer: &output.as_bytes(), - }, - ); - - Ok(()) - } +impl From> for Accuracy { + fn from(_: Vec) -> Self { Accuracy } } -fn transform(y_true: Vec, y_pred: Vec) -> f64 { +fn transform(y_true: ArrayView1<'_, f64>, y_pred: ArrayView1<'_, f64>) -> f64 { + // Note: We need to unnecessarily copy our inputs here because + // smartcore's accuracy metric accepts types implementing BaseVector. + // However, they only have an implementation for Vec and not &[T] + // or ndarray's 1D arrays. + let y_true: Vec = y_true.iter().copied().collect(); + let y_pred: Vec = y_pred.iter().copied().collect(); + ClassificationMetrics::accuracy().get_score(&y_true, &y_pred) } #[cfg(test)] mod tests { + use hotg_rune_proc_blocks::ndarray; + use super::*; #[test] fn check_transform() { - let y_pred: Vec = vec![0., 2., 1., 3.]; - let y_true: Vec = vec![0., 1., 2., 3.]; + let y_pred = ndarray::array![0., 2., 1., 3.]; + let y_true = ndarray::array![0., 1., 2., 3.]; - let accuracy = transform(y_true, y_pred); + let accuracy = transform(y_true.view(), y_pred.view()); assert_eq!(0.5, accuracy); } diff --git a/argmax/Cargo.toml b/argmax/Cargo.toml index a4aa2927488..cccd8e3b0d0 100644 --- a/argmax/Cargo.toml +++ b/argmax/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "argmax" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "Find the index of the largest element." diff --git a/argmax/src/lib.rs b/argmax/src/lib.rs index 9d3f70e8672..c19f12967cc 100644 --- a/argmax/src/lib.rs +++ b/argmax/src/lib.rs @@ -1,117 +1,76 @@ -use crate::proc_block_v1::{ - BadInputReason, GraphError, InvalidInput, KernelError, +use hotg_rune_proc_blocks::{ + guest::{ + Argument, ElementType, InvalidInput, Metadata, ProcBlock, RunError, + Tensor, TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::ArrayViewD, }; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt}; -use std::cmp::Ordering; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Arg Max", env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("max"); - metadata.add_tag("index"); - metadata.add_tag("numeric"); - - let input = TensorMetadata::new("input"); - let hint = - supported_shapes(&[ElementType::F32], DimensionsParam::Fixed(&[0])); - input.add_hint(&hint); - metadata.add_input(&input); - - let max = TensorMetadata::new("max_index"); - max.set_description("The index of the element with the highest value"); - let hint = - supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[1])); - max.add_hint(&hint); - metadata.add_output(&max); - - register_node(&metadata); - } +use std::{cmp::Ordering, convert::TryFrom}; + +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: ArgMax, +} + +#[derive(Debug, Clone, Default, PartialEq)] +struct ArgMax; + +impl From> for ArgMax { + fn from(_: Vec) -> Self { ArgMax } +} - fn graph(id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&id).ok_or_else(|| { - GraphError::Other("Unable to get the graph context".to_string()) - })?; - - ctx.add_input_tensor( - "input", - ElementType::F32, - DimensionsParam::Fixed(&[0]), - ); - ctx.add_output_tensor( - "max_index", - ElementType::U32, - DimensionsParam::Fixed(&[1]), - ); - - Ok(()) +fn metadata() -> Metadata { + Metadata::new("Arg Max", env!("CARGO_PKG_VERSION")) + .with_description(env!("CARGO_PKG_DESCRIPTION")) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("max") + .with_tag("index") + .with_tag("numeric") + .with_input(TensorMetadata::new("input")) + .with_output(TensorMetadata::new("max_index").with_description( + "The index of the element with the highest value", + )) +} + +impl ProcBlock for ArgMax { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::numeric("input", vec![0])], + outputs: vec![TensorConstraint::numeric("max_index", vec![1])], + } } - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&id).ok_or_else(|| { - KernelError::Other("Unable to get the kernel context".to_string()) - })?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let index = match element_type { - ElementType::U8 => arg_max(buffer.elements::()), - ElementType::I8 => arg_max(buffer.elements::()), - ElementType::U16 => arg_max(buffer.elements::()), - ElementType::I16 => arg_max(buffer.elements::()), - ElementType::U32 => arg_max(buffer.elements::()), - ElementType::I32 => arg_max(buffer.elements::()), - ElementType::F32 => arg_max(buffer.elements::()), - ElementType::U64 => arg_max(buffer.elements::()), - ElementType::I64 => arg_max(buffer.elements::()), - ElementType::F64 => arg_max(buffer.elements::()), - other => { - return Err(KernelError::Other(format!( - "The Arg Max proc-block doesn't support {:?} element type", - other, - ))) + fn run(&self, inputs: Vec) -> Result, RunError> { + let tensor = Tensor::get_named(&inputs, "input")?; + + let index = match tensor.element_type { + ElementType::U8 => arg_max(tensor.view::()?), + ElementType::I8 => arg_max(tensor.view::()?), + ElementType::U16 => arg_max(tensor.view::()?), + ElementType::I16 => arg_max(tensor.view::()?), + ElementType::U32 => arg_max(tensor.view::()?), + ElementType::I32 => arg_max(tensor.view::()?), + ElementType::F32 => arg_max(tensor.view::()?), + ElementType::U64 => arg_max(tensor.view::()?), + ElementType::I64 => arg_max(tensor.view::()?), + ElementType::F64 => arg_max(tensor.view::()?), + _ => { + return Err(InvalidInput::incompatible_element_type( + &tensor.name, + ) + .into()); }, }; - let index = match index { - Some(ix) => ix, - None => { - return Err(KernelError::Other( - "The input tensor was empty".to_string(), - )) - }, - }; - let resulting_tensor = (index as u32).to_le_bytes(); - - ctx.set_output_tensor( - "max_index", - TensorParam { - element_type: ElementType::U32, - dimensions: &dimensions, - buffer: &resulting_tensor, - }, - ); + let index = index + .ok_or_else(|| RunError::other("The input tensor was empty"))?; - Ok(()) + Ok(vec![Tensor::new_1d("max_index", &[index as u32])]) } } -fn arg_max(values: &[T]) -> Option +fn arg_max(values: ArrayViewD<'_, T>) -> Option where T: PartialOrd, { @@ -129,18 +88,21 @@ mod tests { #[test] fn test_argmax() { - let values = [2.3, 12.4, 55.1, 15.4]; + let inputs = vec![Tensor::new_1d("input", &[2.3, 12.4, 55.1, 15.4])]; + let should_be = vec![Tensor::new_1d("max_index", &[2_u32])]; - let max = arg_max(&values).unwrap(); + let got = ArgMax.run(inputs).unwrap(); - assert_eq!(max, 2); + assert_eq!(got, should_be); } #[test] fn empty_inputs_are_an_error() { let empty: &[f32] = &[]; - let result = arg_max(empty); + let inputs = vec![Tensor::new_1d("input", empty)]; + + let error = ArgMax.run(inputs).unwrap_err(); - assert!(result.is_none()); + assert_eq!(error, RunError::other("The input tensor was empty")); } } diff --git a/audio_float_conversion/Cargo.toml b/audio_float_conversion/Cargo.toml index 8f5872308dd..0ee8aa675c4 100644 --- a/audio_float_conversion/Cargo.toml +++ b/audio_float_conversion/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "audio_float_conversion" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "Converted values from i16 data type to a floating-point value." diff --git a/audio_float_conversion/src/lib.rs b/audio_float_conversion/src/lib.rs index 068cc88466b..59230b509c7 100644 --- a/audio_float_conversion/src/lib.rs +++ b/audio_float_conversion/src/lib.rs @@ -1,140 +1,81 @@ -use crate::proc_block_v1::{ - BadInputReason, GraphError, InvalidInput, KernelError, -}; use hotg_rune_proc_blocks::{ - ndarray::ArrayView1, runtime_v1::*, BufferExt, SliceExt, + guest::{ + Argument, Dimensions, ElementTypeConstraint, Metadata, ProcBlock, + RunError, Tensor, TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{Array1, ArrayView1}, }; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - const I16_MAX_AS_FLOAT: f32 = i16::MAX as f32; -#[derive(Debug, Clone, PartialEq)] -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Audio Float Conversion", env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("audio"); - metadata.add_tag("float"); - - let input = TensorMetadata::new("input"); - let hint = supported_shapes( - &[ElementType::I16], - DimensionsParam::Fixed(&[1, 0]), - ); - input.add_hint(&hint); - metadata.add_input(&input); - - let output = TensorMetadata::new("output"); - output.set_description( - "converted values from i16 data type to a floating-point value.", - ); - let hint = supported_shapes( - &[ElementType::F32], - DimensionsParam::Fixed(&[1, 0]), - ); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: AudioFloatConversion, +} - fn graph(id: String) -> Result<(), GraphError> { - let ctx = - GraphContext::for_node(&id).ok_or(GraphError::MissingContext)?; +fn metadata() -> Metadata { + Metadata::new("Audio Float Conversion", env!("CARGO_PKG_VERSION")) + .with_description(env!("CARGO_PKG_DESCRIPTION")) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("audio") + .with_tag("float") + .with_input(TensorMetadata::new("input")) + .with_output(TensorMetadata::new("output").with_description( + "converted values from i16 data type to a floating-point value.", + )) +} - ctx.add_input_tensor( - "input", - ElementType::I16, - DimensionsParam::Fixed(&[1, 0]), - ); - ctx.add_output_tensor( - "output", - ElementType::F32, - DimensionsParam::Fixed(&[1, 0]), - ); +#[derive(Debug, Clone, PartialEq)] +struct AudioFloatConversion; - Ok(()) +impl ProcBlock for AudioFloatConversion { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint { + name: "input".to_string(), + element_type: ElementTypeConstraint::I16, + dimensions: Dimensions::Fixed(vec![1, 0]), + }], + outputs: vec![TensorConstraint { + name: "output".to_string(), + element_type: ElementTypeConstraint::F32, + dimensions: Dimensions::Fixed(vec![1, 0]), + }], + } } - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = - KernelContext::for_node(&id).ok_or(KernelError::MissingContext)?; + fn run(&self, inputs: Vec) -> Result, RunError> { + let tensor = Tensor::get_named(&inputs, "input")?; - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let tensor: ArrayView1 = match element_type { - ElementType::I16 => { - let tensor = buffer.view::(&dimensions) - .map_err(|e| KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::Other(e.to_string()), - }))?; - - tensor.into_dimensionality() - .map_err(|_| KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::UnsupportedShape, - }))? - }, - - other => { - return Err(KernelError::Other(format!( - "The Audio Float Conversion proc-block only accepts I16 tensors, found {:?}", - other, - ))) - }, - }; - - let output = audio_float_conversion(tensor); - - ctx.set_output_tensor( - "output", - TensorParam { - element_type: ElementType::F32, - dimensions: &dimensions, - buffer: output.as_bytes(), - }, - ); - - Ok(()) + let output = audio_float_conversion(tensor.view_1d()?); + + Ok(vec![Tensor::new("output", &output)]) } } -fn audio_float_conversion(values: ArrayView1<'_, i16>) -> Vec { - values - .iter() - .map(|&value| (value as f32 / I16_MAX_AS_FLOAT).clamp(-1.0, 1.0)) - .collect() +impl From> for AudioFloatConversion { + fn from(_: Vec) -> Self { AudioFloatConversion } +} + +fn audio_float_conversion(values: ArrayView1<'_, i16>) -> Array1 { + values.mapv(|value| (value as f32 / I16_MAX_AS_FLOAT).clamp(-1.0, 1.0)) } #[cfg(test)] mod tests { use super::*; - extern crate alloc; - use alloc::vec; - use hotg_rune_proc_blocks::ndarray::array; + use hotg_rune_proc_blocks::ndarray::{self, array}; #[test] fn handle_empty() { - let input = array![0, 0, 0, 0, 0, 0]; - let should_be = vec![0.0_f32, 0.0, 0.0, 0.0, 0.0, 0.0]; + let input = vec![Tensor::new_1d("input", &[0_i16, 0, 0, 0, 0, 0])]; + let should_be = vec![Tensor::new_1d( + "output", + &[0.0_f32, 0.0, 0.0, 0.0, 0.0, 0.0], + )]; - let got = audio_float_conversion(input.view()); + let got = AudioFloatConversion.run(input).unwrap(); assert_eq!(got, should_be); } @@ -148,7 +89,7 @@ mod tests { let got = audio_float_conversion(input.view()); - assert_eq!(got, vec![0.0, 0.49998474, -0.50001526]); + assert_eq!(got, ndarray::array![0.0, 0.49998474, -0.50001526]); } #[test] fn clamp_to_bounds() { @@ -159,6 +100,6 @@ mod tests { let got = audio_float_conversion(input.view()); - assert_eq!(got, vec![1.0, -1.0, -1.0]); + assert_eq!(got, ndarray::array![1.0, -1.0, -1.0]); } } diff --git a/binary_classification/Cargo.toml b/binary_classification/Cargo.toml index 1a86147f892..b203418298c 100644 --- a/binary_classification/Cargo.toml +++ b/binary_classification/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "binary_classification" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "A proc-block takes a probability (0.0 to 1.0) score as input and divides the output into two classes ( 0 or 1) based on a threshold" diff --git a/binary_classification/src/lib.rs b/binary_classification/src/lib.rs index 6e81d796bd0..4c72f936124 100644 --- a/binary_classification/src/lib.rs +++ b/binary_classification/src/lib.rs @@ -1,156 +1,77 @@ -use std::fmt::Display; - -use crate::proc_block_v1::{ - BadInputReason, GraphError, InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{ - runtime_v1::{self, *}, - BufferExt, SliceExt, +use hotg_rune_proc_blocks::guest::{ + Argument, ArgumentMetadata, ArgumentType, CreateError, Dimensions, + ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, }; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -#[macro_use] -extern crate alloc; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: BinaryClassification, +} -use alloc::vec::Vec; -use proc_block_v1::{BadArgumentReason, InvalidArgument}; +fn metadata() -> Metadata { + Metadata::new("Binary Classification", env!("CARGO_PKG_VERSION")) + .with_description( + "Classify each element in a tensor depending on whether they are above or below a certain threshold.", + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("classify") + .with_argument(ArgumentMetadata::new("threshold") + .with_default_value("0.5") + .with_description("The classification threshold") + .with_hint(ArgumentType::Float)) + .with_input(TensorMetadata::new("input").with_description("The numbers to classify")) + .with_output( + TensorMetadata::new("classified") + .with_description("A tensor of `1`'s and `0`'s, where `1` indicates an element was above the `threshold` and `0` means it was below."), + ) +} /// A proc-block which takes a rank 1 `tensor` as input, return 1 if value /// inside the tensor is greater than 1 otherwise 0. -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Binary Classification", env!("CARGO_PKG_VERSION")); - metadata.set_description( - "Classify each element in a tensor depending on whether they are above or below a certain threshold.", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("classify"); - - let threshold = ArgumentMetadata::new("threshold"); - threshold.set_default_value("0.5"); - threshold.set_description("The classification threshold."); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - threshold.add_hint(&hint); - metadata.add_argument(&threshold); - - let input = TensorMetadata::new("input"); - input.set_description("The numbers to classify"); - let hint = - supported_shapes(&[ElementType::F32], DimensionsParam::Fixed(&[0])); - input.add_hint(&hint); - metadata.add_input(&input); - - let output = TensorMetadata::new("classified"); - output.set_description("A tensor of `1`'s and `0`'s, where `1` indicates an element was above the `threshold` and `0` means it was below."); - let hint = - supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[0])); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - ctx.add_input_tensor( - "input", - ElementType::U32, - DimensionsParam::Fixed(&[0]), - ); - ctx.add_output_tensor( - "classified", - ElementType::U32, - DimensionsParam::Fixed(&[0]), - ); +struct BinaryClassification { + threshold: f32, +} - Ok(()) +impl ProcBlock for BinaryClassification { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "input", + ElementTypeConstraint::F32, + Dimensions::Dynamic, + )], + outputs: vec![TensorConstraint::new( + "output", + ElementTypeConstraint::U32, + Dimensions::Dynamic, + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let threshold = get_threshold(|n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let output = match element_type { - ElementType::F32 =>{ - buffer.view::(&dimensions) - .map_err(|e| KernelError::InvalidInput(InvalidInput{ name: "bounding_boxes".to_string(), reason: BadInputReason::InvalidValue(e.to_string()) }))?; - transform(buffer.elements(), threshold) - } - other => { - return Err(KernelError::Other(format!( - "The Object Filter proc-block doesn't support {:?} element type", - other, - ))) - }, - }; + fn run(&self, inputs: Vec) -> Result, RunError> { + let tensor = Tensor::get_named(&inputs, "input")?.view::()?; - ctx.set_output_tensor( - "normalized", - TensorParam { - element_type: ElementType::U32, - dimensions: &dimensions, - buffer: &output.as_bytes(), - }, - ); + let output = + tensor.mapv(|v| if v >= self.threshold { 1_u32 } else { 0 }); - Ok(()) + Ok(vec![Tensor::new("output", &output)]) } } -fn get_threshold( - get_argument: impl FnOnce(&str) -> Option, -) -> Result { - get_argument("threshold") - .ok_or_else(|| InvalidArgument::not_found("threshold"))? - .parse::() - .map_err(|e| InvalidArgument::invalid_value("threshold", e)) -} +impl TryFrom> for BinaryClassification { + type Error = CreateError; -impl InvalidArgument { - fn not_found(name: impl Into) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::NotFound, - } - } + fn try_from(args: Vec) -> Result { + let threshold = hotg_rune_proc_blocks::guest::parse::optional_arg( + &args, + "threshold", + )? + .unwrap_or(0.5); - fn invalid_value(name: impl Into, reason: impl Display) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::InvalidValue(reason.to_string()), - } - } -} - -fn transform(input: &[f32], threshold: f32) -> Vec { - // let value = input.into(); - let mut label: u32 = 0; - if input > &[threshold] { - label = 1 + Ok(BinaryClassification { threshold }) } - let v: Vec = vec![label]; - v } #[cfg(test)] @@ -159,9 +80,12 @@ mod tests { #[test] fn test_binary_classification() { - let v = vec![0.7]; - let output = transform(&v, 0.5); - let should_be = vec![1]; - assert_eq!(output, should_be); + let transform = BinaryClassification { threshold: 0.5 }; + let inputs = vec![Tensor::new_1d("input", &[0.7_f32])]; + let should_be = vec![Tensor::new_1d("output", &[1_u32])]; + + let got = transform.run(inputs).unwrap(); + + assert_eq!(got, should_be); } } diff --git a/elastic_net/Cargo.toml b/elastic_net/Cargo.toml index 122186c3faf..39ed29a11c9 100644 --- a/elastic_net/Cargo.toml +++ b/elastic_net/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "elastic_net" version = "0.12.1" -edition = "2018" +edition = "2021" description = "an extension of linear regression that adds L1 and L2 regularization penalties to the loss function during training" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,11 +10,10 @@ description = "an extension of linear regression that adds L1 and L2 regularizat hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/elastic_net/src/lib.rs b/elastic_net/src/lib.rs index e0c5e9468be..7a8b63ead38 100644 --- a/elastic_net/src/lib.rs +++ b/elastic_net/src/lib.rs @@ -1,231 +1,164 @@ -// use linfa_logistic::LogisticRegression; +use hotg_rune_proc_blocks::{ + guest::{ + Argument, ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{Array1, ArrayView1, ArrayView2}, +}; use smartcore::{linalg::naive::dense_matrix::*, linear::elastic_net::*}; -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Elastic, +} -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +fn metadata() -> Metadata { + Metadata::new("Elastic Net", env!("CARGO_PKG_VERSION")) + .with_description( + "a linear approach for modelling the relationship between a scalar response and one or more explanatory variables", + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("regression") + .with_tag("linear modeling") + .with_tag("analytics") + .with_input(TensorMetadata::new("x_train")) + .with_input(TensorMetadata::new("y_train")) + .with_input(TensorMetadata::new("x_test")) + .with_output(TensorMetadata::new("y_test")) +} /// A proc block which can perform linear regression -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Elastic Net", env!("CARGO_PKG_VERSION")); - metadata.set_description( - "a linear approach for modelling the relationship between a scalar response and one or more explanatory variables", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("regression"); - metadata.add_tag("linear modeling"); - metadata.add_tag("analytics"); - - let x_train = TensorMetadata::new("x_train"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_train.add_hint(&hint); - metadata.add_input(&x_train); - - let y_train = TensorMetadata::new("y_train"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_train.add_hint(&hint); - metadata.add_input(&y_train); - - let x_test = TensorMetadata::new("x_test"); - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_test.add_hint(&hint); - metadata.add_input(&x_test); - - let y_test = TensorMetadata::new("y_test"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_test.add_hint(&hint); - metadata.add_output(&y_test); - - register_node(&metadata); +struct Elastic; + +impl ProcBlock for Elastic { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new( + "x_train", + ElementTypeConstraint::F64, + vec![0, 0], + ), + TensorConstraint::new( + "y_train", + ElementTypeConstraint::F64, + vec![0], + ), + TensorConstraint::new( + "x_test", + ElementTypeConstraint::F64, + vec![0, 0], + ), + ], + outputs: vec![TensorConstraint::new( + "y_test", + ElementTypeConstraint::F64, + vec![0], + )], + } } - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "x_train", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_input_tensor( - "y_train", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_input_tensor( - "x_test", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_output_tensor( - "y_test", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - Ok(()) - } + fn run(&self, inputs: Vec) -> Result, RunError> { + let x_train = Tensor::get_named(&inputs, "x_train")?.view_2d()?; + let y_train = Tensor::get_named(&inputs, "y_train")?.view_1d()?; + let x_test = Tensor::get_named(&inputs, "x_test")?.view_2d()?; - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let x_train = ctx.get_input_tensor("x_train").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "x_train".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let y_train = ctx.get_input_tensor("y_train").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_train".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let x_test = ctx.get_input_tensor("x_test").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "x_test".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let output = transform( - &x_train.buffer.elements(), - &x_train.dimensions, - &y_train.buffer.elements(), - &x_test.buffer.elements(), - &x_test.dimensions, - ); - - let y_test_dimension = [x_test.dimensions[0]]; - - ctx.set_output_tensor( - "y_test", - TensorParam { - element_type: ElementType::F64, - dimensions: &y_test_dimension, - buffer: &output.to_vec().as_bytes(), - }, - ); - - Ok(()) + let output = transform(x_train, y_train, x_test)?; + + Ok(vec![Tensor::new("y_test", &output)]) } } +impl From> for Elastic { + fn from(_: Vec) -> Self { Elastic } +} + fn transform( - x_train: &[f64], - x_train_dim: &[u32], - y_train: &[f64], - x_test: &[f64], - x_test_dim: &[u32], -) -> Vec { - // Iris data - let x_train = DenseMatrix::from_array( - x_train_dim[0] as usize, - x_train_dim[1] as usize, - x_train, - ); - - let model = - ElasticNet::fit(&x_train, &y_train.to_vec(), Default::default()) - .unwrap(); - - let x_test = DenseMatrix::from_array( - x_test_dim[0] as usize, - x_test_dim[1] as usize, - x_test, - ); - - let y_hat = model.predict(&x_test).unwrap(); - - y_hat + x_train: ArrayView2<'_, f64>, + y_train: ArrayView1<'_, f64>, + x_test: ArrayView2<'_, f64>, +) -> Result, RunError> { + // Note: we need to copy our values because elasticnet doesn't interoperate + // with ndarray and it can't use &[T] slices. + + let (rows, columns) = x_train.dim(); + let x_train: Vec = x_train.iter().map(|e| *e as f64).collect(); + + let x_train = + DenseMatrix::from_array(rows, columns, &x_train); + + let y_train: Vec<_> = y_train.to_vec(); + + let model = ElasticNet::fit(&x_train, &y_train, Default::default()) + .map_err(RunError::other)?; + + let (rows, columns) = x_test.dim(); + let x_test: Vec = x_test.iter().map(|e| *e as f64).collect(); + + let x_test = + DenseMatrix::from_array(rows, columns, &x_test); + + model + .predict(&x_test) + .map(Array1::from_vec) + .map_err(RunError::other) } -// comenting out test because it will in after deciaml places everytime so we -// can't generate a fixed y_pred. BUt I have tested in local and it's working. -// :) #[cfg(test)] -// mod tests { -// use super::*; - -// #[test] -// fn check_model() { -// let x_train = -// [234.289, 235.6, 159.0, 107.608, 1947., 60.323, -// 259.426, 232.5, 145.6, 108.632, 1948., 61.122, -// 258.054, 368.2, 161.6, 109.773, 1949., 60.171, -// 284.599, 335.1, 165.0, 110.929, 1950., 61.187, -// 328.975, 209.9, 309.9, 112.075, 1951., 63.221, -// 346.999, 193.2, 359.4, 113.270, 1952., 63.639, -// 365.385, 187.0, 354.7, 115.094, 1953., 64.989, -// 363.112, 357.8, 335.0, 116.219, 1954., 63.761, -// 397.469, 290.4, 304.8, 117.388, 1955., 66.019, -// 419.180, 282.2, 285.7, 118.734, 1956., 67.857, -// 442.769, 293.6, 279.8, 120.445, 1957., 68.169, -// 444.546, 468.1, 263.7, 121.950, 1958., 66.513, -// 482.704, 381.3, 255.2, 123.366, 1959., 68.655, -// 502.601, 393.1, 251.4, 125.368, 1960., 69.564, -// 518.173, 480.6, 257.2, 127.852, 1961., 69.331, -// 554.894, 400.7, 282.7, 130.081, 1962., 70.551]; - -// let y_train: Vec = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, -// 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; - -// let dim: Vec = vec![16, 6]; - -// let y_pred = transform( -// &x_train, -// &dim, -// &y_train, -// &x_train, -// &dim, -// ); - -// println!("{:?}", y_pred); - -// let should_be = vec![112.7901174966222, 115.23028619478328, -// 104.00652847960953, 106.91893927853232, 101.89562519168146, -// 98.62225598974453, 100.3986322888735, 90.34439937146931, 99.44618079637769, -// 102.87598179071631, 103.51961064304874, 92.90632404596613, -// 101.22197835350744, 101.6134669106201, 95.40896231278623, 99.70071085566008]; - -// assert_eq!(y_pred, should_be); -// } -// } +#[cfg(test)] +mod tests { + use hotg_rune_proc_blocks::ndarray::{self, Array2}; + + use super::*; + + #[test] + fn check_model() { + let x_train: Array2 = ndarray::array![ + [234.289, 235.6, 159.0, 107.608, 1947., 60.323], + [259.426, 232.5, 145.6, 108.632, 1948., 61.122], + [258.054, 368.2, 161.6, 109.773, 1949., 60.171], + [284.599, 335.1, 165.0, 110.929, 1950., 61.187], + [328.975, 209.9, 309.9, 112.075, 1951., 63.221], + [346.999, 193.2, 359.4, 113.270, 1952., 63.639], + [365.385, 187.0, 354.7, 115.094, 1953., 64.989], + [363.112, 357.8, 335.0, 116.219, 1954., 63.761], + [397.469, 290.4, 304.8, 117.388, 1955., 66.019], + [419.180, 282.2, 285.7, 118.734, 1956., 67.857], + [442.769, 293.6, 279.8, 120.445, 1957., 68.169], + [444.546, 468.1, 263.7, 121.950, 1958., 66.513], + [482.704, 381.3, 255.2, 123.366, 1959., 68.655], + [502.601, 393.1, 251.4, 125.368, 1960., 69.564], + [518.173, 480.6, 257.2, 127.852, 1961., 69.331], + [554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]; + + let y_train: Array1 = ndarray::array![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, + 108.4, 110.8, 112.6, 114.2, 115.7, 116.9 + ]; + + let y_pred = + transform(x_train.view(), y_train.view(), x_train.view()).unwrap(); + + let should_be = vec![ + 112.7901174966222, + 115.23028619478328, + 104.00652847960953, + 106.91893927853232, + 101.89562519168146, + 98.62225598974453, + 100.3986322888735, + 90.34439937146931, + 99.44618079637769, + 102.87598179071631, + 103.51961064304874, + 92.90632404596613, + 101.22197835350744, + 101.6134669106201, + 95.40896231278623, + 99.70071085566008, + ]; + + assert_eq!(y_pred.to_vec(), should_be); + } +} diff --git a/f1-score/Cargo.toml b/f1-score/Cargo.toml index 944dfa6f670..5c77797f8fd 100644 --- a/f1-score/Cargo.toml +++ b/f1-score/Cargo.toml @@ -10,11 +10,10 @@ description = "a proc-block used to calculate f1-score" hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/f1-score/src/lib.rs b/f1-score/src/lib.rs index 98f66ecd3f5..515e22726d4 100644 --- a/f1-score/src/lib.rs +++ b/f1-score/src/lib.rs @@ -1,216 +1,138 @@ +use hotg_rune_proc_blocks::guest::{ + Argument, ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, +}; use smartcore::metrics::{f1::F1, precision::Precision, recall::Recall}; -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: F1Score, +} -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +fn metadata() -> Metadata { + Metadata::new("F-Score", env!("CARGO_PKG_VERSION")) + .with_description("for assessing prediction error") + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("metric") + .with_tag("analytics") + .with_input(TensorMetadata::new("y_true")) + .with_input(TensorMetadata::new("y_pred")) + .with_output(TensorMetadata::new("f1_score")) + .with_output(TensorMetadata::new("precision")) + .with_output(TensorMetadata::new("recall")) +} /// A proc-block used to calculate f1-score -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("F-Score", env!("CARGO_PKG_VERSION")); - metadata.set_description("for assessing prediction error"); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("metric"); - metadata.add_tag("analytics"); - - let y_true = TensorMetadata::new("y_true"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_true.add_hint(&hint); - metadata.add_input(&y_true); - - let y_pred = TensorMetadata::new("y_pred"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_pred.add_hint(&hint); - metadata.add_input(&y_pred); - - let f1 = TensorMetadata::new("f1_score"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[1])); - f1.add_hint(&hint); - metadata.add_input(&f1); - - let precision = TensorMetadata::new("precision"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[1])); - precision.add_hint(&hint); - metadata.add_output(&precision); - - let recall = TensorMetadata::new("recall"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[1])); - recall.add_hint(&hint); - metadata.add_output(&recall); - - register_node(&metadata); +struct F1Score; + +impl ProcBlock for F1Score { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new( + "y_true", + ElementTypeConstraint::F64, + vec![0], + ), + TensorConstraint::new( + "y_pred", + ElementTypeConstraint::F64, + vec![0], + ), + ], + outputs: vec![ + TensorConstraint::new( + "f1_score", + ElementTypeConstraint::F64, + vec![1], + ), + TensorConstraint::new( + "precision", + ElementTypeConstraint::F64, + vec![1], + ), + TensorConstraint::new( + "recall", + ElementTypeConstraint::F64, + vec![1], + ), + ], + } } - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "y_true", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_input_tensor( - "y_pred", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_output_tensor( - "f1_score", - element_type, - DimensionsParam::Fixed(&[1]), - ); - - ctx.add_output_tensor( - "precision", - element_type, - DimensionsParam::Fixed(&[1]), - ); - - ctx.add_output_tensor( - "recall", - element_type, - DimensionsParam::Fixed(&[1]), - ); - - Ok(()) - } + fn run(&self, inputs: Vec) -> Result, RunError> { + let y_true = Tensor::get_named(&inputs, "y_true")?.view_1d()?; + let y_pred = Tensor::get_named(&inputs, "y_pred")?.view_1d()?; + + let (f1_score, precision, recall) = + transform(y_true.to_vec(), y_pred.to_vec())?; - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let y_true = ctx.get_input_tensor("y_true").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_true".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let y_pred = ctx.get_input_tensor("y_pred").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_pred".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let metric = transform( - y_true.buffer.elements().to_vec(), - y_pred.buffer.elements().to_vec(), - ); - - let f1 = vec![metric.0]; - - ctx.set_output_tensor( - "f1_score", - TensorParam { - element_type: ElementType::F64, - dimensions: &[1 as u32], - buffer: &f1.as_bytes(), - }, - ); - - let precision = vec![metric.1]; - - ctx.set_output_tensor( - "precision", - TensorParam { - element_type: ElementType::F64, - dimensions: &[1 as u32], - buffer: &precision.as_bytes(), - }, - ); - - let recall = vec![metric.2]; - - ctx.set_output_tensor( - "recall", - TensorParam { - element_type: ElementType::F64, - dimensions: &[1 as u32], - buffer: &recall.as_bytes(), - }, - ); - - Ok(()) + Ok(vec![ + Tensor::new_1d("f1_score", &[f1_score]), + Tensor::new_1d("precision", &[precision]), + Tensor::new_1d("recall", &[recall]), + ]) } } -fn transform(y_true: Vec, y_pred: Vec) -> (f64, f64, f64) { +impl From> for F1Score { + fn from(_: Vec) -> Self { F1Score } +} + +fn transform( + y_true: Vec, + y_pred: Vec, +) -> Result<(f64, f64, f64), RunError> { + if y_true.len() != y_pred.len() { + let msg = format!( + "\"y_true\" and \"y_pred\" should have the same dimensions ({} != {})", + y_true.len(), y_pred.len(), + ); + return Err(RunError::other(msg)); + } + let f1 = F1 { beta: 1.0 }.get_score(&y_pred, &y_true); let precision = Precision {}.get_score(&y_pred, &y_true); let recall = Recall {}.get_score(&y_pred, &y_true); - (f1, precision, recall) + Ok((f1, precision, recall)) } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn check_f1() { - let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; - let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; - - let metric = transform(y_true, y_pred); + use hotg_rune_proc_blocks::ndarray; - assert_eq!(0.5714285714285715, metric.0); - } + use super::*; #[test] - fn check_precision() { - let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; - let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; - - let metric = transform(y_true, y_pred); - - assert_eq!(0.6666666666666666, metric.1); + fn known_inputs() { + let y_pred = ndarray::array![0_f64, 0., 1., 1., 1., 1.]; + let y_true = ndarray::array![0_f64, 1., 1., 0., 1., 0.]; + let inputs = vec![ + Tensor::new("y_pred", &y_pred), + Tensor::new("y_true", &y_true), + ]; + + let got = F1Score.run(inputs).unwrap(); + + let should_be = vec![ + Tensor::new_1d("f1_score", &[0.5714285714285715_f64]), + Tensor::new_1d("precision", &[0.6666666666666666_f64]), + Tensor::new_1d("recall", &[0.5_f64]), + ]; + assert_eq!(got, should_be); } #[test] - fn check_recall() { + fn check_f1() { let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; - let metric = transform(y_true, y_pred); + let (f1, precision, recall) = transform(y_true, y_pred).unwrap(); - assert_eq!(0.5, metric.2); + assert_eq!(0.5714285714285715, f1); + assert_eq!(0.6666666666666666, precision); + assert_eq!(0.5, recall); } } diff --git a/fft/Cargo.toml b/fft/Cargo.toml index ea5d8a4d001..8ae5d5d4353 100644 --- a/fft/Cargo.toml +++ b/fft/Cargo.toml @@ -2,7 +2,7 @@ name = "fft" version = "0.12.0" authors = ["The Rune Developers "] -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "converts a signal from its original domain (often time or space) to a representation in the frequency domain." @@ -15,12 +15,9 @@ crate-type = ["cdylib", "rlib"] [dependencies] hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen"} -hound = "3.4" -libm = "0.2.1" # See https://github.com/hotg-ai/rune/pull/107#issuecomment-825806000 mel = { git = "https://github.com/hotg-ai/mel", rev = "017694ee3143c11ea9b75ba6cd27fe7c8a69a867", default-features = false } nalgebra = { version = "0.29", default-features = false, features = ["alloc"] } -normalize = { path = "../normalize", version = "^0.12.0"} sonogram = {git = "https://github.com/hotg-ai/sonogram", rev = "009bc0cba44267d8a0807e43c9bb0712f0f334ea" } [dev-dependencies] diff --git a/fft/src/lib.rs b/fft/src/lib.rs index d96cb3fb1c3..7b62aa048bd 100644 --- a/fft/src/lib.rs +++ b/fft/src/lib.rs @@ -1,207 +1,112 @@ -use std::fmt::Display; - -use crate::proc_block_v1::*; - -use hotg_rune_proc_blocks::{ - runtime_v1::{self, *}, - BufferExt, SliceExt, -}; - #[cfg(test)] #[macro_use] extern crate pretty_assertions; -#[macro_use] -extern crate alloc; -use alloc::vec::Vec; +use hotg_rune_proc_blocks::{ + guest::{ + parse, Argument, ArgumentMetadata, ArgumentType, CreateError, + ElementType, Metadata, ProcBlock, RunError, Tensor, TensorConstraint, + TensorConstraints, TensorMetadata, + }, + ndarray::Array1, +}; use nalgebra::DMatrix; use sonogram::SpecOptionsBuilder; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -struct ProcBlockV1; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Fft, +} -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("FFT", env!("CARGO_PKG_VERSION")); - metadata.set_description( +fn metadata() -> Metadata { + Metadata::new("FFT", env!("CARGO_PKG_VERSION")) + .with_description( "converts a signal from its original domain (often time or space) to a representation in the frequency domain.", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("stft"); - metadata.add_tag("frequency domain"); - - let sampling_rate = ArgumentMetadata::new("Sample Rate"); - sampling_rate.set_description("Sampling Rate"); - sampling_rate.set_default_value("16000"); - let hint = - runtime_v1::supported_argument_type(ArgumentType::UnsignedInteger); - sampling_rate.add_hint(&hint); - metadata.add_argument(&sampling_rate); - - let bins = ArgumentMetadata::new("Bins"); - bins.set_description("Intervals between samples in frequency domain"); - bins.set_default_value("480"); - let hint = - runtime_v1::supported_argument_type(ArgumentType::UnsignedInteger); - bins.add_hint(&hint); - metadata.add_argument(&bins); - - let window_overlap = ArgumentMetadata::new("Window Overlap"); - window_overlap.set_description("Ratio of overlapped intervals."); - window_overlap.set_default_value("0.6666667"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - window_overlap.add_hint(&hint); - metadata.add_argument(&window_overlap); - - let input = TensorMetadata::new("audio"); - input.set_description("A 1D tensor of `i16` samples."); - let hint = supported_shapes( - &[ElementType::I16], - DimensionsParam::Fixed(&[1, 0]), - ); - input.add_hint(&hint); - metadata.add_input(&input); - - let output = TensorMetadata::new("output"); - output.set_description("output signal after applying STFT"); - let hint = supported_shapes( - &[ElementType::U32], - DimensionsParam::Fixed(&[1, 0]), - ); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("stft") + .with_tag("frequency domain") + .with_argument( + ArgumentMetadata::new("sample_rate") + .with_description("Sampling rate") + .with_default_value("16000") + .with_hint(ArgumentType::UnsignedInteger) + ) + .with_argument( + ArgumentMetadata::new("bins") + .with_description("Intervals between samples in frequency domain") + .with_default_value("480") + .with_hint(ArgumentType::UnsignedInteger) + ) + .with_argument( + ArgumentMetadata::new("window_overlap") + .with_description("Ratio of overlapped intervals.") + .with_default_value("0.6666667") + .with_hint(ArgumentType::Float) + ) + .with_input( + TensorMetadata::new("audio") + .with_description("A 1D tensor containing PCM-encoded audio samples.") + ) + .with_output( + TensorMetadata::new("output") + .with_description("output signal after applying STFT") + ) +} - ctx.add_input_tensor( - "audio", - ElementType::I16, - DimensionsParam::Fixed(&[1, 0]), - ); - ctx.add_output_tensor( - "output", - ElementType::F32, - DimensionsParam::Fixed(&[1, 0]), - ); +#[derive(Debug, Copy, Clone, PartialEq)] +struct Fft { + sample_rate: u32, + bins: u32, + window_overlap: f32, +} - Ok(()) +impl ProcBlock for Fft { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "audio", + ElementType::I16, + [1, 0], + )], + outputs: vec![TensorConstraint::new( + "output", + ElementType::F32, + [1, 0], + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let sampling_rate = - get_u32_args("sampling_rate", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - let bins = get_u32_args("bins", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - let window_overlap = - get_f32_args("window_overlap", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "audio".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - check_input_dimensions(&dimensions); - - let input: Vec = buffer.elements().to_vec(); - - let output = match element_type { - ElementType::I16 => { - transform_inner(input, sampling_rate, bins, window_overlap) - }, - - other => { - return Err(KernelError::Other(format!( - "The FFT proc-block only accepts I16 tensors, found {:?}", - other, - ))) - }, - }; + fn run(&self, inputs: Vec) -> Result, RunError> { + let input = Tensor::get_named(&inputs, "audio")?.view_1d()?; - let output = match output { - Some(ix) => ix, - None => { - return Err(KernelError::Other( - "The input tensor was empty".to_string(), - )) - }, - }; - - let resulting_tensor = output.as_bytes(); - - ctx.set_output_tensor( - "output", - TensorParam { - element_type: ElementType::F32, - dimensions: &dimensions, - buffer: &resulting_tensor, - }, + let output = transform_inner( + input.to_vec(), + self.sample_rate, + self.bins, + self.window_overlap, ); - Ok(()) + Ok(vec![Tensor::new("output", &output)]) } } -fn check_input_dimensions(dimensions: &[u32]) { - assert_eq!( - (!(dimensions.len() == 2 && dimensions[0] == 1) - || !(dimensions.len() == 1)), - true, - "This proc block only supports 1D outputs (requested output: {:?})", - dimensions - ); -} - -fn get_u32_args( - name: &str, - get_argument: impl FnOnce(&str) -> Option, -) -> Result { - get_argument(name) - .ok_or_else(|| InvalidArgument::not_found(name))? - .parse::() - .map_err(|e| InvalidArgument::invalid_value(name, e)) -} - -fn get_f32_args( - name: &str, - get_argument: impl FnOnce(&str) -> Option, -) -> Result { - get_argument(name) - .ok_or_else(|| InvalidArgument::not_found(name))? - .parse::() - .map_err(|e| InvalidArgument::invalid_value(name, e)) -} +impl TryFrom> for Fft { + type Error = CreateError; -impl InvalidArgument { - fn not_found(name: impl Into) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::NotFound, - } - } + fn try_from(args: Vec) -> Result { + let sample_rate = + parse::optional_arg(&args, "sample_rate")?.unwrap_or(16000); + let bins = parse::optional_arg(&args, "bins")?.unwrap_or(480); + let window_overlap = + parse::optional_arg(&args, "window_overlap")?.unwrap_or(0.6666667); - fn invalid_value(name: impl Into, reason: impl Display) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::InvalidValue(reason.to_string()), - } + Ok(Fft { + sample_rate, + bins, + window_overlap, + }) } } @@ -210,7 +115,7 @@ fn transform_inner( sample_rate: u32, bins: u32, window_overlap: f32, -) -> Option<[u32; 1960]> { +) -> Array1 { // Build the spectrogram computation engine let mut spectrograph = SpecOptionsBuilder::new(49, 241) .set_window_fn(sonogram::hann_function) @@ -251,7 +156,7 @@ fn transform_inner( let power_spectrum_matrix: DMatrix = DMatrix::from_rows(&power_spectrum_vec); let mel_spectrum_matrix = &mel_filter_matrix * &power_spectrum_matrix; - let mel_spectrum_matrix = mel_spectrum_matrix.map(libm::sqrt); + let mel_spectrum_matrix = mel_spectrum_matrix.map(f64::sqrt); let min_value = mel_spectrum_matrix .data @@ -264,17 +169,13 @@ fn transform_inner( .iter() .fold(f64::NEG_INFINITY, |a, &b| a.max(b)); - let res: Vec = mel_spectrum_matrix + mel_spectrum_matrix .data .as_vec() .iter() .map(|freq| 65536.0 * (freq - min_value) / (max_value - min_value)) .map(|freq| freq as u32) - .collect(); - - let mut out = [0; 1960]; - out.copy_from_slice(&res[..1960]); - Some(out) + .collect() } #[cfg(test)] @@ -285,7 +186,7 @@ mod tests { fn it_works() { let input = [0; 16000].to_vec(); - let got = transform_inner(input, 16000, 480, 0.6666667).unwrap(); + let got = transform_inner(input, 16000, 480, 0.6666667); assert_eq!(got.len(), 1960); } diff --git a/image-normalization/Cargo.toml b/image-normalization/Cargo.toml index e62de660dfa..53d691c00f6 100644 --- a/image-normalization/Cargo.toml +++ b/image-normalization/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "image-normalization" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "A normalization routine takes the image matrix as input and fits their values to the range [0, 1] as f32's." @@ -18,4 +18,4 @@ wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/image_input/Cargo.toml b/image_decode/Cargo.toml similarity index 54% rename from image_input/Cargo.toml rename to image_decode/Cargo.toml index 42fb487b263..e0b7a1fd529 100644 --- a/image_input/Cargo.toml +++ b/image_decode/Cargo.toml @@ -1,8 +1,8 @@ [package] -name = "image_input" +name = "image_decode" version = "0.12.0" edition = "2021" -description = "Read an image from the environment." +description = "Decode an image from a well-known image format." [lib] crate-type = ["cdylib", "rlib"] @@ -11,6 +11,8 @@ crate-type = ["cdylib", "rlib"] [dependencies] hotg-rune-proc-blocks = { path = "../support" } +image = { version = "0.24.2", default-features = false, features = ["bmp", "dds", "dxt", "farbfeld", "gif", "hdr", "ico", "jpeg_rayon", "jpeg", "png", "pnm", "tga", "tiff", "webp"] } +strum = { version = "0.24.1", features = ["derive"] } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } [package.metadata.wapm] diff --git a/image_decode/src/image.png b/image_decode/src/image.png new file mode 100644 index 00000000000..896009a194f Binary files /dev/null and b/image_decode/src/image.png differ diff --git a/image_decode/src/lib.rs b/image_decode/src/lib.rs new file mode 100644 index 00000000000..af957a54915 --- /dev/null +++ b/image_decode/src/lib.rs @@ -0,0 +1,228 @@ +use hotg_rune_proc_blocks::{ + guest::{ + parse, Argument, ArgumentHint, ArgumentMetadata, CreateError, + ElementType, ElementTypeConstraint, InvalidInput, Metadata, + PrimitiveTensorElement, ProcBlock, RunError, Tensor, TensorConstraint, + TensorConstraints, TensorMetadata, + }, + ndarray::Array, +}; +use image::{ + flat::SampleLayout, imageops::FilterType, FlatSamples, ImageBuffer, Pixel, +}; +use strum::VariantNames; + +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Image, +} + +fn metadata() -> Metadata { + Metadata::new("Image Decode", env!("CARGO_PKG_VERSION")) + .with_description(env!("CARGO_PKG_DESCRIPTION")) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("input") + .with_tag("image") + .with_argument( + ArgumentMetadata::new("width") + .with_description("The image width in pixels.") + .with_hint(ArgumentHint::NonNegativeNumber), + ) + .with_argument( + ArgumentMetadata::new("height") + .with_description("The image height in pixels.") + .with_hint(ArgumentHint::NonNegativeNumber), + ) + .with_argument( + ArgumentMetadata::new("pixel_format") + .with_description( + "The pixel format to use for the loaded image.", + ) + .with_default_value(PixelFormat::RGB8.to_string()) + .with_hint( + ArgumentHint::OneOf( + PixelFormat::VARIANTS + .iter() + .map(|s| s.to_string()) + .collect(), + ), + ) + .with_hint(ArgumentHint::NonNegativeNumber), + ) + .with_input( + TensorMetadata::new("file") + .with_description("A file containing the image"), + ) + .with_output(TensorMetadata::new("image")) +} + +#[derive(Debug, Clone, PartialEq)] +struct Image { + width: usize, + height: usize, + pixel_format: PixelFormat, +} + +impl ProcBlock for Image { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "file", + ElementTypeConstraint::U8, + vec![0], + )], + outputs: vec![TensorConstraint::new( + "image", + self.pixel_format.element_type(), + self.pixel_format + .dimensions(self.width, self.height) + .to_vec(), + )], + } + } + + fn run(&self, inputs: Vec) -> Result, RunError> { + let tensor = Tensor::get_named(&inputs, "file")?; + let view = tensor.view_1d::()?; + let bytes = view.as_slice().ok_or_else(|| { + RunError::other( + "Unable to view the file tensor as a contiguous slice", + ) + })?; + + let img = image::load_from_memory(bytes) + .map_err(|e| InvalidInput::other("file", e))?; + + let resized = img.resize_exact( + self.width as u32, + self.height as u32, + FilterType::Nearest, + ); + + let formatted = match self.pixel_format { + PixelFormat::RGB8 => to_tensor(resized.into_rgb8()), + PixelFormat::RGBA8 => to_tensor(resized.into_rgba8()), + }; + + Ok(vec![formatted]) + } +} + +fn to_tensor

(img: ImageBuffer>) -> Tensor +where + P: Pixel, + P::Subpixel: PrimitiveTensorElement, +{ + let FlatSamples { + samples, + layout: + SampleLayout { + channels, + width, + height, + .. + }, + .. + } = img.into_flat_samples(); + + let array = Array::from_shape_vec( + (width as usize, height as usize, channels as usize), + samples, + ) + .expect("Image dimensions should always be well-formed"); + + Tensor::new("image", &array) +} + +impl TryFrom> for Image { + type Error = CreateError; + + fn try_from(args: Vec) -> Result { + let pixel_format = parse::optional_arg(&args, "pixel_format")? + .unwrap_or(PixelFormat::RGB8); + let width = parse::required_arg(&args, "width")?; + let height = parse::required_arg(&args, "height")?; + + Ok(Image { + pixel_format, + height, + width, + }) + } +} + +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + strum::EnumString, + strum::EnumVariantNames, + strum::Display, +)] +enum PixelFormat { + #[strum(serialize = "rgb8")] + RGB8, + #[strum(serialize = "rgba8")] + RGBA8, +} + +impl PixelFormat { + fn dimensions(self, width: usize, height: usize) -> [u32; 3] { + match self { + PixelFormat::RGBA8 | PixelFormat::RGB8 => { + [width as u32, height as u32, self.channels()] + }, + } + } + + fn channels(self) -> u32 { + match self { + PixelFormat::RGB8 => 3, + PixelFormat::RGBA8 => 4, + } + } + + fn element_type(self) -> ElementType { + match self { + PixelFormat::RGB8 | PixelFormat::RGBA8 => ElementType::U8, + } + } +} + +#[cfg(test)] +mod tests { + use hotg_rune_proc_blocks::ndarray; + + use super::*; + + #[test] + fn load_a_known_file() { + let bytes = include_bytes!("image.png"); + // [black, red] + // [green, blue] + let tensor = Tensor::new_1d("file", bytes); + let proc_block = Image { + height: 2, + width: 2, + pixel_format: PixelFormat::RGB8, + }; + + let got = proc_block.run(vec![tensor]).unwrap(); + + assert_eq!(got.len(), 1); + let image = Tensor::get_named(&got, "image") + .unwrap() + .view_3d::() + .unwrap(); + + let should_be = ndarray::array![ + [[255_u8, 0, 0], [0, 0, 0]], + [[0, 255, 0], [0, 0, 255],] + ]; + assert_eq!(image, should_be); + } +} diff --git a/image_input/src/lib.rs b/image_input/src/lib.rs deleted file mode 100644 index c09771c7ae4..00000000000 --- a/image_input/src/lib.rs +++ /dev/null @@ -1,199 +0,0 @@ -use std::{ - error::Error, - fmt::{self, Display, Formatter}, - str::FromStr, -}; - -use crate::{ - proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, - }, - runtime_v1::*, -}; -use hotg_rune_proc_blocks::{prelude::*, runtime_v1}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Image Input", env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("input"); - metadata.add_tag("image"); - - let width = ArgumentMetadata::new("width"); - width.set_description("The image width in pixels."); - let hint = runtime_v1::non_negative_number(); - width.add_hint(&hint); - metadata.add_argument(&width); - - let height = ArgumentMetadata::new("height"); - height.set_description("The image height in pixels."); - let hint = runtime_v1::non_negative_number(); - height.add_hint(&hint); - metadata.add_argument(&height); - - let pixel_format = ArgumentMetadata::new("pixel_format"); - pixel_format.set_description("The pixel format."); - let hint = runtime_v1::non_negative_number(); - pixel_format.add_hint(&hint); - metadata.add_argument(&pixel_format); - - let output = TensorMetadata::new("image"); - let hint = supported_shapes( - &[ElementType::U8, ElementType::F32], - DimensionsParam::Fixed(&[0, 0, 0, 0]), - ); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(id: String) -> Result<(), GraphError> { - let ctx = - GraphContext::for_node(&id).ok_or(GraphError::MissingContext)?; - - let width: u32 = ctx.parse_argument("width")?; - let height: u32 = ctx.parse_argument("height")?; - let pixel_format: PixelFormat = ctx.parse_argument("pixel_format")?; - - ctx.add_input_tensor( - "image", - pixel_format.element_type(), - DimensionsParam::Fixed(&[1, 0, 0, 3]), - ); - - ctx.add_output_tensor( - "output", - pixel_format.element_type(), - DimensionsParam::Fixed(&[ - 1, - width, - height, - pixel_format.channels(), - ]), - ); - - Ok(()) - } - - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&id).ok_or_else(|| { - KernelError::Other("Unable to get the kernel context".to_string()) - })?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - // TODO: use the width, height, and pixel format to resize the image for - // now, we're just going to copy it out as-is and hope for the best. - let _width: u32 = ctx.parse_argument("width")?; - let _height: u32 = ctx.parse_argument("height")?; - let _pixel_format: PixelFormat = ctx.parse_argument("pixel_format")?; - - ctx.set_output_tensor( - "output", - TensorParam { - element_type, - dimensions: &dimensions, - buffer: &buffer, - }, - ); - - Ok(()) - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -enum PixelFormat { - RGB8, -} - -impl PixelFormat { - fn channels(self) -> u32 { - match self { - PixelFormat::RGB8 => 3, - } - } - - fn element_type(self) -> ElementType { - match self { - PixelFormat::RGB8 => ElementType::U8, - } - } -} - -impl FromStr for PixelFormat { - type Err = UnknownPixelFormat; - - fn from_str(s: &str) -> Result { - match s { - "rgb" | "rgb8" => Ok(PixelFormat::RGB8), - _ => Err(UnknownPixelFormat), - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] -pub struct UnknownPixelFormat; - -impl Display for UnknownPixelFormat { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - "Unknown pixel format".fmt(f) - } -} - -impl Error for UnknownPixelFormat {} - -impl ContextErrorExt for GraphError { - type InvalidArgument = InvalidArgument; - - fn invalid_argument(inner: InvalidArgument) -> Self { - GraphError::InvalidArgument(inner) - } -} - -impl ContextErrorExt for KernelError { - type InvalidArgument = InvalidArgument; - - fn invalid_argument(inner: InvalidArgument) -> Self { - KernelError::InvalidArgument(inner) - } -} - -impl InvalidArgumentExt for InvalidArgument { - fn other(name: &str, msg: impl std::fmt::Display) -> Self { - InvalidArgument { - name: name.to_string(), - reason: BadArgumentReason::Other(msg.to_string()), - } - } - - fn invalid_value(name: &str, error: impl std::fmt::Display) -> Self { - InvalidArgument { - name: name.to_string(), - reason: BadArgumentReason::InvalidValue(error.to_string()), - } - } - - fn not_found(name: &str) -> Self { - InvalidArgument { - name: name.to_string(), - reason: BadArgumentReason::NotFound, - } - } -} diff --git a/label/Cargo.toml b/label/Cargo.toml index 038a76f5fc3..78ddae8759c 100644 --- a/label/Cargo.toml +++ b/label/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "label" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "Using a wordlist, retrieve the label that corresponds to each element in a tensor." diff --git a/label/src/lib.rs b/label/src/lib.rs index 1c999bfa5c3..0f0f9800c6b 100644 --- a/label/src/lib.rs +++ b/label/src/lib.rs @@ -1,157 +1,80 @@ -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{ - ndarray::ArrayViewD, - runtime_v1::{ - self, ArgumentMetadata, ArgumentType, DimensionsParam, ElementType, - GraphContext, KernelContext, Metadata, TensorMetadata, TensorParam, - TensorResult, - }, - BufferExt, +use hotg_rune_proc_blocks::guest::{ + parse, Argument, ArgumentMetadata, ArgumentType, CreateError, Dimensions, + ElementType, Metadata, ProcBlock, RunError, Tensor, TensorConstraint, + TensorConstraints, TensorMetadata, }; use line_span::LineSpans; -use std::{fmt::Debug, ops::Range}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +use std::{fmt::Debug, ops::Range, str::FromStr}; -struct ProcBlockV1; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Labels, +} -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Label", env!("CARGO_PKG_VERSION")); - metadata.set_description( +fn metadata() -> Metadata { + Metadata::new("Label", env!("CARGO_PKG_VERSION")) + .with_description( "Using a wordlist, retrieve the label that corresponds to each element in a tensor.", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("classify"); - - let labels = ArgumentMetadata::new("wordlist"); - let hint = - runtime_v1::supported_argument_type(ArgumentType::LongString); - labels.add_hint(&hint); - metadata.add_argument(&labels); - - let fallback = ArgumentMetadata::new("fallback"); - fallback.set_default_value(""); - fallback - .set_description("The label to use if an index is out of bounds"); - let hint = runtime_v1::supported_argument_type(ArgumentType::String); - fallback.add_hint(&hint); - metadata.add_argument(&fallback); - - let indices = TensorMetadata::new("indices"); - indices.set_description("Indices for labels in the wordlist."); - let hint = runtime_v1::supported_shapes( - &[ElementType::U32], - DimensionsParam::Dynamic, - ); - indices.add_hint(&hint); - metadata.add_input(&indices); - - let output = TensorMetadata::new("labels"); - output.set_description("The corresponding labels."); - let hint = runtime_v1::supported_shapes( - &[ElementType::Utf8], - DimensionsParam::Dynamic, - ); - output.add_hint(&hint); - metadata.add_output(&output); - - runtime_v1::register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or_else(|| GraphError::MissingContext)?; - - let _ = get_wordlist(|n| ctx.get_argument(n)) - .map_err(GraphError::InvalidArgument)?; + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("classify") + .with_argument(ArgumentMetadata::new("wordlist") + .with_hint(ArgumentType::LongString) + ) + .with_argument(ArgumentMetadata::new("fallback") + .with_hint(ArgumentType::String) + .with_default_value("") + ) + .with_input(TensorMetadata::new("indices").with_description("Indices for labels in the wordlist.")) + .with_output(TensorMetadata::new("labels").with_description("The corresponding labels.")) +} - ctx.add_input_tensor( - "indices", - ElementType::U32, - DimensionsParam::Dynamic, - ); - ctx.add_output_tensor( - "labels", - ElementType::Utf8, - DimensionsParam::Dynamic, - ); +#[derive(Debug, Clone, PartialEq)] +struct Labels { + fallback: String, + wordlist: Lines, +} - Ok(()) +impl ProcBlock for Labels { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "indices", + ElementType::U32, + Dimensions::Dynamic, + )], + outputs: vec![TensorConstraint::new( + "labels", + ElementType::Utf8, + Dimensions::Dynamic, + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or_else(|| KernelError::MissingContext)?; - - let wordlist = get_wordlist(|n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let TensorResult { - buffer, - dimensions, - element_type, - } = ctx.get_input_tensor("indices").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "indices".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let indices = match element_type { - ElementType::U32 => { - buffer.view::(&dimensions).map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "indices".to_string(), - reason: BadInputReason::InvalidValue(e.to_string()), - }) - })? - }, - _ => todo!(), - }; + fn run(&self, inputs: Vec) -> Result, RunError> { + let indices = Tensor::get_named(&inputs, "indices")?.view::()?; - let fallback = ctx.get_argument("fallback").unwrap_or_default(); - let serialized_labels = label(indices, &wordlist, &fallback); + let labels = indices.mapv(|ix| { + self.wordlist + .get(ix as usize) + .unwrap_or(self.fallback.as_str()) + }); - ctx.set_output_tensor( - "labels", - TensorParam { - element_type: ElementType::Utf8, - dimensions: &dimensions, - buffer: &serialized_labels, - }, - ); - - Ok(()) + Ok(vec![Tensor::from_strings("labels", &labels)]) } } -fn label( - indices: ArrayViewD<'_, u32>, - wordlist: &Lines, - fallback: &str, -) -> Vec { - let labels = indices.map(|&index| match wordlist.get(index as usize) { - Some(label) => label, - None => fallback, - }); - - hotg_rune_proc_blocks::string_tensor_from_ndarray(&labels) -} +impl TryFrom> for Labels { + type Error = CreateError; -fn get_wordlist( - get_argument: impl FnOnce(&str) -> Option, -) -> Result { - let wordlist = get_argument("wordlist").ok_or_else(|| InvalidArgument { - name: "wordlist".to_string(), - reason: BadArgumentReason::NotFound, - })?; + fn try_from(args: Vec) -> Result { + let wordlist = parse::required_arg(&args, "wordlist")?; + let fallback = + parse::optional_arg(&args, "fallback")?.unwrap_or_default(); - Ok(Lines::new(wordlist)) + Ok(Labels { wordlist, fallback }) + } } #[derive(Debug, Default, Clone, PartialEq)] @@ -173,6 +96,14 @@ impl Lines { } } +impl FromStr for Lines { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> Result { + Ok(Lines::new(s.to_string())) + } +} + #[cfg(test)] mod tests { use super::*; @@ -183,12 +114,20 @@ mod tests { let wordlist = "zero\none\ntwo\nthree"; let wordlist = Lines::new(wordlist.to_string()); let indices = ndarray::aview1(&[2_u32, 0, 1]); + let proc_block = Labels { + wordlist, + fallback: "...".to_string(), + }; - let serialized = label(indices.into_dyn(), &wordlist, "..."); + let got = proc_block + .run(vec![Tensor::new("indices", &indices)]) + .unwrap(); - let expected = ndarray::arr1(&["two", "zero", "one"]).into_dyn(); - let got = serialized.string_view(&[3]).unwrap(); - assert_eq!(got, expected); + let should_be = vec![Tensor::from_strings( + "labels", + &ndarray::arr1(&["two", "zero", "one"]), + )]; + assert_eq!(got, should_be); } #[test] @@ -196,11 +135,17 @@ mod tests { let wordlist = "zero\none\ntwo\nthree"; let wordlist = Lines::new(wordlist.to_string()); let indices = ndarray::aview1(&[100_u32]); + let proc_block = Labels { + wordlist, + fallback: "...".to_string(), + }; - let serialized = label(indices.into_dyn(), &wordlist, "UNKNOWN"); + let got = proc_block + .run(vec![Tensor::new("indices", &indices)]) + .unwrap(); - let expected = ndarray::arr1(&["UNKNOWN"]).into_dyn(); - let got = serialized.string_view(&[1]).unwrap(); - assert_eq!(got, expected); + let should_be = + vec![Tensor::from_strings("labels", &ndarray::arr1(&["..."]))]; + assert_eq!(got, should_be); } } diff --git a/linear_regression/Cargo.toml b/linear_regression/Cargo.toml index 809414dab4d..2df851d0ca2 100644 --- a/linear_regression/Cargo.toml +++ b/linear_regression/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linear_regression" version = "0.12.1" -edition = "2018" +edition = "2021" description = " a linear approach for modelling the relationship between a scalar response and one or more explanatory variables." # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,11 +10,10 @@ description = " a linear approach for modelling the relationship between a scala hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/logistic_regression/Cargo.toml b/logistic_regression/Cargo.toml index d9f06cb1b6f..69d630bed35 100644 --- a/logistic_regression/Cargo.toml +++ b/logistic_regression/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "logistic_regression" version = "0.12.5" -edition = "2018" +edition = "2021" description = "a statistical model that models the probability of one event taking place by having the log-odds for the event be a linear combination of one or more independent variables." # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -9,7 +9,6 @@ description = "a statistical model that models the probability of one event taki hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] diff --git a/logistic_regression/src/lib.rs b/logistic_regression/src/lib.rs index bcf94cb53bd..a9930747fbb 100644 --- a/logistic_regression/src/lib.rs +++ b/logistic_regression/src/lib.rs @@ -1,252 +1,142 @@ -// use linfa_logistic::LogisticRegression; +use hotg_rune_proc_blocks::{ + guest::{ + Argument, ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{Array1, Array2, ArrayView1, ArrayView2}, +}; use smartcore::{ linalg::naive::dense_matrix::*, linear::logistic_regression::*, }; -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{ndarray, runtime_v1::*, BufferExt, SliceExt}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -/// A proc block which can perform linear regression -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Logistic Regression", env!("CARGO_PKG_VERSION")); - metadata.set_description( - "a linear approach for modelling the relationship between a scalar response and one or more explanatory variables", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("regression"); - metadata.add_tag("linear modeling"); - metadata.add_tag("analytics"); - - let element_type = ArgumentMetadata::new("element_type"); - element_type - .set_description("The type of tensor this proc-block will accept"); - element_type.set_default_value("f64"); - element_type.add_hint(&interpret_as_string_in_enum(&[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ])); - metadata.add_argument(&element_type); - - let x_train = TensorMetadata::new("x_train"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_train.add_hint(&hint); - metadata.add_input(&x_train); - - let y_train = TensorMetadata::new("y_train"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_train.add_hint(&hint); - metadata.add_input(&y_train); - - let x_test = TensorMetadata::new("x_test"); - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_test.add_hint(&hint); - metadata.add_input(&x_test); - - let y_test = TensorMetadata::new("y_test"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_test.add_hint(&hint); - metadata.add_output(&y_test); - - register_node(&metadata); - } +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Logistic, +} + +fn metadata() -> Metadata { + Metadata::new("Logistic Regression", env!("CARGO_PKG_VERSION")) + .with_description( + "a statistical model that models the probability of one event taking place by having the log-odds for the event be a linear combination of one or more independent variables.", + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("classification") + .with_tag("linear modeling") + .with_tag("analytics") + .with_input(TensorMetadata::new("x_train")) + .with_input(TensorMetadata::new("y_train")) + .with_input(TensorMetadata::new("x_test")) + .with_output(TensorMetadata::new("y_test")) +} - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "x_train", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_input_tensor( - "y_train", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_input_tensor( - "x_test", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_output_tensor( - "y_test", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - Ok(()) +struct Logistic; + +impl ProcBlock for Logistic { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new( + "x_train", + ElementTypeConstraint::F64, + vec![0, 0], + ), + TensorConstraint::new( + "y_train", + ElementTypeConstraint::F64, + vec![0], + ), + TensorConstraint::new( + "x_test", + ElementTypeConstraint::F64, + vec![0, 0], + ), + ], + outputs: vec![TensorConstraint::new( + "y_test", + ElementTypeConstraint::F64, + vec![0], + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let x_train = ctx.get_input_tensor("x_train").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "x_train".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - let _xtrain: ndarray::ArrayView2 = x_train - .buffer - .view(&x_train.dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "x_train".to_string(), - reason: BadInputReason::Other(e.to_string()), - }) - })?; - - let y_train = ctx.get_input_tensor("y_train").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_train".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - let _ytrain: ndarray::ArrayView1 = y_train - .buffer - .view(&y_train.dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "y_train".to_string(), - reason: BadInputReason::Other(e.to_string()), - }) - })?; - - let x_test = ctx.get_input_tensor("x_test").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "x_test".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - let _xtest: ndarray::ArrayView2 = x_test - .buffer - .view(&x_test.dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "x_test".to_string(), - reason: BadInputReason::Other(e.to_string()), - }) - })?; - - let output = transform( - &x_train.buffer.elements(), - &x_train.dimensions, - &y_train.buffer.elements(), - &x_test.buffer.elements(), - &x_test.dimensions, - )?; - - let y_test_dimension = [x_test.dimensions[0]]; - - ctx.set_output_tensor( - "y_test", - TensorParam { - element_type: ElementType::F64, - dimensions: &y_test_dimension, - buffer: &output.to_vec().as_bytes(), - }, - ); - - Ok(()) + fn run(&self, inputs: Vec) -> Result, RunError> { + let x_train = Tensor::get_named(&inputs, "x_train")?.view_2d()?; + let y_train = Tensor::get_named(&inputs, "y_train")?.view_1d()?; + let x_test = Tensor::get_named(&inputs, "x_test")?.view_2d()?; + + let output = transform(x_train, y_train, x_test)?; + + Ok(vec![Tensor::new("y_test", &output)]) } } +impl From> for Logistic { + fn from(_: Vec) -> Self { Logistic } +} + fn transform( - x_train: &[f64], - x_train_dim: &[u32], - y_train: &[f64], - x_test: &[f64], - x_test_dim: &[u32], -) -> Result, KernelError> { - // Iris data - let x_train = DenseMatrix::from_array( - x_train_dim[0] as usize, - x_train_dim[1] as usize, - x_train, - ); - - let lr = LogisticRegression::fit( - &x_train, - &y_train.to_vec(), - Default::default(), - ) - .map_err(|e| KernelError::Other(e.to_string()))?; - - let x_test = DenseMatrix::from_array( - x_test_dim[0] as usize, - x_test_dim[1] as usize, - x_test, - ); - - lr.predict(&x_test) - .map_err(|e| KernelError::Other(e.to_string())) + x_train: ArrayView2<'_, f64>, + y_train: ArrayView1<'_, f64>, + x_test: ArrayView2<'_, f64>, +) -> Result, RunError> { + let (rows, columns) = x_train.dim(); + let x_train: Vec = x_train.t().iter().copied().collect(); + let x_train = DenseMatrix::new(rows, columns, x_train); + + let y_train: Vec<_> = y_train.to_vec(); + + let model = LogisticRegression::fit(&x_train, &y_train, Default::default()) + .map_err(RunError::other)?; + + let (rows, columns) = x_test.dim(); + let x_test: Vec = x_test.t().iter().copied().collect(); + let x_test = DenseMatrix::new(rows, columns, x_test); + + model + .predict(&x_test) + .map(Array1::from_vec) + .map_err(RunError::other) } #[cfg(test)] mod tests { + use hotg_rune_proc_blocks::ndarray::array; + use super::*; #[test] fn check_model() { - let x_train = vec![ - 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, - 3.1, 1.5, 0.2, 5.0, 3.6, 1.4, 0.2, 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, - 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, - 0.1, 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, - 5.5, 2.3, 4.0, 1.3, 6.5, 2.8, 4.6, 1.5, 5.7, 2.8, 4.5, 1.3, 6.3, - 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, - 3.9, 1.4, + let x_train: Array2 = array![ + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [4.7, 3.2, 1.3, 0.2], + [4.6, 3.1, 1.5, 0.2], + [5.0, 3.6, 1.4, 0.2], + [5.4, 3.9, 1.7, 0.4], + [4.6, 3.4, 1.4, 0.3], + [5.0, 3.4, 1.5, 0.2], + [4.4, 2.9, 1.4, 0.2], + [4.9, 3.1, 1.5, 0.1], + [7.0, 3.2, 4.7, 1.4], + [6.4, 3.2, 4.5, 1.5], + [6.9, 3.1, 4.9, 1.5], + [5.5, 2.3, 4.0, 1.3], + [6.5, 2.8, 4.6, 1.5], + [5.7, 2.8, 4.5, 1.3], + [6.3, 3.3, 4.7, 1.6], + [4.9, 2.4, 3.3, 1.0], + [6.6, 2.9, 4.6, 1.3], + [5.2, 2.7, 3.9, 1.4], ]; - let y_train: Vec = vec![ + let y_train: Array1 = array![ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; - let dim: Vec = vec![20, 4]; - - let y_pred = transform(&x_train, &dim, &y_train, &x_train, &dim); + let y_pred = + transform(x_train.view(), y_train.view(), x_train.view()).unwrap(); - assert_eq!(y_pred.unwrap(), y_train); + assert_eq!(y_pred, y_train); } } diff --git a/metric/Cargo.toml b/metric/Cargo.toml index 0ce05af1d2a..608f444eaa6 100644 --- a/metric/Cargo.toml +++ b/metric/Cargo.toml @@ -10,11 +10,10 @@ description = "for assessing model evaluation and performace of model" hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/modulo/Cargo.toml b/modulo/Cargo.toml index ff8bd91727e..5050cacb7d6 100644 --- a/modulo/Cargo.toml +++ b/modulo/Cargo.toml @@ -2,7 +2,7 @@ name = "modulo" version = "0.12.0" authors = ["The Rune Developers "] -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "Apply the modulus operator to each element in a tensor." diff --git a/modulo/src/lib.rs b/modulo/src/lib.rs index f7445f72b9a..2af00dd4552 100644 --- a/modulo/src/lib.rs +++ b/modulo/src/lib.rs @@ -1,225 +1,155 @@ #![allow(dead_code)] -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -use std::fmt::Display; - -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; use hotg_rune_proc_blocks::{ - runtime_v1::{ - self, ArgumentMetadata, DimensionsParam, ElementType, GraphContext, - KernelContext, Metadata, TensorMetadata, TensorParam, TensorResult, + guest::{ + Argument, ArgumentHint, ArgumentMetadata, ArgumentType, CreateError, + Dimensions, InvalidInput, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, }, - BufferExt, + ndarray::ArrayViewMutD, }; use num_traits::{FromPrimitive, ToPrimitive}; -pub struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - - let modulo = ArgumentMetadata::new("modulo"); - modulo.add_hint(&runtime_v1::non_negative_number()); - metadata.add_argument(&modulo); - - let element_type = ArgumentMetadata::new("element_type"); - element_type - .set_description("The type of tensor this proc-block will accept"); - element_type.set_default_value("f64"); - element_type.add_hint(&runtime_v1::interpret_as_string_in_enum(&[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ])); - metadata.add_argument(&element_type); - - let input = TensorMetadata::new("input"); - metadata.add_input(&input); +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Modulo, +} - let output = TensorMetadata::new("output"); - metadata.add_output(&output); +fn metadata() -> Metadata { + Metadata::new("Modulo", env!("CARGO_PKG_VERSION")) + .with_description(env!("CARGO_PKG_DESCRIPTION")) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_argument( + ArgumentMetadata::new("modulus") + .with_description("The modulus operand") + .with_hint(ArgumentHint::ArgumentType(ArgumentType::Float)), + ) + .with_input(TensorMetadata::new("input")) + .with_output(TensorMetadata::new("output")) +} - runtime_v1::register_node(&metadata); - } +#[derive(Debug, Clone, PartialEq)] +pub struct Modulo { + modulus: f64, +} - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id).ok_or_else(|| { - GraphError::Other("Unable to load the graph context".to_string()) - })?; - - // make sure the modulus is valid - let _ = get_modulus(|n| ctx.get_argument(n)) - .map_err(GraphError::InvalidArgument)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("u8") => ElementType::U8, - Some("i8") => ElementType::I8, - Some("u16") => ElementType::U16, - Some("i16") => ElementType::I16, - Some("u32") => ElementType::U32, - Some("i32") => ElementType::I32, - Some("f32") => ElementType::F32, - Some("u64") => ElementType::U64, - Some("i64") => ElementType::I64, - Some("f64") | None => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })) - }, - }; - - ctx.add_input_tensor("input", element_type, DimensionsParam::Dynamic); - ctx.add_output_tensor("output", element_type, DimensionsParam::Dynamic); - - Ok(()) +impl ProcBlock for Modulo { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::numeric( + "input", + Dimensions::Dynamic, + )], + outputs: vec![TensorConstraint::numeric( + "output", + Dimensions::Dynamic, + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id).ok_or_else(|| { - KernelError::Other("Unable to load the kernel context".to_string()) - })?; - - let modulus = get_modulus(|n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let TensorResult { - dimensions, - element_type, - mut buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - // Note: The "element_type" argument is only used while constructing the - // ML pipeline. We see its effect at runtime in the form of the tensor - // data variant that gets used. - - match element_type { - ElementType::U8 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::I8 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::U16 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::I16 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::U32 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::I32 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::F32 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::U64 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::I64 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::F64 => { - modulus_in_place(buffer.elements_mut::(), modulus)? - }, - ElementType::Utf8 => { - return Err(KernelError::Other( - "String tensors aren't supported".to_string(), - )) - }, + fn run(&self, inputs: Vec) -> Result, RunError> { + let mut tensor = inputs + .into_iter() + .find(|tensor| tensor.name == "input") + .ok_or_else(|| InvalidInput::not_found("input"))?; + + if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else if let Ok(tensor) = tensor.view_mut::() { + modulo_in_place(tensor, self.modulus); + } else { + return Err( + InvalidInput::incompatible_element_type(tensor.name).into() + ); } - ctx.set_output_tensor( - "output", - TensorParam { - element_type, - dimensions: &dimensions, - buffer: &buffer, - }, - ); - - Ok(()) + Ok(vec![tensor.with_name("output")]) } } -fn modulus_in_place( - values: &mut [T], - modulus: f64, -) -> Result<(), KernelError> -where - T: ToPrimitive + FromPrimitive + Copy + Display, -{ - for value in values { - let as_float = value.to_f64().ok_or_else(|| error(*value))?; - let after_modulus = as_float % modulus; - *value = T::from_f64(after_modulus).ok_or_else(|| error(*value))?; - } +impl TryFrom> for Modulo { + type Error = CreateError; - Ok(()) -} + fn try_from(args: Vec) -> Result { + let modulus = hotg_rune_proc_blocks::guest::parse::required_arg( + &args, "modulus", + )?; -fn error(value: impl Display) -> KernelError { - KernelError::Other(format!( - "Unable to convert `{}` to/from a double", - value - )) + Ok(Modulo { modulus }) + } } -fn get_modulus( - get_argument: impl FnOnce(&str) -> Option, -) -> Result { - let value = match get_argument("modulus") { - Some(s) => s, - None => { - return Err(InvalidArgument { - name: "modulus".to_string(), - reason: BadArgumentReason::NotFound, - }) - }, - }; - - let value = value.parse::().map_err(|e| InvalidArgument { - name: "modulus".to_string(), - reason: BadArgumentReason::InvalidValue(e.to_string()), - })?; - - if value > 0.0 { - Ok(value) - } else { - Err(InvalidArgument { - name: "modulus".to_string(), - reason: BadArgumentReason::InvalidValue( - "The modulus must be a positive, non-zero number".to_string(), - ), - }) +fn modulo_in_place(mut array: ArrayViewMutD<'_, T>, modulus: f64) +where + T: ToPrimitive + FromPrimitive + Copy, +{ + for item in array.iter_mut() { + let result = item + .to_f64() + .map(|n| n % modulus) + .and_then(|n| T::from_f64(n)); + + if let Some(updated) = result { + *item = updated; + } } } #[cfg(test)] mod tests { use super::*; + use hotg_rune_proc_blocks::ndarray; + + fn args(arguments: &[(&str, &str)]) -> Vec { + arguments + .iter() + .map(|(n, v)| Argument { + name: n.to_string(), + value: v.to_string(), + }) + .collect() + } + + #[test] + fn create_modulo_with_good_modulus() { + let args = args(&[("modulus", "42.0")]); + + let proc_block = Modulo::try_from(args).unwrap(); + + assert_eq!(proc_block, Modulo { modulus: 42.0 }); + } #[test] fn apply_modulus() { - let mut values = [0.0_f64, 1.0, 2.0, 3.0, 4.0, 5.0]; + let inputs = vec![Tensor::new( + "input", + &ndarray::array![0.0_f64, 1.0, 2.0, 3.0, 4.0, 5.0], + )]; + let expected = vec![Tensor::new( + "output", + &ndarray::array![0.0_f64, 1.0, 0.0, 1.0, 0.0, 1.0], + )]; + let modulo = Modulo { modulus: 2.0 }; - modulus_in_place(&mut values, 2.0).unwrap(); + let outputs = modulo.run(inputs).unwrap(); - assert_eq!(values, [0.0_f64, 1.0, 0.0, 1.0, 0.0, 1.0]); + assert_eq!(outputs, expected); } } diff --git a/most_confident_indices/Cargo.toml b/most_confident_indices/Cargo.toml index 010d4295605..b81fecbd780 100644 --- a/most_confident_indices/Cargo.toml +++ b/most_confident_indices/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "most_confident_indices" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "Given some confidence values, create a tensor containing the indices of the top N highest confidences." diff --git a/noise-filtering/Cargo.toml b/noise-filtering/Cargo.toml index e45ca65fef2..2cc82d29937 100644 --- a/noise-filtering/Cargo.toml +++ b/noise-filtering/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "noise-filtering" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "reduces noise and applies a gain control algorithm within each frequency bin." diff --git a/noise-filtering/src/gain_control.rs b/noise-filtering/src/gain_control.rs index fb9db4fea56..ac812b16d83 100644 --- a/noise-filtering/src/gain_control.rs +++ b/noise-filtering/src/gain_control.rs @@ -2,8 +2,6 @@ //! //! [tf]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c -use alloc::vec::Vec; - pub const WIDE_DYNAMIC_FUNCTION_BITS: usize = 32; pub const WIDE_DYNAMIC_FUNCTION_LUT_SIZE: usize = 4 * WIDE_DYNAMIC_FUNCTION_BITS - 3; diff --git a/noise-filtering/src/lib.rs b/noise-filtering/src/lib.rs index 3ef1c4b212a..1524d5fa31e 100644 --- a/noise-filtering/src/lib.rs +++ b/noise-filtering/src/lib.rs @@ -1,321 +1,199 @@ -use std::{convert::TryInto, f64, fmt::Display, str::FromStr}; +mod gain_control; +mod noise_reduction; -pub use crate::noise_reduction::ScaledU16; +use std::sync::Mutex; -use crate::proc_block_v1::*; use hotg_rune_proc_blocks::{ - runtime_v1::{self, *}, - BufferExt, SliceExt, + guest::{ + parse, Argument, ArgumentMetadata, ArgumentType, CreateError, + ElementType, Metadata, ProcBlock, RunError, Tensor, TensorConstraint, + TensorConstraints, TensorMetadata, + }, + ndarray::Array1, }; -#[macro_use] -extern crate alloc; - -mod gain_control; -mod noise_reduction; - +pub use crate::noise_reduction::ScaledU16; use crate::{gain_control::GainControl, noise_reduction::NoiseReduction}; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -// It reduces noise and applies a gain control algorithm within each frequency -// bin. -struct ProcBlockV1; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: NoiseFiltering, +} -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Noise Filtering", env!("CARGO_PKG_VERSION")); - metadata.set_description( +fn metadata() -> Metadata { + Metadata::new("Noise Filtering", env!("CARGO_PKG_VERSION")) + .with_description( "Reduce the amount of high frequency noise in an audio clip and increase its gain.", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("audio"); - - let strength = ArgumentMetadata::new("strength"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - strength.add_hint(&hint); - strength.set_default_value("0.95"); - metadata.add_argument(&strength); - - let offset = ArgumentMetadata::new("offset"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - offset.add_hint(&hint); - offset.set_default_value("80"); - metadata.add_argument(&offset); - - let gain_bits = ArgumentMetadata::new("gain_bits"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Integer); - gain_bits.add_hint(&hint); - gain_bits.set_default_value("21"); - metadata.add_argument(&gain_bits); - - let smoothing_bits = ArgumentMetadata::new("smoothing_bits"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Integer); - smoothing_bits.add_hint(&hint); - smoothing_bits.set_default_value("10"); - metadata.add_argument(&smoothing_bits); - - let even_smoothing = ArgumentMetadata::new("even_smoothing"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - even_smoothing.add_hint(&hint); - even_smoothing.set_default_value("0.025"); - metadata.add_argument(&even_smoothing); - - let odd_smoothing = ArgumentMetadata::new("odd_smoothing"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - odd_smoothing.add_hint(&hint); - odd_smoothing.set_default_value("0.06"); - metadata.add_argument(&odd_smoothing); - - let min_signal_remaining = - ArgumentMetadata::new("min_signal_remaining"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - min_signal_remaining.add_hint(&hint); - min_signal_remaining.set_default_value("0.05"); - metadata.add_argument(&min_signal_remaining); + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("audio") + .with_argument(ArgumentMetadata::new("strength") + .with_default_value("0.95") + .with_hint(ArgumentType::Float)) + .with_argument(ArgumentMetadata::new("offset") + .with_default_value("80") + .with_hint(ArgumentType::Float)) + .with_argument(ArgumentMetadata::new("gain_bits") + .with_default_value("21") + .with_hint(ArgumentType::Integer)) + .with_argument(ArgumentMetadata::new("smoothing_bits") + .with_default_value("10") + .with_hint(ArgumentType::Integer)) + .with_argument(ArgumentMetadata::new("even_smoothing") + .with_default_value("0.025") + .with_hint(ArgumentType::Float)) + .with_argument(ArgumentMetadata::new("odd_smoothing") + .with_default_value("0.06") + .with_hint(ArgumentType::Float)) + .with_argument(ArgumentMetadata::new("min_signal_remaining") + .with_default_value("0.05") + .with_hint(ArgumentType::Float)) + .with_input(TensorMetadata::new("audio") + .with_description("An audio clip")) + .with_output(TensorMetadata::new("filtered")) +} - let input = TensorMetadata::new("audio"); - input.set_description("An audio clip"); - let hint = supported_shapes( - &[ElementType::U32], - DimensionsParam::Fixed(&[1, 0]), - ); - input.add_hint(&hint); - metadata.add_input(&input); +#[derive(Debug)] +pub struct NoiseFiltering { + // gain control options + strength: f32, + offset: f32, + gain_bits: i32, - let output = TensorMetadata::new("filtered"); - let hint = supported_shapes( - &[ElementType::I8], - DimensionsParam::Fixed(&[1, 0]), - ); - output.add_hint(&hint); - metadata.add_output(&output); + // noise filtering options + smoothing_bits: u32, + even_smoothing: ScaledU16, + odd_smoothing: ScaledU16, + min_signal_remaining: ScaledU16, + state: Mutex<(gain_control::State, noise_reduction::State)>, +} - register_node(&metadata); +impl ProcBlock for NoiseFiltering { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "audio", + ElementType::U32, + [1, 0], + )], + outputs: vec![TensorConstraint::new( + "filtered", + ElementType::I8, + [1, 0], + )], + } } - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; + fn run(&self, mut inputs: Vec) -> Result, RunError> { + let mut audio = Tensor::take_named(&mut inputs, "audio")?; + let mut tensor = audio.view_1d_mut()?; + let samples = + tensor.as_slice_mut().expect("Should always be contiguous"); - ctx.add_input_tensor( - "audio", - ElementType::F32, - DimensionsParam::Fixed(&[1, 0]), - ); - ctx.add_output_tensor( - "filtered", - ElementType::F32, - DimensionsParam::Fixed(&[1, 0]), - ); + let filtered = self.transform(samples); - Ok(()) + Ok(vec![Tensor::new("filtered", &filtered)]) } +} - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let strength: f32 = get_args("strength", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let offset: f32 = get_args("offset", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let gain_bits: u32 = get_args("gain_bits", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; +impl TryFrom> for NoiseFiltering { + type Error = CreateError; + fn try_from(args: Vec) -> Result { + let strength: f32 = + parse::optional_arg(&args, "strength")?.unwrap_or(0.95); + let offset: f32 = parse::optional_arg(&args, "offset")?.unwrap_or(80.0); + let gain_bits: u32 = + parse::optional_arg(&args, "gain_bits")?.unwrap_or(21); let smoothing_bits: u32 = - get_args("smoothing_bits", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let even_smoothing: ScaledU16 = - get_args("even_smoothing", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let odd_smoothing: ScaledU16 = - get_args("odd_smoothing", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let min_signal_remaining: ScaledU16 = - get_args("min_signal_remaining", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; + parse::optional_arg(&args, "smoothing_bits")?.unwrap_or(10); + let even_smoothing = parse::optional_arg(&args, "even_smoothing")? + .unwrap_or(ScaledU16::from(0.025)); + let odd_smoothing = parse::optional_arg(&args, "odd_smoothing")? + .unwrap_or(ScaledU16::from(0.06)); + let min_signal_remaining = + parse::optional_arg(&args, "min_signal_remaining")? + .unwrap_or(ScaledU16::from(0.05)); let config: GainControl = GainControl { strength, offset, gain_bits: gain_bits.try_into().unwrap(), }; + let state = Mutex::new(( + gain_control::State::new(config, smoothing_bits as u16), + noise_reduction::State::default(), + )); - // todo Need to call estimate from the noise_reduction::State - - // let noise_reduction: NoiseReduction = NoiseReduction { - // smoothing_bits, - // even_smoothing, - // odd_smoothing, - // min_signal_remaining, - // }; - - let noise_filtering: NoiseFiltering = NoiseFiltering { + Ok(NoiseFiltering { strength, offset, gain_bits: gain_bits.try_into().unwrap(), - gain_control: gain_control::State::new( - config, - smoothing_bits as u16, - ), smoothing_bits, even_smoothing, odd_smoothing, min_signal_remaining, - noise_reduction: noise_reduction::State::default(), /* Todo need to change this to noise_reduction::State {estimate} */ - }; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("audio").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "bounding_boxes".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let mut buffer = buffer.clone(); - - let output = match element_type { - ElementType::F32 =>{ - buffer.view::(&dimensions) - .map_err(|e| KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::Other(e.to_string()), - }))?; - transform(noise_filtering, buffer.elements_mut()) - } - other => { - return Err(KernelError::Other(format!( - "The Noise Filtering proc-block doesn't support {:?} element type", - other, - ))) - }, - }; - - ctx.set_output_tensor( - "filtered", - TensorParam { - element_type: ElementType::F32, - dimensions: &dimensions, - buffer: &output.as_bytes(), - }, - ); - - Ok(()) - } -} - -fn get_args( - name: &str, - get_argument: impl FnOnce(&str) -> Option, -) -> Result -where - T: FromStr, - ::Err: Display, -{ - get_argument(name) - .ok_or_else(|| InvalidArgument::not_found(name))? - .parse::() - .map_err(|e| InvalidArgument::invalid_value(name, e)) -} - -impl InvalidArgument { - fn not_found(name: impl Into) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::NotFound, - } - } - - fn invalid_value(name: impl Into, reason: impl Display) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::InvalidValue(reason.to_string()), - } + state, + }) } } -pub struct NoiseFiltering { - // gain control options - strength: f32, - offset: f32, - gain_bits: i32, - gain_control: gain_control::State, - - // noise filtering options - smoothing_bits: u32, - even_smoothing: ScaledU16, - odd_smoothing: ScaledU16, - min_signal_remaining: ScaledU16, - noise_reduction: noise_reduction::State, -} - -fn transform( - mut noise_filtering: NoiseFiltering, - mut input: &mut [u32], -) -> Vec { - let NoiseFiltering { - strength, - offset, - gain_bits, - ref mut gain_control, - smoothing_bits, - even_smoothing, - odd_smoothing, - min_signal_remaining, - ref mut noise_reduction, - } = noise_filtering; - - let n = NoiseReduction { - smoothing_bits, - even_smoothing, - odd_smoothing, - min_signal_remaining, - }; - let cleaned = n.transform(&mut input, noise_reduction); +impl NoiseFiltering { + fn transform(&self, mut input: &mut [u32]) -> Array1 { + let NoiseFiltering { + strength, + offset, + gain_bits, + ref state, + smoothing_bits, + even_smoothing, + odd_smoothing, + min_signal_remaining, + } = *self; + let mut state = state.lock().unwrap(); + let (gain_control, noise_reduction) = &mut *state; - let g = GainControl { - gain_bits, - offset, - strength, - }; + let n = NoiseReduction { + smoothing_bits, + even_smoothing, + odd_smoothing, + min_signal_remaining, + }; + let cleaned = n.transform(&mut input, noise_reduction); - // let cleaned = cleaned.to_vec(); + let g = GainControl { + gain_bits, + offset, + strength, + }; - g.transform( - cleaned, - &noise_reduction.estimate, - smoothing_bits as u16, - gain_control, - ); - let amplified: Vec = input - .iter() - .map(|energy| libm::log2((*energy as f64) + 1.0)) - .collect(); + // let cleaned = cleaned.to_vec(); - let (min_value, max_value) = amplified.iter().copied().fold( - (f64::INFINITY, f64::NEG_INFINITY), - |(lower, upper), current| (lower.min(current), upper.max(current)), - ); + g.transform( + cleaned, + &noise_reduction.estimate, + smoothing_bits as u16, + gain_control, + ); + let amplified: Vec = input + .iter() + .map(|energy| libm::log2((*energy as f64) + 1.0)) + .collect(); + + let (min_value, max_value) = amplified.iter().copied().fold( + (f64::INFINITY, f64::NEG_INFINITY), + |(lower, upper), current| (lower.min(current), upper.max(current)), + ); - amplified - .iter() - .map(|energy| { - ((255.0 * (energy - min_value) / (max_value - min_value)) - 128.0) - as i8 - }) - .collect() + amplified + .iter() + .map(|energy| { + ((255.0 * (energy - min_value) / (max_value - min_value)) + - 128.0) as i8 + }) + .collect() + } } impl Default for NoiseFiltering { @@ -333,19 +211,20 @@ impl Default for NoiseFiltering { gain_bits, } = config; + let state = Mutex::new(( + gain_control::State::new(config, smoothing_bits as u16), + noise_reduction::State::default(), + )); + NoiseFiltering { strength, offset, gain_bits, - gain_control: gain_control::State::new( - config, - smoothing_bits as u16, - ), smoothing_bits, even_smoothing, odd_smoothing, min_signal_remaining, - noise_reduction: noise_reduction::State::default(), + state, } } } @@ -359,7 +238,7 @@ mod tests { #[test] fn smoke_test() { let pb = NoiseFiltering::default(); - let mut microspeech_fft = vec![ + let microspeech_fft = vec![ 9, 130, 180, 93, 61, 42, 43, 47, 75, 81, 73, 29, 10, 16, 11, 13, 18, 11, 5, 9, 7, 8, 4, 6, 10, 11, 13, 10, 11, 14, 8, 10, 13, 10, 9, 12, 9, 9, 9, 1, 33, 100, 133, 123, 38, 52, 30, 21, 21, 35, 37, 19, @@ -649,9 +528,11 @@ mod tests { -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, ]; + let inputs = vec![Tensor::new_1d::("audio", µspeech_fft)]; + let should_be = vec![Tensor::new_1d("filtered", &expected)]; - let output = transform(pb, &mut microspeech_fft); + let output = pb.run(inputs).unwrap(); - assert_eq!(output, expected); + assert_eq!(output, should_be); } } diff --git a/noise-filtering/src/noise_reduction.rs b/noise-filtering/src/noise_reduction.rs index fb1200ca1d0..9bbead963d3 100644 --- a/noise-filtering/src/noise_reduction.rs +++ b/noise-filtering/src/noise_reduction.rs @@ -2,12 +2,11 @@ //! //! [tf]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c -use alloc::vec::Vec; -use core::str::FromStr; +use std::str::FromStr; const NOISE_REDUCTION_BITS: usize = 14; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq)] pub(crate) struct State { pub estimate: Vec, } diff --git a/normalize/Cargo.toml b/normalize/Cargo.toml index 97a311b17ab..7333feacbf7 100644 --- a/normalize/Cargo.toml +++ b/normalize/Cargo.toml @@ -2,7 +2,7 @@ name = "normalize" version = "0.12.0" authors = ["The Rune Developers "] -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "normalizes the input to the range [0, 1]" diff --git a/normalize/src/lib.rs b/normalize/src/lib.rs index 08afee21625..7de1e62f867 100644 --- a/normalize/src/lib.rs +++ b/normalize/src/lib.rs @@ -1,173 +1,111 @@ -use crate::proc_block_v1::*; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; +use hotg_rune_proc_blocks::{ + guest::{ + Argument, Dimensions, InvalidInput, Metadata, ProcBlock, RunError, + Tensor, TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{ArrayD, ArrayViewD}, +}; use num_traits::ToPrimitive; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -/// Normalize the input to the range `[0, 1]`. -struct ProcBlockV1; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Normalize, +} -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Normalize", env!("CARGO_PKG_VERSION")); - metadata.set_description( +fn metadata() -> Metadata { + Metadata::new("Normalize", env!("CARGO_PKG_VERSION")) + .with_description( "Normalize a tensor's elements to the range, `[0, 1]`.", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("normalize"); - - let input = TensorMetadata::new("input"); - let supported_types = [ - ElementType::U8, - ElementType::I8, - ElementType::U16, - ElementType::I16, - ElementType::U32, - ElementType::I32, - ElementType::F32, - ElementType::U64, - ElementType::I64, - ElementType::F64, - ]; - let hint = supported_shapes(&supported_types, DimensionsParam::Dynamic); - input.add_hint(&hint); - metadata.add_input(&input); - - let output = TensorMetadata::new("normalized"); - output.set_description("normalized tensor in the range [0, 1]"); - let hint = - supported_shapes(&[ElementType::F32], DimensionsParam::Dynamic); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(id: String) -> Result<(), GraphError> { - let ctx = - GraphContext::for_node(&id).ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("u8") => ElementType::U8, - Some("i8") => ElementType::I8, - Some("u16") => ElementType::U16, - Some("i16") => ElementType::I16, - Some("u32") => ElementType::U32, - Some("i32") => ElementType::I32, - Some("f32") => ElementType::F32, - Some("u64") => ElementType::U64, - Some("i64") => ElementType::I64, - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor("input", element_type, DimensionsParam::Dynamic); - ctx.add_output_tensor( - "normalized", - ElementType::F32, - DimensionsParam::Dynamic, - ); + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("normalize") + .with_input(TensorMetadata::new("input")) + .with_output( + TensorMetadata::new("normalized") + .with_description("normalized tensor in the range [0, 1]"), + ) +} - Ok(()) +/// Normalize the input to the range `[0, 1]`. +struct Normalize; + +impl ProcBlock for Normalize { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::numeric( + "input", + Dimensions::Dynamic, + )], + outputs: vec![TensorConstraint::numeric( + "normalized", + Dimensions::Dynamic, + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let output = match element_type { - ElementType::U8 => transform(buffer.elements::()), - ElementType::I8 => transform(buffer.elements::()), - ElementType::U16 => transform(buffer.elements::()), - ElementType::I16 => transform(buffer.elements::()), - ElementType::U32 => transform(buffer.elements::()), - ElementType::I32 => transform(buffer.elements::()), - ElementType::F32 => transform(buffer.elements::()), - ElementType::U64 => transform(buffer.elements::()), - ElementType::I64 => transform(buffer.elements::()), - ElementType::F64 => transform(buffer.elements::()), - other => { - return Err(KernelError::Other(format!( - "The Normalize proc-block doesn't support {:?} element type", - other, - ))) - }, - }; - - let output = match output { - Some(out) => out, - None => { - return Err(KernelError::Other( - "The input tensor was empty".to_string(), - )) - }, + fn run(&self, inputs: Vec) -> Result, RunError> { + let tensor = Tensor::get_named(&inputs, "input")?; + + let normalized = if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else if let Ok(tensor) = tensor.view::() { + normalize(tensor) + } else { + return Err( + InvalidInput::incompatible_element_type(&tensor.name).into() + ); }; - ctx.set_output_tensor( - "normalized", - TensorParam { - element_type: ElementType::U32, - dimensions: &dimensions, - buffer: &output.as_bytes(), - }, - ); - - Ok(()) + Ok(vec![Tensor::new("normalized", &normalized)]) } } -fn transform(input: &[T]) -> Option> +impl From> for Normalize { + fn from(_: Vec) -> Self { Normalize } +} + +fn normalize(input: ArrayViewD<'_, T>) -> ArrayD where T: ToPrimitive, { + if input.is_empty() { + return ArrayD::zeros(input.shape()); + } + let (min, max) = - min_max(input.iter().map(|e| e.to_f32().unwrap())).unwrap(); + input.fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), elem| { + match elem.to_f32() { + Some(elem) => (min.min(elem), max.max(elem)), + None => (min, max), + } + }); + let range = max - min; + if range == 0.0 { - return Some(vec![0.0; input.len()]); + return ArrayD::zeros(input.shape()); } - let mut v: Vec = Vec::new(); - for e in input { - let e = e.to_f32().unwrap(); - v.push((e - min) / range) - } - return Some(v); -} + let mean = (max + min) / 2.0; -fn min_max(items: impl Iterator) -> Option<(f32, f32)> { - items.into_iter().fold(None, |bounds, item| match bounds { - Some((min, max)) => { - let min = if item < min { item } else { min }; - let max = if max < item { item } else { max }; - Some((min, max)) - }, - None => Some((item, item)), + input.map(|v| match v.to_f32() { + Some(elem) => (elem - min) / range, + None => mean, }) } @@ -177,29 +115,31 @@ mod tests { #[test] fn it_works() { - let input = [0.0, 1.0, 2.0]; + let inputs = vec![Tensor::new_1d("input", &[0.0_f64, 1.0, 2.0])]; - let output = transform(&input).unwrap(); + let output = Normalize.run(inputs).unwrap(); - assert_eq!(output, vec![0.0, 0.5, 1.0]); + assert_eq!( + output, + vec![Tensor::new_1d("normalized", &[0.0_f32, 0.5, 1.0])] + ); } #[test] - fn it_works_with_integers() { - let input = [0, 1, 2]; + fn handle_all_zeroes() { + let inputs = vec![Tensor::new_1d("input", &[0_i32; 64])]; - let output = transform(&input).unwrap(); + let output = Normalize.run(inputs).unwrap(); - assert_eq!(output, vec![0.0, 0.5, 1.0]); + assert_eq!(output, vec![Tensor::new_1d("normalized", &[0_f32; 64])]); } #[test] - fn handle_empty() { - let input = [0.0; 384]; + fn empty_input() { + let inputs = vec![Tensor::new_1d::("input", &[])]; - let output = transform(&input.clone()).unwrap(); + let output = Normalize.run(inputs).unwrap(); - assert_eq!(output, input); - assert_eq!(output.len(), 384); + assert_eq!(output, vec![Tensor::new_1d::("normalized", &[])]); } } diff --git a/object_filter/Cargo.toml b/object_filter/Cargo.toml index 019463f169d..67d2dcada20 100644 --- a/object_filter/Cargo.toml +++ b/object_filter/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "object_filter" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "A proc-block which filter the object detected by an Object Detection model to: 1. remove duplicate detection for a single object 2. remove the objects with low confidence" diff --git a/parse/Cargo.toml b/parse/Cargo.toml deleted file mode 100644 index d78acf45ce8..00000000000 --- a/parse/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "parse" -version = "0.12.0" -edition = "2018" -publish = false -repository = "https://github.com/hotg-ai/proc-blocks" -description = "Parse a string tensor into a tensor of numeric values." - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -hotg-rune-proc-blocks = { path = "../support" } -wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } - -[lib] -crate-type = ["cdylib", "rlib"] - -[package.metadata.wapm] -namespace = "hotg-ai" -abi = "none" diff --git a/parse/src/lib.rs b/parse/src/lib.rs deleted file mode 100644 index a3cfa1016d9..00000000000 --- a/parse/src/lib.rs +++ /dev/null @@ -1,293 +0,0 @@ -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{ - common, - runtime_v1::{self, *}, - BufferExt, SliceExt, -}; -use std::{fmt::Display, str::FromStr}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -/// A proc block which can parse a string to numbers. -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Parse", env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("string"); - metadata.add_tag("numbers"); - - let input = TensorMetadata::new("input_string_of_numbers"); - let hint = - supported_shapes(&[ElementType::Utf8], DimensionsParam::Dynamic); - input.add_hint(&hint); - metadata.add_input(&input); - - let element_type = ArgumentMetadata::new(common::element_type::NAME); - element_type.set_description("The type that values get parsed into"); - element_type.add_hint(&runtime_v1::interpret_as_string_in_enum( - common::element_type::NUMERIC, - )); - metadata.add_argument(&element_type); - - let output = TensorMetadata::new("parsed_numbers"); - output.set_description("The parsed values"); - let supported_types = [ - ElementType::U8, - ElementType::I8, - ElementType::U16, - ElementType::I16, - ElementType::U32, - ElementType::I32, - ElementType::F32, - ElementType::U64, - ElementType::I64, - ElementType::F64, - ]; - let hint = supported_shapes(&supported_types, DimensionsParam::Dynamic); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id).ok_or_else(|| { - GraphError::Other("Unable to get the graph context".to_string()) - })?; - - ctx.add_input_tensor( - "input_string_of_numbers", - ElementType::Utf8, - DimensionsParam::Dynamic, - ); - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("u8") => ElementType::U8, - Some("i8") => ElementType::I8, - Some("u16") => ElementType::U16, - Some("i16") => ElementType::I16, - Some("u32") => ElementType::U32, - Some("i32") => ElementType::I32, - Some("f32") => ElementType::F32, - Some("u64") => ElementType::U64, - Some("i64") => ElementType::I64, - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_output_tensor( - "parsed_numbers", - element_type, - DimensionsParam::Dynamic, - ); - - Ok(()) - } - - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&id).ok_or_else(|| { - KernelError::Other("Unable to get the kernel context".to_string()) - })?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let numbers = match element_type { - ElementType::Utf8 => buffer - .strings() - .map_err(|e| KernelError::Other(e.to_string()))?, - other => { - return Err(KernelError::Other(format!( - "The Parse proc-block only accepts Utf8 tensors, found {:?}", - other, - ))) - }, - }; - - match ctx.get_argument("element_type").as_deref() { - Some("u8") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::U8, - dimensions: &dimensions, - buffer: &transformed, - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("i8") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::I8, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("u16") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::U16, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("i16") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::I16, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("u32") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::U32, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("i32") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::I32, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("f32") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::F32, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("u64") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::U64, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("i64") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::I64, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some("f64") => { - let transformed = transform::(&numbers)?; - let output = TensorParam { - element_type: ElementType::F64, - dimensions: &dimensions, - buffer: transformed.as_bytes(), - }; - ctx.set_output_tensor("parsed_numbers", output); - }, - Some(_) => { - return Err(KernelError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(KernelError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })); - }, - } - - Ok(()) - } -} - -fn transform(inputs: &[&str]) -> Result, KernelError> -where - T: FromStr, - T::Err: Display, -{ - let mut values: Vec = Vec::new(); - - for input in inputs { - match input.parse() { - Ok(v) => values.push(v), - Err(e) => { - return Err(KernelError::Other(format!( - "Unable to parse \"{input}\" because of {e}" - ))) - }, - } - } - Ok(values) -} - -#[cfg(test)] -mod tests { - extern crate alloc; - use super::*; - use alloc::vec; - - #[test] - fn test_for_number_in_vec() { - let bytes = vec!["5", "6", "7"]; - let output: Vec = transform(&bytes).unwrap(); - let should_be = vec![5, 6, 7]; - assert_eq!(output, should_be); - } - - #[test] - fn test_for_invalid_data_type() { - let bytes = ["1.0", "a"]; - let err = transform::(&bytes).unwrap_err(); - - match err { - KernelError::Other(msg) => assert_eq!( - msg, - "Unable to parse \"a\" because of invalid float literal" - ), - other => panic!("Unexpected error: {:?}", other), - } - } -} diff --git a/password_strength/Cargo.toml b/password_strength/Cargo.toml index 967adeb8ea1..94da57b7903 100644 --- a/password_strength/Cargo.toml +++ b/password_strength/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "password_strength" version = "0.12.0" -edition = "2018" +edition = "2021" description = "It takes the utf8 strings as input and return 0, 1, 2 based on the utf8 string length" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/password_strength/src/lib.rs b/password_strength/src/lib.rs index 84706669a6a..d90809c22e3 100644 --- a/password_strength/src/lib.rs +++ b/password_strength/src/lib.rs @@ -1,159 +1,108 @@ -use crate::proc_block_v1::*; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; - -use std::str; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +use hotg_rune_proc_blocks::guest::{ + Argument, Dimensions, ElementType, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, +}; + +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: PasswordStrength, +} -#[macro_use] -extern crate alloc; -use alloc::string::ToString; +fn metadata() -> Metadata { + Metadata::new("Password Strength", env!("CARGO_PKG_VERSION")) + .with_description("Gauge the strength of your password!") + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("text") + .with_tag("string") + .with_input(TensorMetadata::new("password")) + .with_output( + TensorMetadata::new("password_strength") + .with_description("Label for Password strength"), + ) +} /// A proc block which can convert u8 bytes to utf8 -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("UTF8 Decode", env!("CARGO_PKG_VERSION")); - metadata.set_description("Decode a string from UTF-8 bytes."); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("text"); - metadata.add_tag("bytes"); - metadata.add_tag("string"); - - let input = TensorMetadata::new("bytes"); - input.set_description("string"); - let hint = - supported_shapes(&[ElementType::Utf8], DimensionsParam::Dynamic); - input.add_hint(&hint); - metadata.add_input(&input); - - let output = TensorMetadata::new("password_strength"); - output.set_description("Label for Password strength"); - let hint = - supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[0])); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - ctx.add_input_tensor( - "string", - ElementType::Utf8, - DimensionsParam::Dynamic, - ); - - ctx.add_output_tensor( - "password_strength", - ElementType::U32, - DimensionsParam::Fixed(&[0]), - ); - - Ok(()) +#[derive(Debug, Default, Clone, PartialEq)] +struct PasswordStrength; + +impl ProcBlock for PasswordStrength { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "password", + ElementType::Utf8, + Dimensions::Dynamic, + )], + outputs: vec![TensorConstraint::new( + "password_strength", + ElementType::U32, + Dimensions::Dynamic, + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; + fn run(&self, inputs: Vec) -> Result, RunError> { + let password = Tensor::get_named(&inputs, "password")?.string_view()?; - let words = match element_type { - ElementType::Utf8 => buffer - .strings() - .map_err(|e| KernelError::Other(e.to_string()))?, - other => { - return Err(KernelError::Other(format!( - "The Parse proc-block only accepts Utf8 tensors, found {:?}", - other, - ))) - }, - }; - - let output = transform(words); - - ctx.set_output_tensor( - "password_strength", - TensorParam { - element_type: ElementType::U32, - dimensions: &dimensions, - buffer: &output.as_bytes(), - }, - ); + let strength = password.mapv(password_strength); - Ok(()) + Ok(vec![Tensor::new("password_strength", &strength)]) } } -fn transform(input: Vec<&str>) -> Vec { - let mut password_length: Vec = Vec::new(); +impl From> for PasswordStrength { + fn from(_: Vec) -> Self { PasswordStrength::default() } +} - for i in input { - println!("{:?}", &i); - if &i[i.len() - 1..] == String::from('\n').as_str() { - if i.len() > 11 { - password_length.push(0); - } else if i.len() > 7 && i.len() <= 11 { - password_length.push(1); - } else { - password_length.push(2); - } - continue; - } +fn password_strength(password: &str) -> u32 { + match password.len() { + 0..=6 => 2, + 7..=10 => 1, + _ => 0, } - - return password_length; } #[cfg(test)] mod tests { + use hotg_rune_proc_blocks::ndarray; + use super::*; #[test] fn test_for_utf8_decoding() { - let string = vec![ - "aeroplane\n", - "bicycle\n", - "bird\n", - "boat\n", - "bottle\n", - "bus\n", - "car\n", - "cat\n", - "chair\n", - "cow\n", - "diningtable\n", - "dog\n", - "horse\n", - "motorbike\n", - "person\n", - "pottedplant\n", - "sheep\n", - "sofa\n", - "train\n", - "tv\n", + let passwords = ndarray::array![ + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tv", ]; + let input = vec![Tensor::from_strings("password", &passwords)]; + let should_be = vec![Tensor::new_1d( + "password_strength", + &[ + 1_u32, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 1, 2, 0, 2, 2, 2, 2, + ], + )]; - let should_be = - vec![1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 1, 2, 0, 2, 2, 2, 2]; - - let output = transform(string); + let output = PasswordStrength::default().run(input).unwrap(); assert_eq!(output, should_be); } diff --git a/prediction_errors/Cargo.toml b/prediction_errors/Cargo.toml index 2e41cc5c8ce..b1ce0d3083e 100644 --- a/prediction_errors/Cargo.toml +++ b/prediction_errors/Cargo.toml @@ -10,11 +10,10 @@ description = "a proc-block to find Mean Absolute Error and Mean Squared Error. hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/prediction_errors/src/lib.rs b/prediction_errors/src/lib.rs index 12103ab2cce..3423cd685a1 100644 --- a/prediction_errors/src/lib.rs +++ b/prediction_errors/src/lib.rs @@ -1,161 +1,74 @@ +use hotg_rune_proc_blocks::guest::{ + Argument, ElementType, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, +}; use smartcore::metrics::{ mean_absolute_error::MeanAbsoluteError, mean_squared_error::MeanSquareError, }; -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: PredictionErrors, +} -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +fn metadata() -> Metadata { + Metadata::new("Errors", env!("CARGO_PKG_VERSION")) + .with_description("for assessing prediction error") + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("metric") + .with_tag("analytics") + .with_tag("loss") + .with_input(TensorMetadata::new("y_true")) + .with_input(TensorMetadata::new("y_pred")) + .with_output(TensorMetadata::new("mean_absolute_error")) + .with_output(TensorMetadata::new("mean_square_error")) +} /// a proc-block to find Mean Absolute Error and Mean Squared Error -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Errors", env!("CARGO_PKG_VERSION")); - metadata.set_description("for assessing prediction error"); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("metric"); - metadata.add_tag("analytics"); - metadata.add_tag("loss"); - - let y_true = TensorMetadata::new("y_true"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_true.add_hint(&hint); - metadata.add_input(&y_true); - - let y_pred = TensorMetadata::new("y_pred"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_pred.add_hint(&hint); - metadata.add_input(&y_pred); - - let mae = TensorMetadata::new("mean_absolute_error"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[1])); - mae.add_hint(&hint); - metadata.add_output(&mae); - - let mse = TensorMetadata::new("mean_square_error"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[1])); - mse.add_hint(&hint); - metadata.add_output(&mse); - - register_node(&metadata); +#[derive(Debug, Default, Clone, PartialEq)] +struct PredictionErrors; + +impl ProcBlock for PredictionErrors { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new("y_true", ElementType::F64, [0]), + TensorConstraint::new("y_pred", ElementType::F64, [0]), + ], + outputs: vec![ + TensorConstraint::new( + "mean_absolute_error", + ElementType::F64, + [1], + ), + TensorConstraint::new( + "mean_square_error", + ElementType::F64, + [1], + ), + ], + } } - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "y_true", - element_type, - DimensionsParam::Fixed(&[0]), - ); + fn run(&self, inputs: Vec) -> Result, RunError> { + let y_true = Tensor::get_named(&inputs, "y_true")?.view_1d::()?; + let y_pred = Tensor::get_named(&inputs, "y_pred")?.view_1d::()?; - ctx.add_input_tensor( - "y_pred", - element_type, - DimensionsParam::Fixed(&[0]), - ); + let y_pred = y_pred.to_vec(); + let y_true = y_true.to_vec(); + let mae = MeanAbsoluteError {}.get_score(&y_true, &y_pred); + let mse = MeanSquareError {}.get_score(&y_true, &y_pred); - ctx.add_output_tensor( - "mean_absolute_error", - element_type, - DimensionsParam::Fixed(&[1]), - ); - - ctx.add_output_tensor( - "mean_square_error", - element_type, - DimensionsParam::Fixed(&[1]), - ); - - Ok(()) - } - - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let y_true = ctx.get_input_tensor("y_true").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_true".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let y_pred = ctx.get_input_tensor("y_pred").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_pred".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let metric = transform( - y_true.buffer.elements().to_vec(), - y_pred.buffer.elements().to_vec(), - ); - - let mae = vec![metric.0]; - - ctx.set_output_tensor( - "mean_absolute_error", - TensorParam { - element_type: ElementType::F64, - dimensions: &[1 as u32], - buffer: &mae.as_bytes(), - }, - ); - - let mse = vec![metric.1]; - - ctx.set_output_tensor( - "mean_square_error", - TensorParam { - element_type: ElementType::F64, - dimensions: &[1 as u32], - buffer: &mse.as_bytes(), - }, - ); - - Ok(()) + Ok(vec![ + Tensor::new_1d("mean_absolute_error", &[mae]), + Tensor::new_1d("mean_square_error", &[mse]), + ]) } } -fn transform(y_true: Vec, y_pred: Vec) -> (f64, f64) { - let mae = MeanAbsoluteError {}.get_score(&y_pred, &y_true); - let mse = MeanSquareError {}.get_score(&y_pred, &y_true); - - (mae, mse) +impl From> for PredictionErrors { + fn from(_: Vec) -> Self { PredictionErrors::default() } } #[cfg(test)] @@ -163,21 +76,19 @@ mod tests { use super::*; #[test] - fn check_mae() { - let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; - let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; - let metric = transform(y_true, y_pred); - - assert_eq!(0.5, metric.0); - } - - #[test] - fn check_mse() { - let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; - let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; - - let metric = transform(y_true, y_pred); - - assert_eq!(0.5, metric.1); + fn known_values() { + let predict_errors = PredictionErrors; + let inputs = vec![ + Tensor::new_1d("y_pred", &[0.0_f64, 0.0, 1.0, 1.0, 1.0, 1.0]), + Tensor::new_1d("y_true", &[0.0_f64, 1.0, 1.0, 0.0, 1.0, 0.0]), + ]; + + let got = predict_errors.run(inputs).unwrap(); + + let should_be = vec![ + Tensor::new_1d("mean_absolute_error", &[0.5]), + Tensor::new_1d("mean_square_error", &[0.5]), + ]; + assert_eq!(got, should_be); } } diff --git a/segment_output/Cargo.toml b/segment_output/Cargo.toml index d661bfee0f3..d1d0c666aa0 100644 --- a/segment_output/Cargo.toml +++ b/segment_output/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "segment_output" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "A proc-block which takes a rank 4 `tensor` as input, whose dimension is of this form `[1, x, y, z]`. It will return: 1. a 2-d `tensor` after performing argmax along the axis-3 of the tensor 2. a 1-d `tensor` which a `set` of all the number present in the above 2-d`tensor`" diff --git a/segment_output/src/lib.rs b/segment_output/src/lib.rs index eaba33fa6fb..eb1e2f3f4b7 100644 --- a/segment_output/src/lib.rs +++ b/segment_output/src/lib.rs @@ -1,16 +1,36 @@ -use std::{collections::BTreeSet, convert::TryInto}; - -use crate::proc_block_v1::{ - BadInputReason, GraphError, InvalidInput, KernelError, -}; +use std::collections::BTreeSet; use hotg_rune_proc_blocks::{ - ndarray::{s, ArrayView4}, - runtime_v1::*, - BufferExt, SliceExt, + guest::{ + Argument, ElementType, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{s, Array1, Array2, ArrayView4}, }; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: SegmentOutput, +} + +fn metadata() -> Metadata { + Metadata::new("Segment Output", env!("CARGO_PKG_VERSION")) + .with_description("Useful in image segmentation. A proc-block which takes a rank 4 tensor as input, whose dimension is of this form `[1, rows, columns, confidence]`.") + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("image") + .with_tag("segmentation") + .with_input(TensorMetadata::new("image") + .with_description("An image-like tensor with the dimensions, `[1, rows, columns, category_confidence]`. Each \"pixel\" is associated with a set of confidence values, where each value indicates how confident the model is that the pixel is in that category.")) + .with_output(TensorMetadata::new("segmentation_map") + .with_description( +"An image-like tensor where each pixel contains the index of the category with the highest confidence level." + )) + .with_output( + TensorMetadata::new("indices") + .with_description("The categories used in `segmentation_map`."), + ) +} /// A proc-block which takes a rank 4 `tensor` as input, whose dimension is of /// this form `[1, x, y, z]`. @@ -19,136 +39,51 @@ wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); /// 1. a 2-d `tensor` after performing argmax along the axis-3 of the tensor /// 2. a 1-d `tensor` which a `set` of all the number present in the above 2-d /// `tensor` -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Segment Output", env!("CARGO_PKG_VERSION")); - metadata.set_description("Useful in image segmentation. A proc-block which takes a rank 4 tensor as input, whose dimension is of this form `[1, rows, columns, confidence]`."); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("image"); - metadata.add_tag("segmentation"); - - let input = TensorMetadata::new("image"); - input.set_description("An image-like tensor with the dimensions, `[1, rows, columns, category_confidence]`. Each \"pixel\" is associated with a set of confidence values, where each value indicates how confident the model is that the pixel is in that category."); - let hint = supported_shapes( - &[ElementType::F32], - DimensionsParam::Fixed(&[1, 0, 0, 0]), - ); - input.add_hint(&hint); - metadata.add_input(&input); - - let segmentation_map = TensorMetadata::new("segmentation_map"); - segmentation_map.set_description("An image-like tensor where each pixel contains the index of the category with the highest confidence level."); - let hint = supported_shapes( - &[ElementType::U32], - DimensionsParam::Fixed(&[0, 0]), - ); - segmentation_map.add_hint(&hint); - metadata.add_output(&segmentation_map); - - let indices = TensorMetadata::new("indices"); - indices.set_description("The categories used in `segmentation_map`."); - let hint = - supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[0])); - indices.add_hint(&hint); - metadata.add_output(&indices); - - register_node(&metadata); +struct SegmentOutput; + +impl ProcBlock for SegmentOutput { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "input", + ElementType::F32, + vec![1, 0, 0, 0], + )], + outputs: vec![ + TensorConstraint::new( + "segmentation_map", + ElementType::U32, + vec![0, 0], + ), + TensorConstraint::new("indices", ElementType::U32, vec![0]), + ], + } } - fn graph(id: String) -> Result<(), GraphError> { - let ctx = - GraphContext::for_node(&id).ok_or(GraphError::MissingContext)?; - - ctx.add_input_tensor( - "input", - ElementType::F32, - DimensionsParam::Fixed(&[1, 0, 0, 0]), - ); - - ctx.add_output_tensor( - "segmentation_map", - ElementType::U32, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_output_tensor( - "indices", - ElementType::U32, - DimensionsParam::Fixed(&[0]), - ); - - Ok(()) - } + fn run(&self, inputs: Vec) -> Result, RunError> { + let input = Tensor::get_named(&inputs, "input")?.view_4d::()?; + + let (segmented_map, indices) = transform(input); - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = - KernelContext::for_node(&id).ok_or(KernelError::MissingContext)?; - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let (segmented_map, indices) = match element_type { - ElementType::F32 => { - let tensor =buffer.view::(&dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| KernelError::InvalidInput(InvalidInput{ name: "input".to_string(), reason: BadInputReason::InvalidValue(e.to_string()) }))?; - transform(tensor) - }, - other => { - return Err(KernelError::Other(format!( - "The softmax proc-block only accepts f32 or f64 tensors, found {:?}", - other, - ))) - }, - }; - - ctx.set_output_tensor( - "segmentation_map", - TensorParam { - element_type, - dimensions: &[dimensions[1], dimensions[2]], - buffer: &segmented_map.as_bytes(), - }, - ); - - ctx.set_output_tensor( - "indices", - TensorParam { - element_type, - dimensions: &[indices.len().try_into().unwrap()], - buffer: &indices.as_bytes(), - }, - ); - - Ok(()) + Ok(vec![ + Tensor::new("segmentation_map", &segmented_map), + Tensor::new("indices", &indices), + ]) } } -fn transform(input: ArrayView4) -> (Vec, Vec) { - let dim = input.shape(); - - let mut vec_2d: Vec> = Vec::new(); +impl From> for SegmentOutput { + fn from(_: Vec) -> Self { SegmentOutput } +} - let rows = dim[1] as usize; - let columns = dim[2] as usize; +fn transform(input: ArrayView4<'_, f32>) -> (Array2, Array1) { + let (_, rows, columns, _) = input.dim(); + let mut map = Array2::zeros((rows, columns)); let mut label_index = BTreeSet::new(); for i in 0..rows { - vec_2d.push(vec![]); for j in 0..columns { - println!(" i, j {} {}", i, j); let val = input.slice(s![0 as usize, i, j, ..]); let (index, _) = val.iter().enumerate().fold((0, 0.0), |max, (ind, &val)| { @@ -157,17 +92,13 @@ fn transform(input: ArrayView4) -> (Vec, Vec) { } else { max } - }); // Doing argmax over the array - vec_2d[i].push(index as u32); + }); + map[[i, j]] = index as u32; label_index.insert(index as u32); } } - println!("{:?}", vec_2d); - ( - vec_2d.iter().flat_map(|arr| arr.iter()).cloned().collect(), - label_index.into_iter().collect(), - ) + (map, label_index.into_iter().collect()) } #[cfg(test)] @@ -213,10 +144,16 @@ mod tests { ]; let input = input.broadcast((1, 5, 4, 3)).unwrap(); - let output = transform(input); - let should_be: Vec = - vec![2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1]; - let label_index: Vec = vec![1, 2]; - assert_eq!(output, (should_be, label_index)); + let (segments, indices) = transform(input); + + assert_eq!(indices, ndarray::array![1, 2]); + let segments_should_be: Array2 = ndarray::array![ + [2, 2, 1, 1], + [2, 2, 1, 1], + [2, 2, 1, 1], + [2, 2, 1, 1], + [2, 2, 1, 1], + ]; + assert_eq!(segments, segments_should_be); } } diff --git a/softmax/Cargo.toml b/softmax/Cargo.toml index 63818f561a1..17555c1b531 100644 --- a/softmax/Cargo.toml +++ b/softmax/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "softmax" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "Find the normalised exponential function (softmax)." diff --git a/softmax/src/lib.rs b/softmax/src/lib.rs index 360225d18c3..2a685769308 100644 --- a/softmax/src/lib.rs +++ b/softmax/src/lib.rs @@ -1,133 +1,82 @@ -use crate::proc_block_v1::{ - BadInputReason, GraphError, InvalidInput, KernelError, -}; - use hotg_rune_proc_blocks::{ - ndarray::ArrayViewMut1, runtime_v1::*, BufferExt, ValueType, + guest::{ + Argument, Dimensions, ElementTypeConstraint, InvalidInput, Metadata, + ProcBlock, RunError, Tensor, TensorConstraint, TensorConstraints, + TensorMetadata, + }, + ndarray::ArrayViewMutD, }; use num_traits::Float; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -struct ProcBlockV1; - -fn softmax(mut input: ArrayViewMut1<'_, T>) -where - T: Float + num_traits::FromPrimitive, -{ - input.mapv_inplace(|x| x.exp()); - - let sum = input.sum(); - if !sum.is_zero() { - input.mapv_inplace(|x| x / sum); - } +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: Softmax, } -fn preprocess_buffer<'buf, T>( - buffer: &'buf mut [u8], - dimensions: &[u32], -) -> Result, KernelError> -where - T: ValueType, -{ - buffer - .view_mut::(dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "confidences".to_string(), - reason: BadInputReason::InvalidValue(e.to_string()), - }) - }) +fn metadata() -> Metadata { + Metadata::new("Softmax", env!("CARGO_PKG_VERSION")) + .with_description(env!("CARGO_PKG_DESCRIPTION")) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("softmax") + .with_tag("image") + .with_tag("nlp") + .with_tag("numeric") + .with_tag("classification") + .with_input(TensorMetadata::new("input")) + .with_input(TensorMetadata::new("soft_max").with_description( + "Vector normalised into probability distribution", + )) } -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Softmax", env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("softmax"); - metadata.add_tag("image"); - metadata.add_tag("nlp"); - metadata.add_tag("numeric"); - metadata.add_tag("classification"); - - let input = TensorMetadata::new("input"); - let hint = supported_shapes( - &[ElementType::F32, ElementType::F64], - DimensionsParam::Fixed(&[0]), - ); - input.add_hint(&hint); - metadata.add_input(&input); - - let soft_max = TensorMetadata::new("soft_max"); - soft_max - .set_description("Vector normalised into probability distribution"); - let hint = supported_shapes( - &[ElementType::F32, ElementType::F64], - DimensionsParam::Fixed(&[0]), - ); - soft_max.add_hint(&hint); - metadata.add_output(&soft_max); - - register_node(&metadata); +struct Softmax; + +impl ProcBlock for Softmax { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![TensorConstraint::new( + "input", + ElementTypeConstraint::F32 | ElementTypeConstraint::F64, + Dimensions::Dynamic, + )], + outputs: vec![TensorConstraint::new( + "soft_max", + ElementTypeConstraint::F32 | ElementTypeConstraint::F64, + Dimensions::Dynamic, + )], + } } - fn graph(id: String) -> Result<(), GraphError> { - let ctx = - GraphContext::for_node(&id).ok_or(GraphError::MissingContext)?; - - ctx.add_input_tensor( - "input", - ElementType::F32, - DimensionsParam::Fixed(&[0]), - ); + fn run(&self, mut inputs: Vec) -> Result, RunError> { + let mut input = Tensor::take_named(&mut inputs, "input")?; - ctx.add_output_tensor( - "soft_max", - ElementType::F32, - DimensionsParam::Fixed(&[0]), - ); + if let Ok(floats) = input.view_mut::() { + softmax_inplace(floats); + } else if let Ok(doubles) = input.view_mut::() { + softmax_inplace(doubles); + } else { + return Err( + InvalidInput::incompatible_element_type(&input.name).into() + ); + } - Ok(()) + Ok(vec![input.with_name("soft_max")]) } +} + +impl From> for Softmax { + fn from(_: Vec) -> Self { Softmax } +} - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = - KernelContext::for_node(&id).ok_or(KernelError::MissingContext)?; - let TensorResult { - element_type, - dimensions, - mut buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - match element_type { - ElementType::F32 => preprocess_buffer::(&mut buffer, &dimensions).map(softmax)?, - ElementType::F64 => preprocess_buffer::(&mut buffer, &dimensions).map(softmax)?, - other => { - return Err(KernelError::Other(format!( - "The softmax proc-block only accepts f32 or f64 tensors, found {:?}", - other, - ))) - }, - }; - - ctx.set_output_tensor( - "soft_max", - TensorParam { - element_type, - dimensions: &dimensions, - buffer: &buffer, - }, - ); - - Ok(()) +fn softmax_inplace(mut input: ArrayViewMutD<'_, T>) +where + T: Float + num_traits::FromPrimitive, +{ + input.mapv_inplace(|x| x.exp()); + + let sum = input.sum(); + if !sum.is_zero() { + input.mapv_inplace(|x| x / sum); } } @@ -137,26 +86,26 @@ mod tests { use hotg_rune_proc_blocks::ndarray; #[test] - fn test_softmax_unfiorm() { + fn softmax_uniform() { let mut input = ndarray::arr1(&[1.0, 1.0, 1.0, 1.0]); let softmax_correct = ndarray::arr1(&[0.25, 0.25, 0.25, 0.25]); - softmax(input.view_mut()); + softmax_inplace(input.view_mut().into_dyn()); assert_eq!(input, softmax_correct); } #[test] - fn test_softmax_single() { + fn softmax_single() { let mut input = ndarray::arr1(&[1.0, 0.0]); let softmax_correct = ndarray::arr1(&[0.7310585786300049, 0.26894142136999510]); - softmax(input.view_mut()); + softmax_inplace(input.view_mut().into_dyn()); assert_eq!(input, softmax_correct); } #[test] - fn test_softmax() { + fn known_values() { let mut input = ndarray::arr1(&[1.0, 2.0, 3.0]); let softmax_correct = ndarray::arr1(&[ 0.09003057317038046, @@ -164,35 +113,49 @@ mod tests { 0.6652409557748219, ]); - softmax(input.view_mut()); + softmax_inplace(input.view_mut().into_dyn()); assert_eq!(input, softmax_correct); } #[test] - fn test_softmax_zeros() { + fn softmax_zeros() { let mut input = ndarray::arr1(&[0.0, 0.0]); let softmax_correct = ndarray::arr1(&[0.5, 0.5]); - softmax(input.view_mut()); + softmax_inplace(input.view_mut().into_dyn()); assert_eq!(input, softmax_correct); } #[test] - fn test_softmax_zero() { + fn softmax_zero() { let mut input = ndarray::arr1(&[0.0]); let softmax_correct = ndarray::arr1(&[1.0]); - softmax(input.view_mut()); + softmax_inplace(input.view_mut().into_dyn()); assert_eq!(input, softmax_correct); } #[test] - fn test_softmax_empty() { + fn softmax_empty() { let empty: &[f32] = &[]; let mut input = ndarray::Array::from_vec(empty.to_vec()); let softmax_correct = ndarray::Array::from_vec(empty.to_vec()); - softmax(input.view_mut()); + softmax_inplace(input.view_mut().into_dyn()); assert_eq!(input, softmax_correct); } + + #[test] + fn floats() { + let inputs = vec![Tensor::new_1d("input", &[1.0_f32, 2.0, 3.0])]; + let softmax_correct = ndarray::arr1(&[ + 0.09003057317038046_f32, + 0.24472847105479767, + 0.6652409557748219, + ]); + + let got = Softmax.run(inputs).unwrap(); + + assert_eq!(got, vec![Tensor::new("soft_max", &softmax_correct)]); + } } diff --git a/support/Cargo.toml b/support/Cargo.toml index 79d94b1c271..26cd363a0d4 100644 --- a/support/Cargo.toml +++ b/support/Cargo.toml @@ -6,12 +6,14 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bytemuck = "1.9.1" +getrandom = { version = "0.2.6", default-features = false, features = ["custom"], optional = true } ndarray = "0.15.4" -once_cell = "1.12.0" -rand = { version = "0.8.5", features = ["small_rng"] } -wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen", optional = true } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } +thiserror = "1.0.31" +tracing = { version = "0.1.35", default-features = false, features = ["std"] } +tracing-subscriber = { version = "0.3.11", default-features = false, features = ["registry", "tracing-log"] } +wit-bindgen-rust = { git = "https://github.com/wasmerio/wit-bindgen", branch = "wasmer", version = "0.1.0", optional = true } [features] -default = ["runtime_v1"] -runtime_v1 = ["wit-bindgen-rust"] +default = ["guest"] +guest = ["wit-bindgen-rust", "getrandom"] diff --git a/support/src/bindings.rs b/support/src/bindings.rs deleted file mode 100644 index 673c526be7f..00000000000 --- a/support/src/bindings.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::{fmt::Display, str::FromStr}; - -pub mod runtime_v1 { - // Note: this also generates a `runtime_v1` module, but it's private and - // can't be exported. As a workaround, we've wrapped it in another - // runtime_v1 module and re-exported its contents. - wit_bindgen_rust::import!("../wit-files/rune/runtime-v1.wit"); - - use crate::bindings::ContextExt; - - pub use self::runtime_v1::*; - - use std::{ - fmt::{self, Display, Formatter}, - str::FromStr, - }; - - #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub struct InvalidElementType { - pub actual: String, - } - - impl std::error::Error for InvalidElementType {} - - impl Display for InvalidElementType { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "\"{}\" is not a valid input type, expected one of {:?}", - self.actual, - crate::common::element_type::ALL - ) - } - } - - impl ElementType { - pub const ALL: &'static [&'static str] = &[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - "utf8", - ]; - pub const DESCRIPTION: &'static str = "The output type."; - pub const NAME: &'static str = "element_type"; - pub const NUMERIC: &'static [&'static str] = &[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ]; - - fn human_name(self) -> &'static str { - match self { - ElementType::U8 => "u8", - ElementType::I8 => "i8", - ElementType::U16 => "u16", - ElementType::I16 => "i16", - ElementType::U32 => "u32", - ElementType::I32 => "i32", - ElementType::F32 => "f32", - ElementType::U64 => "u64", - ElementType::I64 => "i64", - ElementType::F64 => "f64", - ElementType::Utf8 => "utf8", - } - } - } - impl Display for ElementType { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.human_name().fmt(f) - } - } - - impl FromStr for ElementType { - type Err = InvalidElementType; - - fn from_str(s: &str) -> Result { - match s { - "u8" => Ok(ElementType::U8), - "i8" => Ok(ElementType::I8), - "u16" => Ok(ElementType::U16), - "i16" => Ok(ElementType::I16), - "u32" => Ok(ElementType::U32), - "i32" => Ok(ElementType::I32), - "f32" => Ok(ElementType::F32), - "u64" => Ok(ElementType::U64), - "i64" => Ok(ElementType::I64), - "f64" => Ok(ElementType::F64), - "utf8" => Ok(ElementType::Utf8), - other => Err(InvalidElementType { - actual: other.to_string(), - }), - } - } - } - - impl ArgumentMetadata { - /// Register an `element_type` argument which accepts any - /// [`ElementType`] and defaults to [`ElementType::F32`]. - pub fn element_type() -> Self { - let element_type = ArgumentMetadata::new(ElementType::NAME); - element_type.set_description(ElementType::DESCRIPTION); - element_type.set_default_value(ElementType::F32.human_name()); - element_type.add_hint(&runtime_v1::interpret_as_string_in_enum( - ElementType::ALL, - )); - element_type - } - - /// Register an `element_type` argument which accepts any **numeric** - /// [`ElementType`] and defaults to [`ElementType::F32`]. - pub fn numeric_element_type() -> Self { - let element_type = ArgumentMetadata::new("element_type"); - element_type.set_description(ElementType::DESCRIPTION); - element_type.set_default_value(ElementType::F32.human_name()); - element_type.add_hint(&runtime_v1::interpret_as_string_in_enum( - ElementType::NUMERIC, - )); - element_type - } - } - - impl ContextExt for GraphContext { - fn _get_argument(&self, name: &str) -> Option { - self.get_argument(name) - } - } - - impl ContextExt for KernelContext { - fn _get_argument(&self, name: &str) -> Option { - self.get_argument(name) - } - } -} - -pub trait ContextErrorExt { - type InvalidArgument: InvalidArgumentExt; - - fn invalid_argument(inner: Self::InvalidArgument) -> Self; -} - -pub trait InvalidArgumentExt { - fn other(name: &str, msg: impl Display) -> Self; - fn invalid_value(name: &str, error: impl Display) -> Self; - fn not_found(name: &str) -> Self; -} - -pub trait ContextExt { - fn _get_argument(&self, name: &str) -> Option; - - fn required_argument(&self, name: &str) -> Result - where - E: ContextErrorExt, - { - self._get_argument(name).ok_or_else(|| { - E::invalid_argument(E::InvalidArgument::not_found(name)) - }) - } - - fn parse_argument(&self, name: &str) -> Result - where - T: FromStr, - T::Err: Display, - E: ContextErrorExt, - { - self.required_argument(name)? - .parse() - .map_err(|e| E::InvalidArgument::invalid_value(name, e)) - .map_err(E::invalid_argument) - } - - fn parse_argument_with_default( - &self, - name: &str, - default: T, - ) -> Result - where - T: FromStr, - T::Err: Display, - E: ContextErrorExt, - { - let arg = match self._get_argument(name) { - Some(a) => a, - None => return Ok(default), - }; - - arg.parse() - .map_err(|e| E::InvalidArgument::invalid_value(name, e)) - .map_err(E::invalid_argument) - } -} diff --git a/support/src/buffer_ext.rs b/support/src/buffer_ext.rs deleted file mode 100644 index 0c0f5772de1..00000000000 --- a/support/src/buffer_ext.rs +++ /dev/null @@ -1,234 +0,0 @@ -use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, ErrorKind, ShapeError}; - -use crate::ValueType; - -/// Extension traits added to a byte buffer. -pub trait BufferExt { - /// Reinterpret this byte buffer as a slice of `T`'s, where `T` is a trivial - /// value type (`f32`, `u8`, etc.). - /// - /// For working with strings, see [`BufferExt::strings()`]. - /// - /// # Examples - /// - /// On a little-endian machine, the binary representation for - /// `[1_u16, 256, 65535]` is `[0x1, 0x0, 0x0, 0x1, 0xff, 0xff]`. We can use - /// the `elements()` method to reinterpret this binary representation as - /// a slice of `u16`s to get back the original numbers. - /// - /// ```rust - /// use hotg_rune_proc_blocks::BufferExt; - /// - /// let binary_representation: [u8; 6] = [0x1, 0x0, 0x0, 0x1, 0xff, 0xff]; - /// - /// let reinterpreted: &[u16] = binary_representation.elements(); - /// - /// // Note: - /// assert_eq!(reinterpreted, &[1, 256, 65535]); - /// ``` - fn elements(&self) -> &[T]; - - /// Reinterpret this byte buffer as a mutable slice of `T`'s. - /// - /// This is the mutable version of [`BufferExt::elements()`]. - fn elements_mut(&mut self) -> &mut [T]; - - /// Interpret this buffer as a sequence of UTF-8 strings, where each string - /// is prefixed by its length as a little-endian `u16`. - /// - /// # Examples - /// - /// The [`BufferExt::strings()`] method can extract strings from the byte - /// buffer created by a [`crate::StringBuilder`]. - /// - /// ```rust - /// use hotg_rune_proc_blocks::{StringBuilder, BufferExt}; - /// - /// let mut builder = StringBuilder::new(); - /// builder - /// .push("this") - /// .push("is") - /// .push("a") - /// .push("sentence"); - /// let bytes: Vec = builder.finish(); - /// - /// let strings = bytes.strings().unwrap(); - /// - /// assert_eq!(strings, &["this", "is", "a", "sentence"]); - /// ``` - fn strings(&self) -> Result, ShapeError>; - - /// View the buffer as a multi-dimensional array. - fn view( - &self, - dimensions: &[u32], - ) -> Result, ShapeError> { - let elements = self.elements(); - let dimensions = dims(dimensions); - ArrayViewD::from_shape(dimensions, elements) - } - - /// View the buffer as a mutable multi-dimensional array. - fn view_mut( - &mut self, - dimensions: &[u32], - ) -> Result, ShapeError> { - let elements = self.elements_mut(); - let dimensions = dims(dimensions); - ArrayViewMutD::from_shape(dimensions, elements) - } - - fn string_view<'a>( - &'a self, - dimensions: &[u32], - ) -> Result, ShapeError> { - let strings = self.strings()?; - let dimensions = dims(dimensions); - ArrayD::from_shape_vec(dimensions, strings) - } -} - -fn dims(d: &[u32]) -> Vec { - d.iter() - .map(|&dim| usize::try_from(dim).expect("Conversion should never fail")) - .collect() -} - -impl BufferExt for [u8] { - fn elements(&self) -> &[T] { - unsafe { - let (start, middle, end) = self.align_to::(); - assert!(start.is_empty()); - assert!(end.is_empty()); - middle - } - } - - fn elements_mut(&mut self) -> &mut [T] { - unsafe { - let (start, middle, end) = self.align_to_mut::(); - assert!(start.is_empty()); - assert!(end.is_empty()); - middle - } - } - - fn strings(&self) -> Result, ShapeError> { - const HEADER_SIZE: usize = std::mem::size_of::(); - - let mut strings = Vec::new(); - let mut buffer = self; - - while !buffer.is_empty() { - if buffer.len() < HEADER_SIZE { - // We don't have enough bytes remaining for a full length field, - // so something is probably wrong with our buffer. - return Err(ShapeError::from_kind(ErrorKind::OutOfBounds)); - } - - let (len, rest) = buffer.split_at(HEADER_SIZE); - - let len: [u8; HEADER_SIZE] = len.try_into().expect("Unreachable"); - let len = u32::from_le_bytes(len); - let len = usize::try_from(len).expect("Unreachable"); - - if rest.len() < len { - // We don't have enough bytes left in the buffer to read a - // string with this length. - return Err(ShapeError::from_kind(ErrorKind::OutOfBounds)); - } - - let (s, rest) = rest.split_at(len); - - match std::str::from_utf8(s) { - Ok(s) => strings.push(s), - Err(_) => { - // The string wasn't valid UTF-8. We're probably using the - // wrong ShapeError here, but our alternative would be - // introducing our own error type and that seems overkill. - return Err(ShapeError::from_kind( - ErrorKind::IncompatibleLayout, - )); - }, - } - - buffer = rest; - } - - Ok(strings) - } -} - -#[cfg(test)] -mod tests { - use std::io::Write; - - use ndarray::ErrorKind; - - use super::*; - - fn as_byte_buffer_mut(items: &mut [T]) -> &mut [u8] { - // Safety: Invariant upheld by the ValueType impl - let (head, bytes, tail) = unsafe { items.align_to_mut() }; - assert!(head.is_empty()); - assert!(tail.is_empty()); - bytes - } - - #[test] - fn view_4_floats_as_2x2() { - let floats = &[0.0_f32, 1.0, 2.0, 3.0]; - let buffer: Vec = - floats.iter().flat_map(|f| f.to_ne_bytes()).collect(); - let dimensions = &[2, 2]; - - let tensor = buffer.view::(dimensions).unwrap(); - - assert_eq!(tensor.dim(), ndarray::Dim(vec![2, 2])); - assert_eq!(tensor[[0, 0]], 0.0); - assert_eq!(tensor[[0, 1]], 1.0); - assert_eq!(tensor[[1, 0]], 2.0); - assert_eq!(tensor[[1, 1]], 3.0); - } - - #[test] - fn incorrect_size_is_error() { - let buffer = [1_u8, 2, 3, 4]; - let dimensions = &[5]; - - let error = buffer.view::(dimensions).unwrap_err(); - - let kind = error.kind(); - assert_eq!(kind, ErrorKind::OutOfBounds); - } - - #[test] - fn mutate_tensor_in_place() { - let mut floats = [0.0_f32, 0.0, 0.0, 0.0]; - let dimensions = &[2, 2]; - - { - let buffer = as_byte_buffer_mut(&mut floats); - let mut tensor = buffer.view_mut::(dimensions).unwrap(); - - tensor[[1, 0]] = 5.0; - }; - - assert_eq!(floats, [0.0, 0.0, 5.0, 0.0]); - } - - #[test] - fn load_string_tensor() { - let strings = ["this", "is a", "sentence", "."]; - let mut buffer = Vec::new(); - for s in &strings { - let length = (s.len() as u32).to_le_bytes(); - buffer.write_all(&length).unwrap(); - buffer.write_all(s.as_bytes()).unwrap(); - } - - let got = buffer.strings().unwrap(); - - assert_eq!(got, strings); - } -} diff --git a/support/src/common.rs b/support/src/common.rs deleted file mode 100644 index 16626491b9d..00000000000 --- a/support/src/common.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! Common arguments that are used across proc-blocks. - -pub mod element_type { - pub const NAME: &str = "element_type"; - pub const DESCRIPTION: &str = "The output type."; - pub const ALL: &[&str] = &[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - "utf8", - ]; - pub const NUMERIC: &[&str] = &[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ]; -} diff --git a/support/src/guest/bindings.rs b/support/src/guest/bindings.rs new file mode 100644 index 00000000000..a72dca773a5 --- /dev/null +++ b/support/src/guest/bindings.rs @@ -0,0 +1,43 @@ +pub use self::{proc_block_v2::*, runtime_v2::*}; + +use crate::guest::ProcBlock; +use wit_bindgen_rust::Handle; + +wit_bindgen_rust::import!("../wit-files/rune/runtime-v2.wit"); +wit_bindgen_rust::export!("../wit-files/rune/proc-block-v2.wit"); + +extern "Rust" { + fn __proc_block_metadata() -> Metadata; + fn __proc_block_new( + args: Vec, + ) -> Result, CreateError>; +} + +struct ProcBlockV2; + +impl proc_block_v2::ProcBlockV2 for ProcBlockV2 { + fn metadata() -> Metadata { + crate::guest::ensure_initialized(); + unsafe { __proc_block_metadata() } + } + + fn create_node( + args: Vec, + ) -> Result, CreateError> { + crate::guest::ensure_initialized(); + let proc_block = unsafe { __proc_block_new(args)? }; + Ok(Handle::new(Node(Box::new(proc_block)))) + } +} + +pub struct Node(Box); + +impl proc_block_v2::Node for Node { + fn tensor_constraints(&self) -> TensorConstraints { + self.0.tensor_constraints() + } + + fn run(&self, inputs: Vec) -> Result, RunError> { + self.0.run(inputs) + } +} diff --git a/support/src/guest/element_type.rs b/support/src/guest/element_type.rs new file mode 100644 index 00000000000..c5ac2ca562c --- /dev/null +++ b/support/src/guest/element_type.rs @@ -0,0 +1,99 @@ +use std::str::FromStr; + +use crate::guest::ElementType; +use bytemuck::{AnyBitPattern, NoUninit}; + +/// A primitive value that can be stored directly in a [`crate::guest::Tensor`]. +pub trait PrimitiveTensorElement: AnyBitPattern + NoUninit { + const ELEMENT_TYPE: ElementType; +} + +impl PrimitiveTensorElement for u8 { + const ELEMENT_TYPE: ElementType = ElementType::U8; +} +impl PrimitiveTensorElement for i8 { + const ELEMENT_TYPE: ElementType = ElementType::I8; +} +impl PrimitiveTensorElement for u16 { + const ELEMENT_TYPE: ElementType = ElementType::U16; +} +impl PrimitiveTensorElement for i16 { + const ELEMENT_TYPE: ElementType = ElementType::I16; +} +impl PrimitiveTensorElement for u32 { + const ELEMENT_TYPE: ElementType = ElementType::U32; +} +impl PrimitiveTensorElement for i32 { + const ELEMENT_TYPE: ElementType = ElementType::I32; +} +impl PrimitiveTensorElement for f32 { + const ELEMENT_TYPE: ElementType = ElementType::F32; +} +impl PrimitiveTensorElement for u64 { + const ELEMENT_TYPE: ElementType = ElementType::U64; +} +impl PrimitiveTensorElement for i64 { + const ELEMENT_TYPE: ElementType = ElementType::I64; +} +impl PrimitiveTensorElement for f64 { + const ELEMENT_TYPE: ElementType = ElementType::F64; +} + +impl ElementType { + pub const NAMES: &'static [&'static str] = &[ + "u8", + "i8", + "u16", + "i16", + "u32", + "i32", + "f32", + "u64", + "i64", + "f64", + "complex64", + "complex128", + "utf8", + ]; +} + +impl TryFrom<&'_ str> for ElementType { + type Error = UnknownElementType; + + fn try_from(value: &'_ str) -> Result { + match value { + "u8" | "U8" => Ok(ElementType::U8), + "i8" | "I8" => Ok(ElementType::I8), + "u16" | "U16" => Ok(ElementType::U16), + "i16" | "I16" => Ok(ElementType::I16), + "u32" | "U32" => Ok(ElementType::U32), + "i32" | "I32" => Ok(ElementType::I32), + "f32" | "F32" => Ok(ElementType::F32), + "u64" | "U64" => Ok(ElementType::U64), + "i64" | "I64" => Ok(ElementType::I64), + "f64" | "F64" => Ok(ElementType::F64), + "complex64" => Ok(ElementType::Complex64), + "complex128" => Ok(ElementType::Complex128), + "utf8" | "UTF8" => Ok(ElementType::Utf8), + other => Err(UnknownElementType(other.to_string())), + } + } +} + +impl FromStr for ElementType { + type Err = UnknownElementType; + + fn from_str(s: &str) -> Result { s.try_into() } +} + +impl TryFrom for ElementType { + type Error = UnknownElementType; + + fn try_from(value: String) -> Result { + value.as_str().try_into() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)] +#[error("Unknown element type, \"{_0}\"")] +pub struct UnknownElementType(String); diff --git a/support/src/guest/errors.rs b/support/src/guest/errors.rs new file mode 100644 index 00000000000..1fa282510bd --- /dev/null +++ b/support/src/guest/errors.rs @@ -0,0 +1,207 @@ +use std::{ + convert::Infallible, + fmt::{self, Display, Formatter}, +}; + +use crate::guest::bindings::*; + +impl RunError { + pub fn other(reason: impl Display) -> Self { + RunError::Other(reason.to_string()) + } + + pub fn missing_input(name: impl Into) -> Self { + RunError::InvalidInput(InvalidInput::not_found(name)) + } +} + +impl PartialEq for RunError { + fn eq(&self, other: &RunError) -> bool { + match (self, other) { + (RunError::Other(left), RunError::Other(right)) => left == right, + (RunError::InvalidInput(left), RunError::InvalidInput(right)) => { + left == right + }, + (RunError::Other(_), _) | (RunError::InvalidInput(_), _) => false, + } + } +} + +impl Display for RunError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + RunError::InvalidInput(i) => i.fmt(f), + RunError::Other(msg) => write!(f, "{}", msg), + } + } +} + +impl std::error::Error for RunError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RunError::InvalidInput(i) => Some(i), + RunError::Other(_) => None, + } + } +} + +impl From for RunError { + fn from(e: InvalidInput) -> Self { RunError::InvalidInput(e) } +} + +impl CreateError { + pub fn other(error: impl Display) -> Self { + CreateError::Other(error.to_string()) + } +} + +impl From for CreateError { + fn from(v: Infallible) -> Self { match v {} } +} + +impl From for CreateError { + fn from(e: ArgumentError) -> Self { CreateError::Argument(e) } +} + +impl InvalidInput { + pub fn incompatible_dimensions(tensor_name: impl Into) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::IncompatibleDimensions, + } + } + + pub fn incompatible_element_type(tensor_name: impl Into) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::IncompatibleElementType, + } + } + + pub fn not_found(tensor_name: impl Into) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::NotFound, + } + } + + pub fn invalid_value( + tensor_name: impl Into, + error: impl Display, + ) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::InvalidValue(error.to_string()), + } + } + + pub fn other(tensor_name: impl Into, reason: impl Display) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::Other(reason.to_string()), + } + } +} + +impl PartialEq for InvalidInput { + fn eq(&self, other: &InvalidInput) -> bool { + let InvalidInput { name, reason } = self; + + name == &other.name && reason == &other.reason + } +} + +impl std::error::Error for InvalidInput { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.reason) + } +} + +impl Display for InvalidInput { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "The \"{}\" input tensor was invalid", self.name) + } +} + +impl PartialEq for InvalidInputReason { + fn eq(&self, other: &InvalidInputReason) -> bool { + match (self, other) { + ( + InvalidInputReason::Other(left), + InvalidInputReason::Other(right), + ) => left == right, + ( + InvalidInputReason::InvalidValue(left), + InvalidInputReason::InvalidValue(right), + ) => left == right, + (InvalidInputReason::NotFound, InvalidInputReason::NotFound) => { + true + }, + ( + InvalidInputReason::IncompatibleDimensions, + InvalidInputReason::IncompatibleDimensions, + ) => true, + ( + InvalidInputReason::IncompatibleElementType, + InvalidInputReason::IncompatibleElementType, + ) => true, + (InvalidInputReason::Other(_), _) + | (InvalidInputReason::InvalidValue(_), _) + | (InvalidInputReason::NotFound, _) + | (InvalidInputReason::IncompatibleElementType, _) + | (InvalidInputReason::IncompatibleDimensions, _) => false, + } + } +} + +impl Display for InvalidInputReason { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + InvalidInputReason::Other(msg) => write!(f, "{msg}"), + InvalidInputReason::NotFound => write!(f, "Not found"), + InvalidInputReason::InvalidValue(msg) => { + write!(f, "Invalid value: {msg}") + }, + InvalidInputReason::IncompatibleDimensions => { + write!(f, "Incompatible dimensions") + }, + InvalidInputReason::IncompatibleElementType => { + write!(f, "Incompatible element type") + }, + } + } +} + +impl std::error::Error for InvalidInputReason {} + +impl Display for ArgumentError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "The \"{}\" argument is invalid", self.name) + } +} + +impl std::error::Error for ArgumentError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.reason) + } +} + +impl Display for ArgumentErrorReason { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ArgumentErrorReason::Other(msg) => write!(f, "{msg}"), + ArgumentErrorReason::NotFound => { + write!(f, "The argument wasn't defined") + }, + + ArgumentErrorReason::InvalidValue(msg) => { + write!(f, "Invalid value: {msg}") + }, + ArgumentErrorReason::ParseFailed(e) => { + write!(f, "Parse failed: {e}") + }, + } + } +} + +impl std::error::Error for ArgumentErrorReason {} diff --git a/support/src/guest/logging.rs b/support/src/guest/logging.rs new file mode 100644 index 00000000000..043e2a9a15c --- /dev/null +++ b/support/src/guest/logging.rs @@ -0,0 +1,148 @@ +use tracing::{Event, Metadata, Subscriber}; +use tracing_subscriber::{ + layer::{Context, SubscriberExt}, + util::{SubscriberInitExt, TryInitError}, + Registry, +}; + +use crate::guest::bindings::{self, LogLevel, LogMetadata, LogValue}; + +pub(crate) fn initialize_logger() -> Result<(), TryInitError> { + Registry::default().with(Layer).try_init() +} + +struct Layer; + +impl tracing_subscriber::Layer for Layer { + fn enabled(&self, metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool { + bindings::is_enabled(LogMetadata::from(metadata)) + } + + fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { + let meta = LogMetadata::from(event.metadata()); + + let mut visitor = Visitor::default(); + event.record(&mut visitor); + let (msg, data) = visitor.log_values(); + + bindings::log(meta, msg, &data); + } +} + +#[derive(Debug)] +enum OwnedLogValue { + Boolean(bool), + Integer(i64), + Float(f64), + String(String), +} + +#[derive(Debug, Default)] +struct Visitor(Vec<(&'static str, OwnedLogValue)>); + +impl Visitor { + fn log_values(&self) -> (&str, Vec<(&str, LogValue<'_>)>) { + let mut values = Vec::new(); + let mut msg = ""; + + for (key, value) in &self.0 { + if let ("message", OwnedLogValue::String(s)) = (*key, value) { + msg = s.as_str(); + continue; + } + + let borrowed = match *value { + OwnedLogValue::Boolean(b) => LogValue::Boolean(b), + OwnedLogValue::Integer(i) => LogValue::Integer(i), + OwnedLogValue::Float(f) => LogValue::Float(f), + OwnedLogValue::String(ref s) => LogValue::String(s), + }; + + values.push((*key, borrowed)); + } + + (msg, values) + } +} + +impl tracing::field::Visit for Visitor { + fn record_debug( + &mut self, + field: &tracing::field::Field, + value: &dyn std::fmt::Debug, + ) { + self.0 + .push((field.name(), OwnedLogValue::String(format!("{value:?}")))); + } + + fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { + self.0.push((field.name(), OwnedLogValue::Float(value))); + } + + fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { + self.0.push((field.name(), OwnedLogValue::Integer(value))); + } + + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + match i64::try_from(value) { + Ok(i) => self.0.push((field.name(), OwnedLogValue::Integer(i))), + Err(_) => self.record_debug(field, &value), + } + } + + fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { + self.0.push((field.name(), OwnedLogValue::Boolean(value))); + } + + fn record_str(&mut self, field: &tracing::field::Field, value: &str) { + self.0 + .push((field.name(), OwnedLogValue::String(value.to_string()))); + } + + fn record_error( + &mut self, + field: &tracing::field::Field, + value: &(dyn std::error::Error + 'static), + ) { + self.0 + .push((field.name(), OwnedLogValue::String(value.to_string()))); + + let mut causes = Vec::new(); + let mut source = value.source(); + + while let Some(next_source) = source { + causes.push(next_source.to_string()); + source = next_source.source(); + } + + if !causes.is_empty() { + self.0 + .push(("causes", OwnedLogValue::String(format!("{causes:?}")))); + } + } +} + +impl From for LogLevel { + fn from(level: tracing::Level) -> Self { + match level { + tracing::Level::TRACE => LogLevel::Trace, + tracing::Level::DEBUG => LogLevel::Debug, + tracing::Level::INFO => LogLevel::Info, + tracing::Level::WARN => LogLevel::Warn, + tracing::Level::ERROR => LogLevel::Error, + } + } +} + +impl<'a> From<&'a Metadata<'a>> for LogMetadata<'a> { + fn from(metadata: &'a Metadata<'a>) -> Self { + LogMetadata { + name: metadata.name(), + target: metadata.target(), + level: LogLevel::from(*metadata.level()), + file: metadata.file(), + line: metadata.line(), + module: metadata.module_path(), + } + } +} diff --git a/support/src/guest/macros.rs b/support/src/guest/macros.rs new file mode 100644 index 00000000000..a9424835905 --- /dev/null +++ b/support/src/guest/macros.rs @@ -0,0 +1,22 @@ +/// Tell the runtime that a WebAssembly module contains a proc-block. +#[macro_export] +macro_rules! export_proc_block { + (metadata: $metadata_func:expr, proc_block: $proc_block:ty $(,)?) => { + #[doc(hidden)] + #[no_mangle] + pub fn __proc_block_metadata() -> $crate::guest::Metadata { $metadata_func() } + + #[doc(hidden)] + #[no_mangle] + pub fn __proc_block_new( + args: Vec<$crate::guest::Argument>, + ) -> Result, $crate::guest::CreateError> { + fn assert_impl_proc_block(_: &impl $crate::guest::ProcBlock) {} + + let proc_block = <$proc_block>::try_from(args)?; + assert_impl_proc_block(&proc_block); + + Ok(Box::new(proc_block) as Box) + } + }; +} diff --git a/support/src/guest/metadata.rs b/support/src/guest/metadata.rs new file mode 100644 index 00000000000..978d763740d --- /dev/null +++ b/support/src/guest/metadata.rs @@ -0,0 +1,179 @@ +use crate::guest::bindings::*; + +impl Metadata { + pub fn new(name: impl Into, version: impl Into) -> Self { + Metadata { + name: name.into(), + version: version.into(), + tags: Vec::new(), + description: None, + homepage: None, + repository: None, + arguments: Vec::new(), + inputs: Vec::new(), + outputs: Vec::new(), + } + } + + pub fn with_description(mut self, description: impl Into) -> Self { + let description = description.into(); + if !description.is_empty() { + self.description = Some(description); + } + + self + } + + pub fn with_homepage(mut self, homepage: impl Into) -> Self { + let homepage = homepage.into(); + if !homepage.is_empty() { + self.homepage = Some(homepage); + } + + self + } + + pub fn with_repository(mut self, repository: impl Into) -> Self { + let repository = repository.into(); + if !repository.is_empty() { + self.repository = Some(repository); + } + + self + } + + pub fn with_tag(mut self, tag: impl Into) -> Self { + self.tags.push(tag.into()); + self + } + + pub fn with_argument(mut self, arg: ArgumentMetadata) -> Self { + self.arguments.push(arg); + self + } + + pub fn with_input(mut self, input: TensorMetadata) -> Self { + self.inputs.push(input); + self + } + + pub fn with_output(mut self, output: TensorMetadata) -> Self { + self.outputs.push(output); + self + } +} + +impl ArgumentMetadata { + pub fn new(name: impl Into) -> Self { + ArgumentMetadata { + name: name.into(), + description: None, + default_value: None, + hints: Vec::new(), + } + } + + pub fn with_description(mut self, description: impl Into) -> Self { + let description = description.into(); + if !description.is_empty() { + self.description = Some(description); + } + self + } + + pub fn with_default_value(mut self, default_value: impl ToString) -> Self { + let default_value = default_value.to_string(); + if !default_value.is_empty() { + self.default_value = Some(default_value); + } + self + } + + pub fn with_hint(mut self, hint: impl Into) -> Self { + self.hints.push(hint.into()); + self + } +} + +impl TensorMetadata { + pub fn new(name: impl Into) -> Self { + TensorMetadata { + name: name.into(), + description: None, + hints: Vec::new(), + } + } + + pub fn with_description(mut self, description: impl Into) -> Self { + let description = description.into(); + if !description.is_empty() { + self.description = Some(description); + } + + self + } + + pub fn with_hint(mut self, hint: TensorHint) -> Self { + self.hints.push(hint); + self + } +} + +impl From for TensorHint { + fn from(m: MediaType) -> Self { TensorHint::MediaType(m) } +} + +impl TensorConstraint { + pub fn new( + name: impl Into, + element_type: impl Into, + dimensions: impl Into, + ) -> Self { + TensorConstraint { + name: name.into(), + element_type: element_type.into(), + dimensions: dimensions.into(), + } + } + + pub fn numeric( + name: impl Into, + dimensions: impl Into, + ) -> Self { + TensorConstraint { + name: name.into(), + element_type: !ElementTypeConstraint::UTF8, + dimensions: dimensions.into(), + } + } +} + +impl ArgumentHint { + pub fn one_of(items: impl IntoIterator) -> Self { + ArgumentHint::OneOf(items.into_iter().map(|s| s.to_string()).collect()) + } +} + +impl From for ArgumentHint { + fn from(a: ArgumentType) -> Self { ArgumentHint::ArgumentType(a) } +} + +impl From for ElementTypeConstraint { + fn from(e: ElementType) -> Self { + match e { + ElementType::U8 => ElementTypeConstraint::U8, + ElementType::I8 => ElementTypeConstraint::I8, + ElementType::U16 => ElementTypeConstraint::U16, + ElementType::I16 => ElementTypeConstraint::I16, + ElementType::U32 => ElementTypeConstraint::U32, + ElementType::I32 => ElementTypeConstraint::I32, + ElementType::F32 => ElementTypeConstraint::F32, + ElementType::U64 => ElementTypeConstraint::U64, + ElementType::I64 => ElementTypeConstraint::I64, + ElementType::F64 => ElementTypeConstraint::F64, + ElementType::Complex64 => ElementTypeConstraint::COMPLEX64, + ElementType::Complex128 => ElementTypeConstraint::COMPLEX128, + ElementType::Utf8 => ElementTypeConstraint::UTF8, + } + } +} diff --git a/support/src/guest/mod.rs b/support/src/guest/mod.rs new file mode 100644 index 00000000000..fe1870c003a --- /dev/null +++ b/support/src/guest/mod.rs @@ -0,0 +1,52 @@ +//! Types and utilities for implementing a proc-block. + +#[macro_use] +mod macros; + +pub(crate) mod bindings; +mod element_type; +mod errors; +mod logging; +mod metadata; +pub mod parse; +mod proc_block; +mod tensor; + +use std::{panic::PanicInfo, sync::Once}; + +pub use self::{ + bindings::{ + abort, Argument, ArgumentError, ArgumentErrorReason, ArgumentHint, + ArgumentMetadata, ArgumentType, CreateError, Dimensions, ElementType, + ElementTypeConstraint, InvalidInput, InvalidInputReason, MediaType, + Metadata, RunError, Tensor, TensorConstraint, TensorConstraints, + TensorHint, TensorMetadata, + }, + element_type::{PrimitiveTensorElement, UnknownElementType}, + proc_block::ProcBlock, +}; + +getrandom::register_custom_getrandom!(host_rng); + +fn host_rng(buffer: &mut [u8]) -> Result<(), getrandom::Error> { + bindings::get_random(buffer); + Ok(()) +} + +/// Run any necessary initialization code. +pub(crate) fn ensure_initialized() { + static ONCE: Once = Once::new(); + + ONCE.call_once(|| { + let _ = logging::initialize_logger(); + std::panic::set_hook(Box::new(panic_hook)); + }); +} + +fn panic_hook(panic_info: &PanicInfo<'_>) { + let location = panic_info.location(); + let file = location.map(|loc| loc.file()); + let line = location.map(|loc| loc.line()); + + tracing::error!(panic.file = file, panic.line = line, "{panic_info}"); +} diff --git a/support/src/guest/parse.rs b/support/src/guest/parse.rs new file mode 100644 index 00000000000..5c6ecf22b8f --- /dev/null +++ b/support/src/guest/parse.rs @@ -0,0 +1,48 @@ +use std::{fmt::Display, str::FromStr}; + +use crate::guest::{Argument, ArgumentError, ArgumentErrorReason}; + +pub fn required_arg( + args: &[Argument], + name: &str, +) -> Result +where + T: FromStr, + T::Err: Display, +{ + for arg in args { + if arg.name == name { + return arg.value.parse::().map_err(|e| ArgumentError { + name: name.to_string(), + reason: ArgumentErrorReason::InvalidValue(e.to_string()), + }); + } + } + + Err(ArgumentError { + name: name.to_string(), + reason: ArgumentErrorReason::NotFound, + }) +} + +pub fn optional_arg( + args: &[Argument], + name: &str, +) -> Result, ArgumentError> +where + T: FromStr, + T::Err: Display, +{ + for arg in args { + if arg.name == name { + return arg.value.parse::().map(Some).map_err(|e| { + ArgumentError { + name: name.to_string(), + reason: ArgumentErrorReason::InvalidValue(e.to_string()), + } + }); + } + } + + Ok(None) +} diff --git a/support/src/guest/proc_block.rs b/support/src/guest/proc_block.rs new file mode 100644 index 00000000000..1b179d72399 --- /dev/null +++ b/support/src/guest/proc_block.rs @@ -0,0 +1,17 @@ +use crate::guest::{RunError, Tensor, TensorConstraints}; + +/// The implementation of a processing block. +pub trait ProcBlock { + fn tensor_constraints(&self) -> TensorConstraints; + fn run(&self, inputs: Vec) -> Result, RunError>; +} + +impl ProcBlock for Box { + fn tensor_constraints(&self) -> TensorConstraints { + (**self).tensor_constraints() + } + + fn run(&self, inputs: Vec) -> Result, RunError> { + (**self).run(inputs) + } +} diff --git a/support/src/guest/tensor.rs b/support/src/guest/tensor.rs new file mode 100644 index 00000000000..b0cc18a0570 --- /dev/null +++ b/support/src/guest/tensor.rs @@ -0,0 +1,360 @@ +use ndarray::{ArrayD, Dim, Dimension, IntoDimension}; + +use crate::{ + guest::{bindings::*, PrimitiveTensorElement}, + StringBuilder, +}; + +impl Tensor { + pub fn new( + name: impl Into, + array: &crate::ndarray::ArrayBase, + ) -> Self + where + T: PrimitiveTensorElement, + S: crate::ndarray::Data, + Dims: crate::ndarray::Dimension, + { + let dimensions = array.shape().iter().map(|&d| d as u32).collect(); + + // Safety: + let mut buffer = Vec::new(); + + for element in array.iter() { + buffer.extend(bytemuck::bytes_of(element)); + } + + Tensor { + name: name.into(), + dimensions, + element_type: T::ELEMENT_TYPE, + buffer, + } + } + + /// Serialize a string tensor so it can be passed to the runtime. + /// + /// # Examples + /// + /// ``` + /// # use hotg_rune_proc_blocks::guest::Tensor; + /// # fn main() -> Result<(), Box> { + /// let strings = ndarray::arr2(&[ + /// ["this", "is", "a", "sentence"], + /// ["and", "this", "is", "another"], + /// ]); + /// + /// let tensor = Tensor::from_strings("tensor", &strings); + /// + /// let deserialized = tensor.string_view()?; + /// assert_eq!(deserialized, strings.into_dyn()); + /// # Ok(()) } + /// ``` + pub fn from_strings( + name: impl Into, + array: &ndarray::ArrayBase, + ) -> Self + where + Dim: ndarray::Dimension, + Data: ndarray::Data, + S: AsRef, + { + let mut builder = StringBuilder::new(); + for s in array.iter() { + builder.push(s.as_ref()); + } + let buffer = builder.finish(); + + let dimensions = array.shape().iter().map(|&dim| dim as u32).collect(); + + Tensor { + name: name.into(), + element_type: ElementType::Utf8, + dimensions, + buffer, + } + } + + pub fn new_1d(name: impl Into, elements: &[T]) -> Self + where + T: PrimitiveTensorElement, + { + let array = crate::ndarray::aview1(elements); + Tensor::new(name, &array) + } + + pub fn with_name(self, name: impl Into) -> Self { + Tensor { + name: name.into(), + ..self + } + } + + pub fn take_named( + tensors: &mut Vec, + name: &str, + ) -> Result { + let index = tensors + .iter() + .position(|t| t.name == name) + .ok_or_else(|| RunError::missing_input(name))?; + + Ok(tensors.remove(index)) + } + + pub fn get_named<'t>( + tensors: &'t [Tensor], + name: &str, + ) -> Result<&'t Self, RunError> { + tensors + .iter() + .find(|t| t.name == name) + .ok_or_else(|| RunError::missing_input(name)) + } + + pub fn view( + &self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + let dimensions: Vec<_> = self.dimensions().collect(); + let elements = self.elements()?; + + crate::ndarray::ArrayViewD::from_shape(dimensions, elements) + .map_err(|e| InvalidInput::other(&self.name, e)) + } + + pub fn view_mut( + &mut self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + let dimensions: Vec<_> = self.dimensions().collect(); + let name = self.name.clone(); + let elements = self.elements_mut()?; + + crate::ndarray::ArrayViewMutD::from_shape(dimensions, elements) + .map_err(|e| InvalidInput::other(name, e)) + } + + pub fn view_with_dimensions_mut( + &mut self, + ) -> Result< + crate::ndarray::ArrayViewMut<'_, T, Dim<[usize; N]>>, + InvalidInput, + > + where + T: PrimitiveTensorElement, + [usize; N]: IntoDimension>, + Dim<[usize; N]>: Dimension, + { + let dimensions: [usize; N] = self.as_nd_shape()?; + let name = self.name.clone(); + let elements = self.elements_mut()?; + + let shape = ndarray::Shape::from(ndarray::Dim(dimensions)); + + crate::ndarray::ArrayViewMut::from_shape(shape, elements) + .map_err(|e| InvalidInput::other(name, e)) + } + + pub fn view_with_dimensions( + &self, + ) -> Result>, InvalidInput> + where + T: PrimitiveTensorElement, + [usize; N]: IntoDimension>, + Dim<[usize; N]>: Dimension, + { + let dimensions: [usize; N] = self.as_nd_shape()?; + let elements = self.elements()?; + + let shape = ndarray::Shape::from(ndarray::Dim(dimensions)); + + crate::ndarray::ArrayView::from_shape(shape, elements) + .map_err(|e| InvalidInput::other(&self.name, e)) + } + + fn elements(&self) -> Result<&[T], InvalidInput> + where + T: PrimitiveTensorElement, + { + if self.element_type != T::ELEMENT_TYPE { + return Err(InvalidInput::incompatible_element_type(&self.name)); + } + + // Note: If our buffer is empty, the slice you get when from + // the Deref implementation will be null + align_of(u8) with + // a length of 0. + // + // This is normally fine, but if we later use bytemuck to + // cast the &[u8] to &[T] and T has an alignment greater + // than 1, we'll panic due to being mis-aligned. + // + // To prevent this, we return a view into an empty slice. + + if self.dimensions.iter().product::() == 0 { + return Ok(&[]); + } + + bytemuck::try_cast_slice(&self.buffer) + .map_err(|e| InvalidInput::other(&self.name, e)) + } + + fn elements_mut(&mut self) -> Result<&mut [T], InvalidInput> + where + T: PrimitiveTensorElement, + { + if self.element_type != T::ELEMENT_TYPE { + return Err(InvalidInput::incompatible_element_type(&self.name)); + } + + if self.dimensions.iter().product::() == 0 { + return Ok(&mut []); + } + + bytemuck::try_cast_slice_mut(&mut self.buffer) + .map_err(|e| InvalidInput::other(&self.name, e)) + } + + fn dimensions( + &self, + ) -> impl Iterator + DoubleEndedIterator + '_ { + self.dimensions.iter().map(|&d| d as usize) + } + + fn as_nd_shape(&self) -> Result<[usize; N], InvalidInput> { + let mut shape = [1; N]; + let mut last_index = N; + + for dim in self.dimensions().rev() { + if dim == 1 { + continue; + } + + match last_index.checked_sub(1) { + Some(ix) => last_index = ix, + None => { + return Err(InvalidInput::incompatible_dimensions( + &self.name, + )); + }, + } + + shape[last_index] = dim; + } + + Ok(shape) + } + + pub fn view_1d( + &self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + self.view_with_dimensions() + } + + pub fn view_2d( + &self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + self.view_with_dimensions() + } + + pub fn view_3d( + &self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + self.view_with_dimensions() + } + + pub fn view_4d( + &self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + self.view_with_dimensions() + } + + pub fn view_1d_mut( + &mut self, + ) -> Result, InvalidInput> + where + T: PrimitiveTensorElement, + { + self.view_with_dimensions_mut() + } + + pub fn string_view(&self) -> Result, InvalidInput> { + let dimensions: Vec<_> = self.dimensions().collect(); + + crate::strings::decode_strings(&self.buffer) + .and_then(|strings| ArrayD::from_shape_vec(dimensions, strings)) + .map_err(|e| InvalidInput::other(&self.name, e)) + } +} + +impl PartialEq for Tensor { + fn eq(&self, other: &Tensor) -> bool { + let Tensor { + name, + element_type, + dimensions, + buffer, + } = self; + + name == &other.name + && element_type == &other.element_type + && dimensions == &other.dimensions + && buffer == &other.buffer + } +} + +impl From> for Dimensions { + fn from(fixed: Vec) -> Self { Dimensions::Fixed(fixed) } +} + +impl From<[u32; N]> for Dimensions { + fn from(fixed: [u32; N]) -> Self { Dimensions::Fixed(fixed.to_vec()) } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn viewing_with_dimensionality_can_strip_or_add_leading_1s() { + let elements = ndarray::arr2(&[[0.0_f64, 0.5, 10.0, 3.5, -200.0]]); + let tensor = Tensor::new("x", &elements); + + // We should be able to view as both 1D, 2D, and 3D + + let view = tensor.view_1d::().unwrap(); + assert_eq!(view.dim(), 5); + + let view = tensor.view_2d::().unwrap(); + assert_eq!(view.dim(), (1, 5)); + + let view = tensor.view_3d::().unwrap(); + assert_eq!(view.dim(), (1, 1, 5)); + } + + #[test] + fn cant_view_2d_as_1d() { + let elements = ndarray::arr2(&[[0.0_f64, 0.5], [10.0, -200.0]]); + let tensor = Tensor::new("x", &elements); + + let err = tensor.view_1d::().unwrap_err(); + + assert_eq!(err.reason, InvalidInputReason::IncompatibleDimensions); + } +} diff --git a/support/src/lib.rs b/support/src/lib.rs index 7d151ff18ab..4d2e1791e71 100644 --- a/support/src/lib.rs +++ b/support/src/lib.rs @@ -2,42 +2,10 @@ pub extern crate ndarray; -#[cfg(feature = "runtime_v1")] -mod bindings; +mod macros; +mod strings; -mod buffer_ext; -pub mod common; -mod string_builder; -mod value_type; +#[cfg(feature = "guest")] +pub mod guest; -use std::sync::Mutex; - -pub use crate::{ - buffer_ext::BufferExt, - string_builder::{string_tensor_from_ndarray, StringBuilder}, - value_type::{SliceExt, ValueType}, -}; - -#[cfg(feature = "runtime_v1")] -pub use bindings::runtime_v1; -use once_cell::sync::Lazy; -use rand::{prelude::SmallRng, Rng, SeedableRng}; - -pub mod prelude { - #[cfg(feature = "runtime_v1")] - pub use crate::bindings::{ - ContextErrorExt, ContextExt, InvalidArgumentExt, - }; -} - -// Note: getrandom is pulled in by the linfa_logistic crate -getrandom::register_custom_getrandom!(unsupported_rng); - -fn unsupported_rng(buffer: &mut [u8]) -> Result<(), getrandom::Error> { - // FIXME: We should probably seed this with something more useful. - static RNG: Lazy> = - Lazy::new(|| Mutex::new(SmallRng::from_seed(Default::default()))); - - RNG.lock().unwrap().fill(buffer); - Ok(()) -} +pub use crate::strings::{decode_strings, StringBuilder}; diff --git a/support/src/macros.rs b/support/src/macros.rs new file mode 100644 index 00000000000..fe3355d347b --- /dev/null +++ b/support/src/macros.rs @@ -0,0 +1,517 @@ +/// Define helper functions and trait implementations for working with the +/// `proc_block_v2` glue generated by `wit_bindgen_rust`. +#[macro_export] +macro_rules! generate_support { + ($($proc_block:ident)::*) => { + mod support { + use std::{fmt::{self, Display, Formatter}, str::FromStr}; + use $($proc_block)::*::*; + use $crate::ndarray::{Dimension, IntoDimension}; + + pub fn parse_arg(args: &[Argument], name: &str) -> Result + where + T: FromStr, + T::Err: Display, + { + for arg in args { + if arg.name == name { + return arg.value.parse::().map_err(|e| ArgumentError { + name: name.to_string(), + reason: ArgumentErrorReason::InvalidValue(e.to_string()), + }); + } + } + + Err(ArgumentError { + name: name.to_string(), + reason: ArgumentErrorReason::NotFound, + }) + } + + pub fn get_input_tensor<'t>(tensors: &'t [Tensor], name: &str) -> Result<&'t Tensor, KernelError> { + tensors.iter() + .find(|t| t.name == name) + .ok_or_else(|| KernelError::InvalidInput(InvalidInput { + name: name.to_string(), + reason: InvalidInputReason::NotFound, + })) + } + + impl Tensor { + pub fn new(name: impl Into, array: &$crate::ndarray::ArrayBase) -> Self + where + T: ValueType, + S: $crate::ndarray::Data, + Dims: $crate::ndarray::Dimension, + { + let dimensions = array.shape() + .iter() + .map(|&d| d as u32) + .collect(); + + // Safety: + let mut buffer = Vec::new(); + + for element in array.iter() { + buffer.extend($crate::bytemuck::bytes_of(element)); + } + + Tensor { + name: name.into(), + dimensions, + element_type: T::ELEMENT_TYPE, + buffer, + } + } + + pub fn new_1d(name: impl Into, elements: &[T]) -> Self + where + T: ValueType, + { + let array = $crate::ndarray::aview1(elements); + Tensor::new(name, &array) + } + + pub fn with_name(self, name: impl Into) -> Self { + Tensor { + name: name.into(), + ..self + } + } + + pub fn view(&self) -> Result<$crate::ndarray::ArrayViewD<'_, T>, KernelError> + where + T: ValueType, + { + if self.element_type != T::ELEMENT_TYPE { + return Err(InvalidInput::unsupported_shape(&self.name).into()); + } + + let dimensions: Vec<_> = self.dimensions + .iter() + .map(|&d| d as usize) + .collect(); + + // Note: If our buffer is empty, the slice you get when from + // the Deref implementation will be null + align_of(u8) with + // a length of 0. + // + // This is normally fine, but if we later use bytemuck to + // cast the &[u8] to &[T] and T has an alignment greater + // than 1, we'll panic due to being mis-aligned. + // + // To prevent this, we return a view into an empty slice. + if dimensions.iter().product::() == 0 { + return $crate::ndarray::ArrayViewD::from_shape(dimensions, &[]) + .map_err(|e| InvalidInput::other(&self.name, e).into()); + } + + let elements = $crate::bytemuck::try_cast_slice(&self.buffer) + .expect("Unable to reinterpret the buffer's bytes as the desired element type"); + + $crate::ndarray::ArrayViewD::from_shape(dimensions, elements) + .map_err(|e| InvalidInput::other(&self.name, e).into()) + } + + pub fn view_with_dimensions(&self) -> Result<$crate::ndarray::ArrayView<'_, T, Dims>, KernelError> + where + T: ValueType, + Dims: $crate::ndarray::Dimension, + { + self.view::()? + .into_dimensionality::() + .map_err(|e| InvalidInput::other(&self.name, e).into()) + } + + pub fn view_1d(&self) -> Result<$crate::ndarray::ArrayView1<'_, T>, KernelError> + where + T: ValueType, + { + self.view_with_dimensions() + } + + pub fn view_2d(&self) -> Result<$crate::ndarray::ArrayView2<'_, T>, KernelError> + where + T: ValueType, + { + self.view_with_dimensions() + } + + pub fn view_mut(&mut self) -> Result<$crate::ndarray::ArrayViewMutD<'_, T>, KernelError> + where + T: ValueType, + { + if self.element_type != T::ELEMENT_TYPE { + return Err(KernelError::InvalidInput(InvalidInput { + name: self.name.clone(), + reason: InvalidInputReason::UnsupportedShape, + })); + } + + let dimensions: Vec<_> = self.dimensions + .iter() + .map(|&d| d as usize) + .collect(); + + // See the comment in Tensor::view() for why we need this. + if dimensions.iter().product::() == 0 { + return $crate::ndarray::ArrayViewMutD::from_shape(dimensions, &mut []) + .map_err(|e| InvalidInput::other(&self.name, e).into()); + } + + let elements = $crate::bytemuck::try_cast_slice_mut(&mut self.buffer) + .expect("Unable to reinterpret the buffer's bytes as the desired element type"); + + $crate::ndarray::ArrayViewMutD::from_shape(dimensions, elements) + .map_err(|e| InvalidInput::other(&self.name, e).into()) + } + + pub fn view_with_dimensions_mut(&mut self) -> Result<$crate::ndarray::ArrayViewMut<'_, T, Dims>, KernelError> + where + T: ValueType, + Dims: $crate::ndarray::Dimension, + { + // FIXME: It'd be nice if we didn't need to make this copy, + // but the borrow checker isn't able to figure out that + // the into_dimensionality() call consumes our view and + // therefore the mutable borrow is finished. + let name = self.name.clone(); + + self.view_mut::()? + .into_dimensionality::() + .map_err(|e| InvalidInput::other(name, e).into()) + } + } + + impl PartialEq for Tensor { + fn eq(&self, other: &Tensor) -> bool { + let Tensor { name, element_type, dimensions, buffer } = self; + + name == &other.name + && element_type == &other.element_type + && dimensions == &other.dimensions + && buffer == &other.buffer + } + } + + impl TensorConstraint { + pub fn new( + name: impl Into, + element_type: ElementTypeConstraint, + dimensions: impl Into, + ) -> Self { + TensorConstraint { + name: name.into(), + element_type, + dimensions: dimensions.into(), + } + } + + pub fn numeric( + name: impl Into, + dimensions: impl Into, + ) -> Self { + TensorConstraint { + name: name.into(), + element_type: !ElementTypeConstraint::UTF8, + dimensions: dimensions.into(), + } + } + } + + impl From> for Dimensions { + fn from(fixed: Vec) -> Self { + Dimensions::Fixed(fixed) + } + } + + impl KernelError { + pub fn other(reason: impl Display) -> Self { + KernelError::Other(reason.to_string()) + } + + pub fn unsupported_shape(tensor_name: impl Into) -> Self { + KernelError::InvalidInput(InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::UnsupportedShape, + }) + } + } + + impl PartialEq for KernelError { + fn eq(&self, other: &KernelError) -> bool { + match (self, other) { + (KernelError::Other(left), KernelError::Other(right)) => left == right, + (KernelError::InvalidInput(left), KernelError::InvalidInput(right)) => left == right, + (KernelError::Other(_), _) + | (KernelError::InvalidInput(_), _) => false, + } + } + } + + impl InvalidInput { + pub fn unsupported_shape(tensor_name: impl Into) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::UnsupportedShape, + } + } + + pub fn not_found(tensor_name: impl Into) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::NotFound, + } + } + + pub fn invalid_value(tensor_name: impl Into, error: impl Display) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::InvalidValue(error.to_string()), + } + } + + pub fn other(tensor_name: impl Into, reason: impl Display) -> Self { + InvalidInput { + name: tensor_name.into(), + reason: InvalidInputReason::Other(reason.to_string()), + } + } + } + + impl PartialEq for InvalidInput { + fn eq(&self, other: &InvalidInput) -> bool { + let InvalidInput { name, reason } = self; + + name == &other.name && reason == &other.reason + } + } + + impl PartialEq for InvalidInputReason { + fn eq(&self, other: &InvalidInputReason) -> bool { + match (self, other) { + (InvalidInputReason::Other(left), InvalidInputReason::Other(right)) => left == right, + (InvalidInputReason::InvalidValue(left), InvalidInputReason::InvalidValue(right)) => left == right, + (InvalidInputReason::NotFound, InvalidInputReason::NotFound) => true, + (InvalidInputReason::UnsupportedShape, InvalidInputReason::UnsupportedShape) => true, + (InvalidInputReason::Other(_), _) + | (InvalidInputReason::InvalidValue(_), _) + | (InvalidInputReason::NotFound, _) + | (InvalidInputReason::UnsupportedShape, _) => false, + } + } + } + + impl Metadata { + pub fn new(name: impl Into, version: impl Into) -> Self { + Metadata { + name: name.into(), + version: version.into(), + tags: Vec::new(), + description: None, + homepage: None, + repository: None, + arguments: Vec::new(), + inputs: Vec::new(), + outputs: Vec::new(), + } + } + + pub fn with_tag(mut self, tag: impl Into) -> Self { + self.tags.push(tag.into()); + self + } + + pub fn with_description(mut self, description: impl Into) -> Self { + let description = description.into(); + if !description.is_empty() { + self.description = Some(description); + } + + self + } + + pub fn with_homepage(mut self, homepage: impl Into) -> Self { + let homepage = homepage.into(); + if !homepage.is_empty() { + self.homepage = Some(homepage); + } + + self + } + + pub fn with_repository(mut self, repository: impl Into) -> Self { + let repository = repository.into(); + if !repository.is_empty() { + self.repository = Some(repository); + } + + self + } + + pub fn with_argument(mut self, arg: ArgumentMetadata) -> Self { + self.arguments.push(arg); + self + } + + pub fn with_input(mut self, input: TensorMetadata) -> Self { + self.inputs.push(input); + self + } + + pub fn with_output(mut self, output: TensorMetadata) -> Self { + self.outputs.push(output); + self + } + } + + impl ArgumentMetadata { + pub fn new(name: impl Into) -> Self { + ArgumentMetadata { + name: name.into(), + description: None, + default_value: None, + hints: Vec::new(), + } + } + + pub fn with_description(mut self, description: impl Into) -> Self { + let description = description.into(); + if !description.is_empty() { + self.description = Some(description); + } + self + } + + pub fn with_default_value(mut self, default_value: impl ToString) -> Self { + let default_value = default_value.to_string(); + if !default_value.is_empty() { + self.default_value = Some(default_value); + } + self + } + + pub fn with_hint(mut self, hint: ArgumentHint) -> Self { + self.hints.push(hint); + self + } + } + + impl TensorMetadata { + pub fn new(name: impl Into) -> Self { + TensorMetadata { + name: name.into(), + description: None, + hints: Vec::new(), + } + } + + pub fn with_description(mut self, description: impl Into) -> Self { + let description = description.into(); + if !description.is_empty() { + self.description = Some(description); + } + + self + } + + pub fn with_hint(mut self, hint: TensorHint) -> Self { + self.hints.push(hint); + self + } + } + + impl From for KernelError { + fn from(i: InvalidInput) -> Self { + KernelError::InvalidInput(i) + } + } + + pub trait ValueType: $crate::ValueType + $crate::bytemuck::AnyBitPattern + $crate::bytemuck::NoUninit { + const ELEMENT_TYPE: ElementType; + } + + impl ValueType for u8 { const ELEMENT_TYPE: ElementType = ElementType::U8; } + impl ValueType for i8 { const ELEMENT_TYPE: ElementType = ElementType::I8; } + impl ValueType for u16 { const ELEMENT_TYPE: ElementType = ElementType::U16; } + impl ValueType for i16 { const ELEMENT_TYPE: ElementType = ElementType::I16; } + impl ValueType for u32 { const ELEMENT_TYPE: ElementType = ElementType::U32; } + impl ValueType for i32 { const ELEMENT_TYPE: ElementType = ElementType::I32; } + impl ValueType for f32 { const ELEMENT_TYPE: ElementType = ElementType::F32; } + impl ValueType for u64 { const ELEMENT_TYPE: ElementType = ElementType::U64; } + impl ValueType for i64 { const ELEMENT_TYPE: ElementType = ElementType::I64; } + impl ValueType for f64 { const ELEMENT_TYPE: ElementType = ElementType::F64; } + + impl Display for KernelError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + KernelError::InvalidInput(i) => i.fmt(f), + KernelError::Other(msg) => write!(f, "{}", msg), + } + } + } + + impl std::error::Error for KernelError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + KernelError::InvalidInput(i) => Some(i), + KernelError::Other(_) => None, + } + } + } + + impl std::error::Error for InvalidInput { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.reason) + } + } + + impl Display for InvalidInput { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "The \"{}\" input tensor was invalid", self.name) + } + } + + impl Display for InvalidInputReason { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + InvalidInputReason::Other(msg) => write!(f, "{msg}"), + InvalidInputReason::NotFound => write!(f, "Not found"), + InvalidInputReason::InvalidValue(msg) => write!(f, "Invalid value: {msg}"), + InvalidInputReason::UnsupportedShape => write!(f, "Unsupported shape"), + } + } + } + + impl std::error::Error for InvalidInputReason {} + + impl Display for ArgumentError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "The \"{}\" argument is invalid", self.name) + } + } + + impl std::error::Error for ArgumentError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.reason) + } + } + + impl Display for ArgumentErrorReason { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ArgumentErrorReason::Other(msg) => write!(f, "{msg}"), + ArgumentErrorReason::NotFound => { + write!(f, "The argument wasn't defined") + }, + ArgumentErrorReason::InvalidValue(msg) => { + write!(f, "Invalid value: {msg}") + }, + } + } + } + + impl std::error::Error for ArgumentErrorReason {} + } + }; +} diff --git a/support/src/string_builder.rs b/support/src/strings.rs similarity index 51% rename from support/src/string_builder.rs rename to support/src/strings.rs index ca05c7b2c38..a672f3eeb7d 100644 --- a/support/src/string_builder.rs +++ b/support/src/strings.rs @@ -1,36 +1,6 @@ -/// Serialize a string tensor so it can be passed to the runtime. -/// -/// # Examples -/// -/// ``` -/// use hotg_rune_proc_blocks::{BufferExt, string_tensor_from_ndarray}; -/// -/// let tensor = ndarray::arr2(&[ -/// ["this", "is", "a", "sentence"], -/// ["and", "this", "is", "another"], -/// ]); -/// -/// let serialized: Vec = string_tensor_from_ndarray(&tensor); -/// -/// let deserialized = serialized.string_view(&[2, 4]).unwrap(); -/// assert_eq!(deserialized, tensor.into_dyn()); -/// ``` -pub fn string_tensor_from_ndarray( - array: &ndarray::ArrayBase, -) -> Vec -where - Dim: ndarray::Dimension, - Data: ndarray::Data, - S: AsRef, -{ - let mut builder = StringBuilder::new(); - - for s in array.iter() { - builder.push(s.as_ref()); - } +use std::fmt::{self, Debug, Formatter}; - builder.finish() -} +use ndarray::{ErrorKind, ShapeError}; /// A builder for serializing multiple UTF-8 strings to a flat byte array. /// @@ -38,7 +8,6 @@ where /// /// ```rust /// # use hotg_rune_proc_blocks::StringBuilder; -/// use hotg_rune_proc_blocks::BufferExt; /// // Construct a new string builder and add some strings to it /// let mut builder = StringBuilder::new(); /// builder.push("this").push("is").push("a").push("sentence"); @@ -46,13 +15,12 @@ where /// // once all the strings have been added, we can get the serialized tensor. /// let buffer: Vec = builder.finish(); /// -/// // The BufferExt trait lets us deserialize the strings again. -/// let strings: Vec<&str> = buffer.strings()?; +/// let strings: Vec<&str> = hotg_rune_proc_blocks::decode_strings(&buffer)?; /// /// assert_eq!(strings, &["this", "is", "a", "sentence"]); /// # Ok::<(), Box>(()) /// ``` -#[derive(PartialEq)] +#[derive(PartialEq, Eq)] pub struct StringBuilder { buffer: Vec, } @@ -88,14 +56,66 @@ impl StringBuilder { } } +impl Debug for StringBuilder { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("StringBuilder").finish_non_exhaustive() + } +} + impl Default for StringBuilder { fn default() -> Self { StringBuilder::new() } } +/// Decode list of strings from their serialized form. +/// +/// See [`StringBuilder`] for how to serialize a list of strings. +pub fn decode_strings(raw: &[u8]) -> Result, ShapeError> { + const HEADER_SIZE: usize = std::mem::size_of::(); + + let mut strings = Vec::new(); + let mut buffer = raw; + + while !buffer.is_empty() { + if buffer.len() < HEADER_SIZE { + // We don't have enough bytes remaining for a full length field, + // so something is probably wrong with our buffer. + return Err(ShapeError::from_kind(ErrorKind::OutOfBounds)); + } + + let (len, rest) = buffer.split_at(HEADER_SIZE); + + let len: [u8; HEADER_SIZE] = len.try_into().expect("Unreachable"); + let len = u32::from_le_bytes(len); + let len = usize::try_from(len).expect("Unreachable"); + + if rest.len() < len { + // We don't have enough bytes left in the buffer to read a + // string with this length. + return Err(ShapeError::from_kind(ErrorKind::OutOfBounds)); + } + + let (s, rest) = rest.split_at(len); + + match std::str::from_utf8(s) { + Ok(s) => strings.push(s), + Err(_) => { + // The string wasn't valid UTF-8. We're probably using the + // wrong ShapeError here, but our alternative would be + // introducing our own error type and that seems overkill. + return Err(ShapeError::from_kind( + ErrorKind::IncompatibleLayout, + )); + }, + } + + buffer = rest; + } + + Ok(strings) +} + #[cfg(test)] mod tests { - use crate::BufferExt; - use super::*; #[test] @@ -104,7 +124,7 @@ mod tests { builder.push("this").push("is").push("a").push("sentence"); let buffer = builder.finish(); - let strings = buffer.strings().unwrap(); + let strings = decode_strings(&buffer).unwrap(); assert_eq!(strings, &["this", "is", "a", "sentence"]); } diff --git a/support/src/value_type.rs b/support/src/value_type.rs deleted file mode 100644 index 2fe0ca7f7ff..00000000000 --- a/support/src/value_type.rs +++ /dev/null @@ -1,60 +0,0 @@ -/// A value type which can be reinterpreted to/from a byte buffer. -/// -/// # Safety -/// -/// It must be safe to reinterpret a `&[u8]` as a `&[Self]`. Among other things, -/// this means: -/// - The type must not have any padding (observing padding bytes is UB) -/// - It must not contain any fields with indirection (e.g. we don't want to -/// interpret random bytes as a pointer... that's just asking for trouble) -pub unsafe trait ValueType: Sized {} - -unsafe impl ValueType for u8 {} -unsafe impl ValueType for i8 {} -unsafe impl ValueType for u16 {} -unsafe impl ValueType for i16 {} -unsafe impl ValueType for u32 {} -unsafe impl ValueType for i32 {} -unsafe impl ValueType for f32 {} -unsafe impl ValueType for u64 {} -unsafe impl ValueType for i64 {} -unsafe impl ValueType for f64 {} - -/// Extension traits for slices of [`ValueType`]s. -pub trait SliceExt { - /// Interpet this as a slice of bytes. - /// - /// # Examples - /// - /// ```rust - /// use hotg_rune_proc_blocks::SliceExt; - /// - /// let numbers: [u16; 2] = [0, 1]; - /// - /// let bytes = numbers.as_bytes(); - /// - /// assert_eq!(bytes, &[0x00, 0x00, 0x01, 0x00]); - /// ``` - fn as_bytes(&self) -> &[u8]; - - /// Interpet this as a mutable slice of bytes. - /// - /// This is the mutable version of [`SliceExt::as_bytes()`]. - fn as_bytes_mut(&mut self) -> &mut [u8]; -} - -impl SliceExt for [T] { - fn as_bytes(&self) -> &[u8] { - let length = std::mem::size_of_val(self); - - unsafe { std::slice::from_raw_parts(self.as_ptr().cast(), length) } - } - - fn as_bytes_mut(&mut self) -> &mut [u8] { - let length = std::mem::size_of_val(self); - - unsafe { - std::slice::from_raw_parts_mut(self.as_mut_ptr().cast(), length) - } - } -} diff --git a/support_vector_classifier/Cargo.toml b/support_vector_classifier/Cargo.toml index 82cfdac2528..e46b22e69e5 100644 --- a/support_vector_classifier/Cargo.toml +++ b/support_vector_classifier/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "support_vector_classifier" version = "0.12.3" -edition = "2018" +edition = "2021" description = "a binary classifier that uses an optimal hyperplane to separate the points in the input variable space by their class." # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,7 +10,6 @@ description = "a binary classifier that uses an optimal hyperplane to separate t hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] diff --git a/support_vector_classifier/src/lib.rs b/support_vector_classifier/src/lib.rs index 74dc20fa87d..03a978247ee 100644 --- a/support_vector_classifier/src/lib.rs +++ b/support_vector_classifier/src/lib.rs @@ -1,4 +1,11 @@ -use hotg_rune_proc_blocks::{ndarray, runtime_v1}; +use hotg_rune_proc_blocks::{ + guest::{ + parse, Argument, ArgumentMetadata, ArgumentType, CreateError, + ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{Array1, ArrayView1, ArrayView2}, +}; use smartcore::{ linalg::naive::dense_matrix::*, svm::{ @@ -6,288 +13,118 @@ use smartcore::{ Kernels, }, }; -use std::{convert::TryInto, fmt::Display, str::FromStr}; - -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{runtime_v1::*, BufferExt, SliceExt}; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: SupportVectorClassifier, +} -/// a binary classifier that uses an optimal hyperplane to separate the points -/// in the input variable space by their class. -struct ProcBlockV1; +fn metadata() -> Metadata { + // TODO: how to add an array of string: [linear, rbf, polynomial, + // polynomial_with_degree, sigmoid, sigmoiod_with_gamma]. + // Have to figure out how to how to change the parameter of polynomial, + // sigmoid, etc -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new( - " Support Vector Classifier", - env!("CARGO_PKG_VERSION"), - ); - metadata.set_description( + Metadata::new(" Support Vector Classifier", env!("CARGO_PKG_VERSION")) + .with_description( "a binary approach for modelling the relationship between a scalar response and one or more explanatory variables", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("binary classifier"); - metadata.add_tag("analytics"); - - let element_type = ArgumentMetadata::new("element_type"); - element_type - .set_description("The type of tensor this proc-block will accept"); - element_type.set_default_value("f64"); - element_type.add_hint(&interpret_as_string_in_enum(&[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ])); - metadata.add_argument(&element_type); - - let epochs = ArgumentMetadata::new("epochs"); - epochs.set_description("Number of epochs"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Integer); - epochs.add_hint(&hint); - epochs.set_default_value("5"); - metadata.add_argument(&epochs); - - let c = ArgumentMetadata::new("c"); - c.set_description("Penalizing parameter"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - c.add_hint(&hint); - c.set_default_value("200.0"); - metadata.add_argument(&c); - - let tol = ArgumentMetadata::new("tolerance"); - tol.set_description("Tolerance for stopping criterion"); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - tol.add_hint(&hint); - tol.set_default_value("0.001"); - metadata.add_argument(&tol); - - // todo: how to add an array of string: [linear, rbf, polynomial, - // polynomial_with_degree, sigmoid, sigmoiod_with_gamma]. - // Have to figure out how to how to change the parameter of polynomial, - // sigmoid, etc - - // let kernel = ArgumentMetadata::new("kernel"); - // epochs.set_description( - // "Tolerance for stopping criterion", - // ); - // let hint = runtime_v1::supported_argument_type(ArgumentType::String); - // kernel.add_hint(&hint); - // kernel.set_default_value("linear"); - // metadata.add_argument(&kernel); - - let x_train = TensorMetadata::new("x_train"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_train.add_hint(&hint); - metadata.add_input(&x_train); - - let y_train = TensorMetadata::new("y_train"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_train.add_hint(&hint); - metadata.add_input(&y_train); - - let x_test = TensorMetadata::new("x_test"); - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_test.add_hint(&hint); - metadata.add_input(&x_test); - - let y_test = TensorMetadata::new("y_test"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_test.add_hint(&hint); - metadata.add_output(&y_test); - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "x_train", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_input_tensor( - "y_train", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_input_tensor( - "x_test", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("binary classifier") + .with_tag("analytics") + .with_argument( + ArgumentMetadata::new("epochs") + .with_description("Number of epochs") + .with_hint(ArgumentType::Integer) + .with_default_value("5"), + ) + .with_argument( + ArgumentMetadata::new("c") + .with_description("Penalizing parameter") + .with_hint(ArgumentType::Float) + .with_default_value("200.0"), + ) + .with_argument( + ArgumentMetadata::new("tol") + .with_description("Tolerance for stopping criterion") + .with_hint(ArgumentType::Float) + .with_default_value("0.001"), + ) + .with_input(TensorMetadata::new("x_train")) + .with_input(TensorMetadata::new("y_train")) + .with_input(TensorMetadata::new("x_test")) + .with_output(TensorMetadata::new("y_test")) +} - ctx.add_output_tensor( - "y_test", - element_type, - DimensionsParam::Fixed(&[0]), - ); +/// a binary classifier that uses an optimal hyperplane to separate the points +/// in the input variable space by their class. +struct SupportVectorClassifier { + epochs: u32, + c: f64, + tol: f64, +} - Ok(()) +impl ProcBlock for SupportVectorClassifier { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new( + "x_train", + ElementTypeConstraint::F64, + vec![0, 0], + ), + TensorConstraint::new( + "y_train", + ElementTypeConstraint::F64, + vec![0], + ), + TensorConstraint::new( + "x_test", + ElementTypeConstraint::F64, + vec![0, 0], + ), + ], + outputs: vec![TensorConstraint::new( + "y_test", + ElementTypeConstraint::F64, + vec![0], + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let epoch: u32 = get_args("epochs", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let c: f64 = get_args("c", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let tol: f64 = get_args("tolerance", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - // let _kernel: String = get_args("kernel", |n| ctx.get_argument(n)) - // .map_err(KernelError::InvalidArgument)?; - - let x_train = ctx.get_input_tensor("x_train").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "x_train".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - let _xtrain: ndarray::ArrayView2 = x_train - .buffer - .view(&x_train.dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "x_train".to_string(), - reason: BadInputReason::Other(e.to_string()), - }) - })?; - - let y_train = ctx.get_input_tensor("y_train").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "y_train".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - let _ytrain: ndarray::ArrayView1 = y_train - .buffer - .view(&y_train.dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "y_train".to_string(), - reason: BadInputReason::Other(e.to_string()), - }) - })?; + fn run(&self, inputs: Vec) -> Result, RunError> { + let x_train = Tensor::get_named(&inputs, "x_train")?.view_2d()?; + let y_train = Tensor::get_named(&inputs, "y_train")?.view_1d()?; + let x_test = Tensor::get_named(&inputs, "x_test")?.view_2d()?; - let x_test = ctx.get_input_tensor("x_test").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "x_test".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - let _xtest: ndarray::ArrayView2 = x_test - .buffer - .view(&x_test.dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "x_test".to_string(), - reason: BadInputReason::Other(e.to_string()), - }) - })?; + let output = + transform(x_train, y_train, x_test, self.c, self.epochs, self.tol)?; - let output = transform( - &x_train.buffer.elements(), - &x_train.dimensions, - &y_train.buffer.elements(), - &x_test.buffer.elements(), - &x_test.dimensions, - c, - epoch, - tol, - )?; - - let y_test_dimension = [x_test.dimensions[0]]; - - ctx.set_output_tensor( - "y_test", - TensorParam { - element_type: ElementType::F64, - dimensions: &y_test_dimension, - buffer: &output.to_vec().as_bytes(), - }, - ); - - Ok(()) + Ok(vec![Tensor::new("y_train", &output)]) } } -fn get_args( - name: &str, - get_argument: impl FnOnce(&str) -> Option, -) -> Result -where - T: FromStr, - ::Err: Display, -{ - get_argument(name) - .ok_or_else(|| InvalidArgument::not_found(name))? - .parse::() - .map_err(|e| InvalidArgument::invalid_value(name, e)) -} +impl TryFrom> for SupportVectorClassifier { + type Error = CreateError; -impl InvalidArgument { - fn not_found(name: impl Into) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::NotFound, - } - } + fn try_from(value: Vec) -> Result { + let epochs = parse::optional_arg(&value, "epochs")?.unwrap_or(5); + let c = parse::optional_arg(&value, "c")?.unwrap_or(200.0); + let tol = parse::optional_arg(&value, "tol")?.unwrap_or(0.001); - fn invalid_value(name: impl Into, reason: impl Display) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::InvalidValue(reason.to_string()), - } + Ok(SupportVectorClassifier { epochs, c, tol }) } } fn transform( - x_train: &[f64], - x_train_dim: &[u32], - y_train: &[f64], - x_test: &[f64], - x_test_dim: &[u32], + x_train: ArrayView2<'_, f64>, + y_train: ArrayView1<'_, f64>, + x_test: ArrayView2<'_, f64>, c: f64, epoch: u32, tol: f64, -) -> Result, KernelError> { +) -> Result, RunError> { // todo: let user change the kernel. Right now setting it to 'linear' let svc_parameters = SVCParameters::default() .with_c(c) @@ -295,55 +132,71 @@ fn transform( .with_kernel(Kernels::linear()) .with_tol(tol); - let x_train = DenseMatrix::from_array( - x_train_dim[0] as usize, - x_train_dim[1] as usize, - x_train, - ); + let (rows, columns) = x_train.dim(); + let x_train = + DenseMatrix::new(rows, columns, x_train.t().iter().copied().collect()); let model = SVC::fit(&x_train, &y_train.to_vec(), svc_parameters) - .map_err(|e| KernelError::Other(e.to_string()))?; + .map_err(RunError::other)?; - let x_test = DenseMatrix::from_array( - x_test_dim[0] as usize, - x_test_dim[1] as usize, - x_test, - ); + let (rows, columns) = x_test.dim(); + let x_test = + DenseMatrix::new(rows, columns, x_test.t().iter().copied().collect()); model .predict(&x_test) - .map_err(|e| KernelError::Other(e.to_string())) + .map(Array1::from_vec) + .map_err(RunError::other) } #[cfg(test)] mod tests { + use hotg_rune_proc_blocks::ndarray; + use super::*; #[test] fn check_model() { - let x_train = vec![ - 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, - 3.1, 1.5, 0.2, 5.0, 3.6, 1.4, 0.2, 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, - 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, - 0.1, 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, - 5.5, 2.3, 4.0, 1.3, 6.5, 2.8, 4.6, 1.5, 5.7, 2.8, 4.5, 1.3, 6.3, - 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, - 3.9, 1.4, + let x_train = ndarray::array![ + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [4.7, 3.2, 1.3, 0.2], + [4.6, 3.1, 1.5, 0.2], + [5.0, 3.6, 1.4, 0.2], + [5.4, 3.9, 1.7, 0.4], + [4.6, 3.4, 1.4, 0.3], + [5.0, 3.4, 1.5, 0.2], + [4.4, 2.9, 1.4, 0.2], + [4.9, 3.1, 1.5, 0.1], + [7.0, 3.2, 4.7, 1.4], + [6.4, 3.2, 4.5, 1.5], + [6.9, 3.1, 4.9, 1.5], + [5.5, 2.3, 4.0, 1.3], + [6.5, 2.8, 4.6, 1.5], + [5.7, 2.8, 4.5, 1.3], + [6.3, 3.3, 4.7, 1.6], + [4.9, 2.4, 3.3, 1.0], + [6.6, 2.9, 4.6, 1.3], + [5.2, 2.7, 3.9, 1.4], ]; - let y_train: Vec = vec![ + let y_train = ndarray::array![ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; + let svc = SupportVectorClassifier { + epochs: 5, + c: 200.0, + tol: 0.001, + }; + let inputs = vec![ + Tensor::new("x_train", &x_train), + Tensor::new("y_train", &y_train), + Tensor::new("x_test", &x_train), + ]; - let dim: Vec = vec![20, 4]; - - let epoch: u32 = 5; - let c: f64 = 200.0; - let tol: f64 = 0.001; - - let y_pred = - transform(&x_train, &dim, &y_train, &x_train, &dim, c, epoch, tol); + let got = svc.run(inputs).unwrap(); - assert_eq!(y_pred.unwrap(), y_train); + let should_be = vec![Tensor::new("y_train", &y_train)]; + assert_eq!(got, should_be); } } diff --git a/support_vector_regression/Cargo.toml b/support_vector_regression/Cargo.toml index b6a1dbf7e5f..166907a3225 100644 --- a/support_vector_regression/Cargo.toml +++ b/support_vector_regression/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "support_vector_regression" version = "0.12.3" -edition = "2018" +edition = "2021" description = "a supervised learning models with associated learning algorithms that analyze data for regression analysis" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,7 +10,6 @@ description = "a supervised learning models with associated learning algorithms hotg-rune-proc-blocks = { path = "../support" } wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } smartcore = { git = "https://github.com/hotg-ai/smartcore", branch = "development" } -getrandom = { version = "0.2.6", default-features = false, features = ["custom"] } [lib] crate-type = ["cdylib", "rlib"] diff --git a/tensor_input/Cargo.toml b/tensor_input/Cargo.toml deleted file mode 100644 index 1cab1a2c43c..00000000000 --- a/tensor_input/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "tensor_input" -version = "0.12.0" -edition = "2021" -description = "Read an input from the environment." - -[lib] -crate-type = ["cdylib", "rlib"] - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -hotg-rune-proc-blocks = { path = "../support" } -wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } - -[package.metadata.wapm] -namespace = "hotg-ai" -abi = "none" diff --git a/tensor_input/src/lib.rs b/tensor_input/src/lib.rs deleted file mode 100644 index f5e050bfa05..00000000000 --- a/tensor_input/src/lib.rs +++ /dev/null @@ -1,122 +0,0 @@ -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; -use hotg_rune_proc_blocks::{prelude::*, runtime_v1::*}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("Tensor Input", env!("CARGO_PKG_VERSION")); - metadata.set_description(env!("CARGO_PKG_DESCRIPTION")); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("input"); - metadata.add_tag("raw"); - - let element_type = ArgumentMetadata::element_type(); - metadata.add_argument(&element_type); - - let output = TensorMetadata::new("output"); - let hint = supported_shapes( - &[ - ElementType::U8, - ElementType::I8, - ElementType::U16, - ElementType::I16, - ElementType::U32, - ElementType::I32, - ElementType::F32, - ElementType::U64, - ElementType::I64, - ElementType::F64, - ElementType::Utf8, - ], - DimensionsParam::Fixed(&[0]), - ); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&id).ok_or_else(|| { - GraphError::Other("Unable to get the graph context".to_string()) - })?; - - let element_type: ElementType = - ctx.parse_argument_with_default("element_type", ElementType::F32)?; - - ctx.add_input_tensor("input", element_type, DimensionsParam::Dynamic); - ctx.add_output_tensor("output", element_type, DimensionsParam::Dynamic); - - Ok(()) - } - - fn kernel(id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&id).ok_or_else(|| { - KernelError::Other("Unable to get the kernel context".to_string()) - })?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("input").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "input".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - // Dinesh, please don't look at this too closely otherwise you might - // notice we're literally copying a tensor into WebAssembly only to - // copy it back again 😅 - - ctx.set_output_tensor( - "output", - TensorParam { - element_type, - dimensions: &dimensions, - buffer: &buffer, - }, - ); - - Ok(()) - } -} - -impl ContextErrorExt for GraphError { - type InvalidArgument = InvalidArgument; - - fn invalid_argument(inner: InvalidArgument) -> Self { - GraphError::InvalidArgument(inner) - } -} - -impl InvalidArgumentExt for InvalidArgument { - fn other(name: &str, msg: impl std::fmt::Display) -> Self { - InvalidArgument { - name: name.to_string(), - reason: BadArgumentReason::Other(msg.to_string()), - } - } - - fn invalid_value(name: &str, error: impl std::fmt::Display) -> Self { - InvalidArgument { - name: name.to_string(), - reason: BadArgumentReason::InvalidValue(error.to_string()), - } - } - - fn not_found(name: &str) -> Self { - InvalidArgument { - name: name.to_string(), - reason: BadArgumentReason::NotFound, - } - } -} diff --git a/text_extractor/Cargo.toml b/text_extractor/Cargo.toml index cb652c95011..318af0ca839 100644 --- a/text_extractor/Cargo.toml +++ b/text_extractor/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "text_extractor" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "It parses the text from a start logit to an end logit" @@ -17,4 +17,4 @@ wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen" } [package.metadata.wapm] namespace = "hotg-ai" -abi = "none" \ No newline at end of file +abi = "none" diff --git a/text_extractor/src/lib.rs b/text_extractor/src/lib.rs index d4f7f759d30..daa7c6605ce 100644 --- a/text_extractor/src/lib.rs +++ b/text_extractor/src/lib.rs @@ -1,230 +1,143 @@ -use crate::proc_block_v1::*; use hotg_rune_proc_blocks::{ - ndarray, runtime_v1::*, string_tensor_from_ndarray, BufferExt, + guest::{ + Argument, ElementType, InvalidInput, Metadata, ProcBlock, RunError, + Tensor, TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray, }; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -#[macro_use] -extern crate alloc; -use alloc::{string::String, vec::Vec}; - -struct ProcBlockV1; +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: TextExtractor, +} -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Text Extractor", env!("CARGO_PKG_VERSION")); - metadata.set_description( +fn metadata() -> Metadata { + Metadata::new("Text Extractor", env!("CARGO_PKG_VERSION")) + .with_description( "Given a body of text and some start/end indices, extract parts of the text (i.e. words/phrases) specified by those indices.", - ); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("nlp"); - - let text = TensorMetadata::new("text"); - text.set_description("A string of text."); - let hint = - supported_shapes(&[ElementType::U8], DimensionsParam::Fixed(&[0])); - text.add_hint(&hint); - metadata.add_input(&text); - - let start_logits = TensorMetadata::new("start_logits"); - start_logits.set_description( - "The indices for the start of each word/phrase to extract.", - ); - let hint = - supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[0])); - start_logits.add_hint(&hint); - metadata.add_input(&start_logits); - - let end_logits = TensorMetadata::new("end_logits"); - end_logits.set_description( - "The indices for the end of each word/phrase to extract.", - ); - let hint = - supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[0])); - end_logits.add_hint(&hint); - metadata.add_input(&end_logits); - - let phrases = TensorMetadata::new("phrases"); - phrases.set_description("The phrases that were extracted."); - let hint = supported_shapes( - &[ElementType::Utf8], - DimensionsParam::Fixed(&[0]), - ); - phrases.add_hint(&hint); - metadata.add_output(&phrases); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - ctx.add_input_tensor( - "text", - ElementType::U8, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_input_tensor( - "start_logits", - ElementType::U32, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_input_tensor( - "end_logits", - ElementType::U32, - DimensionsParam::Fixed(&[0]), - ); - ctx.add_output_tensor( - "phrases", - ElementType::Utf8, - DimensionsParam::Fixed(&[0]), - ); + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("nlp") + .with_input( + TensorMetadata::new("text") + .with_description("The tokens making up this body of text."), + ) + .with_input( + TensorMetadata::new("start_logits") + .with_description("The indices for the start of each word/phrase to extract."), + ) + .with_input( + TensorMetadata::new("end_logits") + .with_description("The indices for the end of each word/phrase to extract."), + ) + .with_output( + TensorMetadata::new("phrases") + .with_description("The phrases that were extracted.") + ) +} - Ok(()) +#[derive(Debug, Default, Clone, PartialEq)] +struct TextExtractor; + +impl ProcBlock for TextExtractor { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new("wordlist", ElementType::Utf8, [0]), + TensorConstraint::new("start_logits", ElementType::U32, [0]), + TensorConstraint::new("end_logits", ElementType::U32, [0]), + ], + outputs: vec![TensorConstraint::new( + "phrases", + ElementType::Utf8, + [0], + )], + } } - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let text = ctx.get_input_tensor("text").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "text".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - + fn run(&self, inputs: Vec) -> Result, RunError> { + let text = Tensor::get_named(&inputs, "text")?.string_view()?; let start_logits = - ctx.get_input_tensor("start_logits").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "start_logits".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - + Tensor::get_named(&inputs, "start_logits")?.view_1d::()?; let end_logits = - ctx.get_input_tensor("end_logits").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "end_logits".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - match text.element_type { - ElementType::U8 =>{ - text.buffer.view::(&text.dimensions) - .map_err(|e| KernelError::InvalidInput(InvalidInput{ name: "text".to_string(), reason: BadInputReason::InvalidValue(e.to_string()) }))?; - } - other => { - return Err(KernelError::Other(format!( - "The Object Filter proc-block doesn't support {:?} element type", - other, - ))) - }, - }; - match start_logits.element_type { - ElementType::U32 =>{ - start_logits.buffer.view::(&start_logits.dimensions) - .map_err(|e| KernelError::InvalidInput(InvalidInput{ name: "start_logits".to_string(), reason: BadInputReason::InvalidValue(e.to_string()) }))?; - } - other => { - return Err(KernelError::Other(format!( - "The Object Filter proc-block doesn't support {:?} element type", - other, - ))) - }, - }; - match end_logits.element_type { - ElementType::U32 =>{ - end_logits.buffer.view::(&start_logits.dimensions) - .map_err(|e| KernelError::InvalidInput(InvalidInput{ name: "end_logits".to_string(), reason: BadInputReason::InvalidValue(e.to_string()) }))?; - } - other => { - return Err(KernelError::Other(format!( - "The Object Filter proc-block doesn't support {:?} element type", - other, - ))) - }, - }; - let output = transform(( - text.buffer.elements(), - start_logits.buffer.elements(), - end_logits.buffer.elements(), - )); - - ctx.set_output_tensor( - "phrases", - TensorParam { - element_type: ElementType::Utf8, - dimensions: &[output.len() as u32], - buffer: &string_tensor_from_ndarray(&ndarray::arr1(&output)), - }, - ); - - Ok(()) - } -} - -fn transform<'a>(inputs: (&[u8], &[u32], &[u32])) -> Vec { - let (text, start_logits, end_logits) = inputs; - - let underlying_bytes: &[u8] = text.elements(); - let input_text = core::str::from_utf8(underlying_bytes) - .expect("Input tensor should be valid UTF8"); + Tensor::get_named(&inputs, "end_logits")?.view_1d::()?; + + let mut phrases = Vec::new(); + + for (i, (&start, &end)) in + start_logits.into_iter().zip(end_logits).enumerate() + { + let start = start as usize; + let end = end as usize; + + if start == 0 && end == 0 { + // No more logits + break; + } else if start > end { + return Err(RunError::other(format!("At index {i}, the start logit ({start}) is after the end ({end})"))); + } else if end > text.len() { + return Err(InvalidInput::invalid_value("end_logits", format!("The {i}'th logit, {end}, is out of bounds (num tokens: {})", text.len())).into()); + } else if start >= text.len() { + return Err(InvalidInput::invalid_value("start_logits", format!("The {i}'th logit, {start}, is out of bounds (num tokens: {})", text.len())).into()); + } - let input_text: Vec<&str> = input_text.lines().collect(); + dbg!(start, end); + let tokens = text.slice(ndarray::s!(start..=end)); + let phrase = merge_phrases(tokens.iter().copied()); + phrases.push(phrase); + } - let start_index: u32 = start_logits[0]; - let end_index: u32 = end_logits[0]; - if end_index <= start_index { - panic!( - "Start index: {} is greater than or equal to end index: {}", - start_index, end_index - ); + let phrases = ndarray::aview1(&phrases); + Ok(vec![Tensor::from_strings("phrases", &phrases)]) } +} - let v = &input_text[start_index as usize..end_index as usize + 1]; +impl From> for TextExtractor { + fn from(_: Vec) -> Self { TextExtractor::default() } +} +fn merge_phrases<'a>(tokens: impl Iterator) -> String { let mut buffer = String::new(); - for tok in v { - if let Some(s) = tok.strip_prefix("##") { - buffer.push_str(s); - } else { - if !buffer.is_empty() { - buffer.push_str(" "); - } - buffer.push_str(tok); + + for token in tokens { + match token.strip_prefix("##") { + Some(token) => buffer.push_str(token), + None => { + if !buffer.is_empty() { + buffer.push_str(" "); + } + buffer.push_str(token); + }, } } - let output_text = vec![(buffer)]; - - println!("output {:?}", &output_text); - - output_text + buffer } #[cfg(test)] mod tests { use super::*; + #[test] - fn test_token_extractor() { - let bytes: Vec = "[UNK]\n[UNK]\nuna\n##ffa\n##ble\nworld\n!" - .as_bytes() - .to_vec(); - // let bytes =(bytes); - let start_index = [2_u32]; - let end_index = [4_u32]; - let output = transform((&bytes, &start_index, &end_index)); - - let should_be = vec!["unaffable".to_string()]; - - assert_eq!(output, should_be); + fn known_inputs() { + let proc_block = TextExtractor::default(); + let words = ndarray::array![ + "[UNK]", "[UNK]", "una", "##ffa", "##ble", "world", "!" + ]; + let inputs = vec![ + Tensor::from_strings("text", &words), + Tensor::new_1d("start_logits", &[2_u32]), + Tensor::new_1d("end_logits", &[4_u32]), + ]; + + let got = proc_block.run(inputs).unwrap(); + + let should_be = vec![Tensor::from_strings( + "phrases", + &ndarray::array!["unaffable"], + )]; + assert_eq!(got, should_be); } } diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 0fa40929e53..67c0a24fc11 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tokenizers" version = "0.12.0" -edition = "2018" +edition = "2021" publish = false repository = "https://github.com/hotg-ai/proc-blocks" description = "A proc-block takes a passage and a question as input and gives us BERT Encoding in form of input_ids, input_masks, segment_ids as output" diff --git a/train_test_split/Cargo.toml b/train_test_split/Cargo.toml index b5c448641b8..d6cf8ea5d87 100644 --- a/train_test_split/Cargo.toml +++ b/train_test_split/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "train_test_split" version = "0.12.1" -edition = "2018" +edition = "2021" description = "a random split into training and test sets" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/train_test_split/src/lib.rs b/train_test_split/src/lib.rs index 88e0d3bacf5..7a3c174b0ef 100644 --- a/train_test_split/src/lib.rs +++ b/train_test_split/src/lib.rs @@ -1,316 +1,177 @@ -use std::{fmt::Display, str::FromStr}; - -use smartcore::{ - linalg::{naive::dense_matrix::DenseMatrix, BaseMatrix}, - model_selection::train_test_split, -}; - -use crate::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; use hotg_rune_proc_blocks::{ - runtime_v1::{self, *}, - BufferExt, SliceExt, + guest::{ + parse, Argument, ArgumentMetadata, ArgumentType, CreateError, + ElementTypeConstraint, Metadata, ProcBlock, RunError, Tensor, + TensorConstraint, TensorConstraints, TensorMetadata, + }, + ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2}, +}; +use smartcore::{ + linalg::naive::dense_matrix::*, model_selection::train_test_split, }; -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -/// A proc block which can perform linear regression -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = - Metadata::new("Train-Test-Split", env!("CARGO_PKG_VERSION")); - metadata.set_description("a random split into training and test sets"); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("split"); - metadata.add_tag("data processing"); - metadata.add_tag("analytics"); - - let x = TensorMetadata::new("features"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x.add_hint(&hint); - metadata.add_input(&x); - - // todo: have to make it dynamic size because y could be 1-d or 2-d - let y = TensorMetadata::new("targets"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y.add_hint(&hint); - metadata.add_input(&y); - - let test_size = ArgumentMetadata::new("test_size"); - test_size.set_description( - "the proportion of the dataset to include in the test split", - ); - let hint = runtime_v1::supported_argument_type(ArgumentType::Float); - test_size.add_hint(&hint); - test_size.set_default_value("0.2"); - metadata.add_argument(&test_size); - - let element_type = ArgumentMetadata::new("element_type"); - element_type - .set_description("The type of tensor this proc-block will accept"); - element_type.set_default_value("f64"); - element_type.add_hint(&runtime_v1::interpret_as_string_in_enum(&[ - "u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64", - ])); - metadata.add_argument(&element_type); - - let x_train = TensorMetadata::new("x_train"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_train.add_hint(&hint); - metadata.add_output(&x_train); - - let y_train = TensorMetadata::new("y_train"); - let hint = - supported_shapes(&[ElementType::F64], DimensionsParam::Fixed(&[0])); - y_train.add_hint(&hint); - metadata.add_output(&y_train); - - let x_test = TensorMetadata::new("x_test"); - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0, 0])); - x_test.add_hint(&hint); - metadata.add_output(&x_test); - - let y_test = TensorMetadata::new("y_test"); - let supported_types = [ElementType::F64]; - let hint = - supported_shapes(&supported_types, DimensionsParam::Fixed(&[0])); - y_test.add_hint(&hint); - metadata.add_output(&y_test); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - let element_type = match ctx.get_argument("element_type").as_deref() { - Some("f64") => ElementType::F64, - Some(_) => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::InvalidValue( - "Unsupported element type".to_string(), - ), - })); - }, - None => { - return Err(GraphError::InvalidArgument(InvalidArgument { - name: "element_type".to_string(), - reason: BadArgumentReason::NotFound, - })) - }, - }; - - ctx.add_input_tensor( - "features", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_input_tensor( - "targets", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_output_tensor( - "x_train", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_output_tensor( - "y_train", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_output_tensor( - "x_test", - element_type, - DimensionsParam::Fixed(&[0, 0]), - ); - - ctx.add_output_tensor( - "y_test", - element_type, - DimensionsParam::Fixed(&[0]), - ); - - Ok(()) - } - - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let x = ctx.get_input_tensor("features").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "features".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let y = ctx.get_input_tensor("targets").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "targets".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let test_size: f32 = get_args("test_size", |n| ctx.get_argument(n)) - .map_err(KernelError::InvalidArgument)?; - - let (x_train, x_test, y_train, y_test, train_dim, test_dim) = transform( - &x.buffer.elements(), - &x.dimensions, - (&y.buffer.elements()).to_vec(), - test_size, - ); - - ctx.set_output_tensor( - "x_train", - TensorParam { - element_type: ElementType::F64, - dimensions: &[train_dim.0 as u32, train_dim.1 as u32], - buffer: &x_train.as_bytes(), - }, - ); - - ctx.set_output_tensor( - "x_test", - TensorParam { - element_type: ElementType::F64, - dimensions: &[test_dim.0 as u32, test_dim.1 as u32], - buffer: &x_test.as_bytes(), - }, - ); - - ctx.set_output_tensor( - "y_train", - TensorParam { - element_type: ElementType::F64, - dimensions: &[train_dim.0 as u32], - buffer: &y_train.as_bytes(), - }, - ); - - ctx.set_output_tensor( - "y_test", - TensorParam { - element_type: ElementType::F64, - dimensions: &[test_dim.0 as u32], - buffer: &y_test.as_bytes(), - }, - ); +hotg_rune_proc_blocks::export_proc_block! { + metadata: metadata, + proc_block: TrainTestSplit, +} - Ok(()) - } +fn metadata() -> Metadata { + Metadata::new("Train Test Split", env!("CARGO_PKG_VERSION")) + .with_description( + "a random split into training and test sets", + ) + .with_repository(env!("CARGO_PKG_REPOSITORY")) + .with_homepage(env!("CARGO_PKG_HOMEPAGE")) + .with_tag("split") + .with_tag("data processing") + .with_tag("analytics") + .with_argument(ArgumentMetadata::new("test_size") + .with_default_value("0.2") + .with_description("the proportion of the dataset to include in the test split") + .with_hint(ArgumentType::Float)) + .with_input(TensorMetadata::new("features").with_description("features")) + .with_input(TensorMetadata::new("targets").with_description("targets")) + .with_output(TensorMetadata::new("x_train").with_description("training features")) + .with_output(TensorMetadata::new("y_train").with_description("training labels")) + .with_output(TensorMetadata::new("x_test").with_description("testing features")) + .with_output(TensorMetadata::new("y_test").with_description("testing labels")) } -fn get_args( - name: &str, - get_argument: impl FnOnce(&str) -> Option, -) -> Result -where - T: FromStr, - ::Err: Display, -{ - get_argument(name) - .ok_or_else(|| InvalidArgument::not_found(name))? - .parse::() - .map_err(|e| InvalidArgument::invalid_value(name, e)) +struct TrainTestSplit { + test_size: f32, } -impl InvalidArgument { - fn not_found(name: impl Into) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::NotFound, +impl ProcBlock for TrainTestSplit { + fn tensor_constraints(&self) -> TensorConstraints { + TensorConstraints { + inputs: vec![ + TensorConstraint::new( + "features", + ElementTypeConstraint::F64, + vec![0, 0], + ), + TensorConstraint::new( + "targets", + ElementTypeConstraint::F64, + vec![0], + ), + ], + outputs: vec![ + TensorConstraint::new( + "x_train", + ElementTypeConstraint::F64, + vec![0, 0], + ), + TensorConstraint::new( + "y_train", + ElementTypeConstraint::F64, + vec![0], + ), + TensorConstraint::new( + "x_test", + ElementTypeConstraint::F64, + vec![0, 0], + ), + TensorConstraint::new( + "y_test", + ElementTypeConstraint::F64, + vec![0], + ), + ], } } - fn invalid_value(name: impl Into, reason: impl Display) -> Self { - InvalidArgument { - name: name.into(), - reason: BadArgumentReason::InvalidValue(reason.to_string()), - } + fn run(&self, inputs: Vec) -> Result, RunError> { + let features = Tensor::get_named(&inputs, "features")?.view_2d()?; + let targets = Tensor::get_named(&inputs, "targets")?.view_1d()?; + + let (x_train, y_train, x_test, y_test) = + transform(features, targets, self.test_size); + + Ok(vec![ + Tensor::new("x_train", &x_train), + Tensor::new("y_train", &y_train), + Tensor::new("x_test", &x_test), + Tensor::new("y_test", &y_test), + ]) } } fn transform( - x: &[f64], - x_dim: &[u32], - y: Vec, + x: ArrayView2<'_, f64>, + y: ArrayView1<'_, f64>, test_size: f32, -) -> ( - Vec, - Vec, - Vec, - Vec, - (usize, usize), - (usize, usize), -) { - let x = DenseMatrix::from_array(x_dim[0] as usize, x_dim[1] as usize, x); +) -> (Array2, Array1, Array2, Array1) { + let (rows, columns) = x.dim(); + let x = DenseMatrix::new(rows, columns, x.into_iter().copied().collect()); + + let y = y.to_vec(); let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, test_size, false); - let train_dim = x_train.shape(); - let test_dim = x_test.shape(); - let x_train: Vec = x_train.iter().map(|f| f).collect(); - let x_test: Vec = x_test.iter().map(|f| f).collect(); - (x_train, x_test, y_train, y_test, train_dim, test_dim) + let x_train: Array2 = + Array::from_shape_vec(x_train.shape(), x_train.iter().collect()) + .unwrap(); + let x_test: Array2 = + Array::from_shape_vec(x_test.shape(), x_test.iter().collect()).unwrap(); + let y_train: Array1 = + Array::from_shape_vec(y_train.len(), y_train).unwrap(); + let y_test: Array1 = + Array::from_shape_vec(y_test.len(), y_test).unwrap(); + + (x_train, y_train, x_test, y_test) +} + +impl TryFrom> for TrainTestSplit { + type Error = CreateError; + + fn try_from(args: Vec) -> Result { + let test_size = parse::optional_arg(&args, "test_size")?.unwrap_or(0.2); + + Ok(TrainTestSplit { test_size }) + } } #[cfg(test)] mod tests { use super::*; + use hotg_rune_proc_blocks::ndarray::array; #[test] fn check_test_dim() { - let x = [ - 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 5.2, 2.7, 3.9, 1.4, 5.1, - 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 5.2, 2.7, 3.9, 1.4, + let x: Array2 = array![ + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [5.2, 2.7, 3.9, 1.4], + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [5.2, 2.7, 3.9, 1.4] ]; - let y: Vec = vec![0., 0., 1., 0., 0., 1.]; - - let dim: Vec = vec![6, 4]; + let y: Array1 = array![0., 0., 1., 0., 0., 1.]; - let (_x_train, _x_test, _y_train, _y_test, _train_dim, test_dim) = - transform(&x, &dim, y, 0.2); + let (_x_train, _y_train, x_test, _y_test) = + transform(x.view(), y.view(), 0.2); - let should_be = (1, 4); - - assert_eq!(test_dim, should_be); + assert_eq!(x_test.dim(), (1, 4)); } + #[test] fn check_train_dim() { - let x = [ - 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 5.2, 2.7, 3.9, 1.4, 5.1, - 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 5.2, 2.7, 3.9, 1.4, + let x: Array2 = array![ + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [5.2, 2.7, 3.9, 1.4], + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [5.2, 2.7, 3.9, 1.4] ]; - let y: Vec = vec![0., 0., 1., 0., 0., 1.]; - - let dim: Vec = vec![6, 4]; + let y: Array1 = array![0., 0., 1., 0., 0., 1.]; - let (_x_train, _x_test, _y_train, _y_test, train_dim, _test_dim) = - transform(&x, &dim, y, 0.2); + let (x_train, y_train, _x_test, _y_test) = + transform(x.view(), y.view(), 0.2); - let should_be = (5, 4); - assert_eq!(train_dim, should_be); + assert_eq!(x_train.dim(), (5, 4)); + assert_eq!(y_train, array![0.0, 1.0, 0.0, 0.0, 1.0]); } } diff --git a/utf8_decode/Cargo.toml b/utf8_decode/Cargo.toml deleted file mode 100644 index 7c7b185a2a4..00000000000 --- a/utf8_decode/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "utf8_decode" -version = "0.12.0" -edition = "2018" -publish = false -repository = "https://github.com/hotg-ai/proc-blocks" -description = "convert back the u8 bytes to utf8" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[lib] -crate-type = ["cdylib", "rlib"] - -[dependencies] -hotg-rune-proc-blocks = { path = "../support" } -wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen"} - -[package.metadata.wapm] -namespace = "hotg-ai" -abi = "none" diff --git a/utf8_decode/src/lib.rs b/utf8_decode/src/lib.rs deleted file mode 100644 index a2272bd66ac..00000000000 --- a/utf8_decode/src/lib.rs +++ /dev/null @@ -1,147 +0,0 @@ -use crate::proc_block_v1::*; -use hotg_rune_proc_blocks::{ - ndarray::{s, ArrayView1}, - runtime_v1::*, - BufferExt, -}; - -wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); - -#[macro_use] -extern crate alloc; -use alloc::string::ToString; - -/// A proc block which can convert u8 bytes to utf8 -struct ProcBlockV1; - -impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { - fn register_metadata() { - let metadata = Metadata::new("UTF8 Decode", env!("CARGO_PKG_VERSION")); - metadata.set_description("Decode a string from UTF-8 bytes."); - metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); - metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); - metadata.add_tag("text"); - metadata.add_tag("nlp"); - metadata.add_tag("bytes"); - - let input = TensorMetadata::new("bytes"); - input.set_description("The string as UTF-8 encoded bytes"); - let hint = - supported_shapes(&[ElementType::U8], DimensionsParam::Fixed(&[0])); - input.add_hint(&hint); - metadata.add_input(&input); - - let output = TensorMetadata::new("string"); - output.set_description("The decoded text."); - let hint = supported_shapes( - &[ElementType::Utf8], - DimensionsParam::Fixed(&[1]), - ); - output.add_hint(&hint); - metadata.add_output(&output); - - register_node(&metadata); - } - - fn graph(node_id: String) -> Result<(), GraphError> { - let ctx = GraphContext::for_node(&node_id) - .ok_or(GraphError::MissingContext)?; - - ctx.add_input_tensor( - "bytes", - ElementType::U8, - DimensionsParam::Fixed(&[0]), - ); - - ctx.add_output_tensor( - "string", - ElementType::Utf8, - DimensionsParam::Fixed(&[1]), - ); - - Ok(()) - } - - fn kernel(node_id: String) -> Result<(), KernelError> { - let ctx = KernelContext::for_node(&node_id) - .ok_or(KernelError::MissingContext)?; - - let TensorResult { - element_type, - dimensions, - buffer, - } = ctx.get_input_tensor("bytes").ok_or_else(|| { - KernelError::InvalidInput(InvalidInput { - name: "bytes".to_string(), - reason: BadInputReason::NotFound, - }) - })?; - - let output = match element_type { - ElementType::U8 => { - let tensor = buffer - .view::(&dimensions) - .and_then(|t| t.into_dimensionality()) - .map_err(|e| { - KernelError::InvalidInput(InvalidInput { - name: "bytes".to_string(), - reason: BadInputReason::InvalidValue(e.to_string()), - }) - })?; - transform(tensor) - }, - other => { - return Err(KernelError::Other(format!( - "The Utf8 Decode proc-block doesn't support {:?} element type", - other, - ))) - }, - }; - - ctx.set_output_tensor( - "string", - TensorParam { - element_type: ElementType::Utf8, - dimensions: &[output.dim() as u32], - buffer: &output.to_vec(), - }, - ); - - Ok(()) - } -} - -fn transform(input: ArrayView1) -> ArrayView1 { - match input.iter().position(|&x| x == 0) { - Some(null_terminator) => input.slice_move(s![..null_terminator]), - None => input, - } -} - -#[cfg(test)] -mod tests { - use hotg_rune_proc_blocks::ndarray; - - use super::*; - - #[test] - fn test_for_utf8_decoding() { - let bytes = ndarray::array![ - 72_u8, 105, 44, 32, 117, 115, 101, 32, 109, 101, 32, 116, 111, 32, - 99, 111, 110, 118, 101, 114, 116, 32, 121, 111, 117, 114, 32, 117, - 56, 32, 98, 121, 116, 101, 115, 32, 116, 111, 32, 117, 116, 102, - 56, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; // bytes encoding for "Hi, use me to convert your u8 bytes to utf8." - - let should_be = ndarray::array![ - 72_u8, 105, 44, 32, 117, 115, 101, 32, 109, 101, 32, 116, 111, 32, - 99, 111, 110, 118, 101, 114, 116, 32, 121, 111, 117, 114, 32, 117, - 56, 32, 98, 121, 116, 101, 115, 32, 116, 111, 32, 117, 116, 102, - 56, 46, - ]; - - let output = transform(bytes.view()); - - assert_eq!(output, should_be); - } -} diff --git a/wit-files b/wit-files index 4b1464187e3..0740d48cfbe 160000 --- a/wit-files +++ b/wit-files @@ -1 +1 @@ -Subproject commit 4b1464187e3467ca491ab96c1819d70187b9a40b +Subproject commit 0740d48cfbeb771f5cb7066693a72df1e685c18e diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 6c68b69954c..6a1a90684f4 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "xtask" version = "0.1.0" -edition = "2018" +edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -11,14 +11,16 @@ cargo_metadata = "0.14.1" heck = "0.4.0" itertools = "0.10.3" once_cell = "1.10.0" +rand = "0.8.5" serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.79" structopt = "0.3.26" +thiserror = "1.0.31" tracing = "0.1.30" tracing-subscriber = { version = "0.3.8", features = ["env-filter"] } walrus = "0.19.0" -wasmtime = "0.34.0" -wit-bindgen-wasmtime = { git = "https://github.com/bytecodealliance/wit-bindgen", rev = "e165c0226ad0f00a840ef786079b58d0910dbb6b" } +wasmer = "2.3.0" +wit-bindgen-wasmer = { git = "https://github.com/wasmerio/wit-bindgen", branch = "wasmer", version = "0.1.0", features = ["tracing"] } [dev-dependencies] tempfile = "3.3.0" diff --git a/xtask/src/bin/xtask.rs b/xtask/src/bin/xtask.rs index 35f331fd6cb..bd25fd70967 100644 --- a/xtask/src/bin/xtask.rs +++ b/xtask/src/bin/xtask.rs @@ -10,16 +10,14 @@ use anyhow::{Context, Error}; use once_cell::sync::Lazy; use structopt::StructOpt; use tracing_subscriber::EnvFilter; -use xtask::{runtime::Runtime, CompilationMode}; +use xtask::{runtime::ProcBlockModule, CompilationMode}; fn main() -> Result<(), Error> { + if std::env::var_os("RUST_LOG").is_none() { + std::env::set_var("RUST_LOG", "warn,xtask=debug,walrus=error"); + } tracing_subscriber::fmt::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive("cranelift_codegen=warn".parse()?) - .add_directive("wasmtime_cranelift=warn".parse()?) - .add_directive("regalloc=warn".parse()?), - ) + .with_env_filter(EnvFilter::from_default_env()) .without_time() .init(); @@ -43,7 +41,6 @@ enum Command { Metadata(Metadata), /// Generate API documentation for one or more proc-blocks. Doc(Doc), - /// Graph(Graph), } @@ -56,6 +53,9 @@ struct Dist { /// optimisations. #[structopt(long)] debug: bool, + /// Do not abort the build as soon as there is an error. + #[structopt(short, long)] + keep_going: bool, /// Where to write compiled proc-blocks to. #[structopt(short, long, default_value = &*DIST_DIR)] out_dir: PathBuf, @@ -73,7 +73,7 @@ impl Dist { CompilationMode::Release }; - let mut wasm_modules = proc_blocks.compile(mode)?; + let mut wasm_modules = proc_blocks.compile(mode, self.keep_going)?; if !self.debug { tracing::info!("Stripping custom sections to reduce binary size"); @@ -113,7 +113,7 @@ impl Metadata { format!("Unable to read \"{}\"", self.proc_block.display()) })?; - let mut runtime = Runtime::load(&wasm) + let runtime = ProcBlockModule::load(&wasm) .context("Unable to load the WebAssembly module")?; let metadata = runtime @@ -146,14 +146,14 @@ impl Graph { format!("Unable to read \"{}\"", self.proc_block.display()) })?; - let mut runtime = Runtime::load(&wasm) + let runtime = ProcBlockModule::load(&wasm) .context("Unable to load the WebAssembly module")?; let arguments: HashMap<_, _> = self.args.into_iter().map(|a| (a.key, a.value)).collect(); let info = runtime - .graph(arguments) + .graph(&arguments) .context("Unable to infer the input and output tensors")?; let json = serde_json::to_string_pretty(&info) @@ -227,7 +227,7 @@ impl Doc { "Read the module into memory" ); - let mut r = Runtime::load(&wasm) + let r = ProcBlockModule::load(&wasm) .context("Unable to load the proc-block")?; let meta = r .metadata() diff --git a/xtask/src/bindings.rs b/xtask/src/bindings.rs new file mode 100644 index 00000000000..f2b1bc2da31 --- /dev/null +++ b/xtask/src/bindings.rs @@ -0,0 +1,337 @@ +pub mod runtime_v2 { + wit_bindgen_wasmer::export!("../wit-files/rune/runtime-v2.wit"); + #[doc(inline)] + pub use self::runtime_v2::*; +} + +pub mod proc_block_v2 { + use std::{error::Error, fmt::Display, num::NonZeroU32}; + + #[doc(inline)] + pub use proc_block_v2::*; + pub use TensorResult as Tensor; + + use serde::ser::{Serialize, SerializeSeq, SerializeStruct, Serializer}; + + wit_bindgen_wasmer::import!("../wit-files/rune/proc-block-v2.wit"); + + impl Serialize for Metadata { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let Metadata { + name, + version, + description, + repository, + homepage, + tags, + arguments, + inputs, + outputs, + } = self; + + let mut ser = serializer.serialize_struct("Metadata", 8)?; + + ser.serialize_field("name", name)?; + ser.serialize_field("version", version)?; + ser.serialize_field("description", description)?; + ser.serialize_field("repository", repository)?; + ser.serialize_field("homepage", homepage)?; + ser.serialize_field("tags", tags)?; + ser.serialize_field("arguments", arguments)?; + ser.serialize_field("inputs", inputs)?; + ser.serialize_field("outputs", outputs)?; + + ser.end() + } + } + + impl Serialize for TensorMetadata { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let TensorMetadata { + name, + description, + hints, + } = self; + let mut ser = serializer.serialize_struct("TensorMetadata", 2)?; + + ser.serialize_field("name", name)?; + ser.serialize_field("description", description)?; + ser.serialize_field("hints", hints)?; + + ser.end() + } + } + + impl Serialize for ArgumentMetadata { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let ArgumentMetadata { + name, + description, + hints, + default_value, + } = self; + let mut ser = serializer.serialize_struct("ArgumentMetadata", 2)?; + + ser.serialize_field("name", name)?; + ser.serialize_field("description", description)?; + ser.serialize_field("hints", hints)?; + ser.serialize_field("default_value", default_value)?; + + ser.end() + } + } + + impl Serialize for TensorConstraints { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let TensorConstraints { inputs, outputs } = self; + let mut ser = + serializer.serialize_struct("TensorConstraints", 2)?; + + ser.serialize_field("inputs", inputs)?; + ser.serialize_field("outputs", outputs)?; + + ser.end() + } + } + + impl Serialize for TensorConstraint { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let TensorConstraint { + name, + element_type, + dimensions, + } = self; + let mut ser = serializer.serialize_struct("TensorConstraint", 3)?; + + ser.serialize_field("name", name)?; + ser.serialize_field("element_type", element_type)?; + ser.serialize_field("dimensions", dimensions)?; + + ser.end() + } + } + + impl Serialize for TensorHint { + fn serialize(&self, ser: S) -> Result + where + S: Serializer, + { + #[derive(serde::Serialize)] + enum TensorHint<'a> { + Other(&'a str), + MediaType(MediaType), + } + + match self { + Self::Other(other) => TensorHint::Other(other).serialize(ser), + Self::MediaType(ty) => TensorHint::MediaType(*ty).serialize(ser), + } + } + } + + impl Serialize for ArgumentHint { + fn serialize(&self, ser: S) -> Result + where + S: Serializer, + { + #[derive(serde::Serialize)] + enum ArgumentHint<'a> { + Between((&'a str, &'a str)), + OneOf(&'a [String]), + NonNegativeNumber, + ArgumentType(ArgumentType), + } + + let arg = match self { + Self::Between((low, high)) => { + ArgumentHint::Between((low, high)) + }, + Self::OneOf(x) => ArgumentHint::OneOf(x), + Self::NonNegativeNumber => ArgumentHint::NonNegativeNumber, + Self::ArgumentType(ty) => ArgumentHint::ArgumentType(*ty), + }; + + arg.serialize(ser) + } + } + + impl Serialize for ArgumentType { + fn serialize(&self, ser: S) -> Result + where + S: Serializer, + { + format!("{self:?}").serialize(ser) + } + } + + impl Serialize for MediaType { + fn serialize(&self, ser: S) -> Result + where + S: Serializer, + { + format!("{self:?}").serialize(ser) + } + } + + impl Serialize for ElementTypeConstraint { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser = serializer + .serialize_seq(Some(self.bits().count_ones() as usize))?; + + if self.contains(ElementTypeConstraint::U8) { + ser.serialize_element(&ElementType::U8)?; + } + if self.contains(ElementTypeConstraint::I8) { + ser.serialize_element(&ElementType::I8)?; + } + if self.contains(ElementTypeConstraint::U16) { + ser.serialize_element(&ElementType::U16)?; + } + if self.contains(ElementTypeConstraint::I16) { + ser.serialize_element(&ElementType::I16)?; + } + if self.contains(ElementTypeConstraint::U32) { + ser.serialize_element(&ElementType::U32)?; + } + if self.contains(ElementTypeConstraint::I32) { + ser.serialize_element(&ElementType::I32)?; + } + if self.contains(ElementTypeConstraint::F32) { + ser.serialize_element(&ElementType::F32)?; + } + if self.contains(ElementTypeConstraint::U64) { + ser.serialize_element(&ElementType::U64)?; + } + if self.contains(ElementTypeConstraint::I64) { + ser.serialize_element(&ElementType::I64)?; + } + if self.contains(ElementTypeConstraint::F64) { + ser.serialize_element(&ElementType::F64)?; + } + if self.contains(ElementTypeConstraint::COMPLEX64) { + ser.serialize_element(&ElementType::Complex64)?; + } + if self.contains(ElementTypeConstraint::COMPLEX128) { + ser.serialize_element(&ElementType::Complex128)?; + } + if self.contains(ElementTypeConstraint::UTF8) { + ser.serialize_element(&ElementType::Utf8)?; + } + + ser.end() + } + } + + impl Serialize for ElementType { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + ElementType::U8 => "u8".serialize(serializer), + ElementType::I8 => "i8".serialize(serializer), + ElementType::U16 => "u16".serialize(serializer), + ElementType::I16 => "i16".serialize(serializer), + ElementType::U32 => "u32".serialize(serializer), + ElementType::I32 => "i32".serialize(serializer), + ElementType::F32 => "f32".serialize(serializer), + ElementType::U64 => "u64".serialize(serializer), + ElementType::I64 => "i64".serialize(serializer), + ElementType::F64 => "f64".serialize(serializer), + ElementType::Complex64 => "complex64".serialize(serializer), + ElementType::Complex128 => "complex128".serialize(serializer), + ElementType::Utf8 => "utf8".serialize(serializer), + } + } + } + + impl Serialize for Dimensions { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + #[derive(serde::Serialize)] + enum DimensionsWrapper { + Dynamic, + Fixed(Vec>), + } + + let dim = match self { + Dimensions::Dynamic => DimensionsWrapper::Dynamic, + Dimensions::Fixed(dims) => DimensionsWrapper::Fixed( + dims.iter().copied().map(NonZeroU32::new).collect(), + ), + }; + + dim.serialize(serializer) + } + } + + impl Display for CreateError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CreateError::Argument(_) => write!(f, "Unable to create the node because of an issue with the arguments"), + CreateError::Other(msg) => write!(f, "{msg}"), + } + } + } + + impl Error for CreateError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + CreateError::Argument(a) => Some(a), + CreateError::Other(_) => None, + } + } + } + + impl Display for ArgumentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = &self.name; + write!(f, "The \"{name}\" argument was invalid") + } + } + + impl Error for ArgumentError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.reason) + } + } + + impl Display for ArgumentErrorReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArgumentErrorReason::Other(msg) => write!(f, "{msg}"), + ArgumentErrorReason::NotFound => { + write!(f, "The argument wasn't provided") + }, + ArgumentErrorReason::InvalidValue(reason) => { + write!(f, "Invalid value: {reason}") + }, + ArgumentErrorReason::ParseFailed(reason) => { + write!(f, "Parse failed: {reason}") + }, + } + } + } + + impl Error for ArgumentErrorReason {} +} diff --git a/xtask/src/build.rs b/xtask/src/build.rs index 0a61813d695..7d1bdacf329 100644 --- a/xtask/src/build.rs +++ b/xtask/src/build.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Error}; -use cargo_metadata::{CargoOpt, Metadata, MetadataCommand, Package}; +use cargo_metadata::{Metadata, MetadataCommand, Package}; use std::{ path::{Path, PathBuf}, process::Command, @@ -61,6 +61,7 @@ impl ProcBlocks { pub fn compile( &self, mode: CompilationMode, + keep_going: bool, ) -> Result, Error> { let _span = tracing::info_span!("Compile").entered(); tracing::info!("Compiling proc-blocks to WebAssembly"); @@ -71,15 +72,16 @@ impl ProcBlocks { let mut libs = Vec::new(); for package in &self.packages { + let _span = + tracing::info_span!("build", package=%package.name).entered(); + let mut cmd = Command::new(&cargo); cmd.arg("rustc") .arg("--manifest-path") .arg(&package.manifest_path) .arg("--lib") .arg("--target=wasm32-unknown-unknown") - .arg("--features=metadata") - .arg("-Zunstable-options") - .arg("--crate-type=cdylib"); + .arg("--locked"); match mode { CompilationMode::Release => { @@ -99,11 +101,16 @@ impl ProcBlocks { tracing::debug!(exit_code = ?status.code(), "Cargo build completed"); - if !status.success() { + if status.success() { + libs.push(&package.name); + } else if keep_going { + tracing::warn!( + "Compiling \"{}\" failed. Skipping...", + package.name, + ); + } else { anyhow::bail!("Compilation failed"); } - - libs.push(&package.name); } tracing::debug!(?libs); @@ -157,9 +164,7 @@ impl CompiledModule { fn manifest(manifest_path: &Path) -> Result { let mut cmd = MetadataCommand::new(); - - cmd.manifest_path(manifest_path) - .features(CargoOpt::SomeFeatures(vec!["metadata".to_string()])); + cmd.manifest_path(manifest_path); tracing::debug!( manifest = %manifest_path.display(), diff --git a/xtask/src/docs.rs b/xtask/src/docs.rs index 22b2e519192..c5ef490ff2a 100644 --- a/xtask/src/docs.rs +++ b/xtask/src/docs.rs @@ -1,11 +1,10 @@ use std::io::Write; use anyhow::Error; -use itertools::Itertools; -use crate::runtime::{ - runtime_v1::ArgumentType, ArgumentHint, ArgumentMetadata, Dimensions, - Metadata, TensorHint, TensorMetadata, +use crate::proc_block_v2::{ + ArgumentHint, ArgumentMetadata, ArgumentType, Metadata, TensorHint, + TensorMetadata, }; pub fn document(w: &mut dyn Write, meta: &Metadata) -> Result<(), Error> { @@ -106,28 +105,11 @@ fn render_tensor_hint( hint: &TensorHint, ) -> Result<(), Error> { match hint { - TensorHint::DisplayAs(ty) => { - writeln!(w, "- Display as `{ty}`")?; + TensorHint::MediaType(ty) => { + writeln!(w, "- Display as `{ty:?}`")?; }, - TensorHint::SupportedShape { - accepted_element_types, - dimensions, - } => { - write!(w, "- A ")?; - - match dimensions { - Dimensions::Dynamic => { - write!(w, "dynamically sized tensor")?; - }, - Dimensions::Fixed(fixed) => { - let dims = fixed.iter().join("`, `"); - write!(w, "fixed-size tensor with dimensions `[{dims}]`")?; - }, - } - - let elements = accepted_element_types.iter().join("`, `"); - - writeln!(w, " which may contain one of `{elements}`")?; + TensorHint::Other(msg) => { + writeln!(w, "- {msg}")?; }, } @@ -195,26 +177,26 @@ fn render_argument_hints( for hint in hints { match hint { ArgumentHint::NonNegativeNumber => write!(w, "- Non-negative")?, - ArgumentHint::StringEnum(variants) => { + ArgumentHint::OneOf(variants) => { let variants = variants.join("`, `"); writeln!(w, "- One of `\"{variants}\"`")?; }, - ArgumentHint::NumberInRange { max, min } => { + ArgumentHint::Between((max, min)) => { writeln!(w, "- A value between `{min}` and `{max}`")? }, - ArgumentHint::SupportedArgumentType(ArgumentType::Float) => { + ArgumentHint::ArgumentType(ArgumentType::Float) => { writeln!(w, "- A float")? }, - ArgumentHint::SupportedArgumentType(ArgumentType::Integer) => { + ArgumentHint::ArgumentType(ArgumentType::Integer) => { writeln!(w, "- An integer")? }, - ArgumentHint::SupportedArgumentType( - ArgumentType::UnsignedInteger, - ) => writeln!(w, "- An unsigned integer")?, - ArgumentHint::SupportedArgumentType(ArgumentType::String) => { + ArgumentHint::ArgumentType(ArgumentType::UnsignedInteger) => { + writeln!(w, "- An unsigned integer")? + }, + ArgumentHint::ArgumentType(ArgumentType::String) => { writeln!(w, "- A string")? }, - ArgumentHint::SupportedArgumentType(ArgumentType::LongString) => { + ArgumentHint::ArgumentType(ArgumentType::LongString) => { writeln!(w, "- A multi-line string")? }, } diff --git a/xtask/src/lib.rs b/xtask/src/lib.rs index e6aa10bdc5b..d946636a56b 100644 --- a/xtask/src/lib.rs +++ b/xtask/src/lib.rs @@ -1,9 +1,11 @@ +mod bindings; mod build; mod docs; mod manifest; pub mod runtime; pub use crate::{ + bindings::{proc_block_v2, runtime_v2}, build::{discover_proc_block_manifests, CompilationMode}, docs::document, manifest::{generate_manifest, Manifest}, diff --git a/xtask/src/manifest.rs b/xtask/src/manifest.rs index 7286c7986fb..ffc3cd8ce7e 100644 --- a/xtask/src/manifest.rs +++ b/xtask/src/manifest.rs @@ -1,6 +1,5 @@ use crate::{ - build::CompiledModule, - runtime::{Metadata, Runtime}, + build::CompiledModule, proc_block_v2::Metadata, runtime::ProcBlockModule, }; use anyhow::{Context, Error}; use serde::Serialize; @@ -40,7 +39,7 @@ pub fn generate_manifest( } fn extract_metadata(serialized: &[u8]) -> Result { - Runtime::load(serialized)?.metadata() + ProcBlockModule::load(serialized)?.metadata() } #[derive(Default)] @@ -63,9 +62,6 @@ impl Manifest { })?; } - save_json(dir.join("metadata.json"), &self.metadata) - .context("Unable to save the metadata")?; - let names: Vec<_> = self.metadata.keys().collect(); save_json(dir.join("manifest.json"), &names) .context("Unable to save the manifest")?; diff --git a/xtask/src/runtime.rs b/xtask/src/runtime.rs index d271375e273..afbddd26f9b 100644 --- a/xtask/src/runtime.rs +++ b/xtask/src/runtime.rs @@ -1,718 +1,90 @@ -use crate::runtime::proc_block_v1::{ - BadArgumentReason, BadInputReason, GraphError, InvalidArgument, - InvalidInput, KernelError, -}; - -use self::proc_block_v1::{ProcBlockV1, ProcBlockV1Data}; use anyhow::{Context, Error}; -use serde::{Deserialize, Serialize}; -use std::{ - collections::HashMap, - fmt::{self, Display, Formatter}, - num::NonZeroUsize, - sync::{Arc, Mutex}, +use rand::Rng; +use wasmer::{ImportObject, Module, Store, WasmerEnv}; + +use crate::{ + proc_block_v2::{Argument, Metadata, Node, ProcBlockV2, TensorConstraints}, + runtime_v2::{self, LogMetadata, LogValueMap}, }; -use wasmtime::{Engine, Linker, Module, Store}; +use std::collections::HashMap; -wit_bindgen_wasmtime::export!("../wit-files/rune/runtime-v1.wit"); -wit_bindgen_wasmtime::import!("../wit-files/rune/proc-block-v1.wit"); +pub struct ProcBlockModule(ProcBlockV2); -pub struct Runtime { - rune: ProcBlockV1, - store: Store, -} - -impl Runtime { - #[tracing::instrument(skip(wasm))] +impl ProcBlockModule { pub fn load(wasm: &[u8]) -> Result { - let engine = Engine::default(); - - tracing::debug!("Loading the WebAssembly module"); - - let module = Module::new(&engine, wasm) - .context("Unable to instantiate the module")?; - let mut store = Store::new(&engine, State::default()); + let store = Store::default(); - tracing::debug!("Setting up the host functions"); + let module = Module::new(&store, wasm) + .context("Unable to compile the WebAssembly module")?; - let mut linker = Linker::new(&engine); - runtime_v1::add_to_linker(&mut linker, |state: &mut State| { - (&mut state.runtime, &mut state.tables) - }) - .context("Unable to register the host functions")?; + let mut imports = ImportObject::new(); + runtime_v2::add_to_imports(&store, &mut imports, HostFunctions); - tracing::debug!("Instantiating the WebAssembly module"); + let (glue, _instance) = + ProcBlockV2::instantiate(&store, &module, &mut imports) + .context("Unable to instantiate the WebAssembly module")?; - let (rune, _) = ProcBlockV1::instantiate( - &mut store, - &module, - &mut linker, - |state: &mut State| &mut state.rune_v1_data, - ) - .context("Unable to instantiate the WebAssembly module")?; - - Ok(Runtime { rune, store }) + Ok(ProcBlockModule(glue)) } - #[tracing::instrument(skip(self))] - pub fn metadata(&mut self) -> Result { - tracing::debug!("Running the register_metadata() function"); - - self.rune.register_metadata(&mut self.store).context( - "Unable to run the WebAssembly module's register_metadata() function", - )?; - - self.store - .data_mut() - .runtime - .node - .take() - .context("The WebAssembly module didn't register any metadata") + pub fn metadata(&self) -> Result { + let meta = self.0.metadata()?; + Ok(meta) } - #[tracing::instrument(skip(self, args))] pub fn graph( - &mut self, - args: HashMap, - ) -> Result { - let ctx = GraphContext::new(args); - self.store.data_mut().runtime.graph_ctx = - Some(Arc::new(Mutex::new(ctx))); - - self.rune - .graph(&mut self.store, "") - .context("Unable to call the graph() function")??; - - let ctx = self.store.data_mut().runtime.graph_ctx.take().unwrap(); - let ctx = ctx.lock().unwrap(); - Ok(ctx.node.clone()) - } -} - -#[derive(Default)] -struct State { - runtime: RuntimeV1, - tables: runtime_v1::RuntimeV1Tables, - rune_v1_data: ProcBlockV1Data, -} - -#[derive(Default)] -struct RuntimeV1 { - node: Option, - graph_ctx: Option>>, - kernel_ctx: Option>>, -} - -#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct Metadata { - pub name: String, - pub version: String, - pub description: Option, - pub repository: Option, - pub homepage: Option, - pub tags: Vec, - pub arguments: Vec, - pub inputs: Vec, - pub outputs: Vec, -} - -#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct ArgumentMetadata { - pub name: String, - pub description: Option, - pub default_value: Option, - pub hints: Vec, -} - -#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct TensorMetadata { - pub name: String, - pub description: Option, - pub hints: Vec, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case", tag = "type", content = "value")] -pub enum TensorHint { - DisplayAs(String), - SupportedShape { - accepted_element_types: Vec, - dimensions: Dimensions, - }, -} - -#[derive( - Debug, Copy, Clone, PartialEq, serde::Serialize, serde::Deserialize, -)] -#[serde(rename_all = "kebab-case")] -pub enum ElementType { - U8, - I8, - U16, - I16, - U32, - I32, - F32, - U64, - I64, - F64, - Utf8, -} + &self, + args: &HashMap, + ) -> Result { + let node = self.instantiate(args)?; + let constraints = self.0.node_tensor_constraints(&node)?; -impl Display for ElementType { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - ElementType::U8 => write!(f, "u8"), - ElementType::I8 => write!(f, "i8"), - ElementType::U16 => write!(f, "u16"), - ElementType::I16 => write!(f, "i16"), - ElementType::U32 => write!(f, "u32"), - ElementType::I32 => write!(f, "i32"), - ElementType::F32 => write!(f, "f32"), - ElementType::U64 => write!(f, "u64"), - ElementType::I64 => write!(f, "i64"), - ElementType::F64 => write!(f, "f64"), - ElementType::Utf8 => write!(f, "utf-8"), - } + Ok(constraints) } -} -impl From for ElementType { - fn from(e: runtime_v1::ElementType) -> Self { - match e { - runtime_v1::ElementType::U8 => ElementType::U8, - runtime_v1::ElementType::I8 => ElementType::I8, - runtime_v1::ElementType::U16 => ElementType::U16, - runtime_v1::ElementType::I16 => ElementType::I16, - runtime_v1::ElementType::U32 => ElementType::U32, - runtime_v1::ElementType::I32 => ElementType::I32, - runtime_v1::ElementType::F32 => ElementType::F32, - runtime_v1::ElementType::I64 => ElementType::I64, - runtime_v1::ElementType::U64 => ElementType::U64, - runtime_v1::ElementType::F64 => ElementType::F64, - runtime_v1::ElementType::Utf8 => ElementType::Utf8, - } - } -} - -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case", tag = "type", content = "value")] -pub enum Dimensions { - Dynamic, - Fixed(Vec), -} + pub fn instantiate( + &self, + args: &HashMap, + ) -> Result { + let args: Vec<_> = args + .into_iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .map(|(name, value)| Argument { name, value }) + .collect(); -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case", tag = "type", content = "value")] -pub enum Dimension { - Fixed(NonZeroUsize), - Dynamic, -} - -impl Display for Dimension { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Dimension::Fixed(fixed) => fixed.fmt(f), - Dimension::Dynamic => "*".fmt(f), - } - } -} + let instance = self.0.create_node(&args)??; -impl From> for Dimensions { - fn from(d: runtime_v1::DimensionsParam<'_>) -> Self { - match d { - runtime_v1::DimensionsParam::Dynamic => Dimensions::Dynamic, - runtime_v1::DimensionsParam::Fixed(dims) => Dimensions::Fixed( - dims.iter() - .map(|d| match NonZeroUsize::new(d.get() as usize) { - Some(d) => Dimension::Fixed(d), - None => Dimension::Dynamic, - }) - .collect(), - ), - } + Ok(instance) } } -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "kebab-case", tag = "type", content = "value")] -pub enum ArgumentHint { - NonNegativeNumber, - StringEnum(Vec), - NumberInRange { - max: String, - min: String, - }, - #[serde(with = "ArgumentTypeRepr")] - SupportedArgumentType(runtime_v1::ArgumentType), -} +#[derive(Debug, Default, Clone, WasmerEnv)] +struct HostFunctions; -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] -#[serde(remote = "runtime_v1::ArgumentType")] -enum ArgumentTypeRepr { - UnsignedInteger, - Integer, - Float, - String, - LongString, -} - -#[derive(Debug)] -pub struct GraphContext { - args: HashMap, - node: NodeInfo, -} +impl crate::runtime_v2::RuntimeV2 for HostFunctions { + fn abort(&mut self, msg: &str) { + #[derive(Debug, thiserror::Error)] + #[error("Abort: {_0}")] + struct Abort(String); -impl GraphContext { - pub fn new(args: HashMap) -> Self { - GraphContext { - args, - node: NodeInfo::default(), + // Safety: This will only ever be called by the WebAssembly guest + unsafe { + wasmer::raise_user_trap(Box::new(Abort(msg.to_string()))); } } -} - -#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct NodeInfo { - pub inputs: Vec, - pub outputs: Vec, -} -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct TensorInfo { - pub name: String, - pub element_type: ElementType, - pub dimensions: Dimensions, -} - -#[derive(Debug)] -pub struct KernelContext { - args: HashMap, -} - -impl runtime_v1::RuntimeV1 for RuntimeV1 { - type ArgumentHint = ArgumentHint; - type ArgumentMetadata = Mutex; - type GraphContext = Arc>; - type KernelContext = Arc>; - type Metadata = Mutex; - type Model = (); - type TensorHint = TensorHint; - type TensorMetadata = Mutex; - - fn metadata_new(&mut self, name: &str, version: &str) -> Self::Metadata { - Mutex::new(Metadata { - name: name.to_string(), - version: version.to_string(), - ..Default::default() - }) - } - - fn metadata_set_description( - &mut self, - self_: &Self::Metadata, - description: &str, - ) { - self_.lock().unwrap().description = Some(description.to_string()); - } - - fn metadata_set_repository(&mut self, self_: &Self::Metadata, url: &str) { - self_.lock().unwrap().repository = Some(url.to_string()); - } - - fn metadata_set_homepage(&mut self, self_: &Self::Metadata, url: &str) { - self_.lock().unwrap().homepage = Some(url.to_string()); - } - - fn metadata_add_tag(&mut self, self_: &Self::Metadata, tag: &str) { - self_.lock().unwrap().tags.push(tag.to_string()); - } - - fn metadata_add_argument( - &mut self, - self_: &Self::Metadata, - arg: &Self::ArgumentMetadata, - ) { - self_ - .lock() - .unwrap() - .arguments - .push(arg.lock().unwrap().clone()); - } - - fn metadata_add_input( - &mut self, - self_: &Self::Metadata, - metadata: &Self::TensorMetadata, - ) { - self_ - .lock() - .unwrap() - .inputs - .push(metadata.lock().unwrap().clone()); - } - - fn metadata_add_output( - &mut self, - self_: &Self::Metadata, - metadata: &Self::TensorMetadata, - ) { - self_ - .lock() - .unwrap() - .outputs - .push(metadata.lock().unwrap().clone()); - } - - fn argument_metadata_new(&mut self, name: &str) -> Self::ArgumentMetadata { - Mutex::new(ArgumentMetadata { - name: name.to_string(), - ..Default::default() - }) - } - - fn argument_metadata_set_description( - &mut self, - self_: &Self::ArgumentMetadata, - description: &str, - ) { - self_.lock().unwrap().description = Some(description.to_string()); - } - - fn argument_metadata_set_default_value( - &mut self, - self_: &Self::ArgumentMetadata, - default_value: &str, - ) { - self_.lock().unwrap().default_value = Some(default_value.to_string()); - } - - fn argument_metadata_add_hint( - &mut self, - self_: &Self::ArgumentMetadata, - hint: &Self::ArgumentHint, - ) { - self_.lock().unwrap().hints.push(hint.clone()); - } - - fn tensor_metadata_new(&mut self, name: &str) -> Self::TensorMetadata { - Mutex::new(TensorMetadata { - name: name.to_string(), - ..Default::default() - }) - } - - fn tensor_metadata_set_description( - &mut self, - self_: &Self::TensorMetadata, - description: &str, - ) { - self_.lock().unwrap().description = Some(description.to_string()); - } - - fn tensor_metadata_add_hint( - &mut self, - self_: &Self::TensorMetadata, - hint: &Self::TensorHint, - ) { - self_.lock().unwrap().hints.push(hint.clone()); - } - - fn interpret_as_image(&mut self) -> Self::TensorHint { - TensorHint::DisplayAs("image".to_string()) - } - - fn interpret_as_audio(&mut self) -> Self::TensorHint { - TensorHint::DisplayAs("audio".to_string()) - } - - fn supported_shapes( - &mut self, - supported_element_type: Vec, - dimensions: runtime_v1::DimensionsParam<'_>, - ) -> Self::TensorHint { - TensorHint::SupportedShape { - accepted_element_types: supported_element_type - .into_iter() - .map(ElementType::from) - .collect(), - dimensions: dimensions.into(), - } - } - - fn interpret_as_number_in_range( - &mut self, - min: &str, - max: &str, - ) -> Self::ArgumentHint { - ArgumentHint::NumberInRange { - min: min.to_string(), - max: max.to_string(), - } - } - - fn interpret_as_string_in_enum( - &mut self, - string_enum: Vec<&str>, - ) -> Self::ArgumentHint { - ArgumentHint::StringEnum( - string_enum.iter().map(|s| s.to_string()).collect(), - ) - } - - fn non_negative_number(&mut self) -> Self::ArgumentHint { - ArgumentHint::NonNegativeNumber - } - - fn supported_argument_type( - &mut self, - hint: runtime_v1::ArgumentType, - ) -> Self::ArgumentHint { - ArgumentHint::SupportedArgumentType(hint) - } - - fn register_node(&mut self, metadata: &Self::Metadata) { - self.node = Some(metadata.lock().unwrap().clone()); - } - - fn graph_context_for_node( - &mut self, - _name: &str, - ) -> Option { - self.graph_ctx.clone() - } - - fn graph_context_get_argument( - &mut self, - self_: &Self::GraphContext, - name: &str, - ) -> Option { - self_.lock().unwrap().args.get(name).cloned() - } - - fn graph_context_add_input_tensor( - &mut self, - self_: &Self::GraphContext, - name: &str, - element_type: runtime_v1::ElementType, - dimensions: runtime_v1::DimensionsParam<'_>, - ) { - self_.lock().unwrap().node.inputs.push(TensorInfo { - name: name.to_string(), - element_type: element_type.into(), - dimensions: dimensions.into(), - }) - } - - fn graph_context_add_output_tensor( - &mut self, - self_: &Self::GraphContext, - name: &str, - element_type: runtime_v1::ElementType, - dimensions: runtime_v1::DimensionsParam<'_>, - ) { - self_.lock().unwrap().node.outputs.push(TensorInfo { - name: name.to_string(), - element_type: element_type.into(), - dimensions: dimensions.into(), - }) - } - - fn kernel_context_for_node( - &mut self, - _name: &str, - ) -> Option { - self.kernel_ctx.clone() - } - - fn kernel_context_get_argument( - &mut self, - self_: &Self::KernelContext, - name: &str, - ) -> Option { - self_.lock().unwrap().args.get(name).cloned() - } - - fn kernel_context_get_input_tensor( - &mut self, - _self_: &Self::KernelContext, - _name: &str, - ) -> Option { - unimplemented!() - } - - fn kernel_context_set_output_tensor( - &mut self, - _self_: &Self::KernelContext, - _name: &str, - _tensor: runtime_v1::TensorParam<'_>, - ) { - unimplemented!() - } - - fn is_enabled(&mut self, _metadata: &Self::Metadata) -> bool { true } + #[tracing::instrument(skip_all, level = "debug")] + fn is_enabled(&mut self, _metadata: LogMetadata<'_>) -> bool { false } fn log( &mut self, - metadata: &Self::Metadata, - message: &str, - data: runtime_v1::LogValueMap<'_>, - ) { - tracing::info!(?metadata, ?data, message); - } - - fn kernel_context_get_global_input( - &mut self, - _self_: &Self::KernelContext, - _name: &str, - ) -> Option { - todo!() - } - - fn kernel_context_set_global_output( - &mut self, - _self_: &Self::KernelContext, - _name: &str, - _tensor: runtime_v1::TensorParam<'_>, + _metadata: LogMetadata<'_>, + _message: &str, + _data: LogValueMap<'_>, ) { - todo!() - } - - fn model_load( - &mut self, - _model_format: &str, - _model: &[u8], - _arguments: Vec<(&str, &str)>, - ) -> Result { - todo!() - } - - fn model_infer( - &mut self, - _self_: &Self::Model, - _inputs: Vec>, - ) -> Result, runtime_v1::ModelInferError> - { - todo!() - } - - fn model_inputs(&mut self, _self_: &Self::Model) -> Vec { - todo!() - } - - fn model_outputs( - &mut self, - _self_: &Self::Model, - ) -> Vec { - todo!() } -} -impl Display for GraphError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - GraphError::InvalidArgument(a) => a.fmt(f), - GraphError::MissingContext => { - write!(f, "The context wasn't passed in") - }, - GraphError::Other(msg) => write!(f, "{}", msg), - } + fn get_random(&mut self, buffer: &mut [u8]) { + rand::thread_rng().fill(buffer); } } - -impl std::error::Error for GraphError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - GraphError::InvalidArgument(a) => a.source(), - GraphError::MissingContext | GraphError::Other(_) => None, - } - } -} - -impl Display for KernelError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - KernelError::InvalidArgument(a) => a.fmt(f), - KernelError::InvalidInput(i) => i.fmt(f), - KernelError::MissingContext => { - write!(f, "The context wasn't passed in") - }, - KernelError::Other(msg) => write!(f, "{}", msg), - } - } -} - -impl std::error::Error for KernelError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - KernelError::InvalidArgument(a) => a.source(), - KernelError::InvalidInput(i) => i.source(), - KernelError::MissingContext | KernelError::Other(_) => None, - } - } -} - -impl Display for InvalidInput { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "The \"{}\" input tensor was invalid", self.name) - } -} - -impl std::error::Error for InvalidInput { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.reason) - } -} - -impl Display for InvalidArgument { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "The \"{}\" argument was invalid", self.name) - } -} - -impl std::error::Error for InvalidArgument { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.reason) - } -} - -impl Display for BadArgumentReason { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - BadArgumentReason::NotFound => { - write!(f, "The argument wasn't provided") - }, - BadArgumentReason::InvalidValue(reason) => { - write!(f, "{}", reason) - }, - BadArgumentReason::Other(msg) => write!(f, "{}", msg), - } - } -} - -impl std::error::Error for BadArgumentReason {} - -impl Display for BadInputReason { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - BadInputReason::NotFound => { - write!(f, "The input tensor wasn't provided") - }, - BadInputReason::InvalidValue(reason) => { - write!(f, "{}", reason) - }, - BadInputReason::UnsupportedShape => { - write!(f, "Unsupported shape") - }, - BadInputReason::Other(msg) => write!(f, "{}", msg), - } - } -} - -impl std::error::Error for BadInputReason {}