package main
import (
"fmt"
"go/ast"
)
// A Visitor's Visit method is invoked for each node encountered by walkToReplace.
// If the result visitor w is not nil, walkToReplace visits each of the children
// of node with the visitor w, followed by a call of w.Visit(nil).
type Visitor interface {
Visit(node ast.Node, replace func(ast.Node)) (w Visitor)
}
// Helper functions for common node lists. They may be empty.
func walkIdentList(v Visitor, list []*ast.Ident) {
for i, x := range list {
walkToReplace(v, x, func(r ast.Node) {
list[i] = r.(*ast.Ident)
})
}
}
func walkExprList(v Visitor, list []ast.Expr) {
for i, x := range list {
walkToReplace(v, x, func(r ast.Node) {
list[i] = r.(ast.Expr)
})
}
}
func walkStmtList(v Visitor, list []ast.Stmt) {
for i, x := range list {
walkToReplace(v, x, func(r ast.Node) {
list[i] = r.(ast.Stmt)
})
}
}
func walkDeclList(v Visitor, list []ast.Decl) {
for i, x := range list {
walkToReplace(v, x, func(r ast.Node) {
list[i] = r.(ast.Decl)
})
}
}
// WalkToReplace traverses an AST in depth-first order: It starts by calling
// v.Visit(node); node must not be nil. If the visitor w returned by
// v.Visit(node) is not nil, walkToReplace is invoked recursively with visitor
// w for each of the non-nil children of node, followed by a call of
// w.Visit(nil).
func WalkReplace(v Visitor, node ast.Node) (replacement ast.Node) {
walkToReplace(v, node, func(r ast.Node) {
replacement = r
})
return
}
func walkToReplace(v Visitor, node ast.Node, replace func(ast.Node)) {
if v == nil {
return
}
var replacement ast.Node
repl := func(r ast.Node) {
replacement = r
replace(r)
}
v = v.Visit(node, repl)
if replacement != nil {
return
}
// walk children
// (the order of the cases matches the order
// of the corresponding node types in ast.go)
switch n := node.(type) {
// These are all leaves, so there's no sub-walk to do.
// We just need to replace them on their parent with a copy.
case *ast.Comment:
cpy := *n
replace(&cpy)
case *ast.BadExpr:
cpy := *n
replace(&cpy)
case *ast.Ident:
cpy := *n
replace(&cpy)
case *ast.BasicLit:
cpy := *n
replace(&cpy)
case *ast.BadDecl:
cpy := *n
replace(&cpy)
case *ast.EmptyStmt:
cpy := *n
replace(&cpy)
case *ast.BadStmt:
cpy := *n
replace(&cpy)
case *ast.CommentGroup:
cpy := *n
if n.List != nil {
cpy.List = make([]*ast.Comment, len(n.List))
copy(cpy.List, n.List)
}
for i, c := range cpy.List {
walkToReplace(v, c, func(r ast.Node) {
cpy.List[i] = r.(*ast.Comment)
})
}
replace(&cpy)
case *ast.Field:
cpy := *n
if n.Names != nil {
cpy.Names = make([]*ast.Ident, len(n.Names))
copy(cpy.Names, n.Names)
}
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
walkIdentList(v, cpy.Names)
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(ast.Expr)
})
if cpy.Tag != nil {
walkToReplace(v, cpy.Tag, func(r ast.Node) {
cpy.Tag = r.(*ast.BasicLit)
})
}
if cpy.Comment != nil {
walkToReplace(v, cpy.Comment, func(r ast.Node) {
cpy.Comment = r.(*ast.CommentGroup)
})
}
replace(&cpy)
case *ast.FieldList:
cpy := *n
if n.List != nil {
cpy.List = make([]*ast.Field, len(n.List))
copy(cpy.List, n.List)
}
for i, f := range cpy.List {
walkToReplace(v, f, func(r ast.Node) {
cpy.List[i] = r.(*ast.Field)
})
}
replace(&cpy)
case *ast.Ellipsis:
cpy := *n
if cpy.Elt != nil {
walkToReplace(v, cpy.Elt, func(r ast.Node) {
cpy.Elt = r.(ast.Expr)
})
}
replace(&cpy)
case *ast.FuncLit:
cpy := *n
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(*ast.FuncType)
})
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
replace(&cpy)
case *ast.CompositeLit:
cpy := *n
if n.Elts != nil {
cpy.Elts = make([]ast.Expr, len(n.Elts))
copy(cpy.Elts, n.Elts)
}
if cpy.Type != nil {
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(ast.Expr)
})
}
walkExprList(v, cpy.Elts)
replace(&cpy)
case *ast.ParenExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
replace(&cpy)
case *ast.SelectorExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
walkToReplace(v, cpy.Sel, func(r ast.Node) {
cpy.Sel = r.(*ast.Ident)
})
replace(&cpy)
case *ast.IndexExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
walkToReplace(v, cpy.Index, func(r ast.Node) {
cpy.Index = r.(ast.Expr)
})
replace(&cpy)
case *ast.SliceExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
if cpy.Low != nil {
walkToReplace(v, cpy.Low, func(r ast.Node) {
cpy.Low = r.(ast.Expr)
})
}
if cpy.High != nil {
walkToReplace(v, cpy.High, func(r ast.Node) {
cpy.High = r.(ast.Expr)
})
}
if cpy.Max != nil {
walkToReplace(v, cpy.Max, func(r ast.Node) {
cpy.Max = r.(ast.Expr)
})
}
replace(&cpy)
case *ast.TypeAssertExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
if cpy.Type != nil {
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(ast.Expr)
})
}
replace(&cpy)
case *ast.CallExpr:
cpy := *n
if n.Args != nil {
cpy.Args = make([]ast.Expr, len(n.Args))
copy(cpy.Args, n.Args)
}
walkToReplace(v, cpy.Fun, func(r ast.Node) {
cpy.Fun = r.(ast.Expr)
})
walkExprList(v, cpy.Args)
replace(&cpy)
case *ast.StarExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
replace(&cpy)
case *ast.UnaryExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
replace(&cpy)
case *ast.BinaryExpr:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
walkToReplace(v, cpy.Y, func(r ast.Node) {
cpy.Y = r.(ast.Expr)
})
replace(&cpy)
case *ast.KeyValueExpr:
cpy := *n
walkToReplace(v, cpy.Key, func(r ast.Node) {
cpy.Key = r.(ast.Expr)
})
walkToReplace(v, cpy.Value, func(r ast.Node) {
cpy.Value = r.(ast.Expr)
})
replace(&cpy)
// Types
case *ast.ArrayType:
cpy := *n
if cpy.Len != nil {
walkToReplace(v, cpy.Len, func(r ast.Node) {
cpy.Len = r.(ast.Expr)
})
}
walkToReplace(v, cpy.Elt, func(r ast.Node) {
cpy.Elt = r.(ast.Expr)
})
replace(&cpy)
case *ast.StructType:
cpy := *n
walkToReplace(v, cpy.Fields, func(r ast.Node) {
cpy.Fields = r.(*ast.FieldList)
})
replace(&cpy)
case *ast.FuncType:
cpy := *n
if cpy.Params != nil {
walkToReplace(v, cpy.Params, func(r ast.Node) {
cpy.Params = r.(*ast.FieldList)
})
}
if cpy.Results != nil {
walkToReplace(v, cpy.Results, func(r ast.Node) {
cpy.Results = r.(*ast.FieldList)
})
}
replace(&cpy)
case *ast.InterfaceType:
cpy := *n
walkToReplace(v, cpy.Methods, func(r ast.Node) {
cpy.Methods = r.(*ast.FieldList)
})
replace(&cpy)
case *ast.MapType:
cpy := *n
walkToReplace(v, cpy.Key, func(r ast.Node) {
cpy.Key = r.(ast.Expr)
})
walkToReplace(v, cpy.Value, func(r ast.Node) {
cpy.Value = r.(ast.Expr)
})
replace(&cpy)
case *ast.ChanType:
cpy := *n
walkToReplace(v, cpy.Value, func(r ast.Node) {
cpy.Value = r.(ast.Expr)
})
replace(&cpy)
case *ast.DeclStmt:
cpy := *n
walkToReplace(v, cpy.Decl, func(r ast.Node) {
cpy.Decl = r.(ast.Decl)
})
replace(&cpy)
case *ast.LabeledStmt:
cpy := *n
walkToReplace(v, cpy.Label, func(r ast.Node) {
cpy.Label = r.(*ast.Ident)
})
walkToReplace(v, cpy.Stmt, func(r ast.Node) {
cpy.Stmt = r.(ast.Stmt)
})
replace(&cpy)
case *ast.ExprStmt:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
replace(&cpy)
case *ast.SendStmt:
cpy := *n
walkToReplace(v, cpy.Chan, func(r ast.Node) {
cpy.Chan = r.(ast.Expr)
})
walkToReplace(v, cpy.Value, func(r ast.Node) {
cpy.Value = r.(ast.Expr)
})
replace(&cpy)
case *ast.IncDecStmt:
cpy := *n
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
replace(&cpy)
case *ast.AssignStmt:
cpy := *n
if n.Lhs != nil {
cpy.Lhs = make([]ast.Expr, len(n.Lhs))
copy(cpy.Lhs, n.Lhs)
}
if n.Rhs != nil {
cpy.Rhs = make([]ast.Expr, len(n.Rhs))
copy(cpy.Rhs, n.Rhs)
}
walkExprList(v, cpy.Lhs)
walkExprList(v, cpy.Rhs)
replace(&cpy)
case *ast.GoStmt:
cpy := *n
walkToReplace(v, cpy.Call, func(r ast.Node) {
cpy.Call = r.(*ast.CallExpr)
})
replace(&cpy)
case *ast.DeferStmt:
cpy := *n
walkToReplace(v, cpy.Call, func(r ast.Node) {
cpy.Call = r.(*ast.CallExpr)
})
replace(&cpy)
case *ast.ReturnStmt:
cpy := *n
if n.Results != nil {
cpy.Results = make([]ast.Expr, len(n.Results))
copy(cpy.Results, n.Results)
}
walkExprList(v, cpy.Results)
replace(&cpy)
case *ast.BranchStmt:
cpy := *n
if cpy.Label != nil {
walkToReplace(v, cpy.Label, func(r ast.Node) {
cpy.Label = r.(*ast.Ident)
})
}
replace(&cpy)
case *ast.BlockStmt:
cpy := *n
if n.List != nil {
cpy.List = make([]ast.Stmt, len(n.List))
copy(cpy.List, n.List)
}
walkStmtList(v, cpy.List)
replace(&cpy)
case *ast.IfStmt:
cpy := *n
if cpy.Init != nil {
walkToReplace(v, cpy.Init, func(r ast.Node) {
cpy.Init = r.(ast.Stmt)
})
}
walkToReplace(v, cpy.Cond, func(r ast.Node) {
cpy.Cond = r.(ast.Expr)
})
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
if cpy.Else != nil {
walkToReplace(v, cpy.Else, func(r ast.Node) {
cpy.Else = r.(ast.Stmt)
})
}
replace(&cpy)
case *ast.CaseClause:
cpy := *n
if n.List != nil {
cpy.List = make([]ast.Expr, len(n.List))
copy(cpy.List, n.List)
}
if n.Body != nil {
cpy.Body = make([]ast.Stmt, len(n.Body))
copy(cpy.Body, n.Body)
}
walkExprList(v, cpy.List)
walkStmtList(v, cpy.Body)
replace(&cpy)
case *ast.SwitchStmt:
cpy := *n
if cpy.Init != nil {
walkToReplace(v, cpy.Init, func(r ast.Node) {
cpy.Init = r.(ast.Stmt)
})
}
if cpy.Tag != nil {
walkToReplace(v, cpy.Tag, func(r ast.Node) {
cpy.Tag = r.(ast.Expr)
})
}
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
replace(&cpy)
case *ast.TypeSwitchStmt:
cpy := *n
if cpy.Init != nil {
walkToReplace(v, cpy.Init, func(r ast.Node) {
cpy.Init = r.(ast.Stmt)
})
}
walkToReplace(v, cpy.Assign, func(r ast.Node) {
cpy.Assign = r.(ast.Stmt)
})
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
replace(&cpy)
case *ast.CommClause:
cpy := *n
if n.Body != nil {
cpy.Body = make([]ast.Stmt, len(n.Body))
copy(cpy.Body, n.Body)
}
if cpy.Comm != nil {
walkToReplace(v, cpy.Comm, func(r ast.Node) {
cpy.Comm = r.(ast.Stmt)
})
}
walkStmtList(v, cpy.Body)
replace(&cpy)
case *ast.SelectStmt:
cpy := *n
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
replace(&cpy)
case *ast.ForStmt:
cpy := *n
if cpy.Init != nil {
walkToReplace(v, cpy.Init, func(r ast.Node) {
cpy.Init = r.(ast.Stmt)
})
}
if cpy.Cond != nil {
walkToReplace(v, cpy.Cond, func(r ast.Node) {
cpy.Cond = r.(ast.Expr)
})
}
if cpy.Post != nil {
walkToReplace(v, cpy.Post, func(r ast.Node) {
cpy.Post = r.(ast.Stmt)
})
}
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
replace(&cpy)
case *ast.RangeStmt:
cpy := *n
if cpy.Key != nil {
walkToReplace(v, cpy.Key, func(r ast.Node) {
cpy.Key = r.(ast.Expr)
})
}
if cpy.Value != nil {
walkToReplace(v, cpy.Value, func(r ast.Node) {
cpy.Value = r.(ast.Expr)
})
}
walkToReplace(v, cpy.X, func(r ast.Node) {
cpy.X = r.(ast.Expr)
})
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
// Declarations
replace(&cpy)
case *ast.ImportSpec:
cpy := *n
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
if cpy.Name != nil {
walkToReplace(v, cpy.Name, func(r ast.Node) {
cpy.Name = r.(*ast.Ident)
})
}
walkToReplace(v, cpy.Path, func(r ast.Node) {
cpy.Path = r.(*ast.BasicLit)
})
if cpy.Comment != nil {
walkToReplace(v, cpy.Comment, func(r ast.Node) {
cpy.Comment = r.(*ast.CommentGroup)
})
}
replace(&cpy)
case *ast.ValueSpec:
cpy := *n
if n.Names != nil {
cpy.Names = make([]*ast.Ident, len(n.Names))
copy(cpy.Names, n.Names)
}
if n.Values != nil {
cpy.Values = make([]ast.Expr, len(n.Values))
copy(cpy.Values, n.Values)
}
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
walkIdentList(v, cpy.Names)
if cpy.Type != nil {
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(ast.Expr)
})
}
walkExprList(v, cpy.Values)
if cpy.Comment != nil {
walkToReplace(v, cpy.Comment, func(r ast.Node) {
cpy.Comment = r.(*ast.CommentGroup)
})
}
replace(&cpy)
case *ast.TypeSpec:
cpy := *n
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
walkToReplace(v, cpy.Name, func(r ast.Node) {
cpy.Name = r.(*ast.Ident)
})
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(ast.Expr)
})
if cpy.Comment != nil {
walkToReplace(v, cpy.Comment, func(r ast.Node) {
cpy.Comment = r.(*ast.CommentGroup)
})
}
replace(&cpy)
case *ast.GenDecl:
cpy := *n
if n.Specs != nil {
cpy.Specs = make([]ast.Spec, len(n.Specs))
copy(cpy.Specs, n.Specs)
}
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
for i, s := range cpy.Specs {
walkToReplace(v, s, func(r ast.Node) {
cpy.Specs[i] = r.(ast.Spec)
})
}
replace(&cpy)
case *ast.FuncDecl:
cpy := *n
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
if cpy.Recv != nil {
walkToReplace(v, cpy.Recv, func(r ast.Node) {
cpy.Recv = r.(*ast.FieldList)
})
}
walkToReplace(v, cpy.Name, func(r ast.Node) {
cpy.Name = r.(*ast.Ident)
})
walkToReplace(v, cpy.Type, func(r ast.Node) {
cpy.Type = r.(*ast.FuncType)
})
if cpy.Body != nil {
walkToReplace(v, cpy.Body, func(r ast.Node) {
cpy.Body = r.(*ast.BlockStmt)
})
}
// Files and packages
replace(&cpy)
case *ast.File:
cpy := *n
if cpy.Doc != nil {
walkToReplace(v, cpy.Doc, func(r ast.Node) {
cpy.Doc = r.(*ast.CommentGroup)
})
}
walkToReplace(v, cpy.Name, func(r ast.Node) {
cpy.Name = r.(*ast.Ident)
})
walkDeclList(v, cpy.Decls)
// don't walk cpy.Comments - they have been
// visited already through the individual
// nodes
replace(&cpy)
case *ast.Package:
cpy := *n
cpy.Files = map[string]*ast.File{}
for i, f := range n.Files {
cpy.Files[i] = f
walkToReplace(v, f, func(r ast.Node) {
cpy.Files[i] = r.(*ast.File)
})
}
replace(&cpy)
default:
panic(fmt.Sprintf("walkToReplace: unexpected node type %T", n))
}
if v != nil {
v.Visit(nil, func(ast.Node) { panic("can't replace the go-up nil") })
}
}