api/vendor/github.com/cweill/gotests/internal/goparser/goparser.go

314 lines
7.2 KiB
Go

// Package goparse contains logic for parsing Go files. Specifically it parses
// source and test files into domain models for generating tests.
package goparser
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"io/ioutil"
"strings"
"github.com/cweill/gotests/internal/models"
)
// ErrEmptyFile represents an empty file error.
var ErrEmptyFile = errors.New("file is empty")
// Result representats a parsed Go file.
type Result struct {
// The package name and imports of a Go file.
Header *models.Header
// All the functions and methods in a Go file.
Funcs []*models.Function
}
// Parser can parse Go files.
type Parser struct {
// The importer to resolve packages from import paths.
Importer types.Importer
}
// Parse parses a given Go file at srcPath, along any files that share the same
// package, into a domain model for generating tests.
func (p *Parser) Parse(srcPath string, files []models.Path) (*Result, error) {
b, err := p.readFile(srcPath)
if err != nil {
return nil, err
}
fset := token.NewFileSet()
f, err := p.parseFile(fset, srcPath)
if err != nil {
return nil, err
}
fs, err := p.parseFiles(fset, f, files)
if err != nil {
return nil, err
}
return &Result{
Header: &models.Header{
Comments: parsePkgComment(f, f.Package),
Package: f.Name.String(),
Imports: parseImports(f.Imports),
Code: goCode(b, f),
},
Funcs: p.parseFunctions(fset, f, fs),
}, nil
}
func (p *Parser) readFile(srcPath string) ([]byte, error) {
b, err := ioutil.ReadFile(srcPath)
if err != nil {
return nil, fmt.Errorf("ioutil.ReadFile: %v", err)
}
if len(b) == 0 {
return nil, ErrEmptyFile
}
return b, nil
}
func (p *Parser) parseFile(fset *token.FileSet, srcPath string) (*ast.File, error) {
f, err := parser.ParseFile(fset, srcPath, nil, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("target parser.ParseFile(): %v", err)
}
return f, nil
}
func (p *Parser) parseFiles(fset *token.FileSet, f *ast.File, files []models.Path) ([]*ast.File, error) {
pkg := f.Name.String()
var fs []*ast.File
for _, file := range files {
ff, err := parser.ParseFile(fset, string(file), nil, 0)
if err != nil {
return nil, fmt.Errorf("other file parser.ParseFile: %v", err)
}
if name := ff.Name.String(); name != pkg {
continue
}
fs = append(fs, ff)
}
return fs, nil
}
func (p *Parser) parseFunctions(fset *token.FileSet, f *ast.File, fs []*ast.File) []*models.Function {
ul, el := p.parseTypes(fset, fs)
var funcs []*models.Function
for _, d := range f.Decls {
fDecl, ok := d.(*ast.FuncDecl)
if !ok {
continue
}
funcs = append(funcs, parseFunc(fDecl, ul, el))
}
return funcs
}
func (p *Parser) parseTypes(fset *token.FileSet, fs []*ast.File) (map[string]types.Type, map[*types.Struct]ast.Expr) {
conf := &types.Config{
Importer: p.Importer,
// Adding a NO-OP error function ignores errors and performs best-effort
// type checking. https://godoc.org/golang.org/x/tools/go/types#Config
Error: func(error) {},
}
ti := &types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
}
// Note: conf.Check can fail, but since Info is not required data, it's ok.
conf.Check("", fset, fs, ti)
ul := make(map[string]types.Type)
el := make(map[*types.Struct]ast.Expr)
for e, t := range ti.Types {
// Collect the underlying types.
ul[t.Type.String()] = t.Type.Underlying()
// Collect structs to determine the fields of a receiver.
if v, ok := t.Type.(*types.Struct); ok {
el[v] = e
}
}
return ul, el
}
func parsePkgComment(f *ast.File, pkgPos token.Pos) []string {
var comments []string
var count int
for _, comment := range f.Comments {
if comment.End() >= pkgPos {
break
}
for _, c := range comment.List {
count += len(c.Text) + 1 // +1 for '\n'
if count < int(c.End()) {
n := int(c.End()) - count - 1
comments = append(comments, strings.Repeat("\n", n))
count++ // for last of '\n'
}
comments = append(comments, c.Text)
}
}
if int(pkgPos)-count > 1 {
comments = append(comments, strings.Repeat("\n", int(pkgPos)-count-2))
}
return comments
}
// Returns the Go code below the imports block.
func goCode(b []byte, f *ast.File) []byte {
furthestPos := f.Name.End()
for _, node := range f.Imports {
if pos := node.End(); pos > furthestPos {
furthestPos = pos
}
}
if furthestPos < token.Pos(len(b)) {
furthestPos++
// Avoid wrong output on windows-encoded files
if b[furthestPos-2] == '\r' && b[furthestPos-1] == '\n' && furthestPos < token.Pos(len(b)) {
furthestPos++
}
}
return b[furthestPos:]
}
func parseFunc(fDecl *ast.FuncDecl, ul map[string]types.Type, el map[*types.Struct]ast.Expr) *models.Function {
f := &models.Function{
Name: fDecl.Name.String(),
IsExported: fDecl.Name.IsExported(),
Receiver: parseReceiver(fDecl.Recv, ul, el),
Parameters: parseFieldList(fDecl.Type.Params, ul),
}
fs := parseFieldList(fDecl.Type.Results, ul)
i := 0
for _, fi := range fs {
if fi.Type.String() == "error" {
f.ReturnsError = true
continue
}
fi.Index = i
f.Results = append(f.Results, fi)
i++
}
return f
}
func parseImports(imps []*ast.ImportSpec) []*models.Import {
var is []*models.Import
for _, imp := range imps {
var n string
if imp.Name != nil {
n = imp.Name.String()
}
is = append(is, &models.Import{
Name: n,
Path: imp.Path.Value,
})
}
return is
}
func parseReceiver(fl *ast.FieldList, ul map[string]types.Type, el map[*types.Struct]ast.Expr) *models.Receiver {
if fl == nil {
return nil
}
r := &models.Receiver{
Field: parseFieldList(fl, ul)[0],
}
t, ok := ul[r.Type.Value]
if !ok {
return r
}
s, ok := t.(*types.Struct)
if !ok {
return r
}
st, found := el[s]
if !found {
return r
}
r.Fields = append(r.Fields, parseFieldList(st.(*ast.StructType).Fields, ul)...)
for i, f := range r.Fields {
// https://github.com/cweill/gotests/issues/69
if i >= s.NumFields() {
break
}
f.Name = s.Field(i).Name()
}
return r
}
func parseFieldList(fl *ast.FieldList, ul map[string]types.Type) []*models.Field {
if fl == nil {
return nil
}
i := 0
var fs []*models.Field
for _, f := range fl.List {
for _, pf := range parseFields(f, ul) {
pf.Index = i
fs = append(fs, pf)
i++
}
}
return fs
}
func parseFields(f *ast.Field, ul map[string]types.Type) []*models.Field {
t := parseExpr(f.Type, ul)
if len(f.Names) == 0 {
return []*models.Field{{
Type: t,
}}
}
var fs []*models.Field
for _, n := range f.Names {
fs = append(fs, &models.Field{
Name: n.Name,
Type: t,
})
}
return fs
}
func parseExpr(e ast.Expr, ul map[string]types.Type) *models.Expression {
switch v := e.(type) {
case *ast.StarExpr:
val := types.ExprString(v.X)
return &models.Expression{
Value: val,
IsStar: true,
Underlying: underlying(val, ul),
}
case *ast.Ellipsis:
exp := parseExpr(v.Elt, ul)
return &models.Expression{
Value: exp.Value,
IsStar: exp.IsStar,
IsVariadic: true,
Underlying: underlying(exp.Value, ul),
}
default:
val := types.ExprString(e)
return &models.Expression{
Value: val,
Underlying: underlying(val, ul),
IsWriter: val == "io.Writer",
}
}
}
func underlying(val string, ul map[string]types.Type) string {
if ul[val] != nil {
return ul[val].String()
}
return ""
}