ast分析

WayytWang · · 800 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

AST

  • 许多自动化代码生成工具都离不开语法树分析。

ast.File结构

type File struct {
    Doc        *CommentGroup   // associated documentation; or nil
    Package    token.Pos       // position of "package" keyword
    Name       *Ident          // go文件的包名
    Decls      []Decl          // 最外层的声明
    Scope      *Scope          // package scope (this file only)
    Imports    []*ImportSpec   // 引入的外部包
    Unresolved []*Ident        // unresolved identifiers in this file
    Comments   []*CommentGroup // list of all comments in the source file
}

Node节点

  • 整个语法树都由不同的node组成

Node接口定义

// All node types implement the Node interface.
type Node interface {
    Pos() token.Pos // position of first character belonging to the node
    End() token.Pos // position of first character immediately after the node
}

Node分类

go ast.png

Expression and Type

// All expression nodes implement the Expr interface.
type Expr interface {
    Node
    exprNode()
}
expression node
  • Indent(identifier)表示一个标识符,比如示例中表示包名的Name字段(包名)就是一个expression node
type node
  •     // 结构体类型
          StructType struct {
              Struct     token.Pos  // position of "struct" keyword
              Fields     *FieldList // list of field declarations
              Incomplete bool       // true if (source) fields are missing in the Fields list
          }
      
          // 函数类型
          FuncType struct {
              Func    token.Pos  // position of "func" keyword (token.NoPos if there is no "func")
              Params  *FieldList // (incoming) parameters; non-nil
              Results *FieldList // (outgoing) results; or nil
          }
      
          // 接口类型
          InterfaceType struct {
              Interface  token.Pos  // position of "interface" keyword
              Methods    *FieldList // list of methods
              Incomplete bool       // true if (source) methods are missing in the Methods list
          }

Statement

  • 赋值语句,控制语句(if,else,for,select...)等均属于statement node。

      // 赋值语句
        AssignStmt struct {
            Lhs    []Expr
            TokPos token.Pos   // position of Tok
            Tok    token.Token // assignment token, DEFINE
            Rhs    []Expr
        }
    
        // 条件语句
        IfStmt struct {
            If   token.Pos // position of "if" keyword
            Init Stmt      // initialization statement; or nil
            Cond Expr      // condition
            Body *BlockStmt
            Else Stmt // else branch; or nil
        }

Spec

  • Spec node只有3种,分别是ImportSpecValueSpecTypeSpec

    // import部分
        ImportSpec struct {
            Doc     *CommentGroup // associated documentation; or nil
            Name    *Ident        // local package name (including "."); or nil
            Path    *BasicLit     // import path
            Comment *CommentGroup // line comments; or nil
            EndPos  token.Pos     // end of spec (overrides Path.Pos if nonzero)
        }
    
    // constant or variable declaration
        ValueSpec struct {
            Doc     *CommentGroup // associated documentation; or nil
            Names   []*Ident      // value names (len(Names) > 0)
            Type    Expr          // value type; or nil
            Values  []Expr        // initial values; or nil
            Comment *CommentGroup // line comments; or nil
        }
    
        // type声明
        TypeSpec struct {
            Doc     *CommentGroup // associated documentation; or nil
            Name    *Ident        // type name
            Assign  token.Pos     // position of '=', if any
            Type    Expr          // *Ident, *ParenExpr, *SelectorExpr, *StarExpr, or any of the *XxxTypes
            Comment *CommentGroup // line comments; or nil
        }

Declaration Node

  • ...
    type (
        // 表示一个有语法错误的节点
        BadDecl struct {
            From, To token.Pos // position range of bad declaration
        }
    
        // 用于表示import, const,type或变量声明
        GenDecl struct {
            Doc    *CommentGroup // associated documentation; or nil
            TokPos token.Pos     // position of Tok
            Tok    token.Token   // IMPORT, CONST, TYPE, VAR
            Lparen token.Pos     // position of '(', if any
            Specs  []Spec
            Rparen token.Pos // position of ')', if any
        }
    
        // 用于表示函数声明
        FuncDecl struct {
            Doc  *CommentGroup // associated documentation; or nil
            Recv *FieldList    // receiver (methods); or nil (functions)
            Name *Ident        // function/method name
            Type *FuncType     // function signature: parameters, results, and position of "func" keyword
            Body *BlockStmt    // function body; or nil for external (non-Go) function
        }
    )
    

