Eino: Chain/Graph 编排介绍

本文所有代码样例都在:https://github.com/cloudwego/eino-examples/tree/main/compose

Graph 编排

Graph

package main

import (
    "context"
    "fmt"
    "io"

    "github.com/cloudwego/eino/components/model"
    "github.com/cloudwego/eino/components/prompt"
    "github.com/cloudwego/eino/compose"
    "github.com/cloudwego/eino/schema"
)

const (
    nodeOfModel  = "model"
    nodeOfPrompt = "prompt"
)

func main() {
    ctx := context.Background()
    g := compose.NewGraph[map[string]any, *schema.Message]()

    pt := prompt.FromMessages(
       schema.FString,
       schema.UserMessage("what's the weather in {location}?"),
    )

    _ = g.AddChatTemplateNode(nodeOfPrompt, pt)
    _ = g.AddChatModelNode(nodeOfModel, &mockChatModel{}, compose.WithNodeName("ChatModel"))
    _ = g.AddEdge(compose.START, nodeOfPrompt)
    _ = g.AddEdge(nodeOfPrompt, nodeOfModel)
    _ = g.AddEdge(nodeOfModel, compose.END)

    r, err := g.Compile(ctx, compose.WithMaxRunSteps(10))
    if err != nil {
       panic(err)
    }

    in := map[string]any{"location": "beijing"}
    ret, err := r.Invoke(ctx, in)
    fmt.Println("invoke result: ", ret)

    // stream
    s, err := r.Stream(ctx, in)
    if err != nil {
       panic(err)
    }

    defer s.Close()
    for {
       chunk, err := s.Recv()
       if err != nil {
          if err == io.EOF {
             break
          }
          panic(err)
       }

       fmt.Println("stream chunk: ", chunk)
    }
}

type mockChatModel struct{}

func (m *mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
    return schema.AssistantMessage("the weather is good", nil), nil
}

func (m *mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
    sr, sw := schema.Pipe[*schema.Message](0)
    go func() {
       defer sw.Close()
       sw.Send(schema.AssistantMessage("the weather is", nil), nil)
       sw.Send(schema.AssistantMessage("good", nil), nil)
    }()
    return sr, nil
}

func (m *mockChatModel) BindTools(tools []*schema.ToolInfo) error {
    panic("implement me")
}

ToolCallAgent

go get github.com/cloudwego/eino-ext/components/model/openai@latest
go get github.com/cloudwego/eino@latest
package main

import (
    "context"
    "os"

    "github.com/cloudwego/eino-ext/components/model/openai"
    "github.com/cloudwego/eino/callbacks"
    "github.com/cloudwego/eino/components/prompt"
    "github.com/cloudwego/eino/components/tool"
    "github.com/cloudwego/eino/components/tool/utils"
    "github.com/cloudwego/eino/compose"
    "github.com/cloudwego/eino/schema"

    "github.com/cloudwego/eino-examples/internal/gptr"
    "github.com/cloudwego/eino-examples/internal/logs"
)

