diff --git a/.circleci/Dockerfile b/.circleci/Dockerfile index cd0406c9f..461e2d82f 100644 --- a/.circleci/Dockerfile +++ b/.circleci/Dockerfile @@ -29,8 +29,8 @@ RUN echo 'export PATH="/usr/local/bin:$PATH"' >> $HOME/.bashrc # Install Rust RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y RUN echo 'export PATH="$HOME/.cargo/bin:$PATH"' >> /home/circleci/.bashrc -RUN ~/.cargo/bin/rustup install 1.81 -RUN ~/.cargo/bin/rustup default 1.81 +RUN ~/.cargo/bin/rustup install 1.84 +RUN ~/.cargo/bin/rustup default 1.84 RUN ~/.cargo/bin/rustup target add wasm32-unknown-unknown # Install Deno diff --git a/.circleci/config.yml b/.circleci/config.yml index dca83de7e..440934596 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ orbs: jobs: build-and-test: docker: - - image: coasys/ad4m-ci-linux:latest@sha256:2dd8206db5ee73ae58f48e2c2db595320ab88b94f5e0c0ddc99698191a03114f + - image: coasys/ad4m-ci-linux:latest@sha256:f6499125645a0df59bacf1790d9c68ab3e4872d0ca435bfc96d78bd5d857e114 resource_class: xlarge steps: - checkout diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 2fcf45100..79b14e3f5 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -61,7 +61,7 @@ jobs: uses: actions-rs/toolchain@v1 with: override: true - toolchain: 1.81.0 + toolchain: 1.84.0 - run: rustup target add wasm32-unknown-unknown - name: Install Rust targets for macOS diff --git a/.github/workflows/publish_staging.yml b/.github/workflows/publish_staging.yml index f6d41d62c..d55c89670 100644 --- a/.github/workflows/publish_staging.yml +++ b/.github/workflows/publish_staging.yml @@ -61,7 +61,7 @@ jobs: uses: actions-rs/toolchain@v1 with: override: true - toolchain: 1.81.0 + toolchain: 1.84.0 - run: rustup target add wasm32-unknown-unknown - name: Install Rust targets for macOS diff --git a/CHANGELOG b/CHANGELOG index d3a463c1a..e1a6bee3f 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -23,6 +23,7 @@ This project _loosely_ adheres to [Semantic Versioning](https://semver.org/spec/ - Fix for syncing broken after some time: Prevent p-diff-sync diff entries from exceeding HC entry size limit of 4MB (which would happen at some point through snapshot creation) [PR#553](https://github.com/coasys/ad4m/pull/553) - Fix for crash when removing .ad4m directory after using multiple agent feature [PR#556](https://github.com/coasys/ad4m/pull/556) - Fix error after spawning AI task [PR#559](https://github.com/coasys/ad4m/pull/559) +- Fix some problems with perspective.removeLinks() with a proper implementation [PR#563](https://github.com/coasys/ad4m/pull/563) ### Added - Prolog predicates needed in new Flux mention notification trigger: @@ -45,7 +46,9 @@ This project _loosely_ adheres to [Semantic Versioning](https://semver.org/spec/ - Make `SubjectEntity.#perspective` protected to enable subclasses to implement complex fields and methods [PR#557](https://github.com/coasys/ad4m/pull/557) - Add AI client to PerspectiveProxy to enable SubjectEntity sub-classes (subject classes) to use AI processes without having to rely on ad4m-connect or similar to access the AI client. [PR#558](https://github.com/coasys/ad4m/pull/558) - SubjectEntity.query() with `where: { propertyName: "value" }` and `where: { condition: 'triple(Base, _, "...")'} [PR#560](https://github.com/coasys/ad4m/pull/560) - +- Update Kalosm and candle to latest versions and add very recent open-source models like DeepSeek and Qwen to the model picker. Also add AI task delete button to launcher UI. Set Qwen 2.5 coder instruct 7b as default for local LLM [PR#558](https://github.com/coasys/ad4m/pull/561) +- Models can now be added directly from Huggingface without changing the code, just providing repo and filename in the launcher [PR562](https://github.com/coasys/ad4m/pull/562) +- ### Changed - Partially migrated the Runtime service to Rust. (DM language installation for agents is pending.) [PR#466](https://github.com/coasys/ad4m/pull/466) - Improved performance of SDNA / SubjectClass functions by moving code from client into executor and saving a lot of client <-> executor roundtrips [PR#480](https://github.com/coasys/ad4m/pull/480) @@ -68,6 +71,7 @@ This project _loosely_ adheres to [Semantic Versioning](https://semver.org/spec/ - Added ability to handle multiple agents in launcher. [PR#459](https://github.com/coasys/ad4m/pull/459) - Added a way to show & add new `AgentInfo` in launcher. [PR#463](https://github.com/coasys/ad4m/pull/463) - `ad4m-executor` binary prints capability request challange to stdout to enable app hand-shake [PR#471](https://github.com/coasys/ad4m/pull/471) + - Add ability to select Whisper model size [PR#564](https://github.com/coasys/ad4m/pull/564) ### Changed - Much improved ADAM Launcher setup flow [PR#440](https://github.com/coasys/ad4m/pull/440) and [PR#444](https://github.com/coasys/ad4m/pull/444): diff --git a/Cargo.lock b/Cargo.lock index 11e7e080e..02e2fac61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,7 +20,7 @@ checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" [[package]] name = "ad4m" -version = "0.10.1-rc7" +version = "0.10.1-rc8" dependencies = [ "ad4m-client", "ad4m-executor", @@ -46,7 +46,7 @@ dependencies = [ [[package]] name = "ad4m-client" -version = "0.10.1-rc7" +version = "0.10.1-rc8" dependencies = [ "anyhow", "async-tungstenite", @@ -70,9 +70,10 @@ dependencies = [ [[package]] name = "ad4m-executor" -version = "0.10.1-rc7" +version = "0.10.1-rc8" dependencies = [ "ad4m-client", + "anyhow", "argon2", "base64 0.21.7", "candle-core", @@ -120,7 +121,7 @@ dependencies = [ "regex", "reqwest 0.11.20", "rocket", - "rodio", + "rodio 0.17.3", "rusqlite", "rust-embed", "rustls 0.23.12", @@ -144,7 +145,7 @@ dependencies = [ [[package]] name = "ad4m-launcher" -version = "0.10.1-rc7" +version = "0.10.1-rc8" dependencies = [ "ad4m-client", "ad4m-executor", @@ -453,9 +454,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" dependencies = [ "backtrace", ] @@ -1757,9 +1758,9 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e1a39b963e261c58017edf2007e5b63425ad21538aaaf51fe23d1da41703701" +checksum = "855dfedff437d2681d68e1f34ae559d88b0dd84aa5a6b63f2c8e75ebdd875bbf" dependencies = [ "accelerate-src", "byteorder", @@ -1778,24 +1779,27 @@ dependencies = [ "rayon", "safetensors", "thiserror 1.0.63", + "ug", + "ug-cuda", + "ug-metal", "yoke", "zip 1.1.4", ] [[package]] name = "candle-kernels" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "539cbfbf2d1d68a6ed97115e579c77c98f8ed0cfe7edbc6d7d30d2ac0c9e3d50" +checksum = "53343628fa470b7075c28c589b98735b4220b464e37ddbb8e117040e199f4787" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "166a92826d615d98b205e346e52128fa0439f2ab3302587403fdc558b4219e19" +checksum = "50fa64274a009a5d95c542b10bf3a4ea809bd394654c6ae99233bcc35b3a33ef" dependencies = [ "metal 0.27.0", "once_cell", @@ -1805,9 +1809,9 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898f8d21b8bdf559a1c8635e2db8386b2134015cd3003c18c1a30a22a67daec6" +checksum = "ddd3c6b2ee0dfd64af12ae5b07e4b7c517898981cdaeffcb10b71d7dd5c8f359" dependencies = [ "accelerate-src", "candle-core", @@ -1823,9 +1827,9 @@ dependencies = [ [[package]] name = "candle-transformers" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06b8a130a8ac1d1e20696d89f7a52948902e037ad0eec085fceb77021007cfee" +checksum = "4270cc692c4a3df2051c2e8c3c4da3a189746af7ca3a547b99ecd335582b92e1" dependencies = [ "accelerate-src", "byteorder", @@ -3112,7 +3116,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b28bfe653d79bd16c77f659305b195b82bb5ce0c0eb2a4846b82ddbd77586813" dependencies = [ "bitflags 2.6.0", - "libloading 0.8.5", + "libloading 0.7.4", "winapi 0.3.9", ] @@ -4595,7 +4599,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" dependencies = [ - "libloading 0.8.5", + "libloading 0.7.4", ] [[package]] @@ -4803,6 +4807,27 @@ dependencies = [ "memmap2 0.5.10", ] +[[package]] +name = "dynosaur" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92fac44672fabad44990176319b9e94393f3a38b960b5ca2af6cd90f5ecd1497" +dependencies = [ + "dynosaur_derive", + "trait-variant", +] + +[[package]] +name = "dynosaur_derive" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16c187d1e575ef546d24f0fcd7701cc04abfe6b5e7e2758aabc450b99e835ac3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "ecb" version = "0.1.2" @@ -5672,9 +5697,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -5682,9 +5707,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-enum" @@ -5755,9 +5780,9 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" @@ -6382,7 +6407,7 @@ version = "0.1.0" source = "git+https://github.com/s3bk/glyphmatcher#7bd5d40aaa8893fa4f2e07c758fa34127e176c8c" dependencies = [ "font", - "istring 0.4.1", + "istring 0.4.2", "itertools 0.13.0", "pathfinder_content", "pathfinder_geometry", @@ -8691,8 +8716,8 @@ dependencies = [ [[package]] name = "istring" -version = "0.4.1" -source = "git+https://github.com/s3bk/istring#ac7b821c94cf5a1295b4712e60b65419e4c05352" +version = "0.4.2" +source = "git+https://github.com/s3bk/istring#bb736be4b5afda8a273b09ad7b27884c79ff09a5" dependencies = [ "serde", ] @@ -8911,9 +8936,8 @@ dependencies = [ [[package]] name = "kalosm" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", "arroy", "async-trait", "candle-core", @@ -8933,7 +8957,7 @@ dependencies = [ "once_cell", "rand 0.8.5", "serde", - "thiserror 2.0.3", + "thiserror 2.0.11", "tokio", "tracing", ] @@ -8941,9 +8965,8 @@ dependencies = [ [[package]] name = "kalosm-common" version = "0.3.3" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", "candle-core", "candle-nn", "dirs 5.0.1", @@ -8953,6 +8976,7 @@ dependencies = [ "metal 0.29.0", "once_cell", "reqwest 0.11.20", + "thiserror 2.0.11", "tokio", "tracing", ] @@ -8960,7 +8984,7 @@ dependencies = [ [[package]] name = "kalosm-language" version = "0.3.3" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ "anyhow", "arroy", @@ -8984,7 +9008,6 @@ dependencies = [ "kalosm-streams", "llm-samplers", "log", - "meval", "once_cell", "pdf", "pdf_text", @@ -8994,9 +9017,7 @@ dependencies = [ "readability", "reqwest 0.11.20", "roaring", - "rphi", "rss", - "rustc-hash 1.1.0", "scraper", "serde", "serde_json", @@ -9004,8 +9025,7 @@ dependencies = [ "srx", "tempfile", "texting_robots", - "thiserror 2.0.3", - "tokenizers", + "thiserror 2.0.11", "tokio", "tokio-util", "tracing", @@ -9016,37 +9036,31 @@ dependencies = [ [[package]] name = "kalosm-language-model" version = "0.3.3" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ "anyhow", - "async-trait", + "async-lock 3.4.0", "candle-core", + "dynosaur", + "futures-channel", "futures-util", "kalosm-common", "kalosm-sample", - "kalosm-streams", "llm-samplers", "log", "lru", "once_cell", - "postcard", "rand 0.8.5", - "rayon", - "safetensors", "serde", - "thiserror 1.0.63", - "tokenizers", - "tokio", + "thiserror 2.0.11", "tracing", ] [[package]] name = "kalosm-llama" version = "0.3.3" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", - "async-trait", "candle-core", "candle-nn", "candle-transformers", @@ -9054,10 +9068,14 @@ dependencies = [ "kalosm-common", "kalosm-language-model", "kalosm-sample", - "kalosm-streams", "llm-samplers", + "minijinja", + "minijinja-contrib", "once_cell", "rand 0.8.5", + "rayon", + "safetensors", + "thiserror 2.0.11", "tokenizers", "tokio", "tracing", @@ -9066,9 +9084,8 @@ dependencies = [ [[package]] name = "kalosm-ocr" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", "candle-core", "candle-nn", "candle-transformers", @@ -9077,6 +9094,7 @@ dependencies = [ "kalosm-common", "serde", "serde_json", + "thiserror 2.0.11", "tokenizers", "tokio", "tracing", @@ -9085,7 +9103,7 @@ dependencies = [ [[package]] name = "kalosm-parse-macro" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ "proc-macro2", "quote", @@ -9095,9 +9113,8 @@ dependencies = [ [[package]] name = "kalosm-sample" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", "kalosm-parse-macro", "regex-automata 0.4.7", ] @@ -9105,10 +9122,8 @@ dependencies = [ [[package]] name = "kalosm-sound" version = "0.3.4" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", - "async-trait", "byteorder", "candle-core", "candle-nn", @@ -9125,10 +9140,9 @@ dependencies = [ "ort", "ort-sys", "rand 0.8.5", - "rodio", + "rodio 0.20.1", "rwhisper", "serde_json", - "tokenizers", "tokio", "tracing", "voice_activity_detector", @@ -9137,7 +9151,7 @@ dependencies = [ [[package]] name = "kalosm-streams" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ "futures-util", "image 0.24.9", @@ -9148,7 +9162,7 @@ dependencies = [ [[package]] name = "kalosm-vision" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ "image 0.24.9", "kalosm-ocr", @@ -10331,6 +10345,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + [[package]] name = "memoffset" version = "0.6.5" @@ -10403,16 +10423,6 @@ dependencies = [ "paste", ] -[[package]] -name = "meval" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f79496a5651c8d57cd033c5add8ca7ee4e3d5f7587a4777484640d9cb60392d9" -dependencies = [ - "fnv", - "nom 1.2.4", -] - [[package]] name = "mime" version = "0.3.17" @@ -10429,6 +10439,28 @@ dependencies = [ "unicase", ] +[[package]] +name = "minijinja" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff7b8df5e85e30b87c2b0b3f58ba3a87b68e133738bf512a7713769326dbca9" +dependencies = [ + "memo-map", + "self_cell", + "serde", + "serde_json", +] + +[[package]] +name = "minijinja-contrib" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ac3e47a9006ed0500425a092c9f8b2e56d10f8aeec8ce870c5e8a7c6ef2d7c3" +dependencies = [ + "minijinja", + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -10990,12 +11022,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0f889fb66f7acdf83442c35775764b51fed3c606ab9cee51500dbde2cf528ca" -[[package]] -name = "nom" -version = "1.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5b8c256fd9471521bcb84c3cdba98921497f1a331cbc15b8030fc63b82050ce" - [[package]] name = "nom" version = "5.1.3" @@ -11252,7 +11278,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ - "proc-macro-crate 3.1.0", + "proc-macro-crate 2.0.0", "proc-macro2", "quote", "syn 2.0.87", @@ -12179,7 +12205,7 @@ dependencies = [ [[package]] name = "pdf" version = "0.9.0" -source = "git+https://github.com/pdf-rs/pdf#b69e0753c9d6adff567b911494309798e0954e6a" +source = "git+https://github.com/pdf-rs/pdf#2eda40101944f78211b9725107e761203d6df375" dependencies = [ "aes", "bitflags 2.6.0", @@ -12207,7 +12233,7 @@ dependencies = [ [[package]] name = "pdf_derive" version = "0.2.0" -source = "git+https://github.com/pdf-rs/pdf#b69e0753c9d6adff567b911494309798e0954e6a" +source = "git+https://github.com/pdf-rs/pdf#2eda40101944f78211b9725107e761203d6df375" dependencies = [ "proc-macro2", "quote", @@ -12234,7 +12260,7 @@ dependencies = [ [[package]] name = "pdf_render" version = "0.1.0" -source = "git+https://github.com/pdf-rs/pdf_render#24d48551c546ac285e8582a27b0ab2753eaf6a45" +source = "git+https://github.com/pdf-rs/pdf_render#9a31988b091495c65a54236fd6cdce8f1fa2afd0" dependencies = [ "custom_debug_derive", "font", @@ -12242,7 +12268,7 @@ dependencies = [ "glyphmatcher", "image 0.24.9", "instant", - "istring 0.4.1", + "istring 0.4.2", "itertools 0.13.0", "log", "once_cell", @@ -12259,7 +12285,7 @@ dependencies = [ [[package]] name = "pdf_text" version = "0.1.0" -source = "git+https://github.com/pdf-rs/pdf_text#03d4789b4d23b24a991a18d31a352c8400237afc" +source = "git+https://github.com/pdf-rs/pdf_text#fbe582d29ad4af5a4f84f6098352fcc033be6127" dependencies = [ "font", "itertools 0.13.0", @@ -13234,7 +13260,7 @@ dependencies = [ "rustc-hash 2.0.0", "rustls 0.23.12", "socket2 0.5.7", - "thiserror 2.0.3", + "thiserror 2.0.11", "tokio", "tracing", ] @@ -13273,7 +13299,7 @@ dependencies = [ "rustls 0.23.12", "rustls-pki-types", "slab", - "thiserror 2.0.3", + "thiserror 2.0.11", "tinyvec", "tracing", "web-time", @@ -13625,10 +13651,8 @@ dependencies = [ [[package]] name = "rbert" version = "0.3.3" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", - "async-trait", "candle-core", "candle-nn", "candle-transformers", @@ -13637,6 +13661,7 @@ dependencies = [ "metal 0.27.0", "serde", "serde_json", + "thiserror 2.0.11", "tokenizers", "tokio", "tracing", @@ -14202,6 +14227,19 @@ dependencies = [ "symphonia", ] +[[package]] +name = "rodio" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ceb6607dd738c99bc8cb28eff249b7cd5c8ec88b9db96c0608c1480d140fb1" +dependencies = [ + "claxon", + "cpal", + "hound", + "lewton", + "symphonia", +] + [[package]] name = "ron" version = "0.8.1" @@ -14231,28 +14269,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "rphi" -version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" -dependencies = [ - "anyhow", - "async-trait", - "candle-core", - "candle-nn", - "candle-transformers", - "kalosm-common", - "kalosm-language-model", - "kalosm-sample", - "kalosm-streams", - "llm-samplers", - "rand 0.8.5", - "serde_json", - "tokenizers", - "tokio", - "tracing", -] - [[package]] name = "rsa" version = "0.9.6" @@ -14664,11 +14680,9 @@ dependencies = [ [[package]] name = "rwhisper" version = "0.3.5" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ "accelerate-src", - "anyhow", - "async-trait", "byteorder", "candle-core", "candle-nn", @@ -14681,8 +14695,9 @@ dependencies = [ "kalosm-language-model", "kalosm-streams", "rand 0.8.5", - "rodio", + "rodio 0.20.1", "serde_json", + "thiserror 2.0.11", "tokenizers", "tokio", "tracing", @@ -14691,10 +14706,8 @@ dependencies = [ [[package]] name = "rwuerstchen" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", - "async-trait", "candle-core", "candle-nn", "candle-transformers", @@ -14732,9 +14745,9 @@ dependencies = [ [[package]] name = "safetensors" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7725d4d98fa515472f43a6e2bbf956c48e06b89bb50593a040e5945160214450" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" dependencies = [ "serde", "serde_json", @@ -15039,14 +15052,14 @@ dependencies = [ [[package]] name = "segment-anything-rs" version = "0.3.2" -source = "git+https://github.com/coasys/floneum.git?branch=coasys-1#0b704ea2f28df596e9a22ef6c18c66748b0f7262" +source = "git+https://github.com/coasys/floneum.git?branch=coasys-2#1c2603fcc3bd89e5e93c47f8147e0168825b8378" dependencies = [ - "anyhow", "candle-core", "candle-nn", "candle-transformers", "hf-hub", "image 0.24.9", + "thiserror 2.0.11", "tracing", ] @@ -16904,7 +16917,7 @@ dependencies = [ "tauri-runtime", "tauri-runtime-wry", "tauri-utils", - "thiserror 2.0.3", + "thiserror 2.0.11", "tokio", "tray-icon", "url", @@ -16957,7 +16970,7 @@ dependencies = [ "sha2 0.10.8", "syn 2.0.87", "tauri-utils", - "thiserror 2.0.3", + "thiserror 2.0.11", "time 0.3.36", "url", "uuid 1.10.0", @@ -17213,7 +17226,7 @@ dependencies = [ "serde", "serde_json", "tauri-utils", - "thiserror 2.0.3", + "thiserror 2.0.11", "url", "windows 0.58.0", ] @@ -17273,7 +17286,7 @@ dependencies = [ "serde_json", "serde_with", "swift-rs", - "thiserror 2.0.3", + "thiserror 2.0.11", "toml 0.8.19", "url", "urlpattern 0.3.0", @@ -17477,11 +17490,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl 2.0.3", + "thiserror-impl 2.0.11", ] [[package]] @@ -17497,9 +17510,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", @@ -17636,9 +17649,9 @@ dependencies = [ [[package]] name = "tokenizers" -version = "0.19.1" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" +checksum = "9ecededfed68a69bc657e486510089e255e53c3d38cc7d4d59c8742668ca2cae" dependencies = [ "aho-corasick", "derive_builder 0.20.1", @@ -18078,6 +18091,17 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "transpose" version = "0.2.3" @@ -18271,7 +18295,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if 1.0.0", - "rand 0.8.5", + "rand 0.6.5", "static_assertions", ] @@ -18493,6 +18517,48 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "ug" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4eef2ebfc18c67a6dbcacd9d8a4d85e0568cc58c82515552382312c2730ea13" +dependencies = [ + "half 2.4.1", + "num", + "serde", + "serde_json", + "thiserror 1.0.63", +] + +[[package]] +name = "ug-cuda" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c4dcab280ad0ef3957e153a82dcad608c954d02cf253b695322f502d1f8902e" +dependencies = [ + "cudarc", + "half 2.4.1", + "serde", + "serde_json", + "thiserror 1.0.63", + "ug", +] + +[[package]] +name = "ug-metal" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e4ed1df2c20a1a138f993041f650cc84ff27aaefb4342b7f986e77d00e80799" +dependencies = [ + "half 2.4.1", + "metal 0.29.0", + "objc", + "serde", + "serde_json", + "thiserror 1.0.63", + "ug", +] + [[package]] name = "unarray" version = "0.1.4" @@ -19595,7 +19661,7 @@ dependencies = [ "log", "naga", "once_cell", - "parking_lot 0.12.3", + "parking_lot 0.11.2", "profiling", "raw-window-handle 0.6.2", "ron", @@ -19630,14 +19696,14 @@ dependencies = [ "js-sys", "khronos-egl", "libc", - "libloading 0.8.5", + "libloading 0.7.4", "log", "metal 0.28.0", "naga", "ndk-sys 0.5.0+25.2.9519653", "objc", "once_cell", - "parking_lot 0.12.3", + "parking_lot 0.11.2", "profiling", "range-alloc", "raw-window-handle 0.6.2", diff --git a/ad4m-hooks/helpers/package.json b/ad4m-hooks/helpers/package.json index bb80fefa0..2c860c94a 100644 --- a/ad4m-hooks/helpers/package.json +++ b/ad4m-hooks/helpers/package.json @@ -18,5 +18,5 @@ "@coasys/ad4m-connect": "*", "uuid": "*" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/ad4m-hooks/react/package.json b/ad4m-hooks/react/package.json index 178293ee2..55f034374 100644 --- a/ad4m-hooks/react/package.json +++ b/ad4m-hooks/react/package.json @@ -24,5 +24,5 @@ "preact": "*", "react": "*" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/ad4m-hooks/vue/package.json b/ad4m-hooks/vue/package.json index 281d78d08..0cd8ebdb3 100644 --- a/ad4m-hooks/vue/package.json +++ b/ad4m-hooks/vue/package.json @@ -19,5 +19,5 @@ "peerDependencies": { "vue": "^3.2.47" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/bootstrap-languages/agent-language/package.json b/bootstrap-languages/agent-language/package.json index 3e58f5158..c02977ff0 100644 --- a/bootstrap-languages/agent-language/package.json +++ b/bootstrap-languages/agent-language/package.json @@ -44,5 +44,5 @@ "md5": "^2.3.0", "postcss": "^8.2.1" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/bootstrap-languages/direct-message-language/package.json b/bootstrap-languages/direct-message-language/package.json index 8f729eea9..970a903f3 100644 --- a/bootstrap-languages/direct-message-language/package.json +++ b/bootstrap-languages/direct-message-language/package.json @@ -35,5 +35,5 @@ "dependencies": { "@types/node": "^18.0.0" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/bootstrap-languages/neighbourhood-language/package.json b/bootstrap-languages/neighbourhood-language/package.json index a1fee873b..540da8ccb 100644 --- a/bootstrap-languages/neighbourhood-language/package.json +++ b/bootstrap-languages/neighbourhood-language/package.json @@ -8,5 +8,5 @@ }, "author": "joshuadparkin@gmail.com", "license": "ISC", - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/bootstrap-languages/p-diff-sync/package.json b/bootstrap-languages/p-diff-sync/package.json index 8722ab91e..a6a8ff1c0 100644 --- a/bootstrap-languages/p-diff-sync/package.json +++ b/bootstrap-languages/p-diff-sync/package.json @@ -38,5 +38,5 @@ "devDependencies": { "run-script-os": "^1.1.6" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/bootstrap-languages/perspective-language/package.json b/bootstrap-languages/perspective-language/package.json index 898dad30e..83377ceef 100644 --- a/bootstrap-languages/perspective-language/package.json +++ b/bootstrap-languages/perspective-language/package.json @@ -30,5 +30,5 @@ "typescript": "^4.5.5", "uint8arrays": "^3.0.0" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/cli/Cargo.toml b/cli/Cargo.toml index c128da5e7..dcafd4fb6 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "ad4m" -version = "0.10.1-rc7" +version = "0.10.1-rc8" edition = "2021" authors = ["Nicolas Luck "] @@ -21,11 +21,14 @@ path = "src/ad4m.rs" name = "ad4m-executor" path = "src/ad4m_executor.rs" -[dependencies] - +[features] +# Pass metal and cuda features through to ad4m-executor +metal = ["ad4m-executor/metal"] +cuda = ["ad4m-executor/cuda"] -ad4m-client = { path = "../rust-client", version="0.10.1-rc7" } -ad4m-executor = { path = "../rust-executor", version="0.10.1-rc7" } +[dependencies] +ad4m-client = { path = "../rust-client", version="0.10.1-rc8" } +ad4m-executor = { path = "../rust-executor", version="0.10.1-rc8" } anyhow = "1.0.65" clap = { version = "4.0.8", features = ["derive"] } futures = "0.3" diff --git a/connect/package.json b/connect/package.json index 4de9f805c..d79d91dd5 100644 --- a/connect/package.json +++ b/connect/package.json @@ -47,7 +47,7 @@ }, "devDependencies": { "@apollo/client": "3.7.10", - "@coasys/ad4m": "workspace:0.10.1-rc7", + "@coasys/ad4m": "workspace:0.10.1-rc8", "@types/node": "^16.11.11", "esbuild": "^0.15.5", "esbuild-plugin-lit": "^0.0.10", @@ -66,5 +66,5 @@ "esbuild-plugin-replace": "^1.4.0", "lit": "^2.3.1" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/core/package.json b/core/package.json index 056919634..857d5c1a6 100644 --- a/core/package.json +++ b/core/package.json @@ -69,5 +69,5 @@ "graphql@15.7.2": "patches/graphql@15.7.2.patch" } }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/core/src/Ad4mClient.test.ts b/core/src/Ad4mClient.test.ts index 1d88db8c5..3fae10a41 100644 --- a/core/src/Ad4mClient.test.ts +++ b/core/src/Ad4mClient.test.ts @@ -1158,9 +1158,26 @@ describe('Ad4mClient', () => { expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); if (models.length > 0) { - expect(models[0]).toHaveProperty('name'); - expect(models[0]).toHaveProperty('api'); - expect(models[0]).toHaveProperty('local'); + const model = models[0]; + expect(model).toHaveProperty('name'); + expect(model).toHaveProperty('id'); + expect(model).toHaveProperty('modelType'); + if (model.api) { + expect(model.api).toHaveProperty('baseUrl'); + expect(model.api).toHaveProperty('apiKey'); + expect(model.api).toHaveProperty('model'); + expect(model.api).toHaveProperty('apiType'); + } + if (model.local) { + expect(model.local).toHaveProperty('fileName'); + if (model.local.tokenizerSource) { + expect(model.local.tokenizerSource).toHaveProperty('repo'); + expect(model.local.tokenizerSource).toHaveProperty('revision'); + expect(model.local.tokenizerSource).toHaveProperty('fileName'); + } + expect(model.local).toHaveProperty('huggingfaceRepo'); + expect(model.local).toHaveProperty('revision'); + } } }) @@ -1175,8 +1192,13 @@ describe('Ad4mClient', () => { }, local: { fileName: "new-test-model.bin", - tokenizerSource: "new-test-tokenizer", - modelParameters: "{}" + tokenizerSource: { + repo: "test-repo", + revision: "main", + fileName: "tokenizer.json" + }, + huggingfaceRepo: "test-repo", + revision: "main" }, modelType: "LLM" }; @@ -1196,8 +1218,13 @@ describe('Ad4mClient', () => { }, local: { fileName: "updated-test-model.bin", - tokenizerSource: "updated-test-tokenizer", - modelParameters: "{}" + tokenizerSource: { + repo: "test-repo", + revision: "main", + fileName: "tokenizer.json" + }, + huggingfaceRepo: "test-repo", + revision: "main" }, modelType: "LLM" }; @@ -1222,18 +1249,23 @@ describe('Ad4mClient', () => { const setResult = await ad4mClient.ai.setDefaultModel(modelType, modelName); expect(setResult).toBe(true); - const model = await ad4mClient.ai.getDefaultModel(modelType); - expect(model).toBeDefined(); - expect(model.name).toBe("Default Test Model"); - expect(model.api).toBeDefined(); - expect(model.api.baseUrl).toBe("https://api.example.com"); - expect(model.api.apiKey).toBe("test-api-key"); - expect(model.api.apiType).toBe("OpenAi"); - expect(model.local).toBeDefined(); - expect(model.local.fileName).toBe("test-model.bin"); - expect(model.local.tokenizerSource).toBe("test-tokenizer"); - expect(model.local.modelParameters).toBe("{}"); - expect(model.modelType).toBe(modelType); + const defaultModel = await ad4mClient.ai.getDefaultModel(modelType); + expect(defaultModel).toBeDefined(); + expect(defaultModel.name).toBe("Default Test Model"); + expect(defaultModel.api).toBeDefined(); + expect(defaultModel.api.baseUrl).toBe("https://api.example.com"); + expect(defaultModel.api.apiKey).toBe("test-api-key"); + expect(defaultModel.api.apiType).toBe("OpenAi"); + expect(defaultModel.local).toBeDefined(); + expect(defaultModel.local.fileName).toBe("test-model.bin"); + if (defaultModel.local.tokenizerSource) { + expect(defaultModel.local.tokenizerSource.repo).toBe("test-repo"); + expect(defaultModel.local.tokenizerSource.revision).toBe("main"); + expect(defaultModel.local.tokenizerSource.fileName).toBe("tokenizer.json"); + } + expect(defaultModel.local.huggingfaceRepo).toBe("test-repo"); + expect(defaultModel.local.revision).toBe("main"); + expect(defaultModel.modelType).toBe(modelType); }) it('embed()', async () => { diff --git a/core/src/ai/AIClient.ts b/core/src/ai/AIClient.ts index 87e844638..3f97a38a1 100644 --- a/core/src/ai/AIClient.ts +++ b/core/src/ai/AIClient.ts @@ -28,8 +28,13 @@ export class AIClient { } local { fileName - tokenizerSource - modelParameters + tokenizerSource { + repo + revision + fileName + } + huggingfaceRepo + revision } modelType } @@ -102,8 +107,13 @@ export class AIClient { } local { fileName - tokenizerSource - modelParameters + tokenizerSource { + repo + revision + fileName + } + huggingfaceRepo + revision } modelType } diff --git a/core/src/ai/AIResolver.ts b/core/src/ai/AIResolver.ts index ee3c579b9..ad1af964f 100644 --- a/core/src/ai/AIResolver.ts +++ b/core/src/ai/AIResolver.ts @@ -24,15 +24,30 @@ export class ModelApi { } @ObjectType() -export class LocalModel { +export class TokenizerSource { @Field() - fileName: string; + repo: string; @Field() - tokenizerSource: string; + revision: string; @Field() - modelParameters: string; + fileName: string; +} + +@ObjectType() +export class LocalModel { + @Field() + fileName: string; + + @Field({ nullable: true }) + tokenizerSource?: TokenizerSource; + + @Field({ nullable: true }) + huggingfaceRepo?: string; + + @Field({ nullable: true }) + revision?: string; } export type ModelType = "LLM" | "EMBEDDING" | "TRANSCRIPTION"; @@ -71,15 +86,30 @@ export class ModelApiInput { } @InputType() -export class LocalModelInput { +export class TokenizerSourceInput { @Field() - fileName: string; + repo: string; + + @Field() + revision: string; @Field() - tokenizerSource: string; + fileName: string; +} +@InputType() +export class LocalModelInput { @Field() - modelParameters: string; + fileName: string; + + @Field({ nullable: true }) + tokenizerSource?: TokenizerSourceInput; + + @Field({ nullable: true }) + huggingfaceRepo?: string; + + @Field({ nullable: true }) + revision?: string; } @InputType() @@ -113,8 +143,13 @@ export default class AIResolver { }, local: { fileName: "test-model.bin", - tokenizerSource: "test-tokenizer", - modelParameters: "{}" + tokenizerSource: { + repo: "test-repo", + revision: "main", + fileName: "tokenizer.json" + }, + huggingfaceRepo: "test-repo", + revision: "main" }, modelType: "LLM" } @@ -165,8 +200,13 @@ export default class AIResolver { }, local: { fileName: "test-model.bin", - tokenizerSource: "test-tokenizer", - modelParameters: "{}" + tokenizerSource: { + repo: "test-repo", + revision: "main", + fileName: "tokenizer.json" + }, + huggingfaceRepo: "test-repo", + revision: "main" }, modelType: modelType } diff --git a/docs/package.json b/docs/package.json index 1d71f3804..e0e38af49 100644 --- a/docs/package.json +++ b/docs/package.json @@ -22,5 +22,5 @@ "typedoc-plugin-markdown": "^3.15.2", "typescript": "^4.9.3" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/executor/package.json b/executor/package.json index 37e23bb6e..074f1a10e 100644 --- a/executor/package.json +++ b/executor/package.json @@ -78,5 +78,5 @@ "tmp": "^0.2.1", "uuid": "*" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/executor/src/core/Config.ts b/executor/src/core/Config.ts index baa2c881a..f031058db 100644 --- a/executor/src/core/Config.ts +++ b/executor/src/core/Config.ts @@ -2,7 +2,7 @@ import * as path from 'node:path'; import * as fs from 'node:fs'; import { Address, Expression } from '@coasys/ad4m'; -export let ad4mExecutorVersion = "0.10.1-rc7"; +export let ad4mExecutorVersion = "0.10.1-rc8"; export let agentLanguageAlias = "did"; export let languageLanguageAlias = "lang"; export let neighbourhoodLanguageAlias = "neighbourhood"; diff --git a/package.json b/package.json index 6754d39cc..e4b73c144 100644 --- a/package.json +++ b/package.json @@ -94,5 +94,5 @@ "safer-buffer@2.1.2": "patches/safer-buffer@2.1.2.patch" } }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 643c41e1c..afcc60c5e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -699,7 +699,7 @@ importers: devDependencies: '@coasys/ad4m-executor': specifier: '*' - version: 0.10.1-rc6(ws@8.13.0(bufferutil@4.0.8)(utf-8-validate@5.0.10)) + version: 0.10.1-rc7(ws@8.13.0(bufferutil@4.0.8)(utf-8-validate@5.0.10)) connect: dependencies: @@ -729,7 +729,7 @@ importers: specifier: 3.7.10 version: 3.7.10(graphql-ws@5.12.0(graphql@15.7.2(patch_hash=nr4gprddtjag7fz5nm4wirqs4q)))(graphql@15.7.2(patch_hash=nr4gprddtjag7fz5nm4wirqs4q))(react-dom@18.2.0(react@18.2.0))(react@18.2.0) '@coasys/ad4m': - specifier: workspace:0.10.1-rc7 + specifier: workspace:0.10.1-rc8 version: link:../core '@types/node': specifier: ^16.11.11 @@ -2276,8 +2276,8 @@ packages: '@coasys/ad4m-connect@0.8.1': resolution: {integrity: sha512-pbyeescsVOVAnXjn2Uh4BV0oWbuiseHT0HvGmDISdaJztaFv1MNglka2oOFf9xYmwz7PBExYNprnE5g75YVyNQ==} - '@coasys/ad4m-executor@0.10.1-rc6': - resolution: {integrity: sha512-kSMEReTb6QKNwIrN9tSjUJyienNY5QtY7kW51GR+F3PNheGm5/4T5tJA1bzzhVzshtsatZMpWBSJ823YS0aHwg==} + '@coasys/ad4m-executor@0.10.1-rc7': + resolution: {integrity: sha512-kcNuq5n98HXTqT8HIwaIvdGHf/WgXC5IJBoOpjtRuUfN3C7Z2kfcQIy9Jv84iiiarmpm1VOiFuOTMt0K1Rz5Cw==} '@coasys/ad4m@0.8.1': resolution: {integrity: sha512-2or0Ykc+F+geDumBABYD/us5Iu9Se2xVRdVcj4h0l7etd9Zx3u5m/Q32YHywJz61DbPQV8Q+zTkYiqF54y9GrA==} @@ -15736,7 +15736,7 @@ snapshots: - esbuild - supports-color - '@coasys/ad4m-executor@0.10.1-rc6(ws@8.13.0(bufferutil@4.0.8)(utf-8-validate@5.0.10))': + '@coasys/ad4m-executor@0.10.1-rc7(ws@8.13.0(bufferutil@4.0.8)(utf-8-validate@5.0.10))': dependencies: '@coasys/ad4m': link:../core '@holochain/client': https://codeload.github.com/coasys/holochain-client-js/tar.gz/2f3a436b6d28344b0aca883ef3dc229cd042c04b(ws@8.13.0(bufferutil@4.0.8)(utf-8-validate@5.0.10)) diff --git a/rust-client/Cargo.toml b/rust-client/Cargo.toml index 7c7b190fb..2fb71603d 100644 --- a/rust-client/Cargo.toml +++ b/rust-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ad4m-client" -version = "0.10.1-rc7" +version = "0.10.1-rc8" edition = "2021" authors = ["Nicolas Luck "] description = "Client library wrapping AD4M's GraphQL interface" diff --git a/rust-executor/Cargo.toml b/rust-executor/Cargo.toml index 9c49e0b0b..e9d60771e 100644 --- a/rust-executor/Cargo.toml +++ b/rust-executor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ad4m-executor" -version = "0.10.1-rc7" +version = "0.10.1-rc8" edition = "2021" authors = ["Nicolas Luck "] description = "Runtime implementation of AD4M as library - https://ad4m.dev" @@ -92,7 +92,7 @@ kitsune_p2p_types = { version = "0.3.6", git = "https://github.com/coasys/holoch scryer-prolog = { version = "0.9.4", git = "https://github.com/coasys/scryer-prolog" } # scryer-prolog = { path = "../../scryer-prolog", features = ["multi_thread"] } -ad4m-client = { path = "../rust-client", version="0.10.1-rc7" } +ad4m-client = { path = "../rust-client", version="0.10.1-rc8" } reqwest = { version = "0.11.20", features = ["json", "native-tls"] } rusqlite = { version = "0.29.0", git = "https://github.com/coasys/rusqlite.git", rev = "12ec1330bd4b46411ab9895364da4a3e172d0fbb", features = ["bundled"] } @@ -106,8 +106,8 @@ rustls = "0.23" tokio-rustls = "0.26" rustls-pemfile = "2" -kalosm = { version = "0.3.2", git = "https://github.com/coasys/floneum.git", branch = "coasys-1", features = ["language", "sound"] } -candle-core = "0.7" +kalosm = { version = "0.3.2", git = "https://github.com/coasys/floneum.git", branch = "coasys-2", features = ["language", "sound"] } +candle-core = "0.8.2" deflate = "1.0.0" futures-core = "0.3.30" futures-util = "0.3.30" @@ -115,6 +115,7 @@ futures-channel = "0.3.30" rodio = "*" libc = "0.2" chat-gpt-lib-rs = { version = "0.5.1", git = "https://github.com/coasys/chat-gpt-lib-rs" } +anyhow = "1.0.95" [dev-dependencies] maplit = "1.0.2" diff --git a/rust-executor/package.json b/rust-executor/package.json index d065aa1fd..63c933675 100644 --- a/rust-executor/package.json +++ b/rust-executor/package.json @@ -31,5 +31,5 @@ "@coasys/ad4m-executor": "link:../core" }, "dependencies": {}, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/rust-executor/src/ai_service/mod.rs b/rust-executor/src/ai_service/mod.rs index f53cea74b..2313f7c52 100644 --- a/rust-executor/src/ai_service/mod.rs +++ b/rust-executor/src/ai_service/mod.rs @@ -18,7 +18,6 @@ use kalosm::sound::TextStream; use kalosm::sound::*; use std::collections::HashMap; use std::future::Future; -use std::panic::catch_unwind; use std::pin::Pin; use std::sync::Arc; use std::thread; @@ -77,12 +76,19 @@ struct LLMTaskRemoveRequest { pub result_sender: oneshot::Sender<()>, } +#[allow(dead_code)] +#[derive(Debug)] +struct LLMTaskShutdownRequest { + pub result_sender: oneshot::Sender<()>, +} + #[allow(dead_code)] #[derive(Debug)] enum LLMTaskRequest { Spawn(LLMTaskSpawnRequest), Prompt(LLMTaskPromptRequest), Remove(LLMTaskRemoveRequest), + Shutdown(LLMTaskShutdownRequest), } enum LlmModel { @@ -170,8 +176,9 @@ impl AIService { model_type: ModelType::Embedding, local: Some(LocalModel { file_name: "bert".to_string(), - tokenizer_source: String::new(), - model_parameters: String::new(), + tokenizer_source: None, + huggingface_repo: None, + revision: None, }), api: None, }) @@ -203,9 +210,9 @@ impl AIService { async fn init_model(&self, model: crate::types::Model) -> Result<()> { match model.model_type { - ModelType::Llm => self.spawn_llm_model(model).await?, + ModelType::Llm => self.spawn_llm_model(model, None).await?, ModelType::Embedding => self.spawn_embedding_model(model).await, - ModelType::Transcription => Self::load_transcriber_model(&model).await, + ModelType::Transcription => Self::load_transcriber_model(model.id.clone()).await, }; Ok(()) } @@ -281,47 +288,76 @@ impl AIService { Device::Cpu } } - async fn build_local_llama_from_string( - model_id: String, - model_size_string: String, - ) -> Result { + async fn build_local_llama(model_id: String, model_config: LocalModel) -> Result { publish_model_status(model_id.clone(), 0.0, "Loading", false, false).await; - - let llama = match model_size_string.as_str() { - // Local TinyLlama models - "llama_tiny" => Llama::builder().with_source(LlamaSource::tiny_llama_1_1b()), - "llama_7b" => Llama::builder().with_source(LlamaSource::llama_7b()), - "llama_7b_chat" => Llama::builder().with_source(LlamaSource::llama_7b_chat()), - "llama_7b_code" => Llama::builder().with_source(LlamaSource::llama_7b_code()), - "llama_8b" => Llama::builder().with_source(LlamaSource::llama_8b()), - "llama_8b_chat" => Llama::builder().with_source(LlamaSource::llama_8b_chat()), - "llama_3_1_8b_chat" => Llama::builder().with_source(LlamaSource::llama_3_1_8b_chat()), - "llama_13b" => Llama::builder().with_source(LlamaSource::llama_13b()), - "llama_13b_chat" => Llama::builder().with_source(LlamaSource::llama_13b_chat()), - "llama_13b_code" => Llama::builder().with_source(LlamaSource::llama_13b_code()), - "llama_34b_code" => Llama::builder().with_source(LlamaSource::llama_34b_code()), - "llama_70b" => Llama::builder().with_source(LlamaSource::llama_70b()), - "mistral_7b" => Llama::builder().with_source(LlamaSource::mistral_7b()), - "mistral_7b_instruct" => { - Llama::builder().with_source(LlamaSource::mistral_7b_instruct()) - } - "mistral_7b_instruct_2" => { - Llama::builder().with_source(LlamaSource::mistral_7b_instruct_2()) - } - "solar_10_7b" => Llama::builder().with_source(LlamaSource::solar_10_7b()), - "solar_10_7b_instruct" => { - Llama::builder().with_source(LlamaSource::solar_10_7b_instruct()) - } + let llama = Llama::builder().with_source(match model_config.file_name.as_str() { + // First check model name shortcuts + "Qwen2.5.1-Coder-7B-Instruct" => LlamaSource::new(FileSource::huggingface( + "bartowski/Qwen2.5.1-Coder-7B-Instruct-GGUF".to_string(), + "main".to_string(), + "Qwen2.5.1-Coder-7B-Instruct-Q4_K_M.gguf".to_string(), + )), + "deepseek_r1_distill_qwen_1_5b" => LlamaSource::deepseek_r1_distill_qwen_1_5b(), + "deepseek_r1_distill_qwen_7b" => LlamaSource::deepseek_r1_distill_qwen_7b(), + "deepseek_r1_distill_qwen_14b" => LlamaSource::deepseek_r1_distill_qwen_14b(), + "deepseek_r1_distill_llama_8b" => LlamaSource::deepseek_r1_distill_llama_8b(), + "llama_tiny" => LlamaSource::tiny_llama_1_1b(), + "llama_tiny_1_1b_chat" => LlamaSource::tiny_llama_1_1b_chat(), + "llama_7b" => LlamaSource::llama_7b(), + "llama_7b_chat" => LlamaSource::llama_7b_chat(), + "llama_7b_code" => LlamaSource::llama_7b_code(), + "llama_8b" => LlamaSource::llama_8b(), + "llama_8b_chat" => LlamaSource::llama_8b_chat(), + "llama_3_1_8b_chat" => LlamaSource::llama_3_1_8b_chat(), + "llama_13b" => LlamaSource::llama_13b(), + "llama_13b_chat" => LlamaSource::llama_13b_chat(), + "llama_13b_code" => LlamaSource::llama_13b_code(), + "llama_34b_code" => LlamaSource::llama_34b_code(), + "llama_70b" => LlamaSource::llama_70b(), + "mistral_7b" => LlamaSource::mistral_7b(), + "mistral_7b_instruct" => LlamaSource::mistral_7b_instruct(), + "mistral_7b_instruct_2" => LlamaSource::mistral_7b_instruct_2(), + "solar_10_7b" => LlamaSource::solar_10_7b(), + "solar_10_7b_instruct" => LlamaSource::solar_10_7b_instruct(), // Handle unknown models _ => { - log::error!("Unknown model string: {}", model_size_string); - return Err(anyhow::anyhow!( - "Unknown model string: {}", - model_size_string - )); + if let Some(repo) = model_config.huggingface_repo.clone() { + log::info!("Trying to load model from Huggingface:\n + model_config.file_name: {:?}\n + model_config.huggingface_repo: {:?}\n + model_config.revision: {:?}", model_config.file_name, model_config.huggingface_repo, model_config.revision); + let mut builder = LlamaSource::new(FileSource::huggingface( + repo, + model_config.revision.unwrap_or("main".to_string()), + model_config.file_name, + )); + if let Some(tokenizer_source) = model_config.tokenizer_source { + log::info!("Trying to load tokenizer from Huggingface:\n + tokenizer_source.repo: {:?}\n + tokenizer_source.revision: {:?}\n + tokenizer_source.file_name: {:?}", tokenizer_source.repo, tokenizer_source.revision, tokenizer_source.file_name); + builder = builder.with_tokenizer(FileSource::huggingface( + tokenizer_source.repo, + tokenizer_source.revision, + tokenizer_source.file_name, + )); + } + builder + } else { + log::error!( + "Unknown model string: {} and no Huggingface repo provided. Don't know where to get model weights from for: {}", + model_config.file_name, + model_id + ); + return Err(anyhow::anyhow!( + "Unknown model string: {} and no Huggingface repo provided. Don't know where to get model weights from for: {}", + model_config.file_name, + model_id + )); + } } - }; + }); // Build the local Llama model let llama = llama @@ -339,7 +375,11 @@ impl AIService { Ok(llama) } - async fn build_remote_gpt4(model_id: String, api_key: String, base_url: Url) -> ChatGPTClient { + async fn build_remote_client( + model_id: String, + api_key: String, + base_url: Url, + ) -> ChatGPTClient { let mut url = base_url; if let Some(segments) = url.path_segments() { if segments.clone().next() == Some("v1") { @@ -352,7 +392,11 @@ impl AIService { client } - async fn spawn_llm_model(&self, model_config: crate::types::Model) -> Result<()> { + async fn spawn_llm_model( + &self, + model_config: crate::types::Model, + model_ready_sender: Option>, + ) -> Result<()> { if model_config.local.is_none() && model_config.api.is_none() { return Err(anyhow!( "AI model definition {} doesn't have a body, nothing to spawn!", @@ -361,20 +405,30 @@ impl AIService { } let (llama_tx, mut llama_rx) = mpsc::unbounded_channel::(); - let model_id = model_config.id.clone(); + self.llm_channel + .lock() + .await + .insert(model_config.id.clone(), llama_tx); thread::spawn({ move || { let model_id = model_config.id.clone(); let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Spawning model thread...", + true, + false, + )); let maybe_model = rt .block_on(async { if let Some(local_model) = model_config.local { - Self::build_local_llama_from_string(model_id, local_model.file_name) + Self::build_local_llama(model_id, local_model) .await .map(LlmModel::Local) } else if let Some(api) = model_config.api { Ok(LlmModel::Remote(( - Self::build_remote_gpt4(model_id, api.api_key, api.base_url).await, + Self::build_remote_client(model_id, api.api_key, api.base_url).await, api.model ))) } else { @@ -387,7 +441,7 @@ impl AIService { Err(e) => { error!("Failed to build LLM model: {}", e); rt.block_on(publish_model_status( - model_config.id, + model_config.id.clone(), 100.0, &format!("Failed to build LLM model: {}", e), true, @@ -397,7 +451,7 @@ impl AIService { } }; - let mut tasks = HashMap::::new(); + let mut tasks = HashMap::>::new(); let mut task_descriptions = HashMap::::new(); let idle_delay = Duration::from_millis(1); @@ -409,6 +463,10 @@ impl AIService { true, )); + if let Some(model_ready_sender) = model_ready_sender { + let _ = model_ready_sender.send(()); + } + loop { match rt.block_on(async { tokio::select! { @@ -419,6 +477,19 @@ impl AIService { Err(_timeout) => std::thread::sleep(idle_delay * 5), Ok(None) => break, Ok(Some(task_request)) => match task_request { + LLMTaskRequest::Shutdown(shutdown_request) => { + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Shutting down", + true, + false, + )); + + // Send confirmation before breaking + let _ = shutdown_request.result_sender.send(()); + break; + } LLMTaskRequest::Spawn(spawn_request) => match model { LlmModel::Remote(_) => { task_descriptions.insert( @@ -436,42 +507,23 @@ impl AIService { true, )); let task_description = spawn_request.task; - let task = - Task::builder(task_description.system_prompt.clone()) - .with_examples( - task_description - .prompt_examples - .clone() - .into_iter() - .map(|example| (example.input, example.output)) - .collect::>(), - ) - .build(); - - let mut task_run = false; - let mut tries = 0; - while !task_run && tries < 20 { - tries += 1; - - match catch_unwind(|| { - rt.block_on(task.run("Test example prompt", llama).all_text()) - }) { - Err(e) => log::error!( - "Llama panicked during task spawn with: {:?}. Trying again..", - e - ), - Ok(_) => task_run = true, - } - } - if task_run { - tasks.insert(task_description.task_id.clone(), task); - let _ = spawn_request.result_sender.send(Ok(())); - } else { - let _ = spawn_request - .result_sender - .send(Err(anyhow!("Couldn't run task without panics"))); - } + let task = llama + .task(task_description.system_prompt.clone()) + .with_examples( + task_description + .prompt_examples + .clone() + .into_iter() + .map(|example| (example.input, example.output)) + .collect::>(), + ); + + rt.block_on(task.run("Test example prompt").all_text()); + + tasks.insert(task_description.task_id.clone(), task); + let _ = spawn_request.result_sender.send(Ok(())); + rt.block_on(publish_model_status( model_config.id.clone(), 100.0, @@ -545,7 +597,7 @@ impl AIService { ))); } } - LlmModel::Local(ref mut llama) => { + LlmModel::Local(_) => { if let Some(task) = tasks.get(&prompt_request.task_id) { rt.block_on(publish_model_status( model_config.id.clone(), @@ -554,34 +606,10 @@ impl AIService { true, true, )); - let mut maybe_result: Option = None; - let mut tries = 0; - while maybe_result.is_none() && tries < 20 { - tries += 1; - - match catch_unwind(|| { - rt.block_on(async { - task.run(prompt_request.prompt.clone(), llama) - .all_text() - .await - }) - }) { - Err(e) => { - log::error!( - "Llama panicked with: {:?}. Trying again..", - e - ); - rt.block_on(publish_model_status( - model_config.id.clone(), - 100.0, - "Panicked while running inference - trying again...", - true, - true, - )); - } - Ok(result) => maybe_result = Some(result), - } - } + + let result = rt.block_on(async { + task.run(prompt_request.prompt.clone()).all_text().await + }); rt.block_on(publish_model_status( model_config.id.clone(), @@ -591,11 +619,7 @@ impl AIService { true, )); - if let Some(result) = maybe_result { - let _ = prompt_request.result_sender.send(Ok(result)); - } else { - let _ = prompt_request.result_sender.send(Err(anyhow!("Unable to get response from Llama model. Giving up after 20 retries"))); - } + let _ = prompt_request.result_sender.send(Ok(result)); } else { let _ = prompt_request.result_sender.send(Err(anyhow!( "Task with ID {} not spawned", @@ -616,7 +640,6 @@ impl AIService { } }); - self.llm_channel.lock().await.insert(model_id, llama_tx); Ok(()) } @@ -779,7 +802,7 @@ impl AIService { publish_model_status(model_id.clone(), 0.0, "Loading", false, false).await; let bert = Bert::builder() - .with_device(Device::Cpu) + .with_device(Self::new_candle_device()) .build_with_loading_handler({ let model_id = model_id.clone(); move |progress| { @@ -807,7 +830,8 @@ impl AIService { Ok(Some(request)) => { let result: Result> = rt .block_on(async { model.embed(request.prompt).await }) - .map(|tensor| tensor.to_vec()); + .map(|tensor| tensor.to_vec()) + .map_err(|bert_error| anyhow!(bert_error)); let _ = request.result_sender.send(result); } } @@ -844,7 +868,66 @@ impl AIService { // Whisper / Transcription // ------------------------------------- - pub async fn open_transcription_stream(&self, _model_id: String) -> Result { + fn whisper_string_to_model(whisper_string: String) -> Result { + match whisper_string.as_str() { + "whisper_tiny" => Ok(WhisperSource::Tiny), + "whisper_tiny_quantized" => Ok(WhisperSource::QuantizedTiny), + "whisper_tiny_en" => Ok(WhisperSource::TinyEn), + "whisper_tiny_en_quantized" => Ok(WhisperSource::QuantizedTinyEn), + "whisper_base" => Ok(WhisperSource::Base), + "whisper_base_en" => Ok(WhisperSource::BaseEn), + "whisper_small" => Ok(WhisperSource::Small), + "whisper_small_en" => Ok(WhisperSource::SmallEn), + "whisper_medium" => Ok(WhisperSource::Medium), + "whisper_medium_en" => Ok(WhisperSource::MediumEn), + "whisper_medium_en_quantized_distil" => Ok(WhisperSource::QuantizedDistilMediumEn), + "whisper_large" => Ok(WhisperSource::Large), + "whisper_large_v2" => Ok(WhisperSource::LargeV2), + "whisper_distil_medium_en" => Ok(WhisperSource::DistilMediumEn), + "whisper_distil_large_v2" => Ok(WhisperSource::DistilLargeV2), + "whisper_distil_large_v3" => Ok(WhisperSource::DistilLargeV3), + "whisper_distil_large_v3_quantized" => Ok(WhisperSource::QuantizedDistilLargeV3), + "whisper_large_v3_turbo_quantized" => Ok(WhisperSource::QuantizedLargeV3Turbo), + _ => Err(anyhow!("Unknown whisper model: {}", whisper_string)), + } + } + + fn get_whisper_model_size(model_id: String) -> Result { + // Try to treat string as model size string first + if let Ok(model) = Self::whisper_string_to_model(model_id.clone()) { + return Ok(model); + } + + // Try to get model from DB by ID + if let Ok(Some(model)) = Ad4mDb::with_global_instance(|db| db.get_model(model_id.clone())) { + if model.model_type != ModelType::Transcription { + return Err(anyhow!("Model '{}' is not a transcription model", model_id)); + } + // Use filename from local model config + if let Some(local) = model.local { + return Self::whisper_string_to_model(local.file_name); + } + } + + // if nothing above works, see if we have a transcription model in the DB and use that + // Try to find first transcription model in DB + if let Ok(models) = Ad4mDb::with_global_instance(|db| db.get_models()) { + if let Some(model) = models + .into_iter() + .find(|m| m.model_type == ModelType::Transcription) + { + if let Some(local) = model.local { + return Self::whisper_string_to_model(local.file_name); + } + } + } + + // Default to tiny if nothing found + Ok(WhisperSource::Tiny) + } + + pub async fn open_transcription_stream(&self, model_id: String) -> Result { + let model_size = Self::get_whisper_model_size(model_id)?; let stream_id = uuid::Uuid::new_v4().to_string(); let stream_id_clone = stream_id.clone(); let (samples_tx, samples_rx) = futures_channel::mpsc::unbounded::>(); @@ -857,8 +940,8 @@ impl AIService { rt.block_on(async { let maybe_model = WhisperBuilder::default() - .with_source(WHISPER_MODEL) - .with_device(Device::Cpu) + .with_source(model_size) + .with_device(Self::new_candle_device()) .build() .await; @@ -950,30 +1033,160 @@ impl AIService { } } - async fn load_transcriber_model(model: &crate::types::Model) { - let id = &model.id; - publish_model_status(id.clone(), 0.0, "Loading", false, false).await; + async fn load_transcriber_model(model_id: String) { + publish_model_status(model_id.clone(), 0.0, "Loading", false, false).await; + + let model_size = Self::get_whisper_model_size(model_id.clone()) + .ok() + .unwrap_or(WHISPER_MODEL); let _ = WhisperBuilder::default() - .with_source(WHISPER_MODEL) - .with_device(Device::Cpu) + .with_source(model_size) + .with_device(Self::new_candle_device()) .build_with_loading_handler({ - let name = id.clone(); + let name = model_id.clone(); move |progress| { tokio::spawn(handle_progress(name.clone(), progress)); } }) .await; - publish_model_status(id.clone(), 100.0, "Loaded", true, false).await; + publish_model_status(model_id.clone(), 100.0, "Loaded", true, false).await; + } + + pub async fn update_model(&self, model_id: String, model_config: ModelInput) -> Result<()> { + log::info!("Updating model: {} with: {:?}", model_id, model_config); + // First get the existing model to determine its type + let existing_model = Ad4mDb::with_global_instance(|db| db.get_model(model_id.clone())) + .map_err(|e| anyhow!("Database error: {}", e))? + .ok_or_else(|| anyhow!("Model not found: {}", model_id))?; + + // Update the model in the database + Ad4mDb::with_global_instance(|db| db.update_model(&model_id, &model_config)) + .map_err(|e| anyhow!("Failed to update model in database: {}", e))?; + + // Get the updated model from the database + let updated_model = Ad4mDb::with_global_instance(|db| db.get_model(model_id.clone())) + .map_err(|e| anyhow!("Database error: {}", e))? + .ok_or_else(|| anyhow!("Model not found after update: {}", model_id))?; + + log::info!("Updated model in DB: {:?}", updated_model); + + match existing_model.model_type { + ModelType::Llm => { + // Shutdown the existing model thread + { + let mut llm_channel = self.llm_channel.lock().await; + if let Some(sender) = llm_channel.get(&model_id) { + log::info!("Shutting down LLM model thread for {}", model_id); + let (tx, rx) = oneshot::channel(); + if let Ok(()) = + sender.send(LLMTaskRequest::Shutdown(LLMTaskShutdownRequest { + result_sender: tx, + })) + { + // Wait for the thread to confirm shutdown + let _ = rx.await; + log::info!("LLM model thread for {} confirmed shutdown", model_id); + } + + // Remove the channel from the map + llm_channel.remove(&model_id); + } else { + log::info!( + "LLM model thread for {} not found. Nothing to shutdown", + model_id + ); + } + } + + // Spawn the model with new configuration + log::info!( + "Spawning new LLM model thread for {} with updated config", + model_id + ); + let (model_ready_tx, model_ready_rx) = oneshot::channel(); + self.spawn_llm_model(updated_model, Some(model_ready_tx)) + .await?; + model_ready_rx.await?; + + // Respawn all tasks for this model + let tasks = Ad4mDb::with_global_instance(|db| db.get_tasks()) + .map_err(|e| AIServiceError::DatabaseError(e.to_string()))?; + + for task in tasks.into_iter().filter(|t| t.model_id == model_id) { + self.spawn_task(task).await?; + } + } + ModelType::Embedding => { + // TODO: Handle embedding model updates + } + ModelType::Transcription => { + Self::load_transcriber_model(updated_model.id.clone()).await; + } + } + + Ok(()) + } + + pub async fn remove_model(&self, model_id: String) -> Result<()> { + // First get the existing model to determine its type + let existing_model = Ad4mDb::with_global_instance(|db| db.get_model(model_id.clone())) + .map_err(|e| anyhow!("Database error: {}", e))? + .ok_or_else(|| anyhow!("Model not found: {}", model_id))?; + + match existing_model.model_type { + ModelType::Llm => { + log::info!("Shutting down LLM model thread for {}", model_id); + // Shutdown the existing model thread + let mut llm_channel = self.llm_channel.lock().await; + if let Some(sender) = llm_channel.get(&model_id) { + let (tx, rx) = oneshot::channel(); + if let Ok(()) = sender.send(LLMTaskRequest::Shutdown(LLMTaskShutdownRequest { + result_sender: tx, + })) { + // Wait for the thread to confirm shutdown + let _ = rx.await; + } + + log::info!("LLM model thread for {} confirmed shutdown", model_id); + + // Remove the channel from the map + llm_channel.remove(&model_id); + } else { + log::warn!( + "LLM model thread for {} not found. Nothing to shutdown", + model_id + ); + } + } + ModelType::Embedding => { + // TODO: Handle embedding model removal + } + ModelType::Transcription => { + // TODO: Handle transcription model removal + } + } + + // Remove the model from the database + Ad4mDb::with_global_instance(|db| db.remove_model(&model_id)) + .map_err(|e| anyhow!("Failed to remove model from database: {}", e))?; + + Ok(()) } } #[cfg(test)] mod tests { - use crate::graphql::graphql_types::AIPromptExamplesInput; - use super::*; + use crate::graphql::graphql_types::{AIPromptExamplesInput, LocalModelInput}; + + // TODO: We ignore these tests because they need a GPU to not take ages to run + // BUT: the model lifecycle and update tests show another problem: + // We can't run them in parallel with each other or other tests because + // the one global DB gets reseted for each test. + // -> need to refactor this so that services like AIService or PerspectiveInstance + // get an DB reference passed in, so we can write proper unit tests. #[ignore] #[tokio::test] @@ -1059,4 +1272,109 @@ mod tests { println!("Responses: {}", response); assert!(!response.is_empty()) } + + #[ignore] + #[tokio::test] + async fn test_model_lifecycle() { + Ad4mDb::init_global_instance(":memory:").expect("Ad4mDb to initialize"); + let service = AIService::new().expect("initialization to work"); + + // Add a model + let model_input = ModelInput { + name: "Test Model".into(), + model_type: ModelType::Llm, + local: Some(LocalModelInput { + file_name: "llama_tiny_1_1b_chat".into(), + tokenizer_source: None, + huggingface_repo: None, + revision: None, + }), + api: None, + }; + + let model_id = service + .add_model(model_input.clone()) + .await + .expect("model to be added"); + + // Update the model + let updated_model = ModelInput { + name: "Updated Test Model".into(), + ..model_input.clone() + }; + service + .update_model(model_id.clone(), updated_model) + .await + .expect("model to be updated"); + + // Verify the update + let model = Ad4mDb::with_global_instance(|db| db.get_model(model_id.clone())) + .expect("to get model"); + assert_eq!(model.unwrap().name, "Updated Test Model"); + + // Remove the model + service + .remove_model(model_id.clone()) + .await + .expect("model to be removed"); + + // Verify removal + let model = Ad4mDb::with_global_instance(|db| db.get_model(model_id.clone())) + .expect("to get model"); + assert!(model.is_none()); + } + + #[ignore] + #[tokio::test] + async fn test_model_update_with_tasks() { + Ad4mDb::init_global_instance(":memory:").expect("Ad4mDb to initialize"); + let service = AIService::new().expect("initialization to work"); + + // Add a model + let model_input = ModelInput { + name: "Test Model".into(), + model_type: ModelType::Llm, + local: Some(LocalModelInput { + file_name: "llama_tiny_1_1b_chat".into(), + tokenizer_source: None, + huggingface_repo: None, + revision: None, + }), + api: None, + }; + + let model_id = service + .add_model(model_input.clone()) + .await + .expect("model to be added"); + + // Create a task using this model + let task = service + .add_task(AITaskInput { + name: "Test task".into(), + model_id: model_id.clone(), + system_prompt: "Test prompt".into(), + prompt_examples: vec![], + meta_data: None, + }) + .await + .expect("task to be created"); + + // Update the model + let updated_model = ModelInput { + name: "Updated Test Model".into(), + ..model_input.clone() + }; + service + .update_model(model_id.clone(), updated_model) + .await + .expect("model to be updated"); + + // Verify the task still works + let response = service + .prompt(task.task_id.clone(), "Test input".into()) + .await + .expect("prompt to work after model update"); + assert!(!response.is_empty()); + } } diff --git a/rust-executor/src/db.rs b/rust-executor/src/db.rs index 7550663e1..3fb31cede 100644 --- a/rust-executor/src/db.rs +++ b/rust-executor/src/db.rs @@ -4,7 +4,7 @@ use crate::graphql::graphql_types::{ }; use crate::types::{ AIPromptExamples, AITask, Expression, ExpressionProof, Link, LinkExpression, LocalModel, Model, - ModelApi, ModelApiType, ModelType, Notification, PerspectiveDiff, + ModelApi, ModelApiType, ModelType, Notification, PerspectiveDiff, TokenizerSource, }; use deno_core::anyhow::anyhow; use deno_core::error::AnyError; @@ -13,7 +13,6 @@ use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use std::str::FromStr; use url::Url; -use uuid::Uuid; #[derive(Serialize, Deserialize)] struct LinkSchema { @@ -189,8 +188,11 @@ impl Ad4mDb { model TEXT, api_type TEXT, local_file_name TEXT, - local_tokenizer_source TEXT, - local_model_parameters TEXT, + local_tokenizer_repo TEXT, + local_tokenizer_revision TEXT, + local_tokenizer_file_name TEXT, + local_huggingface_repo TEXT, + local_revision TEXT, type TEXT NOT NULL )", [], @@ -1050,7 +1052,7 @@ impl Ad4mDb { } // Reduce count and try again - count = count / 2; + count /= 2; } } @@ -1099,10 +1101,10 @@ impl Ad4mDb { } pub fn add_model(&self, model: &ModelInput) -> Ad4mDbResult { - let id = Uuid::new_v4().to_string(); + let id = uuid::Uuid::new_v4().to_string(); self.conn.execute( - "INSERT INTO models (id, name, api_base_url, api_key, model, api_type, local_file_name, local_tokenizer_source, local_model_parameters, type) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", + "INSERT INTO models (id, name, api_base_url, api_key, model, api_type, local_file_name, local_tokenizer_repo, local_tokenizer_revision, local_tokenizer_file_name, local_huggingface_repo, local_revision, type) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)", params![ id, model.name, @@ -1111,8 +1113,11 @@ impl Ad4mDb { model.api.as_ref().map(|api| api.model.clone()), model.api.as_ref().map(|api| serde_json::to_string(&api.api_type).unwrap()), model.local.as_ref().map(|local| local.file_name.clone()), - model.local.as_ref().map(|local| local.tokenizer_source.clone()), - model.local.as_ref().map(|local| local.model_parameters.clone()), + model.local.as_ref().and_then(|local| local.tokenizer_source.as_ref().map(|ts| ts.repo.clone())), + model.local.as_ref().and_then(|local| local.tokenizer_source.as_ref().map(|ts| ts.revision.clone())), + model.local.as_ref().and_then(|local| local.tokenizer_source.as_ref().map(|ts| ts.file_name.clone())), + model.local.as_ref().and_then(|local| local.huggingface_repo.clone()), + model.local.as_ref().and_then(|local| local.revision.clone()), serde_json::to_string(&model.model_type).unwrap(), ], )?; @@ -1139,25 +1144,39 @@ impl Ad4mDb { None }; - let local = - if let (Some(file_name), Some(tokenizer_source), Some(model_parameters)) = - (row.get(6)?, row.get(7)?, row.get(8)?) - { - Some(LocalModel { + let local: Option = if let Some(file_name) = + row.get::<_, Option>(6)? + { + let tokenizer_source = if let (Some(repo), Some(revision), Some(file_name)) = ( + row.get::<_, Option>(7)?, + row.get::<_, Option>(8)?, + row.get::<_, Option>(9)?, + ) { + Some(TokenizerSource { + repo, + revision, file_name, - tokenizer_source, - model_parameters, }) } else { None }; + Some(LocalModel { + file_name, + tokenizer_source, + huggingface_repo: row.get::<_, Option>(10)?, + revision: row.get::<_, Option>(11)?, + }) + } else { + None + }; + Ok(Model { id: row.get(0)?, name: row.get(1)?, api, local, - model_type: serde_json::from_str(&row.get::<_, String>(9)?).unwrap(), + model_type: serde_json::from_str(&row.get::<_, String>(12)?).unwrap(), }) }) .optional()?; @@ -1183,24 +1202,38 @@ impl Ad4mDb { None }; - let local = if let (Some(file_name), Some(tokenizer_source), Some(model_parameters)) = - (row.get(6)?, row.get(7)?, row.get(8)?) - { - Some(LocalModel { - file_name, - tokenizer_source, - model_parameters, - }) - } else { - None - }; + let local: Option = + if let Some(file_name) = row.get::<_, Option>(6)? { + let tokenizer_source = if let (Some(repo), Some(revision), Some(file_name)) = ( + row.get::<_, Option>(7)?, + row.get::<_, Option>(8)?, + row.get::<_, Option>(9)?, + ) { + Some(TokenizerSource { + repo, + revision, + file_name, + }) + } else { + None + }; + + Some(LocalModel { + file_name, + tokenizer_source, + huggingface_repo: row.get::<_, Option>(10)?, + revision: row.get::<_, Option>(11)?, + }) + } else { + None + }; Ok(Model { id: row.get(0)?, name: row.get(1)?, api, local, - model_type: serde_json::from_str(&row.get::<_, String>(9)?).unwrap(), + model_type: serde_json::from_str(&row.get::<_, String>(12)?).unwrap(), }) })?; @@ -1220,11 +1253,15 @@ impl Ad4mDb { let local_tokenizer = model .local .as_ref() - .map(|local| local.tokenizer_source.clone()); - let local_params = model + .and_then(|local| local.tokenizer_source.as_ref()); + let local_huggingface_repo = model + .local + .as_ref() + .and_then(|local| local.huggingface_repo.clone()); + let local_revision = model .local .as_ref() - .map(|local| local.model_parameters.clone()); + .and_then(|local| local.revision.clone()); self.conn.execute( "UPDATE models SET @@ -1234,10 +1271,13 @@ impl Ad4mDb { model = ?4, api_type = ?5, local_file_name = ?6, - local_tokenizer_source = ?7, - local_model_parameters = ?8, - type = ?9 - WHERE id = ?10", + local_tokenizer_repo = ?7, + local_tokenizer_revision = ?8, + local_tokenizer_file_name = ?9, + local_huggingface_repo = ?10, + local_revision = ?11, + type = ?12 + WHERE id = ?13", params![ model.name, api_base_url, @@ -1245,8 +1285,11 @@ impl Ad4mDb { api_model, api_type, local_file_name, - local_tokenizer, - local_params, + local_tokenizer.as_ref().map(|ts| ts.repo.clone()), + local_tokenizer.as_ref().map(|ts| ts.revision.clone()), + local_tokenizer.as_ref().map(|ts| ts.file_name.clone()), + local_huggingface_repo, + local_revision, serde_json::to_string(&model.model_type).unwrap(), id ], @@ -1300,7 +1343,7 @@ impl Ad4mDb { mod tests { use super::*; use crate::{ - graphql::graphql_types::{LocalModelInput, ModelApiInput}, + graphql::graphql_types::{LocalModelInput, ModelApiInput, TokenizerSourceInput}, types::{ExpressionProof, Link, LinkExpression, ModelApiType, ModelType}, }; use chrono::Utc; @@ -1653,8 +1696,13 @@ mod tests { api: None, local: Some(LocalModelInput { file_name: "test_model.bin".to_string(), - tokenizer_source: "test_tokenizer".to_string(), - model_parameters: "test_parameters".to_string(), + tokenizer_source: Some(TokenizerSourceInput { + repo: "test_repo".to_string(), + revision: "main".to_string(), + file_name: "tokenizer.json".to_string(), + }), + huggingface_repo: Some("huggingface/test".to_string()), + revision: Some("main".to_string()), }), model_type: ModelType::Llm, }; @@ -1700,16 +1748,21 @@ mod tests { .local .as_ref() .unwrap() - .tokenizer_source, - "test_tokenizer" + .tokenizer_source + .as_ref() + .unwrap() + .repo, + "test_repo" ); assert_eq!( retrieved_model_local .local .as_ref() .unwrap() - .model_parameters, - "test_parameters" + .huggingface_repo + .as_ref() + .unwrap(), + "huggingface/test" ); assert_eq!(retrieved_model_local.model_type, ModelType::Llm); @@ -1798,8 +1851,13 @@ mod tests { name: "Test Embedding Model".to_string(), local: Some(LocalModelInput { file_name: "embedding.bin".to_string(), - tokenizer_source: "embedding_tokenizer".to_string(), - model_parameters: "{}".to_string(), + tokenizer_source: Some(TokenizerSourceInput { + repo: "test_repo".to_string(), + revision: "main".to_string(), + file_name: "tokenizer.json".to_string(), + }), + huggingface_repo: Some("huggingface/test".to_string()), + revision: Some("main".to_string()), }), api: None, model_type: ModelType::Embedding, @@ -1941,8 +1999,13 @@ mod tests { api: None, local: Some(LocalModelInput { file_name: "model.bin".to_string(), - tokenizer_source: "tokenizer".to_string(), - model_parameters: "{}".to_string(), + tokenizer_source: Some(TokenizerSourceInput { + repo: "test_repo".to_string(), + revision: "main".to_string(), + file_name: "tokenizer.json".to_string(), + }), + huggingface_repo: Some("huggingface/test".to_string()), + revision: Some("main".to_string()), }), model_type: ModelType::Transcription, }; diff --git a/rust-executor/src/globals.rs b/rust-executor/src/globals.rs index cdf0ca9a0..9b1f5706c 100644 --- a/rust-executor/src/globals.rs +++ b/rust-executor/src/globals.rs @@ -2,7 +2,7 @@ use lazy_static::lazy_static; lazy_static! { /// The current version of AD4M - pub static ref AD4M_VERSION: String = String::from("0.10.1-rc7"); + pub static ref AD4M_VERSION: String = String::from("0.10.1-rc8"); } /// Struct representing oldest supported version and indicator if state should be cleared if update is required diff --git a/rust-executor/src/graphql/graphql_types.rs b/rust-executor/src/graphql/graphql_types.rs index 68e4a6535..6fa92e4a2 100644 --- a/rust-executor/src/graphql/graphql_types.rs +++ b/rust-executor/src/graphql/graphql_types.rs @@ -631,8 +631,17 @@ impl From for AIPromptExamplesInput { #[serde(rename_all = "camelCase")] pub struct LocalModelInput { pub file_name: String, - pub tokenizer_source: String, - pub model_parameters: String, + pub tokenizer_source: Option, + pub huggingface_repo: Option, + pub revision: Option, +} + +#[derive(GraphQLInputObject, Default, Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct TokenizerSourceInput { + pub repo: String, + pub revision: String, + pub file_name: String, } #[derive(GraphQLInputObject, Default, Debug, Deserialize, Serialize, Clone)] diff --git a/rust-executor/src/graphql/mutation_resolvers.rs b/rust-executor/src/graphql/mutation_resolvers.rs index d579b3743..ef6ccd328 100644 --- a/rust-executor/src/graphql/mutation_resolvers.rs +++ b/rust-executor/src/graphql/mutation_resolvers.rs @@ -834,12 +834,11 @@ impl Mutation { &perspective_update_capability(vec![uuid.clone()]), )?; let mut perspective = get_perspective_with_uuid_field_error(&uuid)?; - let mut removed_links = Vec::new(); - for link in links.into_iter() { - let link = crate::types::LinkExpression::try_from(link)?; - removed_links.push(perspective.remove_link(link).await?); - } - + let links = links + .into_iter() + .map(LinkExpression::try_from) + .collect::, _>>()?; + let removed_links = perspective.remove_links(links).await?; Ok(removed_links) } @@ -1229,8 +1228,19 @@ impl Mutation { model: ModelInput, ) -> FieldResult { check_capability(&context.capabilities, &AGENT_UPDATE_CAPABILITY)?; - Ad4mDb::with_global_instance(|db| db.update_model(&model_id, &model)) - .map_err(|e| e.to_string())?; + + // Update the model using AIService + AIService::global_instance() + .await? + .update_model(model_id, model) + .await + .map_err(|e| { + FieldError::new( + "Failed to update model", + graphql_value!({ "error": e.to_string() }), + ) + })?; + Ok(true) } @@ -1240,7 +1250,19 @@ impl Mutation { model_id: String, ) -> FieldResult { check_capability(&context.capabilities, &AGENT_UPDATE_CAPABILITY)?; - Ad4mDb::with_global_instance(|db| db.remove_model(&model_id)).map_err(|e| e.to_string())?; + + // Remove the model using AIService + AIService::global_instance() + .await? + .remove_model(model_id) + .await + .map_err(|e| { + FieldError::new( + "Failed to remove model", + graphql_value!({ "error": e.to_string() }), + ) + })?; + Ok(true) } diff --git a/rust-executor/src/perspectives/perspective_instance.rs b/rust-executor/src/perspectives/perspective_instance.rs index 404e52247..fa65bf652 100644 --- a/rust-executor/src/perspectives/perspective_instance.rs +++ b/rust-executor/src/perspectives/perspective_instance.rs @@ -498,7 +498,7 @@ impl PerspectiveInstance { let ok = match commit_result { Ok(Some(rev)) => { - if rev.trim().len() == 0 { + if rev.trim().is_empty() { log::warn!("Committed but got no revision from LinkLanguage!\nStoring in pending diffs for later"); false } else { @@ -846,6 +846,69 @@ impl PerspectiveInstance { } } + pub async fn remove_links( + &mut self, + link_expressions: Vec, + ) -> Result, AnyError> { + let handle = self.persisted.lock().await.clone(); + + // Filter to only existing links and collect their statuses + let mut existing_links = Vec::new(); + for link in link_expressions { + if let Some((link_from_db, status)) = + Ad4mDb::with_global_instance(|db| db.get_link(&handle.uuid, &link))? + { + existing_links.push((link_from_db, status)); + } + } + + // Skip if no links found + if existing_links.is_empty() { + return Ok(Vec::new()); + } + + // Split into links and statuses + let (links, statuses): (Vec<_>, Vec<_>) = existing_links.into_iter().unzip(); + + // Create diff from links that exist + let diff = PerspectiveDiff::from_removals(links.clone()); + + // Create decorated versions + let decorated_links: Vec = links + .into_iter() + .zip(statuses.iter()) + .map(|(link, status)| DecoratedLinkExpression::from((link, status.clone()))) + .collect(); + + let decorated_diff = DecoratedPerspectiveDiff::from_removals(decorated_links.clone()); + + // Remove from DB + for link in diff.removals.iter() { + Ad4mDb::with_global_instance(|db| db.remove_link(&handle.uuid, link))?; + } + + self.spawn_prolog_facts_update(decorated_diff.clone()); + self.pubsub_publish_diff(decorated_diff).await; + + // Only commit shared links by filtering decorated_links + let shared_links: Vec = decorated_links + .iter() + .filter(|link| link.status == Some(LinkStatus::Shared)) + .map(|link| link.clone().into()) + .collect(); + + if !shared_links.is_empty() { + let shared_diff = PerspectiveDiff { + additions: vec![], + removals: shared_links, + }; + self.spawn_commit_and_handle_error(&shared_diff); + } + + *(self.links_have_changed.lock().await) = true; + Ok(decorated_links) + } + async fn get_links_local( &self, query: &LinkQuery, @@ -1553,8 +1616,10 @@ impl PerspectiveInstance { let mut object: HashMap = HashMap::new(); // Get author and timestamp from the first link mentioning base as source - let mut base_query = LinkQuery::default(); - base_query.source = Some(base_expression.clone()); + let base_query = LinkQuery { + source: Some(base_expression.clone()), + ..Default::default() + }; let base_links = self.get_links(&base_query).await?; let first_link = base_links .first() diff --git a/rust-executor/src/prolog_service/prolog_service_extension.rs b/rust-executor/src/prolog_service/prolog_service_extension.rs index 9081cb52d..9ad6d8c3c 100644 --- a/rust-executor/src/prolog_service/prolog_service_extension.rs +++ b/rust-executor/src/prolog_service/prolog_service_extension.rs @@ -20,7 +20,7 @@ pub fn prolog_value_to_json_tring(value: Value) -> String { Value::Integer(i) => format!("{}", i), Value::Float(f) => format!("{}", f), Value::Rational(r) => format!("{}", r), - Value::Atom(a) => format!("{}", a.as_str()), + Value::Atom(a) => a, Value::String(s) => { if let Err(_e) = serde_json::from_str::(s.as_str()) { //treat as string literal diff --git a/rust-executor/src/types.rs b/rust-executor/src/types.rs index c819dd74c..89f8a96f7 100644 --- a/rust-executor/src/types.rs +++ b/rust-executor/src/types.rs @@ -477,8 +477,17 @@ pub struct ModelApi { #[serde(rename_all = "camelCase")] pub struct LocalModel { pub file_name: String, - pub tokenizer_source: String, - pub model_parameters: String, + pub tokenizer_source: Option, + pub huggingface_repo: Option, + pub revision: Option, +} + +#[derive(GraphQLObject, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct TokenizerSource { + pub repo: String, + pub revision: String, + pub file_name: String, } #[derive(Debug, Clone, Serialize, Deserialize, GraphQLEnum, PartialEq, Default)] diff --git a/test-runner/package.json b/test-runner/package.json index 0493d4371..67aaf56ed 100644 --- a/test-runner/package.json +++ b/test-runner/package.json @@ -63,5 +63,5 @@ "bugs": { "url": "https://github.com/coasys/ad4m/issues" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/tests/js/package.json b/tests/js/package.json index a6e84a78a..6df3d59e0 100644 --- a/tests/js/package.json +++ b/tests/js/package.json @@ -65,5 +65,5 @@ "fluent-ffmpeg": "^2.1.3", "uuid": "*" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/tests/js/tests/ai.ts b/tests/js/tests/ai.ts index 3bf8290f3..efb2aaba5 100644 --- a/tests/js/tests/ai.ts +++ b/tests/js/tests/ai.ts @@ -9,6 +9,11 @@ import { ModelInput } from '@coasys/ad4m/lib/src/ai/AIResolver'; export default function aiTests(testContext: TestContext) { return () => { describe('AI service', () => { + // This is used in the skipped tests below + // They are skipped for CI, run on local device with GPU + let testModelFileName: string = "llama_3_1_8b_chat" + let testModelId: string = "" + it("can perform Model CRUD operations", async () => { const ad4mClient = testContext.ad4mClient! @@ -32,8 +37,13 @@ export default function aiTests(testContext: TestContext) { name: "TestLocalModel", local: { fileName: "test_model.bin", - tokenizerSource: "test_tokenizer.json", - modelParameters: JSON.stringify({ param1: "value1", param2: "value2" }) + tokenizerSource: { + repo: "test-repo", + revision: "main", + fileName: "tokenizer.json" + }, + huggingfaceRepo: "test-repo", + revision: "main" }, modelType: "EMBEDDING" } @@ -58,8 +68,11 @@ export default function aiTests(testContext: TestContext) { expect(addedLocalModel).to.exist expect(addedLocalModel?.id).to.equal(addLocalResult) expect(addedLocalModel?.local?.fileName).to.equal("test_model.bin") - expect(addedLocalModel?.local?.tokenizerSource).to.equal("test_tokenizer.json") - expect(addedLocalModel?.local?.modelParameters).to.deep.equal(JSON.stringify({ param1: "value1", param2: "value2" })) + expect(addedLocalModel?.local?.tokenizerSource?.repo).to.equal("test-repo") + expect(addedLocalModel?.local?.tokenizerSource?.revision).to.equal("main") + expect(addedLocalModel?.local?.tokenizerSource?.fileName).to.equal("tokenizer.json") + expect(addedLocalModel?.local?.huggingfaceRepo).to.equal("test-repo") + expect(addedLocalModel?.local?.revision).to.equal("main") // Test removing models const removeApiResult = await ad4mClient.ai.removeModel(addedApiModel!.id) @@ -102,18 +115,51 @@ export default function aiTests(testContext: TestContext) { expect(addedModel).to.exist // Create updated model data - const updatedModel: ModelInput = { + const bogusModelUrls: ModelInput = { name: "UpdatedModel", local: { fileName: "updated_model.bin", - tokenizerSource: "updated_tokenizer.json", - modelParameters: JSON.stringify({ updated: "value" }) + tokenizerSource: { + repo: "updated-repo", + revision: "main", + fileName: "updated_tokenizer.json" + }, + huggingfaceRepo: "updated-repo", + revision: "main" }, modelType: "EMBEDDING" } // Update the model - const updateResult = await ad4mClient.ai.updateModel(addedModel!.id, updatedModel) + let updateResult = false + let error = {} + try { + updateResult = await ad4mClient.ai.updateModel(addedModel!.id, bogusModelUrls) + }catch(e) { + //@ts-ignore + error = e + console.log(error) + } + expect(updateResult).to.be.false + expect(error).to.have.property('message') + //@ts-ignore + expect(error.message).to.include('Failed to update model') + + + // Create updated model data + const updatedModel: ModelInput = { + name: "UpdatedModel", + api: { + baseUrl: "https://api.example.com/v2", + apiKey: "updated-api-key", + model: "gpt-4", + apiType: "OPEN_AI" + }, + modelType: "LLM" + } + + updateResult = await ad4mClient.ai.updateModel(addedModel!.id, updatedModel) + expect(updateResult).to.be.true // Verify the update @@ -121,17 +167,89 @@ export default function aiTests(testContext: TestContext) { const retrievedModel = updatedModels.find(model => model.id === addedModel!.id) expect(retrievedModel).to.exist expect(retrievedModel?.name).to.equal("UpdatedModel") - expect(retrievedModel?.api).to.be.null - expect(retrievedModel?.local?.fileName).to.equal("updated_model.bin") - expect(retrievedModel?.local?.tokenizerSource).to.equal("updated_tokenizer.json") - expect(retrievedModel?.local?.modelParameters).to.equal(JSON.stringify({ updated: "value" })) - expect(retrievedModel?.modelType).to.equal("EMBEDDING") + expect(retrievedModel?.local).to.be.null + expect(retrievedModel?.api?.baseUrl).to.equal("https://api.example.com/v2") + expect(retrievedModel?.api?.apiKey).to.equal("updated-api-key") + expect(retrievedModel?.api?.model).to.equal("gpt-4") + expect(retrievedModel?.api?.apiType).to.equal("OPEN_AI") + expect(retrievedModel?.modelType).to.equal("LLM") // Clean up const removeResult = await ad4mClient.ai.removeModel(addedModel!.id) expect(removeResult).to.be.true }) + it.skip('can update model and verify it works', async () => { + const ad4mClient = testContext.ad4mClient! + + // Create initial model + const initialModel: ModelInput = { + name: "TestModel", + local: { + fileName: "llama_tiny_1_1b_chat" + }, + modelType: "LLM" + } + + // Add initial model + const modelId = await ad4mClient.ai.addModel(initialModel) + expect(modelId).to.be.a.string + + // Wait for model to be loaded + let status; + do { + status = await ad4mClient.ai.modelLoadingStatus(modelId); + await new Promise(resolve => setTimeout(resolve, 1000)); // Wait 1 second between checks + } while (status.progress < 100); + + testModelId = modelId + + // Create task using "default" as model_id + const task = await ad4mClient.ai.addTask( + "test-task", + modelId, + "You are a helpful assistant", + [{ input: "Say hi", output: "Hello!" }] + ) + + // Test that initial model works + const prompt = "Say hello" + const initialResponse = await ad4mClient.ai.prompt(task.taskId, prompt) + expect(initialResponse).to.be.a.string + expect(initialResponse.length).to.be.greaterThan(0) + + // Create updated model config + const updatedModel: ModelInput = { + name: "UpdatedTestModel", + local: { fileName: testModelFileName }, + modelType: "LLM" + } + + // Update the model + const updateResult = await ad4mClient.ai.updateModel(modelId, updatedModel) + expect(updateResult).to.be.true + + // Wait for model to be loaded + do { + status = await ad4mClient.ai.modelLoadingStatus(modelId); + await new Promise(resolve => setTimeout(resolve, 1000)); // Wait 1 second between checks + } while (status.progress < 100); + + // Verify model was updated in DB + const models = await ad4mClient.ai.getModels() + const retrievedModel = models.find(m => m.id === modelId) + expect(retrievedModel).to.exist + expect(retrievedModel?.name).to.equal("UpdatedTestModel") + expect(retrievedModel?.local?.fileName).to.equal(testModelFileName) + + // Test that updated model still works + const updatedResponse = await ad4mClient.ai.prompt(task.taskId, prompt) + expect(updatedResponse).to.be.a.string + expect(updatedResponse.length).to.be.greaterThan(0) + + // keep model around for other tests + }) + it ('AI model status', async () => { const ad4mClient = testContext.ad4mClient! const status = await ad4mClient.ai.modelLoadingStatus("bert-id"); @@ -166,37 +284,18 @@ export default function aiTests(testContext: TestContext) { expect(defaultModel.api?.baseUrl).to.equal("https://api.example.com/") // Clean up - await ad4mClient.ai.removeModel("TestDefaultApiModel") + await ad4mClient.ai.removeModel(id) }) it.skip('can use "default" as model_id in tasks and prompting works', async () => { const ad4mClient = testContext.ad4mClient! - - // Create a test model and set as default - const modelInput: ModelInput = { - name: "TestDefaultModel", - local: { - fileName: "llama_tiny", - tokenizerSource: "test_tokenizer.json", - modelParameters: JSON.stringify({ param1: "value1" }) - }, - modelType: "LLM" - } - const modelId = await ad4mClient.ai.addModel(modelInput) - await ad4mClient.ai.setDefaultModel("LLM", modelId) - - // Wait for model to be loaded - let status; - do { - status = await ad4mClient.ai.modelLoadingStatus(modelId); - await new Promise(resolve => setTimeout(resolve, 1000)); - } while (status.progress < 100); + await ad4mClient.ai.setDefaultModel("LLM", testModelId) // Create task using "default" as model_id const task = await ad4mClient.ai.addTask( "default-model-task", "default", - "You are a helpful assistant", + "You are a helpful assistant. Whatever you say, it will include 'hello'", [{ input: "Say hi", output: "Hello!" }] ) expect(task).to.have.property('taskId') @@ -210,12 +309,8 @@ export default function aiTests(testContext: TestContext) { // Create another test model const newModelInput: ModelInput = { - name: "TestDefaultModel2", - local: { - fileName: "llama_tiny", - tokenizerSource: "test_tokenizer.json", - modelParameters: JSON.stringify({ param1: "value1" }) - }, + name: "TestDefaultModel2", + local: { fileName: "llama_3_1_8b_chat" }, modelType: "LLM" } const newModelId = await ad4mClient.ai.addModel(newModelInput) @@ -239,43 +334,24 @@ export default function aiTests(testContext: TestContext) { expect(response2).to.be.a('string') expect(response2.toLowerCase()).to.include('hello') - // Clean up new model - await ad4mClient.ai.removeModel(newModelId) - // Clean up await ad4mClient.ai.removeTask(task.taskId) - await ad4mClient.ai.removeModel(modelId) + await ad4mClient.ai.removeModel(newModelId) }) it.skip('can do Tasks CRUD', async() => { const ad4mClient = testContext.ad4mClient! - const llamaDescription: ModelInput = { - name: "Llama tiny", - local: { - fileName: "llama_tiny", - tokenizerSource: "test_tokenizer.json", - modelParameters: JSON.stringify({ param1: "value1", param2: "value2" }) - }, - modelType: "LLM" - } - let llamaId = await ad4mClient.ai.addModel(llamaDescription) - // Wait for model to be loaded - let status; - do { - status = await ad4mClient.ai.modelLoadingStatus(llamaId); - await new Promise(resolve => setTimeout(resolve, 1000)); // Wait 1 second between checks - } while (status.progress < 100); // Add a task const newTask = await ad4mClient.ai.addTask( "test-name", - llamaId, + testModelId, "This is a test system prompt", [{ input: "Test input", output: "Test output" }] ); expect(newTask).to.have.property('taskId'); expect(newTask.name).to.equal('test-name'); - expect(newTask.modelId).to.equal(llamaId); + expect(newTask.modelId).to.equal(testModelId); expect(newTask.systemPrompt).to.equal("This is a test system prompt"); expect(newTask.promptExamples).to.deep.equal([{ input: "Test input", output: "Test output" }]); @@ -306,28 +382,11 @@ export default function aiTests(testContext: TestContext) { it.skip('can prompt a task', async () => { const ad4mClient = testContext.ad4mClient! - const llamaDescription: ModelInput = { - name: "Llama tiny", - local: { - fileName: "llama_tiny", - tokenizerSource: "test_tokenizer.json", - modelParameters: JSON.stringify({ param1: "value1", param2: "value2" }) - }, - modelType: "LLM" - } - let llamaId = await ad4mClient.ai.addModel(llamaDescription) - - // Wait for model to be loaded - let status; - do { - status = await ad4mClient.ai.modelLoadingStatus(llamaId); - await new Promise(resolve => setTimeout(resolve, 1000)); // Wait 1 second between checks - } while (status.progress < 100); // Create a new task const newTask = await ad4mClient.ai.addTask( "test-name", - llamaId, + testModelId, "You are inside a test. Please ALWAYS respond with 'works', plus something else.", [ { input: "What's the capital of France?", output: "works. Also that is Paris" }, @@ -361,7 +420,7 @@ export default function aiTests(testContext: TestContext) { // Create a new task const newTask = await ad4mClient.ai.addTask( "test-name", - "llama", + testModelId, "You are inside a test. Please respond with a short, unique message each time.", [ { input: "Test long 1", output: "This is a much longer response that includes various details. It talks about the weather being sunny, the importance of staying hydrated, and even mentions a recipe for chocolate chip cookies. The response goes on to discuss the benefits of regular exercise, the plot of a popular novel, and concludes with a fun fact about the migration patterns of monarch butterflies." }, diff --git a/tests/js/tests/perspective.ts b/tests/js/tests/perspective.ts index 806abd849..744777950 100644 --- a/tests/js/tests/perspective.ts +++ b/tests/js/tests/perspective.ts @@ -130,6 +130,28 @@ export default function perspectiveTests(testContext: TestContext) { expect(linksPostMutation.length).to.equal(2); }) + it(`doesn't error when duplicate entries passed to removeLinks`, async () => { + const ad4mClient = testContext.ad4mClient!; + const perspective = await ad4mClient.perspective.add('test-duplicate-link-removal'); + expect(perspective.name).to.equal('test-duplicate-link-removal'); + + // create link + const link = { source: 'root', predicate: 'p', target: 'abc' }; + const addLink = await perspective.add(link); + expect(addLink.data.target).to.equal("abc"); + + // get link expression + const linkExpression = (await perspective.get(new LinkQuery(link)))[0]; + expect(linkExpression.data.target).to.equal("abc"); + + // attempt to remove link twice (currently errors and prevents further execution of code) + await perspective.removeLinks([linkExpression, linkExpression]) + + // check link is removed + const links = await perspective.get(new LinkQuery(link)); + expect(links.length).to.equal(0); + }) + it('test local perspective links - time query', async () => { const ad4mClient = testContext.ad4mClient! diff --git a/ui/package.json b/ui/package.json index a7ab4dc7f..eb8f7d0f4 100644 --- a/ui/package.json +++ b/ui/package.json @@ -82,5 +82,5 @@ "resolutions": { "react-error-overlay": "6.0.9" }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/ui/src-tauri/Cargo.toml b/ui/src-tauri/Cargo.toml index 6e80a3ea8..e54eec983 100644 --- a/ui/src-tauri/Cargo.toml +++ b/ui/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ad4m-launcher" -version = "0.10.1-rc7" +version = "0.10.1-rc8" description = "Administration of ad4m services" authors = ["Kaichao Sun"] license = "" diff --git a/ui/src-tauri/tauri.conf.json b/ui/src-tauri/tauri.conf.json index 8929bce0c..8cfdb91c1 100644 --- a/ui/src-tauri/tauri.conf.json +++ b/ui/src-tauri/tauri.conf.json @@ -72,5 +72,5 @@ } } }, - "version": "0.10.1-rc7" + "version": "0.10.1-rc8" } diff --git a/ui/src/components/AI.tsx b/ui/src/components/AI.tsx index 1ec44da0e..381b191a9 100644 --- a/ui/src/components/AI.tsx +++ b/ui/src/components/AI.tsx @@ -66,6 +66,21 @@ const AI = () => { setModels(newModels); } + function deleteTask(modelId: string, taskId: string) { + const newModels = modelsRef.current.map((model) => { + if (model.id === modelId) { + client!.ai.removeTask(taskId) + + model.tasks = model.tasks.filter((task: any) => { + return (task.taskId !== taskId) + }); + } + return model; + }); + modelsRef.current = newModels; + setModels(newModels); + } + useEffect(() => { if (client) getData(); }, [client, getData]); @@ -148,6 +163,7 @@ const AI = () => { removeModel={() => removeModel(model)} setDefaultModel={() => setDefaultModel(model)} toggleTask={toggleTask} + deleteTask={deleteTask} /> ))} diff --git a/ui/src/components/Login.tsx b/ui/src/components/Login.tsx index 3d312772c..11f1086c5 100644 --- a/ui/src/components/Login.tsx +++ b/ui/src/components/Login.tsx @@ -155,15 +155,13 @@ const Login = () => { } async function saveModels() { + let whisperModel = "whisper_small"; // add llm model if (aiMode !== "None") { const llm = { name: "LLM Model 1", modelType: "LLM" } as ModelInput; if (aiMode === "Local") { - llm.local = { - fileName: "solar_10_7b_instruct", - tokenizerSource: "", - modelParameters: "", - }; + llm.local = { fileName: "Qwen2.5.1-Coder-7B-Instruct" }; + whisperModel = "whisper_large_v3_turbo_quantized"; } else { llm.api = { baseUrl: apiUrl, @@ -179,23 +177,16 @@ const Login = () => { // add embedding model client!.ai.addModel({ name: "bert", - local: { - fileName: "bert", - tokenizerSource: "", - modelParameters: "", - }, + local: { fileName: "bert" }, modelType: "EMBEDDING", }); - // add transcription model + // add medium whisper model client!.ai.addModel({ - name: "Transcription Model 1", - local: { - fileName: "whisper", - tokenizerSource: "", - modelParameters: "", - }, + name: "Whisper", + local: { fileName: whisperModel }, modelType: "TRANSCRIPTION", }); + setCurrentIndex(6); } @@ -493,39 +484,6 @@ const Login = () => { Is your computer capabale of running Large Language Models locally? - - Regardless of your choice here, we will always download and use - small AI models (such as{" "} - - open("https://huggingface.co/openai/whisper-small") - } - style={{ cursor: "pointer" }} - > - Whisper small - {" "} - and an{" "} - - open( - "https://huggingface.co/Snowflake/snowflake-arctic-embed-xs" - ) - } - style={{ cursor: "pointer" }} - > - Embedding model - - ) to handle basic tasks on all devices. -