Common Node

  • 除去上述四种类别划分的node,还有一些node不属于上面四种类别:
  • // File 表示一个文件节点
    type File struct {
        ...
    }
    
    // Package 表示一个包节点
    type Package struct {
        ...
    }
    
    // Field 字段节点, 可以代表结构体定义中的字段,接口定义中的方法列表,函数前面中的入参和返回值字段
    type Field struct {
        ...
    }
    ...
    // FieldList 包含多个Field
    type FieldList struct {
        ...
    }
    
    // 表示基础类型
    type    BasicLit struct {
            ValuePos token.Pos   
            Kind     token.Token 
            Value    string     
    }
    

demo.go示列

demo.go

package main

import (
    "context"
    "fmt"
)

const (
    name = "call duck"
)

// Duck 结构体
type Duck struct {
    name string
}

// Animal 接口
type Animal interface {
    Name(ctx context.Context, n string) string
}

func (d Duck) Name(ctx context.Context, n string) string {
    d.name = d.name + n
    return d.name
}

// main方法
func main() {
    var duck Animal
    duck = Duck{name: name}
    ctx := context.Background()
    fmt.Println(duck.Name(ctx, "cute "))
}

查看demo.go的语法树结构

package main

import (
    "go/ast"
    "go/parser"
    "go/token"
    "log"
    "path/filepath"
)

func main() {
    fSet := token.NewFileSet()
    path, _ := filepath.Abs("demo.go")
    f, err := parser.ParseFile(fSet, path, nil, parser.AllErrors)
    if err != nil {
        log.Println(err)
        return
    }
    // 打印语法树
    _ = ast.Print(fSet, f)
}

ast分析

  • 即使demo.go文件中的代码已经非常少,但是对应的语法树也非常复杂。下图仅展示一个结构。
  • demo_ast.png
  • 其中最为复杂的是Decls,它包括了代码中所有的声明。
  • demo_detail.png

利用语法树帮接口方法加上context参数

Walk

  • 语法树层级比较复杂,为方便使用者利用语法树,ast包提供了便利方法

    func Walk(v Visitor, node Node) 
    
    type Visitor interface {
        Visit(node Node) (w Visitor)
    }
  • Walk方法会按照深度优先搜索方法(depth-first order)遍历整个语法树,所以使用者只需按照业务需要,实现Visitor接口即可。
  • Walk每遍历一个节点就会调用Visitor.Visit方法,传入当前节点。直到返回的Visitor为nil,则停止遍历当前节点的子节点。

待添加文件demo.go

package main

type Animal interface {
    Name(n string)
    Eat()
    Sleep()
}

type Duck interface {
    Name(n string)
    Eat()
    Sleep()
    Color() string
}

帮demo.go中入参数加context参数,返回参数加error类型

package main

import (
    "bytes"
    "fmt"
    "go/ast"
    "go/format"
    "go/parser"
    "go/token"
    "log"
    "path/filepath"
    "strconv"
)

type AddContextVisitor struct {
    pkgContext string // 如果有引入context包情况下,context包的别名
}

func (vi *AddContextVisitor) Visit(node ast.Node) ast.Visitor {

    switch node.(type) {
    case *ast.File:
        file := node.(*ast.File)
        // 没有任何import
        if len(file.Imports) == 0 {
            vi.addImportWithoutAnyImport(file)
        }
    case *ast.GenDecl:
        genDecl := node.(*ast.GenDecl)
        // 有import
        if genDecl.Tok == token.IMPORT {
            vi.addImport(genDecl)
        }
    case *ast.InterfaceType:
        // 遍历所有的接口类型
        interfaceType := node.(*ast.InterfaceType)
        vi.addContextAndError(interfaceType)
        return nil
    }
    return vi
}

