diff --git a/syntax/printer.go b/syntax/printer.go index 4419e40fd..ea7cb2deb 100644 --- a/syntax/printer.go +++ b/syntax/printer.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "strings" + "text/tabwriter" "unicode" ) @@ -53,7 +54,6 @@ func Minify(p *Printer) { p.minify = true } func NewPrinter(options ...func(*Printer)) *Printer { p := &Printer{ bufWriter: bufio.NewWriter(nil), - lenPrinter: new(Printer), tabsPrinter: new(Printer), } for _, opt := range options { @@ -70,6 +70,20 @@ func NewPrinter(options ...func(*Printer)) *Printer { // *File is used. func (p *Printer) Print(w io.Writer, node Node) error { p.reset() + + // TODO: consider adding a raw mode to skip the tab writer, much like in + // go/printer. + twmode := tabwriter.DiscardEmptyColumns | tabwriter.StripEscape + tabwidth := 8 + if p.indentSpaces == 0 { + // indenting with tabs + twmode |= tabwriter.TabIndent + } else { + // indenting with spaces + tabwidth = int(p.indentSpaces) + } + w = tabwriter.NewWriter(w, 0, tabwidth, 1, ' ', twmode) + p.bufWriter.Reset(w) switch x := node.(type) { case *File: @@ -89,7 +103,17 @@ func (p *Printer) Print(w io.Writer, node Node) error { } p.flushHeredocs() p.flushComments() - return p.bufWriter.Flush() + + // flush the writers + if err := p.bufWriter.Flush(); err != nil { + return err + } + if tw, _ := w.(*tabwriter.Writer); tw != nil { + if err := tw.Flush(); err != nil { + return err + } + } + return nil } type bufWriter interface { @@ -153,8 +177,6 @@ type Printer struct { wantNewline bool wroteSemi bool - commentPadding uint - // pendingComments are any comments in the current line or statement // that we have yet to print. This is useful because that way, we can // ensure that all comments are written immediately before a newline. @@ -180,17 +202,12 @@ type Printer struct { // pendingHdocs is the list of pending heredocs to write. pendingHdocs []*Redirect - // used in stmtCols to align comments - lenPrinter *Printer - lenCounter byteCounter - // used when printing <<- heredocs with tab indentation tabsPrinter *Printer } func (p *Printer) reset() { p.wantSpace, p.wantNewline = false, false - p.commentPadding = 0 p.pendingComments = p.pendingComments[:0] // minification uses its own newline logic @@ -272,6 +289,23 @@ func (p *Printer) semiOrNewl(s string, pos Pos) { p.wantSpace = true } +func (p *Printer) writeLit(s string) { + if !strings.Contains(s, "\t") { + p.WriteString(s) + return + } + for i := 0; i < len(s); i++ { + b := s[i] + if b != '\t' { + p.WriteByte(b) + continue + } + p.WriteByte('\xff') + p.WriteByte(b) + p.WriteByte('\xff') + } +} + func (p *Printer) incLevel() { inc := false if p.level <= p.lastLevel || len(p.levelIncs) == 0 { @@ -299,9 +333,11 @@ func (p *Printer) indent() { switch { case p.level == 0: case p.indentSpaces == 0: + p.WriteByte('\xff') for i := uint(0); i < p.level; i++ { p.WriteByte('\t') } + p.WriteByte('\xff') default: p.spaces(p.indentSpaces * p.level) } @@ -345,8 +381,7 @@ func (p *Printer) flushHeredocs() { p.line++ p.WriteByte('\n') p.wantNewline, p.wantSpace = false, false - if r.Op == DashHdoc && p.indentSpaces == 0 && - !p.minify && p.tabsPrinter != nil { + if r.Op == DashHdoc && p.indentSpaces == 0 && !p.minify { if r.Hdoc != nil { extra := extraIndenter{ bufWriter: p.bufWriter, @@ -433,7 +468,7 @@ func (p *Printer) flushComments() { if p.keepPadding { p.spacePad(c.Pos()) } else { - p.spaces(p.commentPadding + 1) + p.WriteByte('\t') } } // don't go back one line, which may happen in some edge cases @@ -441,7 +476,7 @@ func (p *Printer) flushComments() { p.line = cline } p.WriteByte('#') - p.WriteString(strings.TrimRightFunc(c.Text, unicode.IsSpace)) + p.writeLit(strings.TrimRightFunc(c.Text, unicode.IsSpace)) p.wantNewline = true } p.pendingComments = nil @@ -467,13 +502,13 @@ func (p *Printer) wordParts(wps []WordPart) { func (p *Printer) wordPart(wp, next WordPart) { switch x := wp.(type) { case *Lit: - p.WriteString(x.Value) + p.writeLit(x.Value) case *SglQuoted: if x.Dollar { p.WriteByte('$') } p.WriteByte('\'') - p.WriteString(x.Value) + p.writeLit(x.Value) p.WriteByte('\'') p.line = x.End().Line() case *DblQuoted: @@ -527,7 +562,7 @@ func (p *Printer) wordPart(wp, next WordPart) { p.WriteString("))") case *ExtGlob: p.WriteString(x.Op.String()) - p.WriteString(x.Pattern.Value) + p.writeLit(x.Pattern.Value) p.WriteByte(')') case *ProcSubst: // avoid conflict with << and others @@ -564,13 +599,13 @@ func (p *Printer) wroteIndex(index ArithmExpr) bool { func (p *Printer) paramExp(pe *ParamExp) { if pe.nakedIndex() { // arr[x] - p.WriteString(pe.Param.Value) + p.writeLit(pe.Param.Value) p.wroteIndex(pe.Index) return } if pe.Short { // $var p.WriteByte('$') - p.WriteString(pe.Param.Value) + p.writeLit(pe.Param.Value) return } // ${var...} @@ -583,7 +618,7 @@ func (p *Printer) paramExp(pe *ParamExp) { case pe.Excl: p.WriteByte('!') } - p.WriteString(pe.Param.Value) + p.writeLit(pe.Param.Value) p.wroteIndex(pe.Index) switch { case pe.Slice != nil: @@ -606,7 +641,7 @@ func (p *Printer) paramExp(pe *ParamExp) { p.word(pe.Repl.With) } case pe.Names != 0: - p.WriteString(pe.Names.String()) + p.writeLit(pe.Names.String()) case pe.Exp != nil: p.WriteString(pe.Exp.Op.String()) if pe.Exp.Word != nil { @@ -619,7 +654,7 @@ func (p *Printer) paramExp(pe *ParamExp) { func (p *Printer) loop(loop Loop) { switch x := loop.(type) { case *WordIter: - p.WriteString(x.Name.Value) + p.writeLit(x.Name.Value) if x.InPos.IsValid() { p.spacedString(" in", Pos{}) p.wordJoin(x.Items) @@ -710,7 +745,7 @@ func (p *Printer) unquotedWord(w *Word) { for _, wp := range w.Parts { switch x := wp.(type) { case *SglQuoted: - p.WriteString(x.Value) + p.writeLit(x.Value) case *DblQuoted: p.wordParts(x.Parts) case *Lit: @@ -818,7 +853,7 @@ func (p *Printer) stmt(s *Stmt) { p.spacePad(r.Pos()) } if r.N != nil { - p.WriteString(r.N.Value) + p.writeLit(r.N.Value) } p.WriteString(r.Op.String()) if p.spaceRedirects && (r.Op != DplIn && r.Op != DplOut) { @@ -868,7 +903,7 @@ func (p *Printer) command(cmd Command, redirs []*Redirect) (startRedirs int) { p.spacePad(r.Pos()) } if r.N != nil { - p.WriteString(r.N.Value) + p.writeLit(r.N.Value) } p.WriteString(r.Op.String()) if p.spaceRedirects && (r.Op != DplIn && r.Op != DplOut) { @@ -959,7 +994,7 @@ func (p *Printer) command(cmd Command, redirs []*Redirect) (startRedirs int) { if x.RsrvWord { p.WriteString("function ") } - p.WriteString(x.Name.Value) + p.writeLit(x.Name.Value) p.WriteString("()") if !p.minify { p.space() @@ -1040,7 +1075,7 @@ func (p *Printer) command(cmd Command, redirs []*Redirect) (startRedirs int) { p.spacedString("coproc", x.Pos()) if x.Name != nil { p.space() - p.WriteString(x.Name.Value) + p.writeLit(x.Name.Value) } p.space() p.stmt(x.Stmt) @@ -1101,21 +1136,10 @@ func startsWithLparen(s *Stmt) bool { return false } -func (p *Printer) hasInline(s *Stmt) bool { - for _, c := range s.Comments { - if c.Pos().Line() == s.End().Line() { - return true - } - } - return false -} - func (p *Printer) stmtList(sl StmtList) { sep := p.wantNewline || (len(sl.Stmts) > 0 && sl.Stmts[0].Pos().Line() > p.line) - inlineIndent := 0 - lastIndentedLine := uint(0) - for i, s := range sl.Stmts { + for _, s := range sl.Stmts { pos := s.Pos() var midComs, endComs []Comment for _, c := range s.Comments { @@ -1123,7 +1147,7 @@ func (p *Printer) stmtList(sl StmtList) { endComs = append(endComs, c) break } - if c.Pos().After(s.Pos()) { + if c.Pos().After(pos) { midComs = append(midComs, c) continue } @@ -1133,35 +1157,8 @@ func (p *Printer) stmtList(sl StmtList) { p.newlines(pos) } p.line = pos.Line() - if !p.hasInline(s) { - inlineIndent = 0 - p.commentPadding = 0 - p.comments(midComs...) - p.stmt(s) - p.wantNewline = true - continue - } p.comments(midComs...) p.stmt(s) - if s.Pos().Line() > lastIndentedLine+1 { - inlineIndent = 0 - } - if inlineIndent == 0 { - for _, s2 := range sl.Stmts[i:] { - if !p.hasInline(s2) { - break - } - if l := p.stmtCols(s2); l > inlineIndent { - inlineIndent = l - } - } - } - if inlineIndent > 0 { - if l := p.stmtCols(s); l > 0 { - p.commentPadding = uint(inlineIndent - l) - } - lastIndentedLine = p.line - } p.comments(endComs...) p.wantNewline = true } @@ -1171,34 +1168,6 @@ func (p *Printer) stmtList(sl StmtList) { p.comments(sl.Last...) } -type byteCounter int - -func (c *byteCounter) WriteByte(b byte) error { - switch { - case *c < 0: - case b == '\n': - *c = -1 - default: - *c++ - } - return nil -} -func (c *byteCounter) Write(p []byte) (int, error) { - return c.WriteString(string(p)) -} -func (c *byteCounter) WriteString(s string) (int, error) { - switch { - case *c < 0: - case strings.Contains(s, "\n"): - *c = -1 - default: - *c += byteCounter(len(s)) - } - return 0, nil -} -func (c *byteCounter) Reset(io.Writer) { *c = 0 } -func (c *byteCounter) Flush() error { return nil } - // extraIndenter ensures that all lines in a '<<-' heredoc body have at least // baseIndent leading tabs. Those that had more tab indentation than the first // heredoc line will keep that relative indentation. @@ -1216,8 +1185,9 @@ func (e *extraIndenter) WriteByte(b byte) error { if b != '\n' { return nil } - trimmed := bytes.TrimLeft(e.curLine, "\t") - lineIndent := len(e.curLine) - len(trimmed) + trimmed := bytes.TrimLeft(e.curLine, "\xff\t") + // divided by 3, as each tab is escaped via "\xff\t\xff" + lineIndent := (len(e.curLine) - len(trimmed)) / 3 if e.firstIndent < 0 { e.firstIndent = lineIndent e.firstChange = e.baseIndent - lineIndent @@ -1244,21 +1214,6 @@ func (e *extraIndenter) WriteString(s string) (int, error) { return len(s), nil } -// stmtCols reports the length that s will take when formatted in a -// single line. If it will span multiple lines, stmtCols will return -1. -func (p *Printer) stmtCols(s *Stmt) int { - if p.lenPrinter == nil { - return -1 // stmtCols call within stmtCols, bail - } - *p.lenPrinter = Printer{ - bufWriter: &p.lenCounter, - line: s.Pos().Line(), - } - p.lenPrinter.bufWriter.Reset(nil) - p.lenPrinter.stmt(s) - return int(p.lenCounter) -} - func (p *Printer) nestedStmts(sl StmtList, closing Pos) { p.incLevel() switch { @@ -1294,7 +1249,7 @@ func (p *Printer) assigns(assigns []*Assign) { p.spacePad(a.Pos()) } if a.Name != nil { - p.WriteString(a.Name.Value) + p.writeLit(a.Name.Value) p.wroteIndex(a.Index) if a.Append { p.WriteByte('+') diff --git a/syntax/printer_test.go b/syntax/printer_test.go index ca7a82746..c90aee4f4 100644 --- a/syntax/printer_test.go +++ b/syntax/printer_test.go @@ -190,7 +190,7 @@ var printTests = []printCase{ }, { "{ a; } #x\nbbb #y\n{ #z\n}", - "{ a; } #x\nbbb #y\n{ #z\n}", + "{ a; } #x\nbbb #y\n{ #z\n}", }, { "foo; foooo # 1", @@ -204,7 +204,7 @@ var printTests = []printCase{ "a #1\nbbb; c #2\nd #3", "a #1\nbbb\nc #2\nd #3", }, - samePrint("aa #c1\n{ #c2\n\tb\n}"), + samePrint("aa #c1\n{ #c2\n\tb\n}"), { "aa #c1\n{ b; c; } #c2", "aa #c1\n{\n\tb\n\tc\n} #c2", @@ -502,6 +502,10 @@ var printTests = []printCase{ "case i in\nx)\n\ta\n\t;;\n\t#a\n#b\n\t#c\ny) ;;\nesac", "case i in\nx)\n\ta\n\t;;\n\t#a\n\t#b\n\t#c\ny) ;;\nesac", }, + samePrint("'foo\tbar'\n'foooo\tbar'"), + samePrint("\"foo\tbar\"\n\"foooo\tbar\""), + samePrint("foo\\\tbar\nfoooo\\\tbar"), + samePrint("#foo\tbar\n#foooo\tbar"), } func TestPrintWeirdFormat(t *testing.T) { @@ -559,6 +563,7 @@ func TestPrintMultiline(t *testing.T) { } func BenchmarkPrint(b *testing.B) { + b.ReportAllocs() prog := parsePath(b, canonicalPath) printer := NewPrinter() for i := 0; i < b.N; i++ { @@ -589,6 +594,11 @@ func TestPrintSpaces(t *testing.T) { "{\nfoo \\\nbar\n}", "{\n foo \\\n bar\n}", }, + { + 2, + "if foo; then # inline1\nbar # inline2\n# withfi\nfi", + "if foo; then # inline1\n bar # inline2\n# withfi\nfi", + }, } parser := NewParser(KeepComments) @@ -608,7 +618,6 @@ func (b badWriter) Write(p []byte) (int, error) { return 0, errBadWriter } func TestWriteErr(t *testing.T) { t.Parallel() - _ = (*byteCounter)(nil).Flush() f := &File{StmtList: StmtList{Stmts: []*Stmt{ { Redirs: []*Redirect{{ @@ -907,7 +916,7 @@ func printTest(t *testing.T, parser *Parser, printer *Printer, in, want string) t.Fatal(err) } if got != wantNewl { - t.Fatalf("Print mismatch:\nin:\n%s\nwant:\n%sgot:\n%s", + t.Fatalf("Print mismatch:\nin:\n%q\nwant:\n%q\ngot:\n%q", in, wantNewl, got) } _, err = parser.Parse(strings.NewReader(want), "")