AST
- 许多自动化代码生成工具都离不开语法树分析。
ast.File结构
type File struct {
Doc *CommentGroup // associated documentation; or nil
Package token.Pos // position of "package" keyword
Name *Ident // go文件的包名
Decls []Decl // 最外层的声明
Scope *Scope // package scope (this file only)
Imports []*ImportSpec // 引入的外部包
Unresolved []*Ident // unresolved identifiers in this file
Comments []*CommentGroup // list of all comments in the source file
}
Node节点
- 整个语法树都由不同的node组成
Node接口定义
// All node types implement the Node interface.
type Node interface {
Pos() token.Pos // position of first character belonging to the node
End() token.Pos // position of first character immediately after the node
}
Node分类
Expression and Type
// All expression nodes implement the Expr interface.
type Expr interface {
Node
exprNode()
}
expression node
-
Indent
(identifier)表示一个标识符,比如示例中表示包名的Name
字段(包名)就是一个expression node
type node
// 结构体类型 StructType struct { Struct token.Pos // position of "struct" keyword Fields *FieldList // list of field declarations Incomplete bool // true if (source) fields are missing in the Fields list } // 函数类型 FuncType struct { Func token.Pos // position of "func" keyword (token.NoPos if there is no "func") Params *FieldList // (incoming) parameters; non-nil Results *FieldList // (outgoing) results; or nil } // 接口类型 InterfaceType struct { Interface token.Pos // position of "interface" keyword Methods *FieldList // list of methods Incomplete bool // true if (source) methods are missing in the Methods list }
Statement
-
赋值语句,控制语句(if,else,for,select...)等均属于statement node。
// 赋值语句 AssignStmt struct { Lhs []Expr TokPos token.Pos // position of Tok Tok token.Token // assignment token, DEFINE Rhs []Expr } // 条件语句 IfStmt struct { If token.Pos // position of "if" keyword Init Stmt // initialization statement; or nil Cond Expr // condition Body *BlockStmt Else Stmt // else branch; or nil }
Spec
-
Spec node只有3种,分别是
ImportSpec
,ValueSpec
和TypeSpec
// import部分 ImportSpec struct { Doc *CommentGroup // associated documentation; or nil Name *Ident // local package name (including "."); or nil Path *BasicLit // import path Comment *CommentGroup // line comments; or nil EndPos token.Pos // end of spec (overrides Path.Pos if nonzero) } // constant or variable declaration ValueSpec struct { Doc *CommentGroup // associated documentation; or nil Names []*Ident // value names (len(Names) > 0) Type Expr // value type; or nil Values []Expr // initial values; or nil Comment *CommentGroup // line comments; or nil } // type声明 TypeSpec struct { Doc *CommentGroup // associated documentation; or nil Name *Ident // type name Assign token.Pos // position of '=', if any Type Expr // *Ident, *ParenExpr, *SelectorExpr, *StarExpr, or any of the *XxxTypes Comment *CommentGroup // line comments; or nil }
Declaration Node
... type ( // 表示一个有语法错误的节点 BadDecl struct { From, To token.Pos // position range of bad declaration } // 用于表示import, const,type或变量声明 GenDecl struct { Doc *CommentGroup // associated documentation; or nil TokPos token.Pos // position of Tok Tok token.Token // IMPORT, CONST, TYPE, VAR Lparen token.Pos // position of '(', if any Specs []Spec Rparen token.Pos // position of ')', if any } // 用于表示函数声明 FuncDecl struct { Doc *CommentGroup // associated documentation; or nil Recv *FieldList // receiver (methods); or nil (functions) Name *Ident // function/method name Type *FuncType // function signature: parameters, results, and position of "func" keyword Body *BlockStmt // function body; or nil for external (non-Go) function } )
Common Node
- 除去上述四种类别划分的node,还有一些node不属于上面四种类别:
// File 表示一个文件节点 type File struct { ... } // Package 表示一个包节点 type Package struct { ... } // Field 字段节点, 可以代表结构体定义中的字段,接口定义中的方法列表,函数前面中的入参和返回值字段 type Field struct { ... } ... // FieldList 包含多个Field type FieldList struct { ... } // 表示基础类型 type BasicLit struct { ValuePos token.Pos Kind token.Token Value string }
demo.go示列
demo.go
package main
import (
"context"
"fmt"
)
const (
name = "call duck"
)
// Duck 结构体
type Duck struct {
name string
}
// Animal 接口
type Animal interface {
Name(ctx context.Context, n string) string
}
func (d Duck) Name(ctx context.Context, n string) string {
d.name = d.name + n
return d.name
}
// main方法
func main() {
var duck Animal
duck = Duck{name: name}
ctx := context.Background()
fmt.Println(duck.Name(ctx, "cute "))
}
查看demo.go的语法树结构
package main
import (
"go/ast"
"go/parser"
"go/token"
"log"
"path/filepath"
)
func main() {
fSet := token.NewFileSet()
path, _ := filepath.Abs("demo.go")
f, err := parser.ParseFile(fSet, path, nil, parser.AllErrors)
if err != nil {
log.Println(err)
return
}
// 打印语法树
_ = ast.Print(fSet, f)
}
ast分析
- 即使demo.go文件中的代码已经非常少,但是对应的语法树也非常复杂。下图仅展示一个结构。
- 其中最为复杂的是Decls,它包括了代码中所有的声明。
利用语法树帮接口方法加上context参数
Walk
-
语法树层级比较复杂,为方便使用者利用语法树,ast包提供了便利方法
func Walk(v Visitor, node Node) type Visitor interface { Visit(node Node) (w Visitor) }
- Walk方法会按照深度优先搜索方法(depth-first order)遍历整个语法树,所以使用者只需按照业务需要,实现Visitor接口即可。
- Walk每遍历一个节点就会调用Visitor.Visit方法,传入当前节点。直到返回的Visitor为nil,则停止遍历当前节点的子节点。
待添加文件demo.go
package main
type Animal interface {
Name(n string)
Eat()
Sleep()
}
type Duck interface {
Name(n string)
Eat()
Sleep()
Color() string
}
帮demo.go中入参数加context参数,返回参数加error类型
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"path/filepath"
"strconv"
)
type AddContextVisitor struct {
pkgContext string // 如果有引入context包情况下,context包的别名
}
func (vi *AddContextVisitor) Visit(node ast.Node) ast.Visitor {
switch node.(type) {
case *ast.File:
file := node.(*ast.File)
// 没有任何import
if len(file.Imports) == 0 {
vi.addImportWithoutAnyImport(file)
}
case *ast.GenDecl:
genDecl := node.(*ast.GenDecl)
// 有import
if genDecl.Tok == token.IMPORT {
vi.addImport(genDecl)
}
case *ast.InterfaceType:
// 遍历所有的接口类型
interfaceType := node.(*ast.InterfaceType)
vi.addContextAndError(interfaceType)
return nil
}
return vi
}
func (vi *AddContextVisitor) addImport(genDecl *ast.GenDecl) {
hasImportedContext := false
for _, value := range genDecl.Specs {
// 如果已经包含"context"
importSpec := value.(*ast.ImportSpec)
if importSpec.Path.Value == strconv.Quote("context") {
hasImportedContext = true
// 有别名的情况下记录别名
if importSpec.Name != nil {
vi.pkgContext = importSpec.Name.Name
}
}
}
if hasImportedContext {
return
}
if !hasImportedContext {
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote("context"),
},
})
}
}
// 没有import情况下 引入context
func (vi *AddContextVisitor) addImportWithoutAnyImport(file *ast.File) {
genDecl := &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote("context"),
},
},
},
}
list := []ast.Decl{genDecl}
file.Decls = append(list, file.Decls...)
}
// 为接口方法添加参数
func (vi *AddContextVisitor) addContextAndError(interfaceType *ast.InterfaceType) {
// 接口方法不为空是,遍历接口方法
if interfaceType.Methods != nil || interfaceType.Methods.List != nil {
for _, v := range interfaceType.Methods.List {
ft := v.Type.(*ast.FuncType)
hasContext := false
hasError := false
// 判断参数中是否包含context.Context类型
for _, value := range ft.Params.List {
if expr, ok := value.Type.(*ast.SelectorExpr); ok {
if ident, ok := expr.X.(*ast.Ident); ok {
if ident.Name == "context" {
hasContext = true
}
}
}
}
if ft.Results != nil && ft.Results.List != nil {
// 判断返回参数中是否包含error类型
for i, value := range ft.Results.List {
if ident, ok := value.Type.(*ast.Ident); ok {
if ident.Name == "error" {
ft.Results.List[i].Names = []*ast.Ident{
ast.NewIdent("err"),
}
hasError = true
}
}
}
}
if !hasError {
errField := &ast.Field{
Names: []*ast.Ident{
ast.NewIdent("err"),
},
Type: ast.NewIdent("error"),
}
if ft.Results == nil {
ft.Results = &ast.FieldList{}
}
ft.Results.List = append(ft.Results.List, errField)
}
// 为没有context参数的方法添加context参数
if !hasContext {
x := "context"
if vi.pkgContext != "" {
x = vi.pkgContext
}
ctxField := &ast.Field{
Names: []*ast.Ident{
ast.NewIdent("ctx"),
},
Type: &ast.SelectorExpr{
X: ast.NewIdent(x),
Sel: ast.NewIdent("Context"),
},
}
list := []*ast.Field{
ctxField,
}
ft.Params.List = append(list, ft.Params.List...)
}
}
}
}
func main() {
fSet := token.NewFileSet()
path, _ := filepath.Abs("zCode/ast_stu/example/demo.go")
f, err := parser.ParseFile(fSet, path, nil, parser.ParseComments)
if err != nil {
log.Println(err)
return
}
v := &AddContextVisitor{}
ast.Walk(v, f)
var output []byte
buffer := bytes.NewBuffer(output)
err = format.Node(buffer, fSet, f)
if err != nil {
log.Fatal(err)
}
// 输出Go代码
b, _ := format.Source(buffer.Bytes()) // fmt
//b, _ = imports.Process("", b, nil) // imports
fmt.Printf("%s\n", b)
}
生成结果
package main
import "context"
type Animal interface {
Name(ctx context.Context, n string) (err error)
Eat(ctx context.Context) (err error)
Sleep(ctx context.Context) (err error)
}
type Duck interface {
Name(ctx context.Context, n string) (err error)
Eat(ctx context.Context) (err error)
Sleep(ctx context.Context) (err error)
Color(ctx context.Context) (string, err error)
}
有疑问加站长微信联系(非本文作者)