-

- When it comes to LLMs, it depends on you having either an Apple - Silicon mac (M1 or better) or an nVidia GPU (with enough vRAM). -

-

- Alternatively, you can configure ADAM to out-source LLM tasks to a - remote API. If you unsure, you can select "None" now and add, - remove or change model settings later-on in the AI tab. -
@@ -594,17 +552,43 @@ const Login = () => { style={{ marginTop: 30, maxWidth: 350 }} > - This will download{" "} - - open( - "https://huggingface.co/TheBloke/SOLAR-10.7B-Instruct-v1.0-GGUF" - ) - } - style={{ cursor: "pointer" }} - > - SOLAR 10.7b instruct - + This will download +

+ + open( + "https://huggingface.co/bartowski/Qwen2.5.1-Coder-7B-Instruct-GGUF" + ) + } + style={{ cursor: "pointer" }} + > + Qwen2.5 Coder 7B Instruct (4.68GB) + +

+ and +

+ + open( + "https://huggingface.co/openai/whisper-large-v3-turbo" + ) + } + style={{ cursor: "pointer" }} + >Whisper large v3 turbo (809MB) +

+ and +

+ + open( + "https://huggingface.co/Snowflake/snowflake-arctic-embed-xs" + ) + } + style={{ cursor: "pointer" }} + > + Bert Embedding model (90MB) + +

)} @@ -689,6 +673,33 @@ const Login = () => { )} + + This will still download +

