package cidranger
import (
"net"
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
func TestInsert(t *testing.T) {
ranger := newBruteRanger().(*bruteRanger)
_, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
_, networkIPv6, _ := net.ParseCIDR("8000::/96")
entryIPv4 := NewBasicRangerEntry(*networkIPv4)
entryIPv6 := NewBasicRangerEntry(*networkIPv6)
ranger.Insert(entryIPv4)
ranger.Insert(entryIPv6)
assert.Equal(t, 1, len(ranger.ipV4Entries))
assert.Equal(t, entryIPv4, ranger.ipV4Entries["0.0.1.0/24"])
assert.Equal(t, 1, len(ranger.ipV6Entries))
assert.Equal(t, entryIPv6, ranger.ipV6Entries["8000::/96"])
}
func TestInsertError(t *testing.T) {
bRanger := newBruteRanger().(*bruteRanger)
_, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
networkIPv4.IP = append(networkIPv4.IP, byte(4))
err := bRanger.Insert(NewBasicRangerEntry(*networkIPv4))
assert.Equal(t, ErrInvalidNetworkInput, err)
}
func TestRemove(t *testing.T) {
ranger := newBruteRanger().(*bruteRanger)
_, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
_, networkIPv6, _ := net.ParseCIDR("8000::/96")
_, notInserted, _ := net.ParseCIDR("8000::/96")
insertIPv4 := NewBasicRangerEntry(*networkIPv4)
insertIPv6 := NewBasicRangerEntry(*networkIPv6)
ranger.Insert(insertIPv4)
deletedIPv4, err := ranger.Remove(*networkIPv4)
assert.NoError(t, err)
ranger.Insert(insertIPv6)
deletedIPv6, err := ranger.Remove(*networkIPv6)
assert.NoError(t, err)
entry, err := ranger.Remove(*notInserted)
assert.NoError(t, err)
assert.Nil(t, entry)
assert.Equal(t, insertIPv4, deletedIPv4)
assert.Equal(t, 0, len(ranger.ipV4Entries))
assert.Equal(t, insertIPv6, deletedIPv6)
assert.Equal(t, 0, len(ranger.ipV6Entries))
}
func TestRemoveError(t *testing.T) {
r := newBruteRanger().(*bruteRanger)
_, invalidNetwork, _ := net.ParseCIDR("0.0.1.0/24")
invalidNetwork.IP = append(invalidNetwork.IP, byte(4))
_, err := r.Remove(*invalidNetwork)
assert.Equal(t, ErrInvalidNetworkInput, err)
}
func TestContains(t *testing.T) {
r := newBruteRanger().(*bruteRanger)
_, network, _ := net.ParseCIDR("0.0.1.0/24")
_, network1, _ := net.ParseCIDR("8000::/112")
r.Insert(NewBasicRangerEntry(*network))
r.Insert(NewBasicRangerEntry(*network1))
cases := []struct {
ip net.IP
contains bool
err error
name string
}{
{net.ParseIP("0.0.1.255"), true, nil, "IPv4 should contain"},
{net.ParseIP("0.0.0.255"), false, nil, "IPv4 houldn't contain"},
{net.ParseIP("8000::ffff"), true, nil, "IPv6 shouldn't contain"},
{net.ParseIP("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"},
{append(net.ParseIP("8000::1:ffff"), byte(0)), false, ErrInvalidNetworkInput, "Invalid IP"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
contains, err := r.Contains(tc.ip)
if tc.err != nil {
assert.Equal(t, tc.err, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.contains, contains)
}
})
}
}
func TestContainingNetworks(t *testing.T) {
r := newBruteRanger().(*bruteRanger)
_, network1, _ := net.ParseCIDR("0.0.1.0/24")
_, network2, _ := net.ParseCIDR("0.0.1.0/25")
_, network3, _ := net.ParseCIDR("8000::/112")
_, network4, _ := net.ParseCIDR("8000::/113")
entry1 := NewBasicRangerEntry(*network1)
entry2 := NewBasicRangerEntry(*network2)
entry3 := NewBasicRangerEntry(*network3)
entry4 := NewBasicRangerEntry(*network4)
r.Insert(entry1)
r.Insert(entry2)
r.Insert(entry3)
r.Insert(entry4)
cases := []struct {
ip net.IP
containingNetworks []RangerEntry
err error
name string
}{
{net.ParseIP("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"},
{net.ParseIP("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"},
{net.ParseIP("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"},
{net.ParseIP("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"},
{net.ParseIP("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"},
{net.ParseIP("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"},
{append(net.ParseIP("8000::1:7fff"), byte(0)), nil, ErrInvalidNetworkInput, "Invalid IP"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
networks, err := r.ContainingNetworks(tc.ip)
if tc.err != nil {
assert.Equal(t, tc.err, err)
} else {
assert.NoError(t, err)
assert.Equal(t, len(tc.containingNetworks), len(networks))
for _, network := range tc.containingNetworks {
assert.Contains(t, networks, network)
}
}
})
}
}
func TestCoveredNetworks(t *testing.T) {
for _, tc := range coveredNetworkTests {
t.Run(tc.name, func(t *testing.T) {
ranger := newBruteRanger()
for _, insert := range tc.inserts {
_, network, _ := net.ParseCIDR(insert)
err := ranger.Insert(NewBasicRangerEntry(*network))
assert.NoError(t, err)
}
var expectedEntries []string
for _, network := range tc.networks {
expectedEntries = append(expectedEntries, network)
}
sort.Strings(expectedEntries)
_, snet, _ := net.ParseCIDR(tc.search)
networks, err := ranger.CoveredNetworks(*snet)
assert.NoError(t, err)
var results []string
for _, result := range networks {
net := result.Network()
results = append(results, net.String())
}
sort.Strings(results)
assert.Equal(t, expectedEntries, results)
})
}
}