func main() {

    openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
    openAIAPIKey := os.Getenv("OPENAI_API_KEY")
    modelName := os.Getenv("MODEL_NAME")

    ctx := context.Background()

    callbacks.InitCallbackHandlers([]callbacks.Handler{&loggerCallbacks{}})

    // 1. create an instance of ChatTemplate as 1st Graph Node
    systemTpl := `你是一名房产经纪人,结合用户的薪酬和工作,使用 user_info API,为其提供相关的房产信息。邮箱是必须的`
    chatTpl := prompt.FromMessages(schema.FString,
       schema.SystemMessage(systemTpl),
       schema.MessagesPlaceholder("message_histories", true),
       schema.UserMessage("{user_query}"),
    )

    modelConf := &openai.ChatModelConfig{
       BaseURL:     openAIBaseURL,
       APIKey:      openAIAPIKey,
       ByAzure:     true,
       Model:       modelName,
       Temperature: gptr.Of(float32(0.7)),
       APIVersion:  "2024-06-01",
    }

    // 2. create an instance of ChatModel as 2nd Graph Node
    chatModel, err := openai.NewChatModel(ctx, modelConf)
    if err != nil {
       logs.Errorf("NewChatModel failed, err=%v", err)
       return
    }

    // 3. create an instance of tool.InvokableTool for Intent recognition and execution
    userInfoTool := utils.NewTool(
       &schema.ToolInfo{
          Name: "user_info",
          Desc: "根据用户的姓名和邮箱,查询用户的公司、职位、薪酬信息",
          ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
             "name": {
                Type: "string",
                Desc: "用户的姓名",
             },
             "email": {
                Type: "string",
                Desc: "用户的邮箱",
             },
          }),
       },
       func(ctx context.Context, input *userInfoRequest) (output *userInfoResponse, err error) {
          return &userInfoResponse{
             Name:     input.Name,
             Email:    input.Email,
             Company:  "Bytedance",
             Position: "CEO",
             Salary:   "9999",
          }, nil
       })

    info, err := userInfoTool.Info(ctx)
    if err != nil {
       logs.Errorf("Get ToolInfo failed, err=%v", err)
       return
    }

    // 4. bind ToolInfo to ChatModel. ToolInfo will remain in effect until the next BindTools.
    err = chatModel.BindForcedTools([]*schema.ToolInfo{info})
    if err != nil {
       logs.Errorf("BindForcedTools failed, err=%v", err)
       return
    }

    // 5. create an instance of ToolsNode as 3rd Graph Node
    toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{
       Tools: []tool.BaseTool{userInfoTool},
    })
    if err != nil {
       logs.Errorf("NewToolNode failed, err=%v", err)
       return
    }

    const (
       nodeKeyOfTemplate  = "template"
       nodeKeyOfChatModel = "chat_model"
       nodeKeyOfTools     = "tools"
    )

    // 6. create an instance of Graph
    // input type is 1st Graph Node's input type, that is ChatTemplate's input type: map[string]any
    // output type is last Graph Node's output type, that is ToolsNode's output type: []*schema.Message
    g := compose.NewGraph[map[string]any, []*schema.Message]()

    // 7. add ChatTemplate into graph
    _ = g.AddChatTemplateNode(nodeKeyOfTemplate, chatTpl)

    // 8. add ChatModel into graph
    _ = g.AddChatModelNode(nodeKeyOfChatModel, chatModel)

    // 9. add ToolsNode into graph
    _ = g.AddToolsNode(nodeKeyOfTools, toolsNode)

    // 10. add connection between nodes
    _ = g.AddEdge(compose.START, nodeKeyOfTemplate)

    _ = g.AddEdge(nodeKeyOfTemplate, nodeKeyOfChatModel)

    _ = g.AddEdge(nodeKeyOfChatModel, nodeKeyOfTools)

    _ = g.AddEdge(nodeKeyOfTools, compose.END)

    // 9. compile Graph[I, O] to Runnable[I, O]
    r, err := g.Compile(ctx)
    if err != nil {
       logs.Errorf("Compile failed, err=%v", err)
       return
    }

    out, err := r.Invoke(ctx, map[string]any{
       "message_histories": []*schema.Message{},
       "user_query":        "我叫 zhangsan, 邮箱是 zhangsan@bytedance.com, 帮我推荐一处房产",
    })
    if err != nil {
       logs.Errorf("Invoke failed, err=%v", err)
       return
    }
    logs.Infof("Generation: %v Messages", len(out))
    for _, msg := range out {
       logs.Infof("    %v", msg)
    }
}

type userInfoRequest struct {
    Name  string `json:"name"`
    Email string `json:"email"`
}

type userInfoResponse struct {
    Name     string `json:"name"`
    Email    string `json:"email"`
    Company  string `json:"company"`
    Position string `json:"position"`
    Salary   string `json:"salary"`
}

type loggerCallbacks struct{}

func (l *loggerCallbacks) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
    logs.Infof("name: %v, type: %v, component: %v, input: %v", info.Name, info.Type, info.Component, input)
    return ctx
}

func (l *loggerCallbacks) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
    logs.Infof("name: %v, type: %v, component: %v, output: %v", info.Name, info.Type, info.Component, output)
    return ctx
}

func (l *loggerCallbacks) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
    logs.Infof("name: %v, type: %v, component: %v, error: %v", info.Name, info.Type, info.Component, err)
    return ctx
}

func (l *loggerCallbacks) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
    return ctx
}

func (l *loggerCallbacks) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
    return ctx
}

State Graph

package main

