Skip to content

Commit ef72225

Browse files
committedMar 5, 2025·
feat: add more AgentOption for react agent and a converter function to compose.Option
Change-Id: I3a069cd40b5fff9d701b35c4f1d383c303e2387f
1 parent 2d75c5c commit ef72225

File tree

12 files changed

+440
-48
lines changed

12 files changed

+440
-48
lines changed
 

‎_typos.toml

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Invokable = "Invokable"
66
invokable = "invokable"
77
InvokableLambda = "InvokableLambda"
88
InvokableRun = "InvokableRun"
9+
typ = "typ"
910

1011
[files]
1112
extend-exclude = ["go.mod", "go.sum", "check_branch_name.sh"]

‎components/tool/option.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
package tool
1818

1919
// Option defines call option for InvokableTool or StreamableTool component, which is part of component interface signature.
20-
// Each tool implementation could define its own options struct and option funcs within its own package,
21-
// then wrap the impl specific option funcs into this type, before passing to InvokableRun or StreamableRun.
20+
// Each tool implementation could define its own options struct and option functions within its own package,
21+
// then wrap the impl specific option functions into this type, before passing to InvokableRun or StreamableRun.
2222
type Option struct {
2323
implSpecificOptFn any
2424
}

‎compose/graph_call_options.go

+18
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ func (o Option) DesignateNodeWithPath(path ...*NodePath) Option {
8080
return o
8181
}
8282

83+
// DesignateNodePrependPath prepends the prefix to the path of the node(s) to which the option will be applied to.
84+
// Useful when you already have an Option designated to a graph's node, and now you want to add this graph as a subgraph.
85+
// e.g.
86+
// Your subgraph has a Node with key "A", and your subgraph's NodeKey is "sub_graph", you can specify option to A using:
87+
//
88+
// option := WithCallbacks(...).DesignateNode("A").DesignateNodePrependPath("sub_graph")
89+
// Note: as an End User, you probably don't need to use this method, as DesignateNodeWithPath will be sufficient in most use cases.
90+
// Note: as a Flow author, if you define your own Option type, and at the same time your flow can be exported to graph and added as GraphNode,
91+
// you can use this method to prepend your Option's designated path with the GraphNode's path.
92+
func (o Option) DesignateNodePrependPath(prefix *NodePath) Option {
93+
for i := range o.paths {
94+
p := o.paths[i]
95+
p.path = append(prefix.path, p.path...)
96+
}
97+
98+
return o
99+
}
100+
83101
// WithEmbeddingOption is a functional option type for embedding component.
84102
// e.g.
85103
//

‎flow/agent/agent_option.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package agent
1919
import "github.com/cloudwego/eino/compose"
2020

2121
// AgentOption is the common option type for various agent and multi-agent implementations.
22-
// For options intended to use with underlying graph or components, use WithComposeOptions to specify.
2322
// For options intended to use with particular agent/multi-agent implementations, use WrapImplSpecificOptFn to specify.
2423
type AgentOption struct {
2524
implSpecificOptFn any
2625
composeOptions []compose.Option
2726
}
2827

