Skip to content

Commit 32d97a2

Browse files
committed
feat: add tools/ai/wasmtrans
Updates #1454 Signed-off-by: Aofei Sheng <[email protected]>
1 parent d94fa92 commit 32d97a2

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

tools/ai/wasmtrans/promise.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//go:build js && wasm
2+
3+
package wasmtrans
4+
5+
import (
6+
"context"
7+
"errors"
8+
"fmt"
9+
"syscall/js"
10+
)
11+
12+
// awaitPromise waits for a JavaScript Promise to resolve or reject.
13+
func awaitPromise(ctx context.Context, promise js.Value) (js.Value, error) {
14+
if promise.IsUndefined() || promise.IsNull() {
15+
return js.Undefined(), errors.New("promise is undefined or null")
16+
}
17+
if promise.Type() != js.TypeObject || promise.Get("then").Type() != js.TypeFunction {
18+
return js.Undefined(), errors.New("value is not a Promise")
19+
}
20+
21+
resultChan := make(chan js.Value, 1)
22+
then := js.FuncOf(func(this js.Value, args []js.Value) any {
23+
result := js.Undefined()
24+
if len(args) > 0 {
25+
result = args[0]
26+
}
27+
resultChan <- result
28+
return nil
29+
})
30+
defer then.Release()
31+
32+
errChan := make(chan error, 1)
33+
catch := js.FuncOf(func(this js.Value, args []js.Value) any {
34+
errMsg := "promise rejected"
35+
if len(args) > 0 {
36+
errVal := args[0]
37+
if errVal.Type() == js.TypeObject && errVal.Get("message").Type() == js.TypeString {
38+
errMsg = fmt.Sprintf("promise rejected: %s", errVal.Get("message"))
39+
} else if errVal.Type() == js.TypeString {
40+
errMsg = fmt.Sprintf("promise rejected: %s", errVal)
41+
} else {
42+
errMsg = fmt.Sprintf("promise rejected: %v", errVal)
43+
}
44+
}
45+
errChan <- errors.New(errMsg)
46+
return nil
47+
})
48+
defer catch.Release()
49+
50+
promise.Call("then", then).Call("catch", catch)
51+
select {
52+
case result := <-resultChan:
53+
return result, nil
54+
case err := <-errChan:
55+
return js.Undefined(), err
56+
case <-ctx.Done():
57+
return js.Undefined(), fmt.Errorf("context cancelled while waiting for promise: %w", ctx.Err())
58+
}
59+
}

tools/ai/wasmtrans/wasmtrans.go

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//go:build js && wasm
2+
3+
// Package wasmtrans provides a Transport implementation for AI interactions
4+
// within a WebAssembly (Wasm) environment, typically running in a browser.
5+
package wasmtrans
6+
7+
import (
8+
"context"
9+
"encoding/json"
10+
"fmt"
11+
"syscall/js"
12+
13+
"github.com/goplus/builder/tools/ai"
14+
)
15+
16+
// wasmTransport implements [ai.Transport] using JavaScript's fetch API.
17+
type wasmTransport struct {
18+
// endpoint is the URL for the AI interaction API.
19+
endpoint string
20+
21+
// tokenProvider is a function that returns the auth token (without "Bearer ").
22+
// It's called before each request. If it returns "", no auth header is sent.
23+
tokenProvider func() string
24+
}
25+
26+
// Option is a function type for configuring the [wasmTransport].
27+
type Option func(*wasmTransport)
28+
29+
// WithEndpoint sets a custom endpoint for the AI interaction API.
30+
func WithEndpoint(endpoint string) Option {
31+
return func(t *wasmTransport) {
32+
t.endpoint = endpoint
33+
}
34+
}
35+
36+
// WithTokenProvider sets a function that provides the Bearer token for
37+
// Authorization. The provider function will be called before each request to
38+
// get the current token. If the provider returns an empty string, no
39+
// Authorization header will be sent.
40+
func WithTokenProvider(provider func() string) Option {
41+
return func(t *wasmTransport) {
42+
t.tokenProvider = provider
43+
}
44+
}
45+
46+
// New creates a new [ai.Transport] suitable for Wasm environments. It uses
47+
// JavaScript interop (syscall/js) to make network requests. By default, it
48+
// uses "/api/ai/interaction" endpoint and sends no Authorization token.
49+
func New(opts ...Option) ai.Transport {
50+
t := &wasmTransport{
51+
endpoint: "/api/ai/interaction",
52+
tokenProvider: func() string { return "" },
53+
}
54+
for _, opt := range opts {
55+
opt(t)
56+
}
57+
return t
58+
}
59+
60+
// Interact implements [ai.Transport].
61+
func (t *wasmTransport) Interact(ctx context.Context, req ai.Request) (ai.Response, error) {
62+
reqBody, err := json.Marshal(req)
63+
if err != nil {
64+
return ai.Response{}, fmt.Errorf("failed to marshal request: %w", err)
65+
}
66+
67+
headers := map[string]any{
68+
"Content-Type": "application/json",
69+
}
70+
if t.tokenProvider != nil {
71+
if token := t.tokenProvider(); token != "" {
72+
headers["Authorization"] = "Bearer " + token
73+
}
74+
}
75+
76+
jsAbortController := js.Global().Get("AbortController").New()
77+
defer context.AfterFunc(ctx, func() {
78+
jsAbortController.Call("abort")
79+
})()
80+
jsAbortSignal := jsAbortController.Get("signal")
81+
82+
jsResp, err := awaitPromise(ctx, js.Global().Call("fetch", t.endpoint+"/turn", map[string]any{
83+
"method": "POST",
84+
"headers": headers,
85+
"body": string(reqBody),
86+
"signal": jsAbortSignal,
87+
}))
88+
if err != nil {
89+
return ai.Response{}, fmt.Errorf("failed to fetch: %w", err)
90+
}
91+
92+
if !jsResp.Get("ok").Bool() {
93+
status := jsResp.Get("status").Int()
94+
statusText := jsResp.Get("statusText").String()
95+
96+
bodyPromise := jsResp.Call("text")
97+
bodyTextVal, bodyErr := awaitPromise(ctx, bodyPromise)
98+
if bodyErr != nil {
99+
return ai.Response{}, fmt.Errorf("failed to fetch with status %d %s (and failed to read error body: %w)", status, statusText, bodyErr)
100+
}
101+
102+
bodyText := bodyTextVal.String()
103+
return ai.Response{}, fmt.Errorf("failed to fetch with status %d %s: %s", status, statusText, bodyText)
104+
}
105+
jsJSON, err := awaitPromise(ctx, jsResp.Call("json"))
106+
if err != nil {
107+
return ai.Response{}, fmt.Errorf("failed to process json response: %w", err)
108+
}
109+
jsonString := js.Global().Get("JSON").Call("stringify", jsJSON).String()
110+
111+
var aiResp ai.Response
112+
if err := json.Unmarshal([]byte(jsonString), &aiResp); err != nil {
113+
return ai.Response{}, fmt.Errorf("failed to unmarshal response json: %w", err)
114+
}
115+
return aiResp, nil
116+
}

0 commit comments

Comments
 (0)