import (
    "context"
    "errors"
    "io"
    "runtime/debug"
    "strings"
    "unicode/utf8"

    "github.com/cloudwego/eino/compose"
    "github.com/cloudwego/eino/schema"
    "github.com/cloudwego/eino/utils/safe"

    "github.com/cloudwego/eino-examples/internal/logs"
)

func main() {
    ctx := context.Background()

    const (
       nodeOfL1 = "invokable"
       nodeOfL2 = "streamable"
       nodeOfL3 = "transformable"
    )

    type testState struct {
       ms []string
    }

    gen := func(ctx context.Context) *testState {
       return &testState{}
    }

    sg := compose.NewGraph[string, string](compose.WithGenLocalState(gen))

    l1 := compose.InvokableLambda(func(ctx context.Context, in string) (out string, err error) {
       return "InvokableLambda: " + in, nil
    })

    l1StateToInput := func(ctx context.Context, in string, state *testState) (string, error) {
       state.ms = append(state.ms, in)
       return in, nil
    }

    l1StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) {
       state.ms = append(state.ms, out)
       return out, nil
    }

    _ = sg.AddLambdaNode(nodeOfL1, l1,
       compose.WithStatePreHandler(l1StateToInput), compose.WithStatePostHandler(l1StateToOutput))

    l2 := compose.StreamableLambda(func(ctx context.Context, input string) (output *schema.StreamReader[string], err error) {
       outStr := "StreamableLambda: " + input

       sr, sw := schema.Pipe[string](utf8.RuneCountInString(outStr))

       // nolint: byted_goroutine_recover
       go func() {
          for _, field := range strings.Fields(outStr) {
             sw.Send(field+" ", nil)
          }
          sw.Close()
       }()

       return sr, nil
    })

    l2StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) {
       state.ms = append(state.ms, out)
       return out, nil
    }

    _ = sg.AddLambdaNode(nodeOfL2, l2, compose.WithStatePostHandler(l2StateToOutput))

    l3 := compose.TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (
       output *schema.StreamReader[string], err error) {

       prefix := "TransformableLambda: "
       sr, sw := schema.Pipe[string](20)

       go func() {

          defer func() {
             panicErr := recover()
             if panicErr != nil {
                err := safe.NewPanicErr(panicErr, debug.Stack())
                logs.Errorf("panic occurs: %v\n", err)
             }

          }()

          for _, field := range strings.Fields(prefix) {
             sw.Send(field+" ", nil)
          }

          for {
             chunk, err := input.Recv()
             if err != nil {
                if err == io.EOF {
                   break
                }
                // TODO: how to trace this kind of error in the goroutine of processing sw
                sw.Send(chunk, err)
                break
             }

             sw.Send(chunk, nil)

          }
          sw.Close()
       }()

       return sr, nil
    })

    l3StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) {
       state.ms = append(state.ms, out)
       logs.Infof("state result: ")
       for idx, m := range state.ms {
          logs.Infof("    %vth: %v", idx, m)
       }
       return out, nil
    }

    _ = sg.AddLambdaNode(nodeOfL3, l3, compose.WithStatePostHandler(l3StateToOutput))

    _ = sg.AddEdge(compose.START, nodeOfL1)

    _ = sg.AddEdge(nodeOfL1, nodeOfL2)

    _ = sg.AddEdge(nodeOfL2, nodeOfL3)

    _ = sg.AddEdge(nodeOfL3, compose.END)

    run, err := sg.Compile(ctx)
    if err != nil {
       logs.Errorf("sg.Compile failed, err=%v", err)
       return
    }

    out, err := run.Invoke(ctx, "how are you")
    if err != nil {
       logs.Errorf("run.Invoke failed, err=%v", err)
       return
    }
    logs.Infof("invoke result: %v", out)

    stream, err := run.Stream(ctx, "how are you")
    if err != nil {
       logs.Errorf("run.Stream failed, err=%v", err)
       return
    }

    for {

       chunk, err := stream.Recv()
       if err != nil {
          if errors.Is(err, io.EOF) {
             break
          }
          logs.Infof("stream.Recv() failed, err=%v", err)
          break
       }

       logs.Tokenf("%v", chunk)
    }
    stream.Close()

    sr, sw := schema.Pipe[string](1)
    sw.Send("how are you", nil)
    sw.Close()

    stream, err = run.Transform(ctx, sr)
    if err != nil {
       logs.Infof("run.Transform failed, err=%v", err)
       return
    }

    for {

       chunk, err := stream.Recv()
       if err != nil {
          if errors.Is(err, io.EOF) {
             break
          }
          logs.Infof("stream.Recv() failed, err=%v", err)
          break
       }

       logs.Infof("%v", chunk)
    }
    stream.Close()
}

