Codebase list golang-github-anacrolix-tagflag / HEAD parser.go
HEAD

Tree @HEAD (Download .tar.gz)

parser.go @HEADraw · history · blame

package tagflag

import (
	"fmt"
	"log"
	"reflect"
	"strings"

	"github.com/anacrolix/missinggo/slices"
	"github.com/huandu/xstrings"
)

type parser struct {
	// The value from which the parser is built, and values are assigned.
	cmd interface{}
	// Disables the default handling of -h and -help.
	noDefaultHelp bool
	program       string
	description   string

	posArgs []arg
	// Maps -K=V to map[K]arg(V)
	flags  map[string]arg
	excess *ExcessArgs

	// Count of positional arguments parsed so far. Used to locate the next
	// positional argument where it's non-trivial (non-unity arity).
	numPos int
}

func (p *parser) hasOptions() bool {
	return len(p.flags) != 0
}

func (p *parser) parse(args []string) (err error) {
	posOnly := false
	for len(args) != 0 {
		if p.excess != nil && p.nextPosArg() == nil {
			*p.excess = args
			return
		}
		a := args[0]
		args = args[1:]
		if !posOnly && a == "--" {
			posOnly = true
			continue
		}
		if !posOnly && isFlag(a) {
			err = p.parseFlag(a[1:])
		} else {
			err = p.parsePos(a)
		}
		if err != nil {
			return
		}
	}
	if p.numPos < p.minPos() {
		return userError{fmt.Sprintf("missing argument: %q", p.indexPosArg(p.numPos).name)}
	}
	return
}

func (p *parser) minPos() (min int) {
	for _, arg := range p.posArgs {
		min += arg.arity.min
	}
	return
}

func newParser(cmd interface{}, opts ...parseOpt) (p *parser, err error) {
	p = &parser{
		cmd: cmd,
	}
	for _, opt := range opts {
		opt(p)
	}
	err = p.parseCmd()
	return
}

func (p *parser) parseCmd() error {
	if p.cmd == nil {
		return nil
	}
	s := reflect.ValueOf(p.cmd).Elem()
	if s.Kind() != reflect.Struct {
		return fmt.Errorf("expected struct got %s", s.Type())
	}
	return p.parseStruct(reflect.ValueOf(p.cmd).Elem(), nil)
}

// Positional arguments are marked per struct.
func (p *parser) parseStruct(st reflect.Value, path []flagNameComponent) (err error) {
	posStarted := false
	foreachStructField(st, func(f reflect.Value, sf reflect.StructField) (stop bool) {
		if !posStarted && f.Type() == reflect.TypeOf(StartPos{}) {
			posStarted = true
			return false
		}
		if f.Type() == reflect.TypeOf(ExcessArgs{}) {
			p.excess = f.Addr().Interface().(*ExcessArgs)
			return false
		}
		if sf.PkgPath != "" {
			return false
		}
		if p.excess != nil {
			err = ErrFieldsAfterExcessArgs
			return true
		}
		if canMarshal(f) {
			if posStarted {
				err = p.addPos(f, sf, path)
			} else {
				err = p.addFlag(f, sf, path)
				if err != nil {
					err = fmt.Errorf("error adding flag in %s: %s", st.Type(), err)
				}
			}
			return err != nil
		}
		if f.Kind() == reflect.Struct {
			if canMarshal(f.Addr()) {
				err = fmt.Errorf("field %q has type %s, did you mean to use %s?", sf.Name, f.Type(), f.Addr().Type())
				return true
			}
			err = p.parseStruct(f, append(path, structFieldFlagNameComponent(sf)))
			return err != nil
		}
		err = fmt.Errorf("field has bad type: %v", f.Type())
		return true
	})
	return
}

func newArg(v reflect.Value, sf reflect.StructField, name string) arg {
	return arg{
		arity: fieldArity(v, sf),
		value: v,
		name:  name,
		help:  sf.Tag.Get("help"),
	}
}

func (p *parser) addPos(f reflect.Value, sf reflect.StructField, path []flagNameComponent) error {
	p.posArgs = append(p.posArgs, newArg(f, sf, strings.ToUpper(xstrings.ToSnakeCase(sf.Name))))
	return nil
}

func flagName(comps []flagNameComponent) string {
	var ss []string
	slices.MakeInto(&ss, comps)
	return strings.Join(ss, ".")
}

func (p *parser) addFlag(f reflect.Value, sf reflect.StructField, path []flagNameComponent) error {
	name := flagName(append(path, structFieldFlagNameComponent(sf)))
	if _, ok := p.flags[name]; ok {
		return fmt.Errorf("flag %q defined more than once", name)
	}
	if p.flags == nil {
		p.flags = make(map[string]arg)
	}
	p.flags[name] = newArg(f, sf, name)
	return nil
}

func isFlag(arg string) bool {
	return len(arg) > 1 && arg[0] == '-'
}

func (p *parser) parseFlag(s string) error {
	i := strings.IndexByte(s, '=')
	k := s
	v := ""
	if i != -1 {
		k = s[:i]
		v = s[i+1:]
	}
	flag, ok := p.flags[k]
	if !ok {
		if (k == "help" || k == "h") && !p.noDefaultHelp {
			return ErrDefaultHelp
		}
		return userError{fmt.Sprintf("unknown flag: %q", k)}
	}
	err := flag.marshal(v, i != -1)
	if err != nil {
		return fmt.Errorf("error setting flag %q: %s", k, err)
	}
	return nil
}

func (p *parser) indexPosArg(i int) *arg {
	for _, arg := range p.posArgs {
		if i < arg.arity.max {
			return &arg
		}
		i -= arg.arity.max
	}
	return nil
}

func (p *parser) nextPosArg() *arg {
	return p.indexPosArg(p.numPos)
}

func (p *parser) parsePos(s string) (err error) {
	arg := p.nextPosArg()
	if arg == nil {
		return userError{fmt.Sprintf("excess argument: %q", s)}
	}
	err = arg.marshal(s, true)
	if err != nil {
		return
	}
	p.numPos++
	return
}

type flagNameComponent string

func structFieldFlagNameComponent(sf reflect.StructField) flagNameComponent {
	name := sf.Tag.Get("name")
	if name != "" {
		return flagNameComponent(name)
	}
	return fieldFlagName(sf.Name)
}

// Gets the reflect.Value for the nth positional argument.
func posIndexValue(v reflect.Value, _i int) (ret reflect.Value, i int) {
	i = _i
	log.Println("posIndexValue", v.Type(), i)
	switch v.Kind() {
	case reflect.Ptr:
		return posIndexValue(v.Elem(), i)
	case reflect.Struct:
		posStarted := false
		foreachStructField(v, func(fv reflect.Value, sf reflect.StructField) bool {
			log.Println("posIndexValue struct field", fv, sf)
			if !posStarted {
				if fv.Type() == reflect.TypeOf(StartPos{}) {
					// log.Println("posStarted")
					posStarted = true
				}
				return true
			}
			ret, i = posIndexValue(fv, i)
			if ret.IsValid() {
				return false
			}
			return true
		})
		return
	case reflect.Slice:
		ret = v
		return
	default:
		if i == 0 {
			ret = v
			return
		}
		i--
		return
	}
}

func (p *parser) posWithHelp() (ret []arg) {
	for _, a := range p.posArgs {
		if a.help != "" {
			ret = append(ret, a)
		}
	}
	return
}