Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid overwriting go mod/sum files on module generation #7194

Merged
merged 5 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/codegen/codegen.go
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"

"dagger.io/dagger"
Expand Down Expand Up @@ -49,6 +50,9 @@ func Generate(ctx context.Context, cfg generator.Config, dag *dagger.Client) (er

for _, cmd := range generated.PostCommands {
cmd.Dir = cfg.OutputDir
if cfg.ModuleName != "" {
cmd.Dir = filepath.Join(cfg.OutputDir, cfg.ModuleContextPath)
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
fmt.Fprintln(logsW, "running post-command:", strings.Join(cmd.Args, " "))
Expand Down
19 changes: 10 additions & 9 deletions cmd/codegen/generator/generator.go
Expand Up @@ -24,17 +24,18 @@ const (
)

type Config struct {
// Language supported by this codegen infra.
// Lang is the language supported by this codegen infra.
Lang SDKLang

// Destination directory for generated code.
// OutputDir is the path to place generated code.
OutputDir string

// Name of the module to generate code for
ModuleName string
// ModuleName is the module name to generate code for.
ModuleName string
// ModuleContextPath is the subpath where a module can be found.
ModuleContextPath string

// Optional pre-computed introspection json string
// IntrospectionJSON is an optional pre-computed introspection json string
IntrospectionJSON string
}

Expand All @@ -53,10 +54,10 @@ type GeneratedState struct {
// Go code.
PostCommands []*exec.Cmd

// NeedSync indicates that the code needs to be generated again. This can
// happen if the codegen spat out templates that depend on generated types.
// In that case the codegen needs to be run again with both the templates and
// the initially generated types available.
// NeedRegenerate indicates that the code needs to be generated again. This
// can happen if the codegen spat out templates that depend on generated
// types. In that case the codegen needs to be run again with both the
// templates and the initially generated types available.
NeedRegenerate bool
}

Expand Down
194 changes: 110 additions & 84 deletions cmd/codegen/generator/go/generator.go
Expand Up @@ -3,6 +3,7 @@ package gogenerator
import (
"bytes"
"context"
"errors"
"fmt"
"go/format"
"go/token"
Expand Down Expand Up @@ -52,14 +53,19 @@ func (g *GoGenerator) Generate(ctx context.Context, schema *introspection.Schema
// 2b. add stub main.go
// 3. load package, generate dagger.gen.go (possibly again)

outDir := "."
if g.Config.ModuleName != "" {
outDir = g.Config.ModuleContextPath
}

mfs := memfs.New()

var overlay fs.FS = mfs
if g.Config.ModuleName != "" {
overlay = layerfs.New(
mfs,
&MountedFS{FS: dagger.QueryBuilder, Name: "internal"},
&MountedFS{FS: dagger.Telemetry, Name: "internal"},
&MountedFS{FS: dagger.QueryBuilder, Name: filepath.Join(outDir, "internal")},
&MountedFS{FS: dagger.Telemetry, Name: filepath.Join(outDir, "internal")},
)
}

Expand All @@ -70,7 +76,7 @@ func (g *GoGenerator) Generate(ctx context.Context, schema *introspection.Schema
exec.Command("go", "mod", "tidy"),
},
}
if _, err := os.Stat(filepath.Join(g.Config.ModuleContextPath, "go.work")); err == nil {
if _, err := os.Stat(filepath.Join(g.Config.OutputDir, "go.work")); err == nil {
// run "go work use ." after generating if we had a go.work at the root
genSt.PostCommands = append(genSt.PostCommands, exec.Command("go", "work", "use", "."))
}
Expand All @@ -79,15 +85,21 @@ func (g *GoGenerator) Generate(ctx context.Context, schema *introspection.Schema
if err != nil {
return nil, fmt.Errorf("bootstrap package: %w", err)
}
if outDir != "." {
mfs.MkdirAll(outDir, 0700)
fs, err := mfs.Sub(outDir)
if err != nil {
return nil, err
}
mfs = fs.(*memfs.FS)
}

outDir := g.Config.OutputDir

initialGoFiles, err := filepath.Glob(filepath.Join(outDir, "*.go"))
initialGoFiles, err := filepath.Glob(filepath.Join(g.Config.OutputDir, outDir, "*.go"))
if err != nil {
return nil, fmt.Errorf("glob go files: %w", err)
}

genFile := filepath.Join(outDir, ClientGenFile)
genFile := filepath.Join(g.Config.OutputDir, outDir, ClientGenFile)
if _, err := os.Stat(genFile); err != nil {
// assume package main, default for modules
pkgInfo.PackageName = "main"
Expand Down Expand Up @@ -115,7 +127,7 @@ func (g *GoGenerator) Generate(ctx context.Context, schema *introspection.Schema
return genSt, nil
}

pkg, fset, err := loadPackage(ctx, outDir)
pkg, fset, err := loadPackage(ctx, filepath.Join(g.Config.OutputDir, outDir))
if err != nil {
return nil, fmt.Errorf("load package %q: %w", outDir, err)
}
Expand All @@ -136,107 +148,121 @@ type PackageInfo struct {
}

func (g *GoGenerator) bootstrapMod(ctx context.Context, mfs *memfs.FS) (*PackageInfo, bool, error) {
var needsRegen bool

outDir := g.Config.OutputDir

info := &PackageInfo{}

// use embedded go.mod as basis for pinning versions
sdkMod, err := modfile.Parse("go.mod", dagger.GoMod, nil)
if err != nil {
return nil, false, fmt.Errorf("parse embedded go.mod: %w", err)
// don't mess around go.mod if we're not building modules
if g.Config.ModuleName == "" {
if pkg, _, err := loadPackage(ctx, g.Config.OutputDir); err == nil {
return &PackageInfo{
PackageName: pkg.Name,
PackageImport: pkg.Module.Path,
}, false, nil
}
return nil, false, fmt.Errorf("no module name configured and no existing package found")
}

newMod := new(modfile.File)
var needsRegen bool

var modPath string
var mod *modfile.File

if content, err := os.ReadFile(filepath.Join(outDir, "go.mod")); err == nil {
// respect existing go.mod
// check for a go.mod already for the dagger module
if content, err := os.ReadFile(filepath.Join(g.Config.OutputDir, g.Config.ModuleContextPath, "go.mod")); err == nil {
modPath = g.Config.ModuleContextPath

currentMod, err := modfile.Parse("go.mod", content, nil)
mod, err = modfile.ParseLax("go.mod", content, nil)
if err != nil {
return nil, false, fmt.Errorf("parse go.mod: %w", err)
}
currentModGoVersion, err := semver.Parse(currentMod.Go.Version)
if err != nil {
var err2 error
currentModGoVersion, err2 = semver.Parse(currentMod.Go.Version + ".0")
if err2 != nil {
return nil, false, fmt.Errorf("parse go.mod version %q: %w", currentMod.Go.Version, err)
}
}
if currentModGoVersion.GT(goVersion) {
return nil, false, fmt.Errorf("existing go.mod has unsupported version %v (highest supported version is %v)", currentMod.Go.Version, goVersion)
}
newMod = currentMod

for _, req := range sdkMod.Require {
newMod.AddRequire(req.Mod.Path, req.Mod.Version)
}

info.PackageImport = currentMod.Module.Mod.Path
} else {
if g.Config.ModuleName != "" {
outDir, err := filepath.Abs(outDir)
if err != nil {
return nil, false, fmt.Errorf("get absolute path: %w", err)
}
rootDir := g.Config.ModuleContextPath
subdirRelPath, err := filepath.Rel(rootDir, outDir)
}
// if no go.mod is available, check the root output directory instead
//
// this is a necessary part of bootstrapping: SDKs such as the Go SDK
// will want to have a runtime module that lives in the same Go module as
// the generated client, which typically lives in the parent directory.
if mod == nil {
if content, err := os.ReadFile(filepath.Join(g.Config.OutputDir, "go.mod")); err == nil {
modPath = "."
mod, err = modfile.ParseLax("go.mod", content, nil)
if err != nil {
return nil, false, fmt.Errorf("failed to get output dir rel path: %w", err)
}

// when a module is configured, look for a go.mod in its root dir instead
//
// this is a necessary part of bootstrapping: SDKs such as the Go SDK
// will want to have a runtime module that lives in the same Go module as
// the generated client, which typically lives in the parent directory.
if pkg, _, err := loadPackage(ctx, rootDir); err == nil {
return &PackageInfo{
// leave package name blank
// TODO: maybe we don't even need to return it?
PackageImport: path.Join(pkg.Module.Path, subdirRelPath),
}, false, nil
return nil, false, fmt.Errorf("parse go.mod: %w", err)
}
}
}
// could not find a go.mod, so we can init a basic one
if mod == nil {
modPath = g.Config.ModuleContextPath
mod = new(modfile.File)

// bootstrap go.mod using dependencies from the embedded Go SDK

newModName := fmt.Sprintf("dagger/%s", strcase.ToKebab(g.Config.ModuleName))

newMod.AddModuleStmt(newModName)
newMod.AddGoStmt(goVersion.String())
newMod.SetRequire(sdkMod.Require)
modname := fmt.Sprintf("dagger/%s", strcase.ToKebab(g.Config.ModuleName))

info.PackageImport = newModName
mod.AddModuleStmt(modname)
mod.AddGoStmt(goVersion.String())

needsRegen = true
} else {
// no module; assume client-only codegen
needsRegen = true
}

if pkg, _, err := loadPackage(ctx, outDir); err == nil {
return &PackageInfo{
PackageName: pkg.Name,
PackageImport: pkg.Module.Path,
}, false, nil
}
// sanity check the parsed go version
// if this fails, then the go.mod version is too high! and in that case, we
// won't be able to load the resulting package
modGoVersion, err := semver.Parse(mod.Go.Version)
if err != nil {
var err2 error
modGoVersion, err2 = semver.Parse(mod.Go.Version + ".0")
if err2 != nil {
return nil, false, fmt.Errorf("parse go.mod version %q: %w", mod.Go.Version, err)
}
}
if modGoVersion.GT(goVersion) {
return nil, false, fmt.Errorf("existing go.mod has unsupported version %v (highest supported version is %v)", mod.Go.Version, goVersion)
}

return nil, false, fmt.Errorf("no module name configured and no existing package found")
// use dagger's embedded go.mod as basis for pinning versions
daggerMod, err := modfile.Parse("go.mod", dagger.GoMod, nil)
if err != nil {
return nil, false, fmt.Errorf("parse embedded go.mod: %w", err)
}
modRequires := make(map[string]*modfile.Require)
for _, req := range mod.Require {
modRequires[req.Mod.Path] = req
}
for _, req := range daggerMod.Require {
if _, ok := modRequires[req.Mod.Path]; ok {
// check if mod already includes this
continue
}
mod.AddNewRequire(req.Mod.Path, req.Mod.Version, req.Indirect)
}

// try and find a go.sum next to the go.mod, and use that to pin
sum, err := os.ReadFile(filepath.Join(g.Config.OutputDir, modPath, "go.sum"))
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, false, fmt.Errorf("could not read go.sum: %w", err)
}
sum = append(sum, '\n')
sum = append(sum, dagger.GoSum...)

modBody, err := newMod.Format()
modBody, err := mod.Format()
if err != nil {
return nil, false, fmt.Errorf("format go.mod: %w", err)
}
if err := mfs.WriteFile("go.mod", modBody, 0600); err != nil {

if err := mfs.MkdirAll(modPath, 0700); err != nil {
return nil, false, err
}
if err := mfs.WriteFile(filepath.Join(modPath, "go.mod"), modBody, 0600); err != nil {
return nil, false, err
}
if err := mfs.WriteFile("go.sum", dagger.GoSum, 0600); err != nil {
if err := mfs.WriteFile(filepath.Join(modPath, "go.sum"), sum, 0600); err != nil {
return nil, false, err
}

return info, needsRegen, nil
packageImport, err := filepath.Rel(modPath, g.Config.ModuleContextPath)
if err != nil {
return nil, false, err
}
return &PackageInfo{
// PackageName is unknown until we load the package
PackageImport: path.Join(mod.Module.Mod.Path, packageImport),
}, needsRegen, nil
}

func generateCode(
Expand Down
@@ -1,8 +1,10 @@
import (
"context"
"context"
"log/slog"

"{{.PackageImport}}/internal/dagger"
"{{.PackageImport}}/internal/telemetry"
"{{.PackageImport}}/internal/querybuilder"

"go.opentelemetry.io/otel/trace"
)
Expand Down
10 changes: 9 additions & 1 deletion cmd/codegen/generator/typescript/generator.go
Expand Up @@ -3,6 +3,7 @@ package typescriptgenerator
import (
"bytes"
"context"
"path/filepath"
"sort"

"github.com/psanford/memfs"
Expand Down Expand Up @@ -51,7 +52,14 @@ func (g *TypeScriptGenerator) Generate(_ context.Context, schema *introspection.

mfs := memfs.New()

if err := mfs.WriteFile(ClientGenFile, b.Bytes(), 0600); err != nil {
target := ClientGenFile
if g.Config.ModuleName != "" {
target = filepath.Join(g.Config.ModuleContextPath, "sdk/api", ClientGenFile)
}
if err := mfs.MkdirAll(filepath.Dir(target), 0700); err != nil {
return nil, err
}
if err := mfs.WriteFile(target, b.Bytes(), 0600); err != nil {
return nil, err
}

Expand Down