Skip to content

Commit

Permalink
Merge pull request #14 from budougumi0617/add-inspect
Browse files Browse the repository at this point in the history
add inspect
  • Loading branch information
budougumi0617 authored Feb 1, 2021
2 parents 94d1559 + 769f7e8 commit 68a1164
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 14 deletions.
51 changes: 51 additions & 0 deletions inspect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package nrseg

import (
"errors"
"go/ast"
"go/parser"
"go/token"
)

func (nrseg *nrseg) Inspect(filename string, src []byte) error {
if len(src) != 0 && c.Match(src) {
return nil
}
fs := token.NewFileSet()
f, err := parser.ParseFile(fs, filename, src, parser.ParseComments)
if err != nil {
return err
}
// import newrelic pkg
pkg := "newrelic"
name, err := findImport(fs, f) // importされたpkgの名前
if err != nil && !errors.Is(err, ErrNoImportNrPkg) {
return err
}
if len(name) != 0 {
// change name if named import.
pkg = name
}

ast.Inspect(f, func(n ast.Node) bool {
if fd, ok := n.(*ast.FuncDecl); ok {
if findIgnoreComment(fd.Doc) {
return false
}
if fd.Body != nil && len(fd.Body.List) > 0 {
if _, t := parseParams(fd.Type); !(t == TypeContext || t == TypeHttpRequest) {
return false
}

if !existFromContext(pkg, fd.Body.List[0]) {
nrseg.errFlag = true
nrseg.reportf(filename, fs, fd.Pos(), fd)
}
return false
}
}
return true
})

return nil
}
115 changes: 104 additions & 11 deletions nrseg.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"flag"
"fmt"
"go/ast"
"go/token"
"io"
"io/ioutil"
"os"
Expand All @@ -19,9 +21,11 @@ var (
)

type nrseg struct {
inspectMode bool
in, dist string
ignoreDirs []string
outStream, errStream io.Writer
errFlag bool
}

func fill(args []string, outStream, errStream io.Writer, version, revision string) (*nrseg, error) {
Expand Down Expand Up @@ -78,6 +82,61 @@ func fill(args []string, outStream, errStream io.Writer, version, revision strin
}, nil
}

func fill2(args []string, outStream, errStream io.Writer, version, revision string) (*nrseg, error) {
cn := args[0]
flags := flag.NewFlagSet(cn, flag.ContinueOnError)
flags.SetOutput(errStream)
flags.Usage = func() {
fmt.Fprintf(
flag.CommandLine.Output(),
"Insert function segments into any function/method for Newrelic APM.\n\nUsage of %s:\n",
os.Args[0],
)
flags.PrintDefaults()
}

var v bool
vdesc := "print version information and quit."
flags.BoolVar(&v, "version", false, vdesc)
flags.BoolVar(&v, "v", false, vdesc)

var ignoreDirs string
idesc := "ignore directory names. ex: foo,bar,baz\n(testdata directory is always ignored.)"
flags.StringVar(&ignoreDirs, "ignore", "", idesc)
flags.StringVar(&ignoreDirs, "i", "", idesc)

if err := flags.Parse(args[2:]); err != nil {
return nil, err
}
if v {
fmt.Fprintf(errStream, "%s version %q, revison %q\n", cn, version, revision)
return nil, ErrShowVersion
}

dirs := []string{"testdata"}
if len(ignoreDirs) != 0 {
dirs = append(dirs, strings.Split(ignoreDirs, ",")...)
}

dir := "./"
nargs := flags.Args()
if len(nargs) > 2 {
msg := "execution path must be only one or no-set(current directory)."
return nil, fmt.Errorf(msg)
}
if len(nargs) == 2 {
dir = nargs[1]
}

return &nrseg{
inspectMode: true,
in: dir,
ignoreDirs: dirs,
outStream: outStream,
errStream: errStream,
}, nil
}

var c = regexp.MustCompile("(?m)^// Code generated .* DO NOT EDIT\\.$")