func (vi *AddContextVisitor) addImport(genDecl *ast.GenDecl) {
    hasImportedContext := false
    for _, value := range genDecl.Specs {
        // 如果已经包含"context"
        importSpec := value.(*ast.ImportSpec)
        if importSpec.Path.Value == strconv.Quote("context") {
            hasImportedContext = true
            // 有别名的情况下记录别名
            if importSpec.Name != nil {
                vi.pkgContext = importSpec.Name.Name
            }
        }
    }
    if hasImportedContext {
        return
    }
    if !hasImportedContext {
        genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
            Path: &ast.BasicLit{
                Kind:  token.STRING,
                Value: strconv.Quote("context"),
            },
        })
    }
}

// 没有import情况下 引入context
func (vi *AddContextVisitor) addImportWithoutAnyImport(file *ast.File) {
    genDecl := &ast.GenDecl{
        Tok: token.IMPORT,
        Specs: []ast.Spec{
            &ast.ImportSpec{
                Path: &ast.BasicLit{
                    Kind:  token.STRING,
                    Value: strconv.Quote("context"),
                },
            },
        },
    }
    list := []ast.Decl{genDecl}
    file.Decls = append(list, file.Decls...)
}

// 为接口方法添加参数
func (vi *AddContextVisitor) addContextAndError(interfaceType *ast.InterfaceType) {
    // 接口方法不为空是,遍历接口方法
    if interfaceType.Methods != nil || interfaceType.Methods.List != nil {
        for _, v := range interfaceType.Methods.List {
            ft := v.Type.(*ast.FuncType)
            hasContext := false
            hasError := false
            // 判断参数中是否包含context.Context类型
            for _, value := range ft.Params.List {
                if expr, ok := value.Type.(*ast.SelectorExpr); ok {
                    if ident, ok := expr.X.(*ast.Ident); ok {
                        if ident.Name == "context" {
                            hasContext = true
                        }
                    }
                }
            }

            if ft.Results != nil && ft.Results.List != nil {
                // 判断返回参数中是否包含error类型
                for i, value := range ft.Results.List {
                    if ident, ok := value.Type.(*ast.Ident); ok {
                        if ident.Name == "error" {
                            ft.Results.List[i].Names = []*ast.Ident{
                                ast.NewIdent("err"),
                            }
                            hasError = true
                        }
                    }
                }
            }

            if !hasError {
                errField := &ast.Field{
                    Names: []*ast.Ident{
                        ast.NewIdent("err"),
                    },
                    Type: ast.NewIdent("error"),
                }
                if ft.Results == nil {
                    ft.Results = &ast.FieldList{}
                }
                ft.Results.List = append(ft.Results.List, errField)
            }

            // 为没有context参数的方法添加context参数
            if !hasContext {
                x := "context"
                if vi.pkgContext != "" {
                    x = vi.pkgContext
                }
                ctxField := &ast.Field{
                    Names: []*ast.Ident{
                        ast.NewIdent("ctx"),
                    },
                    Type: &ast.SelectorExpr{
                        X:   ast.NewIdent(x),
                        Sel: ast.NewIdent("Context"),
                    },
                }
                list := []*ast.Field{
                    ctxField,
                }
                ft.Params.List = append(list, ft.Params.List...)
            }
        }
    }
}

func main() {
    fSet := token.NewFileSet()

    path, _ := filepath.Abs("zCode/ast_stu/example/demo.go")
    f, err := parser.ParseFile(fSet, path, nil, parser.ParseComments)
    if err != nil {
        log.Println(err)
        return
    }

    v := &AddContextVisitor{}
    ast.Walk(v, f)

    var output []byte
    buffer := bytes.NewBuffer(output)
    err = format.Node(buffer, fSet, f)
    if err != nil {
        log.Fatal(err)
    }
    // 输出Go代码
    b, _ := format.Source(buffer.Bytes()) // fmt
    //b, _ = imports.Process("", b, nil)    // imports
    fmt.Printf("%s\n", b)
}

生成结果

package main

import "context"

type Animal interface {
        Name(ctx context.Context, n string) (err error)
        Eat(ctx context.Context) (err error)
        Sleep(ctx context.Context) (err error)
}

type Duck interface {
        Name(ctx context.Context, n string) (err error)
        Eat(ctx context.Context) (err error)
        Sleep(ctx context.Context) (err error)
        Color(ctx context.Context) (string, err error)
}

有疑问加站长微信联系(非本文作者)

本文来自:Segmentfault

感谢作者:WayytWang

查看原文:ast分析

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

800 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传