Codebase list golang-github-bruth-assert / HEAD assert.go
HEAD

Tree @HEAD (Download .tar.gz)

assert.go @HEADraw · history · blame

package assert

import (
    "fmt"
    "github.com/kr/pretty"
    "reflect"
    "runtime"
    "strings"
    "testing"
)

var errorPrefix = "! "

// -- Assertion handlers

func assert(t *testing.T, success bool, f func(), callDepth int) {
    if !success {
        _, file, line, _ := runtime.Caller(callDepth + 1)
        t.Errorf("%s:%d", file, line)
        f()
        t.FailNow()
    }
}

func equal(t *testing.T, expected, got interface{}, callDepth int, messages ...interface{}) {
    fn := func() {
        for _, desc := range pretty.Diff(expected, got) {
            t.Error(errorPrefix, desc)
        }
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, isEqual(expected, got), fn, callDepth+1)
}

func notEqual(t *testing.T, expected, got interface{}, callDepth int, messages ...interface{}) {
    fn := func() {
        t.Errorf("%s Unexpected: %#v", errorPrefix, got)
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, !isEqual(expected, got), fn, callDepth+1)
}

func contains(t *testing.T, expected, got string, callDepth int, messages ...interface{}) {
    fn := func() {
        t.Errorf("%s Expected to find: %#v", errorPrefix, expected)
        t.Errorf("%s in: %#v", errorPrefix, got)
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, strings.Contains(got, expected), fn, callDepth+1)
}

func notContains(t *testing.T, unexpected, got string, callDepth int, messages ...interface{}) {
    fn := func() {
        t.Errorf("%s Expected not to find: %#v", errorPrefix, unexpected)
        t.Errorf("%s in: %#v", errorPrefix, got)
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, !strings.Contains(got, unexpected), fn, callDepth+1)
}

// -- Matching

func isEqual(expected, got interface{}) bool {
    if expected == nil {
        return isNil(got)
    }
    return reflect.DeepEqual(expected, got)
}

func isNil(got interface{}) bool {
    if got == nil {
        return true
    }
    value := reflect.ValueOf(got)
    switch value.Kind() {
    case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
        return value.IsNil()
    }
    return false
}

// -- Public API

func Equal(t *testing.T, expected, got interface{}, messages ...interface{}) {
    equal(t, expected, got, 1, messages...)
}

func NotEqual(t *testing.T, expected, got interface{}, messages ...interface{}) {
    notEqual(t, expected, got, 1, messages...)
}

func True(t *testing.T, got interface{}, messages ...interface{}) {
    equal(t, true, got, 1, messages...)
}

func False(t *testing.T, got interface{}, messages ...interface{}) {
    equal(t, false, got, 1, messages...)
}

func Nil(t *testing.T, got interface{}, messages ...interface{}) {
    equal(t, nil, got, 1, messages...)
}

func NotNil(t *testing.T, got interface{}, messages ...interface{}) {
    notEqual(t, nil, got, 1, messages...)
}

func Contains(t *testing.T, expected, got string, messages ...interface{}) {
    contains(t, expected, got, 1, messages...)
}

func NotContains(t *testing.T, unexpected, got string, messages ...interface{}) {
    notContains(t, unexpected, got, 1, messages...)
}

func Panic(t *testing.T, err interface{}, fn func(), messages ...interface{}) {
    defer func() {
        equal(t, err, recover(), 3, messages...)
    }()
    fn()
}