2928
// GetComposeOptions returns all compose options from the given agent options.
29+
// Deprecated
3030
func GetComposeOptions(opts ...AgentOption) []compose.Option {
3131
var result []compose.Option
3232
for _, opt := range opts {
@@ -37,6 +37,7 @@ func GetComposeOptions(opts ...AgentOption) []compose.Option {
3737
}
3838

3939
// WithComposeOptions returns an agent option that specifies compose options.
40+
// Deprecated: use option functions defined by each agent flow implementation instead.
4041
func WithComposeOptions(opts ...compose.Option) AgentOption {
4142
return AgentOption{
4243
composeOptions: opts,

‎flow/agent/multiagent/host/callback.go

+73-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type HandOffInfo struct {
3939
}
4040

4141
// ConvertCallbackHandlers converts []host.MultiAgentCallback to callbacks.Handler.
42+
// Deprecated: use ConvertOptions to convert agent.AgentOption to compose.Option when adding MultiAgent's Graph to another Graph.
4243
func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler {
4344
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
4445
if output == nil || info == nil {
@@ -121,5 +122,76 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
121122
}
122123

123124
handlers := agentOptions.agentCallbacks
124-
return ConvertCallbackHandlers(handlers...)
125+
126+
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
127+
if output == nil || info == nil {
128+
return ctx
129+
}
130+
131+
msg := output.Message
132+
if msg == nil || msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
133+
return ctx
134+
}
135+
136+
agentName := msg.ToolCalls[0].Function.Name
137+
argument := msg.ToolCalls[0].Function.Arguments
138+
139+
for _, cb := range handlers {
140+
ctx = cb.OnHandOff(ctx, &HandOffInfo{
141+
ToAgentName: agentName,
142+
Argument: argument,
143+
})
144+
}
145+
146+
return ctx
147+
}
148+
149+
onChatModelEndWithStreamOutput := func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context {
150+
if output == nil || info == nil {
151+
return ctx
152+
}
153+
154+
defer output.Close()
155+
156+
var msgs []*schema.Message
157+
for {
158+
oneOutput, err := output.Recv()
159+
if err == io.EOF {
160+
break
161+
}
162+
if err != nil {
163+
return ctx
164+
}
165+
166+
msg := oneOutput.Message
167+
if msg == nil {
168+
continue
169+
}
170+
171+
msgs = append(msgs, msg)
172+
}
173+
174+
msg, err := schema.ConcatMessages(msgs)
175+
if err != nil {
176+
return ctx
177+
}
178+
179+
if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
180+
return ctx
181+
}
182+
183+
for _, cb := range handlers {
184+
ctx = cb.OnHandOff(ctx, &HandOffInfo{
185+
ToAgentName: msg.ToolCalls[0].Function.Name,
186+
Argument: msg.ToolCalls[0].Function.Arguments,
187+
})
188+
}
189+
190+
return ctx
191+
}
192+
193+
return template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
194+
OnEnd: onChatModelEnd,
195+
OnEndWithStreamOutput: onChatModelEndWithStreamOutput,
196+
}).Handler()
125197
}

‎flow/agent/multiagent/host/compose_test.go

+97-8
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@ import (
2424
"github.com/stretchr/testify/assert"
2525
"go.uber.org/mock/gomock"
2626

27+
"github.com/cloudwego/eino/callbacks"
28+
chatmodel "github.com/cloudwego/eino/components/model"
2729
"github.com/cloudwego/eino/components/prompt"
2830
"github.com/cloudwego/eino/compose"
2931
"github.com/cloudwego/eino/flow/agent"
3032
"github.com/cloudwego/eino/internal/generic"
3133
"github.com/cloudwego/eino/internal/mock/components/model"
3234
"github.com/cloudwego/eino/schema"
35+
template "github.com/cloudwego/eino/utils/callbacks"
3336
)
3437

3538
func TestHostMultiAgent(t *testing.T) {
@@ -48,6 +51,14 @@ func TestHostMultiAgent(t *testing.T) {
4851

4952
specialist2 := &Specialist{
5053
Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) {
54+
agentOpts := agent.GetImplSpecificOptions(&specialist2Options{}, opts...)
55+
if agentOpts.mockOutput != nil {
56+
return &schema.Message{
57+
Role: schema.Assistant,
58+
Content: *agentOpts.mockOutput,
59+
}, nil
60+
}
61+
5162
return &schema.Message{
5263
Role: schema.Assistant,
5364
Content: "specialist2 invoke answer",
@@ -92,11 +103,18 @@ func TestHostMultiAgent(t *testing.T) {
92103
Content: "direct answer",
93104
}
94105

95-
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(directAnswerMsg, nil).Times(1)
106+
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
107+
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
108+
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
109+
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
110+
return directAnswerMsg, nil
111+
}).
112+
Times(1)
96113

97114
mockCallback := &mockAgentCallback{}
98115

99-
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
116+
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback),
117+
WithAgentModelOptions(hostMA.HostNodeKey(), chatmodel.WithTemperature(0.7)))
100118
assert.NoError(t, err)
101119
assert.Equal(t, "direct answer", out.Content)
102120
assert.Empty(t, mockCallback.infos)
@@ -164,11 +182,18 @@ func TestHostMultiAgent(t *testing.T) {
164182
}
165183

