123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- package critbitgo
- import (
- "net"
- )
- var (
- mask32 = net.IPMask{0xff, 0xff, 0xff, 0xff}
- mask128 = net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
- )
- // IP routing table.
- type Net struct {
- trie *Trie
- }
- // Add a route.
- // If `r` is not IPv4/IPv6 network, returns an error.
- func (n *Net) Add(r *net.IPNet, value interface{}) (err error) {
- var ip net.IP
- if ip, _, err = netValidateIPNet(r); err == nil {
- n.trie.Set(netIPNetToKey(ip, r.Mask), value)
- }
- return
- }
- // Add a route.
- // If `s` is not CIDR notation, returns an error.
- func (n *Net) AddCIDR(s string, value interface{}) (err error) {
- var r *net.IPNet
- if _, r, err = net.ParseCIDR(s); err == nil {
- n.Add(r, value)
- }
- return
- }
- // Delete a specific route.
- // If `r` is not IP4/IPv6 network or a route is not found, `ok` is false.
- func (n *Net) Delete(r *net.IPNet) (value interface{}, ok bool, err error) {
- var ip net.IP
- if ip, _, err = netValidateIPNet(r); err == nil {
- value, ok = n.trie.Delete(netIPNetToKey(ip, r.Mask))
- }
- return
- }
- // Delete a specific route.
- // If `s` is not CIDR notation or a route is not found, `ok` is false.
- func (n *Net) DeleteCIDR(s string) (value interface{}, ok bool, err error) {
- var r *net.IPNet
- if _, r, err = net.ParseCIDR(s); err == nil {
- value, ok, err = n.Delete(r)
- }
- return
- }
- // Get a specific route.
- // If `r` is not IPv4/IPv6 network or a route is not found, `ok` is false.
- func (n *Net) Get(r *net.IPNet) (value interface{}, ok bool, err error) {
- var ip net.IP
- if ip, _, err = netValidateIPNet(r); err == nil {
- value, ok = n.trie.Get(netIPNetToKey(ip, r.Mask))
- }
- return
- }
- // Get a specific route.
- // If `s` is not CIDR notation or a route is not found, `ok` is false.
- func (n *Net) GetCIDR(s string) (value interface{}, ok bool, err error) {
- var r *net.IPNet
- if _, r, err = net.ParseCIDR(s); err == nil {
- value, ok, err = n.Get(r)
- }
- return
- }
- // Return a specific route by using the longest prefix matching.
- // If `r` is not IPv4/IPv6 network or a route is not found, `route` is nil.
- func (n *Net) Match(r *net.IPNet) (route *net.IPNet, value interface{}, err error) {
- var ip net.IP
- if ip, _, err = netValidateIP(r.IP); err == nil {
- if k, v := n.match(netIPNetToKey(ip, r.Mask)); k != nil {
- route = netKeyToIPNet(k)
- value = v
- }
- }
- return
- }
- // Return a specific route by using the longest prefix matching.
- // If `s` is not CIDR notation, or a route is not found, `route` is nil.
- func (n *Net) MatchCIDR(s string) (route *net.IPNet, value interface{}, err error) {
- var r *net.IPNet
- if _, r, err = net.ParseCIDR(s); err == nil {
- route, value, err = n.Match(r)
- }
- return
- }
- // Return a bool indicating whether a route would be found
- func (n *Net) ContainedIP(ip net.IP) (contained bool, err error) {
- k, _, err := n.matchIP(ip)
- contained = k != nil
- return
- }
- // Return a specific route by using the longest prefix matching.
- // If `ip` is invalid IP, or a route is not found, `route` is nil.
- func (n *Net) MatchIP(ip net.IP) (route *net.IPNet, value interface{}, err error) {
- k, v, err := n.matchIP(ip)
- if k != nil {
- route = netKeyToIPNet(k)
- value = v
- }
- return
- }
- func (n *Net) matchIP(ip net.IP) (k []byte, v interface{}, err error) {
- var isV4 bool
- ip, isV4, err = netValidateIP(ip)
- if err != nil {
- return
- }
- var mask net.IPMask
- if isV4 {
- mask = mask32
- } else {
- mask = mask128
- }
- k, v = n.match(netIPNetToKey(ip, mask))
- return
- }
- func (n *Net) match(key []byte) ([]byte, interface{}) {
- if n.trie.size > 0 {
- if node := lookup(&n.trie.root, key, false); node != nil {
- return node.external.key, node.external.value
- }
- }
- return nil, nil
- }
- func lookup(p *node, key []byte, backtracking bool) *node {
- if p.internal != nil {
- var direction int
- if p.internal.offset == len(key)-1 {
- // selecting the larger side when comparing the mask
- direction = 1
- } else if backtracking {
- direction = 0
- } else {
- direction = p.internal.direction(key)
- }
- if c := lookup(&p.internal.child[direction], key, backtracking); c != nil {
- return c
- }
- if direction == 1 {
- // search other node
- return lookup(&p.internal.child[0], key, true)
- }
- return nil
- } else {
- nlen := len(p.external.key)
- if nlen != len(key) {
- return nil
- }
- // check mask
- mask := p.external.key[nlen-1]
- if mask > key[nlen-1] {
- return nil
- }
- // compare both keys with mask
- div := int(mask >> 3)
- for i := 0; i < div; i++ {
- if p.external.key[i] != key[i] {
- return nil
- }
- }
- if mod := uint(mask & 0x07); mod > 0 {
- bit := 8 - mod
- if p.external.key[div] != key[div]&(0xff>>bit<<bit) {
- return nil
- }
- }
- return p
- }
- }
- // Deletes all routes.
- func (n *Net) Clear() {
- n.trie.Clear()
- }
- // Returns number of routes.
- func (n *Net) Size() int {
- return n.trie.Size()
- }
- // Create IP routing table
- func NewNet() *Net {
- return &Net{NewTrie()}
- }
- func netValidateIP(ip net.IP) (nIP net.IP, isV4 bool, err error) {
- if v4 := ip.To4(); v4 != nil {
- nIP = v4
- isV4 = true
- } else if ip.To16() != nil {
- nIP = ip
- } else {
- err = &net.AddrError{Err: "Invalid IP address", Addr: ip.String()}
- }
- return
- }
- func netValidateIPNet(r *net.IPNet) (nIP net.IP, isV4 bool, err error) {
- if r == nil {
- err = &net.AddrError{Err: "IP network is nil"}
- return
- }
- return netValidateIP(r.IP)
- }
- func netIPNetToKey(ip net.IP, mask net.IPMask) []byte {
- // +--------------+------+
- // | ip address.. | mask |
- // +--------------+------+
- ones, _ := mask.Size()
- return append(ip, byte(ones))
- }
- func netKeyToIPNet(k []byte) *net.IPNet {
- iplen := len(k) - 1
- return &net.IPNet{
- IP: net.IP(k[:iplen]),
- Mask: net.CIDRMask(int(k[iplen]), iplen*8),
- }
- }
|