Skip to content

Commit df93610

Browse files
feat: workflow can create FieldMapping from or to fields of maps (#95)
Change-Id: I20e74d123d5d3d0687bbfab13da3ce404c58e93d
1 parent 2d75c5c commit df93610

File tree

3 files changed

+128
-33
lines changed

3 files changed

+128
-33
lines changed

_typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Invokable = "Invokable"
66
invokable = "invokable"
77
InvokableLambda = "InvokableLambda"
88
InvokableRun = "InvokableRun"
9+
typ = "typ"
910

1011
[files]
1112
extend-exclude = ["go.mod", "go.sum", "check_branch_name.sh"]

compose/field_mapping.go

Lines changed: 98 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ func (m *FieldMapping) String() string {
4747
sb.WriteString("(field) of ")
4848
}
4949

50+
sb.WriteString(m.fromNodeKey)
51+
5052
if m.to != "" {
5153
sb.WriteString(" to ")
5254
sb.WriteString(m.to)
@@ -60,20 +62,23 @@ func (m *FieldMapping) String() string {
6062
// FromField creates a FieldMapping that maps a single predecessor field to the entire successor input.
6163
// This is an exclusive mapping - once set, no other field mappings can be added since the successor input
6264
// has already been fully mapped.
65+
// Field: either the field of a struct, or the key of a map.
6366
func FromField(from string) *FieldMapping {
6467
return &FieldMapping{
6568
from: from,
6669
}
6770
}
6871

69-
// ToField creates a FieldMapping that maps the entire predecessor output to a single successor field
72+
// ToField creates a FieldMapping that maps the entire predecessor output to a single successor field.
73+
// Field: either the field of a struct, or the key of a map.
7074
func ToField(to string) *FieldMapping {
7175
return &FieldMapping{
7276
to: to,
7377
}
7478
}
7579

76-
// MapFields creates a FieldMapping that maps a single predecessor field to a single successor field
80+
// MapFields creates a FieldMapping that maps a single predecessor field to a single successor field.
81+
// Field: either the field of a struct, or the key of a map.
7782
func MapFields(from, to string) *FieldMapping {
7883
return &FieldMapping{
7984
from: from,
@@ -160,6 +165,16 @@ func assignOne[T any](dest T, taken any, to string) (T, error) {
160165

161166
toSet := reflect.ValueOf(taken)
162167

168+
if destValue.Kind() == reflect.Map {
169+
key, err := checkAndExtractToMapKey(to, destValue, toSet)
170+
if err != nil {
171+
return dest, err
172+
}
173+
174+
destValue.SetMapIndex(key, toSet)
175+
return destValue.Interface().(T), nil
176+
}
177+
163178
field, err := checkAndExtractToField(to, destValue, toSet)
164179
if err != nil {
165180
return dest, err
@@ -171,30 +186,44 @@ func assignOne[T any](dest T, taken any, to string) (T, error) {
171186
}
172187

173188
func checkAndExtractFromField(fromField string, input reflect.Value) (reflect.Value, error) {
174-
if input.Kind() == reflect.Ptr {
175-
input = input.Elem()
176-
}
177-
178-
if input.Kind() != reflect.Struct {
179-
return reflect.Value{}, fmt.Errorf("mapping has from but input is not struct or struct ptr, type= %v", input.Type())
180-
}
181-
182189
f := input.FieldByName(fromField)
183190
if !f.IsValid() {
184-
return reflect.Value{}, fmt.Errorf("mapping has from not found. field=%v, inputType=%v", fromField, input.Type())
191+
return reflect.Value{}, fmt.Errorf("field mapping from a struct field, but field not found. field=%v, inputType=%v", fromField, input.Type())
185192
}
186193

187194
if !f.CanInterface() {
188-
return reflect.Value{}, fmt.Errorf("mapping has from not exported. field= %v, inputType=%v", fromField, input.Type())
195+
return reflect.Value{}, fmt.Errorf("field mapping from a struct field, but field not exported. field= %v, inputType=%v", fromField, input.Type())
189196
}
190197

191198
return f, nil
192199
}
193200

201+
func checkAndExtractFromMapKey(fromMapKey string, input reflect.Value) (reflect.Value, error) {
202+
if !reflect.TypeOf(fromMapKey).AssignableTo(input.Type().Key()) {
203+
return reflect.Value{}, fmt.Errorf("field mapping from a map key, but input is not a map with string key, type=%v", input.Type())
204+
}
205+
206+
v := input.MapIndex(reflect.ValueOf(fromMapKey))
207+
if !v.IsValid() {
208+
return reflect.Value{}, fmt.Errorf("field mapping from a map key, but key not found in input. key=%s, inputType= %v", fromMapKey, input.Type())
209+
}
210+
211+
return v, nil
212+
}
213+
194214
func checkAndExtractFieldType(field string, typ reflect.Type) (reflect.Type, error) {
195215
if len(field) == 0 {
196216
return typ, nil
197217
}
218+
219+
if typ.Kind() == reflect.Map {
220+
if typ.Key() != strType {
221+
return nil, fmt.Errorf("type[%v] is not a map with string key", typ)
222+
}
223+
224+
return typ.Elem(), nil
225+
}
226+
198227
for typ.Kind() == reflect.Ptr {
199228
typ = typ.Elem()
200229
}
@@ -215,31 +244,49 @@ func checkAndExtractFieldType(field string, typ reflect.Type) (reflect.Type, err
215244
return f.Type, nil
216245
}
217246

247+
var strType = reflect.TypeOf("")
248+
218249
func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflect.Value, error) {
219250
for output.Kind() == reflect.Ptr {
220251
output = output.Elem()
221252
}
222253

223254
if output.Kind() != reflect.Struct {
224-
return reflect.Value{}, fmt.Errorf("mapping has to but output is not a struct, type=%v", output.Type())
255+
return reflect.Value{}, fmt.Errorf("field mapping to a struct field but output is not a struct, type=%v", output.Type())
225256
}
226257

227258
field := output.FieldByName(toField)
228259
if !field.IsValid() {
229-
return reflect.Value{}, fmt.Errorf("mapping has to not found. field=%v, outputType=%v", toField, output.Type())
260+
return reflect.Value{}, fmt.Errorf("field mapping to a struct field, but field not found. field=%v, outputType=%v", toField, output.Type())
230261
}
231262

232263
if !field.CanSet() {
233-
return reflect.Value{}, fmt.Errorf("mapping has to not exported. field=%v, outputType=%v", toField, output.Type())
264+
return reflect.Value{}, fmt.Errorf("field mapping to a struct field, but field not exported. field=%v, outputType=%v", toField, output.Type())
234265
}
235266

236267
if !toSet.Type().AssignableTo(field.Type()) {
237-
return reflect.Value{}, fmt.Errorf("mapping to has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type())
268+
return reflect.Value{}, fmt.Errorf("field mapping to a struct field, but field has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type())
238269
}
239270

240271
return field, nil
241272
}
242273

274+
func checkAndExtractToMapKey(toMapKey string, output, toSet reflect.Value) (reflect.Value, error) {
275+
if output.Kind() != reflect.Map {
276+
return reflect.Value{}, fmt.Errorf("field mapping to a map key but output is not a map, type=%v", output.Type())
277+
}
278+
279+
if !reflect.TypeOf(toMapKey).AssignableTo(output.Type().Key()) {
280+
return reflect.Value{}, fmt.Errorf("field mapping to a map key but output is not a map with string key, type=%v", output.Type())
281+
}
282+
283+
if !toSet.Type().AssignableTo(output.Type().Elem()) {
284+
return reflect.Value{}, fmt.Errorf("field mapping to a map key but map value has a mismatched type. key=%s, from=%v, to=%v", toMapKey, toSet.Type(), output.Type().Elem())
285+
}
286+
287+
return reflect.ValueOf(toMapKey), nil
288+
}
289+
243290
func fieldMap(mappings []*FieldMapping) func(any) (map[string]any, error) {
244291
return func(input any) (map[string]any, error) {
245292
result := make(map[string]any, len(mappings))
@@ -262,14 +309,26 @@ func streamFieldMap(mappings []*FieldMapping) func(streamReader) streamReader {
262309
}
263310
}
264311

265-
func takeOne(input any, from string) (any, error) {
312+
func takeOne(input any, from string) (taken any, err error) {
266313
if len(from) == 0 {
267314
return input, nil
268315
}
269316

270317
inputValue := reflect.ValueOf(input)
271318

272-
f, err := checkAndExtractFromField(from, inputValue)
319+
var f reflect.Value
320+
switch inputValue.Kind() {
321+
case reflect.Map:
322+
f, err = checkAndExtractFromMapKey(from, inputValue)
323+
case reflect.Ptr:
324+
inputValue = inputValue.Elem()
325+
fallthrough
326+
case reflect.Struct:
327+
f, err = checkAndExtractFromField(from, inputValue)
328+
default:
329+
return reflect.Value{}, fmt.Errorf("field mapping from a field, but input is not struct, struct ptr or map, type= %v", inputValue.Type())
330+
}
331+
273332
if err != nil {
274333
return nil, err
275334
}
@@ -295,11 +354,18 @@ func isToAll(mappings []*FieldMapping) bool {
295354
return false
296355
}
297356

298-
func validateStruct(t reflect.Type) bool {
299-
for t.Kind() == reflect.Ptr {
357+
func validateStructOrMap(t reflect.Type) bool {
358+
switch t.Kind() {
359+
case reflect.Map:
360+
return true
361+
case reflect.Ptr:
300362
t = t.Elem()
363+
fallthrough
364+
case reflect.Struct:
365+
return true
366+
default:
367+
return false
301368
}
302-
return t.Kind() != reflect.Struct
303369
}
304370

305371
func validateFieldMapping(predecessorType reflect.Type, successorType reflect.Type, mappings []*FieldMapping) (*handlerPair, error) {
@@ -308,20 +374,25 @@ func validateFieldMapping(predecessorType reflect.Type, successorType reflect.Ty
308374
// check if mapping is legal
309375
if isFromAll(mappings) && isToAll(mappings) {
310376
return nil, fmt.Errorf("invalid field mappings: from all fields to all, use common edge instead")
311-
} else if !isToAll(mappings) && validateStruct(successorType) {
377+
} else if !isToAll(mappings) && !validateStructOrMap(successorType) {
312378
// if user has not provided a specific struct type, graph cannot construct any struct in the runtime
313-
return nil, fmt.Errorf("static check fail: upstream input type should be struct, actual: %v", successorType)
314-
} else if !isFromAll(mappings) && validateStruct(predecessorType) {
379+
return nil, fmt.Errorf("static check fail: successor input type should be struct or map, actual: %v", successorType)
380+
} else if !isFromAll(mappings) && !validateStructOrMap(predecessorType) {
315381
// TODO: should forbid?
316-
return nil, fmt.Errorf("static check fail: downstream output type should be struct, actual: %v", predecessorType)
382+
return nil, fmt.Errorf("static check fail: predecessor output type should be struct or map, actual: %v", predecessorType)
317383
}
318384

385+
var (
386+
predecessorFieldType, successorFieldType reflect.Type
387+
err error
388+
)
389+
319390
for _, mapping := range mappings {
320-
predecessorFieldType, err := checkAndExtractFieldType(mapping.from, predecessorType)
391+
predecessorFieldType, err = checkAndExtractFieldType(mapping.from, predecessorType)
321392
if err != nil {
322393
return nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err)
323394
}
324-
successorFieldType, err := checkAndExtractFieldType(mapping.to, successorType)
395+
successorFieldType, err = checkAndExtractFieldType(mapping.to, successorType)
325396
if err != nil {
326397
return nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err)
327398
}

compose/workflow_test.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,29 @@ func TestWorkflow(t *testing.T) {
241241
}
242242
}
243243

244+
func TestWorkflowWithMap(t *testing.T) {
245+
ctx := context.Background()
246+
247+
type structA struct {
248+
F1 any
249+
}
250+
251+
wf := NewWorkflow[map[string]any, map[string]any]()
252+
wf.AddLambdaNode("lambda1", InvokableLambda(func(ctx context.Context, in map[string]any) (map[string]any, error) {
253+
return in, nil
254+
})).AddInput(START, MapFields("map_key", "lambda1_key"))
255+
wf.AddLambdaNode("lambda2", InvokableLambda(func(ctx context.Context, in *structA) (*structA, error) {
256+
return in, nil
257+
})).AddInput(START, MapFields("map_key", "F1"))
258+
wf.AddEnd("lambda1", MapFields("lambda1_key", "end_lambda1"))
259+
wf.AddEnd("lambda2", MapFields("F1", "end_lambda2"))
260+
r, err := wf.Compile(ctx)
261+
assert.NoError(t, err)
262+
out, err := r.Invoke(ctx, map[string]any{"map_key": "value"})
263+
assert.NoError(t, err)
264+
assert.Equal(t, map[string]any{"end_lambda1": "value", "end_lambda2": "value"}, out)
265+
}
266+
244267
func TestWorkflowCompile(t *testing.T) {
245268
ctx := context.Background()
246269
ctrl := gomock.NewController(t)
@@ -260,32 +283,32 @@ func TestWorkflowCompile(t *testing.T) {
260283
assert.ErrorContains(t, err, "mismatch")
261284
})
262285

263-
t.Run("upstream not struct/struct ptr, mapping has FromField", func(t *testing.T) {
286+
t.Run("predecessor's output not struct/struct ptr/map, mapping has FromField", func(t *testing.T) {
264287
w := NewWorkflow[[]*schema.Document, []string]()
265288

266289
w.AddIndexerNode("indexer", indexer.NewMockIndexer(ctrl)).AddInput(START, FromField("F1"))
267290
w.AddEnd("indexer")
268291
_, err := w.Compile(ctx)
269-
assert.ErrorContains(t, err, "downstream output type should be struct")
292+
assert.ErrorContains(t, err, "predecessor output type should be struct")
270293
})
271294

272-
t.Run("downstream not struct/struct ptr, mapping has ToField", func(t *testing.T) {
295+
t.Run("successor's input not struct/struct ptr/map, mapping has ToField", func(t *testing.T) {
273296
w := NewWorkflow[[]string, [][]float64]()
274297
w.AddEmbeddingNode("embedder", embedding.NewMockEmbedder(ctrl)).AddInput(START, ToField("F1"))
275298
w.AddEnd("embedder")
276299
_, err := w.Compile(ctx)
277-
assert.ErrorContains(t, err, "upstream input type should be struct")
300+
assert.ErrorContains(t, err, "successor input type should be struct")
278301
})
279302

280-
t.Run("map to non existing field in upstream", func(t *testing.T) {
303+
t.Run("map to non existing field in predecessor", func(t *testing.T) {
281304
w := NewWorkflow[*schema.Message, []*schema.Message]()
282305
w.AddToolsNode("tools_node", &ToolsNode{}).AddInput(START, FromField("non_exist"))
283306
w.AddEnd("tools_node")
284307
_, err := w.Compile(ctx)
285308
assert.ErrorContains(t, err, "type[schema.Message] has no field[non_exist]")
286309
})
287310

288-
t.Run("map to not exported field in downstream", func(t *testing.T) {
311+
t.Run("map to not exported field in successor", func(t *testing.T) {
289312
w := NewWorkflow[string, *FieldMapping]()
290313
w.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
291314
return input, nil

0 commit comments

Comments
 (0)