166184
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
167-
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
185+
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
186+
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
187+
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
188+
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
189+
return specialistMsg, nil
190+
}).
191+
Times(1)
168192

169193
mockCallback := &mockAgentCallback{}
170194

171-
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
195+
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback),
196+
WithAgentModelOptions(specialist1.Name, chatmodel.WithTemperature(0.7)))
172197
assert.NoError(t, err)
173198
assert.Equal(t, "specialist 1 answer", out.Content)
174199
assert.Equal(t, []*HandOffInfo{
@@ -379,16 +404,41 @@ func TestHostMultiAgent(t *testing.T) {
379404
},
380405
}
381406

382-
specialistMsg := &schema.Message{
407+
specialist1Msg := &schema.Message{
383408
Role: schema.Assistant,
384409
Content: "Beijing",
385410
}
386411

387-
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
388-
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
412+
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(2)
413+
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
414+
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
415+
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
416+
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
417+
return specialist1Msg, nil
418+
}).
419+
Times(1)
389420

390421
mockCallback := &mockAgentCallback{}
391422

423+
var hostOutput, specialist1Output, specialist2Output string
424+
hostModelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
425+
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *chatmodel.CallbackOutput) context.Context {
426+
hostOutput = output.Message.ToolCalls[0].Function.Name
427+
return ctx
428+
},
429+
}).Handler()
430+
specialist1ModelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
431+
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *chatmodel.CallbackOutput) context.Context {
432+
specialist1Output = output.Message.Content
433+
return ctx
434+
},
435+
}).Handler()
436+
specialist2LambdaCallback := template.NewHandlerHelper().Lambda(callbacks.NewHandlerBuilder().OnEndFn(
437+
func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
438+
specialist2Output = output.(*schema.Message).Content
439+
return ctx
440+
}).Build()).Handler()
441+
392442
hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{
393443
Host: Host{
394444
ChatModel: mockHostLLM,
@@ -409,7 +459,14 @@ func TestHostMultiAgent(t *testing.T) {
409459
Compile(ctx)
410460
assert.NoError(t, err)
411461

412-
out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, compose.WithCallbacks(ConvertCallbackHandlers(mockCallback)).DesignateNodeWithPath(compose.NewNodePath("host_ma_node", hostMA.HostNodeKey())))
462+
convertedOptions := ConvertOptions(compose.NewNodePath("host_ma_node"), WithAgentCallbacks(mockCallback),
463+
WithAgentModelOptions(specialist1.Name, chatmodel.WithTemperature(0.7)),
464+
WithAgentModelCallbacks(hostMA.HostNodeKey(), hostModelCallback),
465+
WithAgentModelCallbacks(specialist1.Name, specialist1ModelCallback),
466+
WithSpecialistLambdaCallbacks(specialist2.Name, specialist2LambdaCallback),
467+
WithSpecialistLambdaOptions(specialist2.Name, withSpecialist2MockOutput("mock_city_name")))
468+
469+
out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, convertedOptions...)
413470
assert.NoError(t, err)
414471
assert.Equal(t, "Beijing", out.Content)
415472
assert.Equal(t, []*HandOffInfo{
@@ -418,6 +475,28 @@ func TestHostMultiAgent(t *testing.T) {
418475
Argument: `{"reason": "specialist 1 is the best"}`,
419476
},
420477
}, mockCallback.infos)
478+
assert.Equal(t, hostOutput, specialist1.Name)
479+
assert.Equal(t, specialist1Output, out.Content)
480+
assert.Equal(t, specialist2Output, "")
481+
482+
handOffMsg.ToolCalls[0].Function.Name = specialist2.Name
483+
handOffMsg.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}`
484+
485+
out, err = fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, convertedOptions...)
486+
assert.NoError(t, err)
487+
assert.Equal(t, "mock_city_name", out.Content)
488+
assert.Equal(t, []*HandOffInfo{
489+
{
490+
ToAgentName: specialist1.Name,
491+
Argument: `{"reason": "specialist 1 is the best"}`,
492+
},
493+
{
494+
ToAgentName: specialist2.Name,
495+
Argument: `{"reason": "specialist 2 is even better"}`,
496+
},
497+
}, mockCallback.infos)
498+
assert.Equal(t, hostOutput, specialist2.Name)
499+
assert.Equal(t, specialist2Output, "mock_city_name")
421500
})
422501
}
423502

@@ -429,3 +508,13 @@ func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) co
429508
m.infos = append(m.infos, info)
430509
return ctx
431510
}
511+
512+
type specialist2Options struct {
513+
mockOutput *string
514+
}
515+
516+
func withSpecialist2MockOutput(mockOutput string) agent.AgentOption {
517+
return agent.WrapImplSpecificOptFn(func(o *specialist2Options) {
518+
o.mockOutput = &mockOutput
519+
})
520+
}

‎flow/agent/multiagent/host/options.go

+75-1
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,88 @@
1616

1717
package host
1818

19-
import "github.com/cloudwego/eino/flow/agent"
19+
import (
20+
"github.com/cloudwego/eino/callbacks"
21+
"github.com/cloudwego/eino/components/model"
22+
"github.com/cloudwego/eino/compose"
23+
"github.com/cloudwego/eino/flow/agent"
24+
)
2025

2126
type options struct {
2227
agentCallbacks []MultiAgentCallback
28+
composeOptions []compose.Option
2329
}
2430

2531
func WithAgentCallbacks(agentCallbacks ...MultiAgentCallback) agent.AgentOption {
2632
return agent.WrapImplSpecificOptFn(func(opts *options) {
2733
opts.agentCallbacks = append(opts.agentCallbacks, agentCallbacks...)
2834
})
2935
}
36+
37+
// WithAgentModelOptions returns an agent option that specifies model.Option for the given agent.
38+
// The given agentName should be the name of the agent in the graph.
39+
// e.g.
40+
//
41+
// if specifying model.Option for the host agent, use MultiAgent.HostNodeKey() as the agentName.
42+
// if specifying model.Option for the specialist agent, use the specialist agent's AgentMeta.Name as the agentName.
43+
func WithAgentModelOptions(agentName string, opts ...model.Option) agent.AgentOption {
44+
return agent.WrapImplSpecificOptFn(func(o *options) {
45+
o.composeOptions = append(o.composeOptions, compose.WithChatModelOption(opts...).DesignateNode(agentName))
46+
})
47+
}
48+
49+
// WithAgentModelCallbacks returns an agent option that specifies callbacks.Handler for the given agent's ChatModel.
50+
// The given agentName should be the name of the agent in the graph.
51+
// e.g.
52+
//
53+
// if specifying model.Option for the host agent, use MultiAgent.HostNodeKey() as the agentName.
54+
// if specifying model.Option for the specialist agent, use the specialist agent's AgentMeta.Name as the agentName.
55+
func WithAgentModelCallbacks(agentName string, cbs ...callbacks.Handler) agent.AgentOption {
56+
return agent.WrapImplSpecificOptFn(func(o *options) {
57+
o.composeOptions = append(o.composeOptions, compose.WithCallbacks(cbs...).DesignateNode(agentName))
58+
})
59+
}
60+
61+
// WithSpecialistLambdaOptions returns an agent option that specifies agent.AgentOption for the given specialist's Lambda.
62+
// The given specialistName should be the name of the specialist in the graph.
63+
func WithSpecialistLambdaOptions(specialistName string, opts ...agent.AgentOption) agent.AgentOption {
64+
anyOpts := make([]any, len(opts))
65+
for i, opt := range opts {
66+
anyOpts[i] = opt
67+
}
68+
return agent.WrapImplSpecificOptFn(func(o *options) {
69+
o.composeOptions = append(o.composeOptions, compose.WithLambdaOption(anyOpts...).DesignateNode(specialistName))
70+
})
71+
}
72+
73+
// WithSpecialistLambdaCallbacks returns an agent option that specifies callbacks.Handler for the given specialist's Lambda.
74+
// The given specialistName should be the name of the specialist in the graph.
75+
func WithSpecialistLambdaCallbacks(specialistName string, cbs ...callbacks.Handler) agent.AgentOption {
76+
return agent.WrapImplSpecificOptFn(func(o *options) {
77+
o.composeOptions = append(o.composeOptions, compose.WithCallbacks(cbs...).DesignateNode(specialistName))
78+
})
79+
}
80+
81+
// ConvertOptions converts agent options to compose options, and designate them to the Agent sub graph specified by nodePath.
82+
// Useful when adding MultiAgent's Graph to another Graph.
83+
// The parameter nodePath is the path to the MultiAgent sub graph within the whole Graph.
84+
// If nodePath == nil, then MultiAgent's Graph is treated as a stand-alone, top-level Graph.
85+
func ConvertOptions(nodePath *compose.NodePath, opts ...agent.AgentOption) []compose.Option {
86+
composeOpts := agent.GetImplSpecificOptions(&options{}, opts...).composeOptions
87+
if nodePath != nil {
88+
for i := range composeOpts {
89+
composeOpts[i] = composeOpts[i].DesignateNodePrependPath(nodePath)
90+
}
91+
}
92+
93+
convertedCallbackHandler := convertCallbacks(opts...)
94+
if convertedCallbackHandler != nil {
95+
callbackOpt := compose.WithCallbacks(convertedCallbackHandler).DesignateNode(defaultHostNodeKey)
96+
if nodePath != nil {
97+
return append(composeOpts, callbackOpt.DesignateNodePrependPath(nodePath))
98+
}
99+
return append(composeOpts, callbackOpt)
100+
}
101+
102+
return composeOpts
103+
}

‎flow/agent/multiagent/host/types.go

+2-16
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,11 @@ type MultiAgent struct {
3838
}
3939

4040
func (ma *MultiAgent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) {
41-
composeOptions := agent.GetComposeOptions(opts...)
42-
43-
handler := convertCallbacks(opts...)
44-
if handler != nil {
45-
composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(ma.HostNodeKey()))
46-
}
47-
48-
return ma.runnable.Invoke(ctx, input, composeOptions...)
41+
return ma.runnable.Invoke(ctx, input, ConvertOptions(nil, opts...)...)
4942
}
5043

5144
func (ma *MultiAgent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) {
52-
composeOptions := agent.GetComposeOptions(opts...)
53-
54-
handler := convertCallbacks(opts...)
55-
if handler != nil {
56-
composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(ma.HostNodeKey()))
57-
}
58-
59-
return ma.runnable.Stream(ctx, input, composeOptions...)
45+
return ma.runnable.Stream(ctx, input, ConvertOptions(nil, opts...)...)
6046
}
6147

6248
// ExportGraph exports the underlying graph from MultiAgent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph.

‎flow/agent/react/callback.go

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import (
2727
// callback := BuildAgentCallback(modelHandler, toolHandler)
2828
// agent, err := react.NewAgent(ctx, &AgentConfig{})
2929
// agent.Generate(ctx, input, agent.WithComposeOptions(compose.WithCallbacks(callback)))
30+
//
31+
// Deprecated: use WithModelCallbacks and WithToolCallbacks instead.
3032
func BuildAgentCallback(modelHandler *template.ModelCallbackHandler, toolHandler *template.ToolCallbackHandler) callbacks.Handler {
3133
return template.NewHandlerHelper().ChatModel(modelHandler).Tool(toolHandler).Handler()
3234
}

‎flow/agent/react/options.go

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package react
18+
19+
import (
20+
"github.com/cloudwego/eino/callbacks"
21+
"github.com/cloudwego/eino/components/model"
22+
"github.com/cloudwego/eino/components/tool"
23+
"github.com/cloudwego/eino/compose"
24+
"github.com/cloudwego/eino/flow/agent"
25+
)
26+
27+
type options struct {
28+
composeOptions []compose.Option
29+
}
30+
31+
// WithChatModelOptions returns an agent option that specifies model.Option for the ChatModel node.
32+
func WithChatModelOptions(opts ...model.Option) agent.AgentOption {
33+
return agent.WrapImplSpecificOptFn(func(o *options) {
34+
o.composeOptions = append(o.composeOptions, compose.WithChatModelOption(opts...))
35+
})
36+
}
37+
38+
// WithToolOptions returns an agent option that specifies tool.Option for the ToolsNode node.
39+
func WithToolOptions(opts ...tool.Option) agent.AgentOption {
40+
return agent.WrapImplSpecificOptFn(func(o *options) {
41+
o.composeOptions = append(o.composeOptions, compose.WithToolsNodeOption(compose.WithToolOption(opts...)))
42+
})
43+
}
44+
45+
// WithToolList returns an agent option that specifies tool list for the ToolsNode node, overriding default tool list.
46+
func WithToolList(tool ...tool.BaseTool) agent.AgentOption {
47+
return agent.WrapImplSpecificOptFn(func(o *options) {
48+
o.composeOptions = append(o.composeOptions, compose.WithToolsNodeOption(compose.WithToolList(tool...)))
49+
})
50+
}
51+
52+
// WithRuntimeMaxSteps returns an agent option that specifies max steps for the agent, overriding default configuration.
53+
func WithRuntimeMaxSteps(maxSteps int) agent.AgentOption {
54+
return agent.WrapImplSpecificOptFn(func(o *options) {
55+
o.composeOptions = append(o.composeOptions, compose.WithRuntimeMaxSteps(maxSteps))
56+
})
57+
}
58+
59+
// WithModelCallbacks returns an agent option that specifies callback handlers for the ChatModel node.
60+
func WithModelCallbacks(cbs ...callbacks.Handler) agent.AgentOption {
61+
return agent.WrapImplSpecificOptFn(func(o *options) {
62+
o.composeOptions = append(o.composeOptions, compose.WithCallbacks(cbs...).DesignateNode(nodeKeyModel))
63+
})
64+
}
65+
66+
// WithToolCallbacks returns an agent option that specifies callback handlers for the tools executed by ToolsNode node.
67+
func WithToolCallbacks(cbs ...callbacks.Handler) agent.AgentOption {
68+
return agent.WrapImplSpecificOptFn(func(o *options) {
69+
o.composeOptions = append(o.composeOptions, compose.WithCallbacks(cbs...).DesignateNode(nodeKeyTools))
70+
})
71+
}
72+
73+
// ConvertOptions converts agent options to compose options, and designate them to the Agent sub graph specified by nodePath.
74+
// Useful when adding Agent's Graph to another Graph.
75+
// The parameter nodePath is the path to the Agent sub graph within the whole Graph.
76+
// If nodePath == nil, then Agent's Graph is treated as a stand-alone, top-level Graph.
77+
func ConvertOptions(nodePath *compose.NodePath, opts ...agent.AgentOption) []compose.Option {
78+
composeOpts := agent.GetImplSpecificOptions(&options{}, opts...).composeOptions
79+
if nodePath != nil {
80+
for i := range composeOpts {
81+
composeOpts[i] = composeOpts[i].DesignateNodePrependPath(nodePath)
82+
}
83+
}
84+
85+
return composeOpts
86+
}

‎flow/agent/react/react.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func getReturnDirectlyToolCallID(input *schema.Message, toolReturnDirectly map[s
308308

309309
// Generate generates a response from the agent.
310310
func (r *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (output *schema.Message, err error) {
311-
output, err = r.runnable.Invoke(ctx, input, agent.GetComposeOptions(opts...)...)
311+
output, err = r.runnable.Invoke(ctx, input, ConvertOptions(nil, opts...)...)
312312
if err != nil {
313313
return nil, err
314314
}
@@ -319,7 +319,7 @@ func (r *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...a
319319
// Stream calls the agent and returns a stream response.
320320
func (r *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (
321321
output *schema.StreamReader[*schema.Message], err error) {
322-
res, err := r.runnable.Stream(ctx, input, agent.GetComposeOptions(opts...)...)
322+
res, err := r.runnable.Stream(ctx, input, ConvertOptions(nil, opts...)...)
323323
if err != nil {
324324
return nil, err
325325
}

‎flow/agent/react/react_test.go

+80-17
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/stretchr/testify/assert"
2929
"go.uber.org/mock/gomock"
3030

31+
"github.com/cloudwego/eino/callbacks"
3132
"github.com/cloudwego/eino/components/model"
3233
"github.com/cloudwego/eino/components/tool"
3334
"github.com/cloudwego/eino/compose"
@@ -53,6 +54,11 @@ func TestReact(t *testing.T) {
5354
times := 0
5455
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
5556
DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
57+
modelOpts := model.GetCommonOptions(&model.Options{}, opts...)
58+
if modelOpts.Temperature == nil || *modelOpts.Temperature != 0.7 {
59+
return nil, errors.New("temperature not match")
60+
}
61+
5662
times++
5763
if times <= 2 {
5864
info, _ := fakeTool.Info(ctx)
@@ -95,13 +101,10 @@ func TestReact(t *testing.T) {
95101
Role: schema.User,
96102
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
97103
},
98-
}, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest)))
104+
}, WithChatModelOptions(model.WithTemperature(0.7)),
105+
WithToolOptions(withMockOutput("mock_output")))
99106
assert.Nil(t, err)
100107

101-
if out != nil {
102-
t.Log(out.Content)
103-
}
104-
105108
// test return directly
106109
times = 0
107110
a, err = NewAgent(ctx, &AgentConfig{
@@ -123,12 +126,37 @@ func TestReact(t *testing.T) {
123126
Role: schema.User,
124127
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
125128
},
126-
}, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest)))
129+
}, WithChatModelOptions(model.WithTemperature(0.7)))
127130
assert.Nil(t, err)
128131

129132
if out != nil {
130133
t.Log(out.Content)
131134
}
135+
136+
// test max steps
137+
times = 0
138+
out, err = a.Generate(ctx, []*schema.Message{
139+
{
140+
Role: schema.User,
141+
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
142+
},
143+
}, WithRuntimeMaxSteps(1),
144+
WithChatModelOptions(model.WithTemperature(0.7)))
145+
assert.ErrorContains(t, err, "max step")
146+
147+
// test change tools
148+
times = 0
149+
fakeStreamTool := &fakeStreamToolGreetForTest{
150+
tarCount: 20,
151+
}
152+
out, err = a.Generate(ctx, []*schema.Message{
153+
{
154+
Role: schema.User,
155+
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
156+
},
157+
}, WithChatModelOptions(model.WithTemperature(0.7)),
158+
WithToolList(fakeStreamTool))
159+
assert.ErrorContains(t, err, "tool greet not found in toolsNode indexes")
132160
}
133161

134162
func TestReactStream(t *testing.T) {
@@ -224,7 +252,7 @@ func TestReactStream(t *testing.T) {
224252
Role: schema.User,
225253
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
226254
},
227-
}, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest)))
255+
})
228256
if err != nil {
229257
t.Fatal(err)
230258
}
@@ -275,7 +303,7 @@ func TestReactStream(t *testing.T) {
275303
Role: schema.User,
276304
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
277305
},
278-
}, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest)))
306+
})
279307
if err != nil {
280308
t.Fatal(err)
281309
}
@@ -311,7 +339,7 @@ func TestReactStream(t *testing.T) {
311339
Role: schema.User,
312340
Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!",
313341
},
314-
}, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest)))
342+
})
315343
assert.NoError(t, err)
316344

317345
defer out.Close()
@@ -385,7 +413,7 @@ func TestReactWithModifier(t *testing.T) {
385413
Role: schema.User,
386414
Content: "hello",
387415
},
388-
}, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest)))
416+
})
389417
if err != nil {
390418
t.Fatal(err)
391419
}
@@ -452,8 +480,7 @@ func TestAgentInGraph(t *testing.T) {
452480
r, err := chain.Compile(ctx)
453481
assert.Nil(t, err)
454482

455-
res, err := r.Invoke(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}},
456-
compose.WithCallbacks(callbackForTest))
483+
res, err := r.Invoke(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}}, compose.WithCallbacks(callbackForTest))
457484
assert.Nil(t, err)
458485

459486
t.Log(res)
@@ -510,16 +537,35 @@ func TestAgentInGraph(t *testing.T) {
510537
assert.Nil(t, err)
511538

512539
chain.
513-
AppendGraph(agentGraph, opts...).
540+
AppendGraph(agentGraph, append(opts, compose.WithNodeKey("react_agent"))...).
514541
AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (string, error) {
515542
t.Log("got agent response: ", input.Content)
516543
return input.Content, nil
517544
}))
518545
r, err := chain.Compile(ctx)
519546
assert.Nil(t, err)
520547

521-
outStream, err := r.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}},
522-
compose.WithCallbacks(callbackForTest))
548+
var modelCallbackCnt, toolCallbackCnt int
549+
modelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
550+
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context {
551+
modelCallbackCnt++
552+
return ctx
553+
},
554+
}).Handler()
555+
toolCallback := template.NewHandlerHelper().Tool(&template.ToolCallbackHandler{
556+
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context {
557+
toolCallbackCnt++
558+
return ctx
559+
},
560+
}).Handler()
561+
562+
agentOptions := []agent.AgentOption{
563+
WithChatModelOptions(model.WithTemperature(0.7)),
564+
WithModelCallbacks(modelCallback),
565+
WithToolCallbacks(toolCallback),
566+
}
567+
convertedOptions := ConvertOptions(compose.NewNodePath("react_agent"), agentOptions...)
568+
outStream, err := r.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}}, convertedOptions...)
523569
if err != nil {
524570
t.Fatal(err)
525571
}
@@ -541,8 +587,10 @@ func TestAgentInGraph(t *testing.T) {
541587
}
542588

543589
t.Log(msg)
544-
})
545590

591+
assert.Equal(t, 3, modelCallbackCnt)
592+
assert.Equal(t, 2, toolCallbackCnt)
593+
})
546594
}
547595

548596
type fakeStreamToolGreetForTest struct {
@@ -572,6 +620,16 @@ type fakeToolGreetForTest struct {
572620
curCount int
573621
}
574622

623+
type toolOption struct {
624+
mockOutput *string
625+
}
626+
627+
func withMockOutput(output string) tool.Option {
628+
return tool.WrapImplSpecificOptFn(func(opt *toolOption) {
629+
opt.mockOutput = &output
630+
})
631+
}
632+
575633
func (t *fakeToolGreetForTest) Info(_ context.Context) (*schema.ToolInfo, error) {
576634
return &schema.ToolInfo{
577635
Name: "greet",
@@ -602,7 +660,12 @@ func (t *fakeStreamToolGreetForTest) Info(_ context.Context) (*schema.ToolInfo,
602660
}, nil
603661
}
604662

605-
func (t *fakeToolGreetForTest) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) {
663+
func (t *fakeToolGreetForTest) InvokableRun(_ context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
664+
toolOpts := tool.GetImplSpecificOptions(&toolOption{}, opts...)
665+
if toolOpts.mockOutput != nil {
666+
return *toolOpts.mockOutput, nil
667+
}
668+
606669
p := &fakeToolInput{}
607670
err := sonic.UnmarshalString(argumentsInJSON, p)
608671
if err != nil {

0 commit comments

Comments
 (0)
Please sign in to comment.