+ + open( + "https://huggingface.co/openai/whisper-small" + ) + } + style={{ cursor: "pointer" }} + >Whisper small (244MB) +

+ and +

+ + open( + "https://huggingface.co/Snowflake/snowflake-arctic-embed-xs" + ) + } + style={{ cursor: "pointer" }} + > + Bert Embedding model (90MB) + +

+
+ {apiValid && ( @@ -729,6 +740,7 @@ const Login = () => { )} + )} {(!apiModelValid || !apiValid) && ( @@ -767,6 +779,33 @@ const Login = () => { Selecting None here and not having any LLM configured might result in new Synergy features not working in Flux... + + + This will still download +

+ + open( + "https://huggingface.co/openai/whisper-small" + ) + } + style={{ cursor: "pointer" }} + >Whisper small (244MB) +

+ and +

+ + open( + "https://huggingface.co/Snowflake/snowflake-arctic-embed-xs" + ) + } + style={{ cursor: "pointer" }} + > + Bert Embedding model (90MB) + +

+
)} diff --git a/ui/src/components/ModelCard.tsx b/ui/src/components/ModelCard.tsx index b5227f9da..50e299a9f 100644 --- a/ui/src/components/ModelCard.tsx +++ b/ui/src/components/ModelCard.tsx @@ -7,8 +7,9 @@ export default function ModelCard(props: { removeModel: () => void; setDefaultModel: () => void; toggleTask: (modelId: string, taskId: string) => void; + deleteTask: (modelId: string, taskId: string) => void; }) { - const { model, editModel, removeModel, setDefaultModel, toggleTask } = props; + const { model, editModel, removeModel, setDefaultModel, toggleTask, deleteTask } = props; const { id, name, @@ -77,6 +78,11 @@ export default function ModelCard(props: { )} + {modelType == "TRANSCRIPTION" && ( + + + + )} @@ -157,6 +163,15 @@ export default function ModelCard(props: { name={`chevron-${task.collapsed ? "down" : "up"}`} /> + deleteTask(id, task.taskId)} + > + Delete + + ))} diff --git a/ui/src/components/ModelModal.tsx b/ui/src/components/ModelModal.tsx index 1a5374c06..040b421b8 100644 --- a/ui/src/components/ModelModal.tsx +++ b/ui/src/components/ModelModal.tsx @@ -6,7 +6,12 @@ import "../index.css"; const AITypes = ["LLM", "EMBEDDING", "TRANSCRIPTION"]; const llmModels = [ "External API", - // "tiny_llama_1_1b", + "Custom Hugging Face Model", + "Qwen2.5.1-Coder-7B-Instruct", + "deepseek_r1_distill_qwen_1_5b", + "deepseek_r1_distill_qwen_7b", + "deepseek_r1_distill_qwen_14b", + "deepseek_r1_distill_llama_8b", "mistral_7b", "mistral_7b_instruct", "mistral_7b_instruct_2", @@ -24,7 +29,26 @@ const llmModels = [ "llama_34b_code", "llama_70b", ]; -const transcriptionModels = ["whisper"]; +const transcriptionModels = [ + "whisper_tiny", + "whisper_tiny_quantized", + "whisper_tiny_en", + "whisper_tiny_en_quantized", + "whisper_base", + "whisper_base_en", + "whisper_small", + "whisper_small_en", + "whisper_medium", + "whisper_medium_en", + "whisper_medium_en_quantized_distil", + "whisper_large", + "whisper_large_v2", + "whisper_distil_medium_en", + "whisper_distil_large_v2", + "whisper_distil_large_v3", + "whisper_distil_large_v3_quantized", + "whisper_large_v3_turbo_quantized" +]; const embeddingModels = ["bert"]; export default function ModelModal(props: { close: () => void; oldModel?: any }) { @@ -50,6 +74,17 @@ export default function ModelModal(props: { close: () => void; oldModel?: any }) const [apiModels, setApiModels] = useState([]); const apiUrlRef = useRef("https://api.openai.com/v1"); const apiKeyRef = useRef(""); + const [useCustomTokenizer, setUseCustomTokenizer] = useState(false); + const [customHfModel, setCustomHfModel] = useState({ + huggingfaceRepo: "", + revision: "main", + fileName: "", + tokenizerSource: { + repo: "", + revision: "main", + fileName: "" + } + }); function closeMenu(menuId: string) { const menu = document.getElementById(menuId); @@ -160,19 +195,26 @@ export default function ModelModal(props: { close: () => void; oldModel?: any }) apiType: "OPEN_AI", model: apiModel, }; + } else if (newModel === "Custom Hugging Face Model") { + model.local = { + fileName: customHfModel.fileName, + huggingfaceRepo: customHfModel.huggingfaceRepo, + revision: customHfModel.revision, + tokenizerSource: useCustomTokenizer ? customHfModel.tokenizerSource : undefined + }; } else { model.local = { fileName: newModel, - tokenizerSource: "", - modelParameters: "", }; } if (oldModel) client!.ai.updateModel(oldModel.id, model); else { const newModelId = await client!.ai.addModel(model); // if no default LLM set, mark new model as default - const defaultLLM = await client!.ai.getDefaultModel("LLM"); - if (!defaultLLM) client!.ai.setDefaultModel("LLM", newModelId); + if (newModelType === "LLM") { + const defaultLLM = await client!.ai.getDefaultModel("LLM"); + if (!defaultLLM) client!.ai.setDefaultModel("LLM", newModelId); + } } close(); } @@ -186,13 +228,34 @@ export default function ModelModal(props: { close: () => void; oldModel?: any }) if (oldModel.modelType === "LLM") { setNewModels(llmModels); - setNewModel(oldModel.api ? "External API" : oldModel.local.fileName); + if (oldModel.api) { + setNewModel("External API"); + setApiUrl(oldModel.api.baseUrl); + apiUrlRef.current = oldModel.api.baseUrl; + setApiKey(oldModel.api.apiKey); + apiKeyRef.current = oldModel.api.apiKey; + } else if (oldModel.local?.huggingfaceRepo) { + setNewModel("Custom Hugging Face Model"); + setCustomHfModel({ + huggingfaceRepo: oldModel.local.huggingfaceRepo, + revision: oldModel.local.revision || "main", + fileName: oldModel.local.fileName, + tokenizerSource: oldModel.local.tokenizerSource || { + repo: "", + revision: "main", + fileName: "" + } + }); + setUseCustomTokenizer(!!oldModel.local.tokenizerSource); + } else { + setNewModel(oldModel.local.fileName); + } } else if (oldModel.modelType === "EMBEDDING") { setNewModels(embeddingModels); setNewModel(oldModel.local.fileName); } else { setNewModels(transcriptionModels); - setNewModel(oldModel.local.fileName); + setNewModel(oldModel.local.fileName || "whisper_small"); } if (oldModel.api) { @@ -256,7 +319,7 @@ export default function ModelModal(props: { close: () => void; oldModel?: any }) setNewModel("bert"); } else { setNewModels(transcriptionModels); - setNewModel("whisper"); + setNewModel("whisper_small"); } closeMenu("ai-types"); }} @@ -404,6 +467,140 @@ export default function ModelModal(props: { close: () => void; oldModel?: any }) )} )} + + {newModel === "Custom Hugging Face Model" && ( + + + Note: The model file must be a GGUF file format, which typically includes the tokenizer. + + + + + Repository: + + setCustomHfModel({ + ...customHfModel, + huggingfaceRepo: e.target.value + })} + style={{ width: "100%" }} + /> + + + + + Branch/Revision: + + setCustomHfModel({ + ...customHfModel, + revision: e.target.value + })} + style={{ width: "100%" }} + /> + + + + + Filename: + + setCustomHfModel({ + ...customHfModel, + fileName: e.target.value + })} + style={{ width: "100%" }} + /> + + + + setUseCustomTokenizer(e.target.checked)} + > + Use Custom Tokenizer (Optional) + + + + {useCustomTokenizer && ( + + + + + Tokenizer Repo: + + setCustomHfModel({ + ...customHfModel, + tokenizerSource: { + ...customHfModel.tokenizerSource, + repo: e.target.value + } + })} + style={{ width: "100%" }} + /> + + + + + Tokenizer Branch: + + setCustomHfModel({ + ...customHfModel, + tokenizerSource: { + ...customHfModel.tokenizerSource, + revision: e.target.value + } + })} + style={{ width: "100%" }} + /> + + + + + Tokenizer File: + + setCustomHfModel({ + ...customHfModel, + tokenizerSource: { + ...customHfModel.tokenizerSource, + fileName: e.target.value + } + })} + style={{ width: "100%" }} + /> + + + + )} + + )} {newModel === "External API" && (!apiModelValid || !apiValid) ? (