diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..4f2ee4d --- /dev/null +++ b/.travis.yml @@ -0,0 +1 @@ +language: go diff --git a/reflectwalk.go b/reflectwalk.go index 1f20665..ec0a623 100644 --- a/reflectwalk.go +++ b/reflectwalk.go @@ -5,6 +5,7 @@ package reflectwalk import ( + "errors" "reflect" ) @@ -18,6 +19,12 @@ Primitive(reflect.Value) error } +// InterfaceWalker implementations are able to handle interface values as they +// are encountered during the walk. +type InterfaceWalker interface { + Interface(reflect.Value) error +} + // MapWalker implementations are able to handle individual elements // found within a map structure. type MapWalker interface { @@ -54,6 +61,13 @@ PointerEnter(bool) error PointerExit(bool) error } + +// SkipEntry can be returned from walk functions to skip walking +// the value of this field. This is only valid in the following functions: +// +// - StructField: skips walking the struct value +// +var SkipEntry = errors.New("skip this entry") // Walk takes an arbitrary value and an interface and traverses the // value, calling callbacks on the interface if they are supported. @@ -79,23 +93,63 @@ func walk(v reflect.Value, w interface{}) (err error) { // Determine if we're receiving a pointer and if so notify the walker. + // The logic here is convoluted but very important (tests will fail if + // almost any part is changed). I will try to explain here. + // + // First, we check if the value is an interface, if so, we really need + // to check the interface's VALUE to see whether it is a pointer. + // + // Check whether the value is then a pointer. If so, then set pointer + // to true to notify the user. + // + // If we still have a pointer or an interface after the indirections, then + // we unwrap another level + // + // At this time, we also set "v" to be the dereferenced value. This is + // because once we've unwrapped the pointer we want to use that value. pointer := false - if v.Kind() == reflect.Ptr { - pointer = true - v = reflect.Indirect(v) - } - if pw, ok := w.(PointerWalker); ok { - if err = pw.PointerEnter(pointer); err != nil { - return - } - - defer func() { - if err != nil { + pointerV := v + + for { + if pointerV.Kind() == reflect.Interface { + if iw, ok := w.(InterfaceWalker); ok { + if err = iw.Interface(pointerV); err != nil { + return + } + } + + pointerV = pointerV.Elem() + } + + if pointerV.Kind() == reflect.Ptr { + pointer = true + v = reflect.Indirect(pointerV) + } + if pw, ok := w.(PointerWalker); ok { + if err = pw.PointerEnter(pointer); err != nil { return } - err = pw.PointerExit(pointer) - }() + defer func(pointer bool) { + if err != nil { + return + } + + err = pw.PointerExit(pointer) + }(pointer) + } + + if pointer { + pointerV = v + } + pointer = false + + // If we still have a pointer or interface we have to indirect another level. + switch pointerV.Kind() { + case reflect.Ptr, reflect.Interface: + continue + } + break } // We preserve the original value here because if it is an interface @@ -251,6 +305,12 @@ if sw, ok := w.(StructWalker); ok { err = sw.StructField(sf, f) + + // SkipEntry just pretends this field doesn't even exist + if err == SkipEntry { + continue + } + if err != nil { return } diff --git a/reflectwalk_test.go b/reflectwalk_test.go index 4ec1066..0ec68ae 100644 --- a/reflectwalk_test.go +++ b/reflectwalk_test.go @@ -1,6 +1,7 @@ package reflectwalk import ( + "fmt" "reflect" "testing" ) @@ -24,15 +25,27 @@ } type TestPointerWalker struct { - Ps []bool + pointers []bool + count int + enters int + exits int } func (t *TestPointerWalker) PointerEnter(v bool) error { - t.Ps = append(t.Ps, v) + t.pointers = append(t.pointers, v) + t.enters++ + if v { + t.count++ + } return nil } func (t *TestPointerWalker) PointerExit(v bool) error { + t.exits++ + if t.pointers[len(t.pointers)-1] != v { + return fmt.Errorf("bad pointer exit '%t' at exit %d", v, t.exits) + } + t.pointers = t.pointers[:len(t.pointers)-1] return nil } @@ -65,8 +78,8 @@ type TestMapWalker struct { MapVal reflect.Value - Keys []string - Values []string + Keys map[string]bool + Values map[string]bool } func (t *TestMapWalker) Map(m reflect.Value) error { @@ -76,12 +89,12 @@ func (t *TestMapWalker) MapElem(m, k, v reflect.Value) error { if t.Keys == nil { - t.Keys = make([]string, 0, 1) - t.Values = make([]string, 0, 1) - } - - t.Keys = append(t.Keys, k.Interface().(string)) - t.Values = append(t.Values, v.Interface().(string)) + t.Keys = make(map[string]bool) + t.Values = make(map[string]bool) + } + + t.Keys[k.Interface().(string)] = true + t.Values[v.Interface().(string)] = true return nil } @@ -189,6 +202,23 @@ } if data.Bar[0].([]string)[0] != "bar" { t.Fatalf("bad: %#v", data.Bar) + } +} + +func TestWalk_Basic_ReplaceInterface(t *testing.T) { + w := new(TestPrimitiveReplaceWalker) + + type S struct { + Foo []interface{} + } + + data := &S{ + Foo: []interface{}{"foo"}, + } + + err := Walk(data, w) + if err != nil { + t.Fatalf("err: %s", err) } } @@ -294,12 +324,12 @@ t.Fatalf("Bad: %#v", w.MapVal.Interface()) } - expectedK := []string{"foo", "bar"} + expectedK := map[string]bool{"foo": true, "bar": true} if !reflect.DeepEqual(w.Keys, expectedK) { t.Fatalf("Bad keys: %#v", w.Keys) } - expectedV := []string{"foov", "barv"} + expectedV := map[string]bool{"foov": true, "barv": true} if !reflect.DeepEqual(w.Values, expectedV) { t.Fatalf("Bad values: %#v", w.Values) } @@ -310,20 +340,57 @@ type S struct { Foo string - } - - data := &S{ - Foo: "foo", - } - - err := Walk(data, w) - if err != nil { - t.Fatalf("err: %s", err) - } - - expected := []bool{true, false} - if !reflect.DeepEqual(w.Ps, expected) { - t.Fatalf("bad: %#v", w.Ps) + Bar *string + Baz **string + } + + s := "" + sp := &s + + data := &S{ + Baz: &sp, + } + + err := Walk(data, w) + if err != nil { + t.Fatalf("err: %s", err) + } + + if w.enters != 5 { + t.Fatal("expected 4 values, saw", w.enters) + } + + if w.count != 4 { + t.Fatal("exptec 3 pointers, saw", w.count) + } + + if w.exits != w.enters { + t.Fatalf("number of enters (%d) and exits (%d) don't match", w.enters, w.exits) + } +} + +func TestWalk_PointerPointer(t *testing.T) { + w := new(TestPointerWalker) + + s := "" + sp := &s + pp := &sp + + err := Walk(pp, w) + if err != nil { + t.Fatalf("err: %s", err) + } + + if w.enters != 2 { + t.Fatal("expected 2 values, saw", w.enters) + } + + if w.count != 2 { + t.Fatal("expected 2 pointers, saw", w.count) + } + + if w.exits != w.enters { + t.Fatalf("number of enters (%d) and exits (%d) don't match", w.enters, w.exits) } } @@ -352,12 +419,86 @@ } } +func TestWalk_SliceWithPtr(t *testing.T) { + w := new(TestSliceWalker) + + // This is key, the panic only happened when the slice field was + // an interface! + type I interface{} + + type S struct { + Foo []I + } + + type Empty struct{} + + data := &S{ + Foo: []I{&Empty{}}, + } + + err := Walk(data, w) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !reflect.DeepEqual(w.SliceVal.Interface(), data.Foo) { + t.Fatalf("bad: %#v", w.SliceVal.Interface()) + } + + if w.Count != 1 { + t.Fatalf("Bad count: %d", w.Count) + } +} + +type testErr struct{} + +func (t *testErr) Error() string { + return "test error" +} + func TestWalk_Struct(t *testing.T) { w := new(TestStructWalker) + // This makes sure we can also walk over pointer-to-pointers, and the ever + // so rare pointer-to-interface + type S struct { + Foo string + Bar *string + Baz **string + Err *error + } + + bar := "ptr" + baz := &bar + e := error(&testErr{}) + + data := &S{ + Foo: "foo", + Bar: &bar, + Baz: &baz, + Err: &e, + } + + err := Walk(data, w) + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := []string{"Foo", "Bar", "Baz", "Err"} + if !reflect.DeepEqual(w.Fields, expected) { + t.Fatalf("bad: %#v", w.Fields) + } +} + +// Very similar to above test but used to fail for #2, copied here for +// regression testing +func TestWalk_StructWithPtr(t *testing.T) { + w := new(TestStructWalker) + type S struct { Foo string Bar string + Baz *int } data := &S{ @@ -370,8 +511,130 @@ t.Fatalf("err: %s", err) } - expected := []string{"Foo", "Bar"} + expected := []string{"Foo", "Bar", "Baz"} if !reflect.DeepEqual(w.Fields, expected) { t.Fatalf("bad: %#v", w.Fields) } } + +type TestInterfaceMapWalker struct { + MapVal reflect.Value + Keys map[string]bool + Values map[interface{}]bool +} + +func (t *TestInterfaceMapWalker) Map(m reflect.Value) error { + t.MapVal = m + return nil +} + +func (t *TestInterfaceMapWalker) MapElem(m, k, v reflect.Value) error { + if t.Keys == nil { + t.Keys = make(map[string]bool) + t.Values = make(map[interface{}]bool) + } + + t.Keys[k.Interface().(string)] = true + t.Values[v.Interface()] = true + return nil +} + +func TestWalk_MapWithPointers(t *testing.T) { + w := new(TestInterfaceMapWalker) + + type S struct { + Foo map[string]interface{} + } + + a := "a" + b := "b" + + data := &S{ + Foo: map[string]interface{}{ + "foo": &a, + "bar": &b, + "baz": 11, + "zab": (*int)(nil), + }, + } + + err := Walk(data, w) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !reflect.DeepEqual(w.MapVal.Interface(), data.Foo) { + t.Fatalf("Bad: %#v", w.MapVal.Interface()) + } + + expectedK := map[string]bool{"foo": true, "bar": true, "baz": true, "zab": true} + if !reflect.DeepEqual(w.Keys, expectedK) { + t.Fatalf("Bad keys: %#v", w.Keys) + } + + expectedV := map[interface{}]bool{&a: true, &b: true, 11: true, (*int)(nil): true} + if !reflect.DeepEqual(w.Values, expectedV) { + t.Fatalf("Bad values: %#v", w.Values) + } +} + +type TestStructWalker_fieldSkip struct { + Skip bool + Fields int +} + +func (t *TestStructWalker_fieldSkip) Enter(l Location) error { + if l == StructField { + t.Fields++ + } + + return nil +} + +func (t *TestStructWalker_fieldSkip) Exit(Location) error { + return nil +} + +func (t *TestStructWalker_fieldSkip) Struct(v reflect.Value) error { + return nil +} + +func (t *TestStructWalker_fieldSkip) StructField(sf reflect.StructField, v reflect.Value) error { + if t.Skip && sf.Name[0] == '_' { + return SkipEntry + } + + return nil +} + +func TestWalk_StructWithSkipEntry(t *testing.T) { + data := &struct { + Foo, _Bar int + }{ + Foo: 1, + _Bar: 2, + } + + { + var s TestStructWalker_fieldSkip + if err := Walk(data, &s); err != nil { + t.Fatalf("err: %s", err) + } + + if s.Fields != 2 { + t.Fatalf("bad: %d", s.Fields) + } + } + + { + var s TestStructWalker_fieldSkip + s.Skip = true + if err := Walk(data, &s); err != nil { + t.Fatalf("err: %s", err) + } + + if s.Fields != 1 { + t.Fatalf("bad: %d", s.Fields) + } + } +}