func (n *nrseg) skipDir(p string) bool {
Expand Down Expand Up @@ -114,17 +173,22 @@ func (n *nrseg) run() error {
if err != nil {
return err
}
got, err := Process(path, org)
if err != nil {
return err
}
if !bytes.Equal(org, got) {
if len(n.dist) != 0 && n.in != n.dist {
return n.writeOtherPath(n.in, n.dist, path, got)
}
if _, err := f.WriteAt(got, 0); err != nil {

if n.inspectMode {
return n.Inspect(path, org)
} else {
got, err := Process(path, org)
if err != nil {
return err
}
if !bytes.Equal(org, got) {
if len(n.dist) != 0 && n.in != n.dist {
return n.writeOtherPath(n.in, n.dist, path, got)
}
if _, err := f.WriteAt(got, 0); err != nil {
return err
}
}
}
return nil
})
Expand Down Expand Up @@ -162,11 +226,40 @@ func (n *nrseg) writeOtherPath(in, dist, path string, got []byte) error {
return err
}

func (n *nrseg) reportf(filename string, fs *token.FileSet, pos token.Pos, fd *ast.FuncDecl) {
var rcv string
if fd.Recv != nil && len(fd.Recv.List) > 0 {
if rn, ok := fd.Recv.List[0].Type.(*ast.StarExpr); ok {
if idt, ok := rn.X.(*ast.Ident); ok {
rcv = idt.Name
}
} else if idt, ok := fd.Recv.List[0].Type.(*ast.Ident); ok {
rcv = idt.Name
}
}

if len(rcv) != 0 {
fmt.Fprintf(n.errStream, "%s:%d:1: %s.%s no insert segment\n", filename, fs.File(pos).Line(pos), rcv, fd.Name.Name)
return
}
fmt.Fprintf(n.errStream, "%s:%d:1: %s no insert segment\n", filename, fs.File(pos).Line(pos), fd.Name.Name)
}

// Run is entry point.
func Run(args []string, outStream, errStream io.Writer, version, revision string) error {
nrseg, err := fill(args, outStream, errStream, version, revision)
var nrseg *nrseg
var err error
if len(args) >= 2 && args[1] == "inspect" {
nrseg, err = fill2(args, outStream, errStream, version, revision)
} else {
nrseg, err = fill(args, outStream, errStream, version, revision)
}
if err != nil {
return err
}
return nrseg.run()
err = nrseg.run()
if nrseg.errFlag {
err = errors.New("find error")
}
return err
}
21 changes: 18 additions & 3 deletions process.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nrseg

import (
"bytes"
"errors"
"go/ast"
"go/format"
"go/parser"
Expand Down Expand Up @@ -39,7 +40,7 @@ func Process(filename string, src []byte) ([]byte, error) {
if findIgnoreComment(fd.Doc) {
return false
}
if fd.Body != nil {
if fd.Body != nil && len(fd.Body.List) > 0 {
sn := getSegName(fd)
vn, t := parseParams(fd.Type)
var ds ast.Stmt
Expand Down Expand Up @@ -79,6 +80,21 @@ func Process(filename string, src []byte) ([]byte, error) {
const NewRelicV3Pkg = "github.com/newrelic/go-agent/v3/newrelic"

func addImport(fs *token.FileSet, f *ast.File) (string, error) {
pkg, err := findImport(fs, f)
if err == nil {
return pkg, nil
}
if errors.Is(err, ErrNoImportNrPkg) {
astutil.AddImport(fs, f, NewRelicV3Pkg)
return "", nil
}

return "", err
}

var ErrNoImportNrPkg = errors.New("not import newrelic pkg")

func findImport(fs *token.FileSet, f *ast.File) (string, error) {
for _, spec := range f.Imports {
path, err := strconv.Unquote(spec.Path.Value)
if err != nil {
Expand All @@ -92,8 +108,7 @@ func addImport(fs *token.FileSet, f *ast.File) (string, error) {
return "", nil
}
}
astutil.AddImport(fs, f, NewRelicV3Pkg)
return "", nil
return "", ErrNoImportNrPkg
}

var nrignoreReg = regexp.MustCompile("(?m)^// nrseg:ignore .*$")
Expand Down

0 comments on commit 68a1164

Please sign in to comment.