Chain

Chain 可以视为是 Graph 的简化封装

package main

import (
    "context"
    "fmt"
    "log"
    "math/rand"
    "os"

    "github.com/cloudwego/eino-ext/components/model/openai"
    "github.com/cloudwego/eino/components/prompt"
    "github.com/cloudwego/eino/compose"
    "github.com/cloudwego/eino/schema"

    "github.com/cloudwego/eino-examples/internal/gptr"
    "github.com/cloudwego/eino-examples/internal/logs"
)

func main() {
    openAPIBaseURL := os.Getenv("OPENAI_BASE_URL")
    openAPIAK := os.Getenv("OPENAI_API_KEY")
    modelName := os.Getenv("MODEL_NAME")

    ctx := context.Background()
    // build branch func
    const randLimit = 2
    branchCond := func(ctx context.Context, input map[string]any) (string, error) { // nolint: byted_all_nil_return
       if rand.Intn(randLimit) == 1 {
          return "b1", nil
       }

       return "b2", nil
    }

    b1 := compose.InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) {
       logs.Infof("hello in branch lambda 01")
       if kvs == nil {
          return nil, fmt.Errorf("nil map")
       }

       kvs["role"] = "cat"
       return kvs, nil
    })

    b2 := compose.InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) {
       logs.Infof("hello in branch lambda 02")
       if kvs == nil {
          return nil, fmt.Errorf("nil map")
       }

       kvs["role"] = "dog"
       return kvs, nil
    })

    // build parallel node
    parallel := compose.NewParallel()
    parallel.
       AddLambda("role", compose.InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) {
          // may be change role to others by input kvs, for example (dentist/doctor...)
          role, ok := kvs["role"].(string)
          if !ok || role == "" {
             role = "bird"
          }

          return role, nil
       })).
       AddLambda("input", compose.InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) {
          return "你的叫声是怎样的?", nil
       }))

    modelConf := &openai.ChatModelConfig{
       BaseURL:     openAPIBaseURL,
       APIKey:      openAPIAK,
       ByAzure:     true,
       Model:       modelName,
       Temperature: gptr.Of(float32(0.7)),
       APIVersion:  "2024-06-01",
    }

    // create chat model node
    cm, err := openai.NewChatModel(context.Background(), modelConf)
    if err != nil {
       log.Panic(err)
       return
    }

    rolePlayerChain := compose.NewChain[map[string]any, *schema.Message]()
    rolePlayerChain.
       AppendChatTemplate(prompt.FromMessages(schema.FString, schema.SystemMessage(`You are a {role}.`), schema.UserMessage(`{input}`))).
       AppendChatModel(cm)

    // =========== build chain ===========
    chain := compose.NewChain[map[string]any, string]()
    chain.
       AppendLambda(compose.InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) {
          // do some logic to prepare kv as input val for next node
          // just pass through
          logs.Infof("in view lambda: %v", kvs)
          return kvs, nil
       })).
       AppendBranch(compose.NewChainBranch(branchCond).AddLambda("b1", b1).AddLambda("b2", b2)). // nolint: byted_use_receiver_without_nilcheck
       AppendPassthrough().
       AppendParallel(parallel).
       AppendGraph(rolePlayerChain).
       AppendLambda(compose.InvokableLambda(func(ctx context.Context, m *schema.Message) (string, error) {
          // do some logic to check the output or something
          logs.Infof("in view of messages: %v", m.Content)
          return m.Content, nil
       }))

    // compile
    r, err := chain.Compile(ctx)
    if err != nil {
       log.Panic(err)
       return
    }

    output, err := r.Invoke(context.Background(), map[string]any{})
    if err != nil {
       log.Panic(err)
       return
    }

    logs.Infof("output is : %v", output)
}

最后修改 January 20, 2025 : doc: eino doc update (#1214) (b324caa)