net.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package critbitgo
  2. import (
  3. "net"
  4. )
  5. var (
  6. mask32 = net.IPMask{0xff, 0xff, 0xff, 0xff}
  7. mask128 = net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
  8. )
  9. // IP routing table.
  10. type Net struct {
  11. trie *Trie
  12. }
  13. // Add a route.
  14. // If `r` is not IPv4/IPv6 network, returns an error.
  15. func (n *Net) Add(r *net.IPNet, value interface{}) (err error) {
  16. var ip net.IP
  17. if ip, _, err = netValidateIPNet(r); err == nil {
  18. n.trie.Set(netIPNetToKey(ip, r.Mask), value)
  19. }
  20. return
  21. }
  22. // Add a route.
  23. // If `s` is not CIDR notation, returns an error.
  24. func (n *Net) AddCIDR(s string, value interface{}) (err error) {
  25. var r *net.IPNet
  26. if _, r, err = net.ParseCIDR(s); err == nil {
  27. n.Add(r, value)
  28. }
  29. return
  30. }
  31. // Delete a specific route.
  32. // If `r` is not IP4/IPv6 network or a route is not found, `ok` is false.
  33. func (n *Net) Delete(r *net.IPNet) (value interface{}, ok bool, err error) {
  34. var ip net.IP
  35. if ip, _, err = netValidateIPNet(r); err == nil {
  36. value, ok = n.trie.Delete(netIPNetToKey(ip, r.Mask))
  37. }
  38. return
  39. }
  40. // Delete a specific route.
  41. // If `s` is not CIDR notation or a route is not found, `ok` is false.
  42. func (n *Net) DeleteCIDR(s string) (value interface{}, ok bool, err error) {
  43. var r *net.IPNet
  44. if _, r, err = net.ParseCIDR(s); err == nil {
  45. value, ok, err = n.Delete(r)
  46. }
  47. return
  48. }
  49. // Get a specific route.
  50. // If `r` is not IPv4/IPv6 network or a route is not found, `ok` is false.
  51. func (n *Net) Get(r *net.IPNet) (value interface{}, ok bool, err error) {
  52. var ip net.IP
  53. if ip, _, err = netValidateIPNet(r); err == nil {
  54. value, ok = n.trie.Get(netIPNetToKey(ip, r.Mask))
  55. }
  56. return
  57. }
  58. // Get a specific route.
  59. // If `s` is not CIDR notation or a route is not found, `ok` is false.
  60. func (n *Net) GetCIDR(s string) (value interface{}, ok bool, err error) {
  61. var r *net.IPNet
  62. if _, r, err = net.ParseCIDR(s); err == nil {
  63. value, ok, err = n.Get(r)
  64. }
  65. return
  66. }
  67. // Return a specific route by using the longest prefix matching.
  68. // If `r` is not IPv4/IPv6 network or a route is not found, `route` is nil.
  69. func (n *Net) Match(r *net.IPNet) (route *net.IPNet, value interface{}, err error) {
  70. var ip net.IP
  71. if ip, _, err = netValidateIP(r.IP); err == nil {
  72. if k, v := n.match(netIPNetToKey(ip, r.Mask)); k != nil {
  73. route = netKeyToIPNet(k)
  74. value = v
  75. }
  76. }
  77. return
  78. }
  79. // Return a specific route by using the longest prefix matching.
  80. // If `s` is not CIDR notation, or a route is not found, `route` is nil.
  81. func (n *Net) MatchCIDR(s string) (route *net.IPNet, value interface{}, err error) {
  82. var r *net.IPNet
  83. if _, r, err = net.ParseCIDR(s); err == nil {
  84. route, value, err = n.Match(r)
  85. }
  86. return
  87. }
  88. // Return a bool indicating whether a route would be found
  89. func (n *Net) ContainedIP(ip net.IP) (contained bool, err error) {
  90. k, _, err := n.matchIP(ip)
  91. contained = k != nil
  92. return
  93. }
  94. // Return a specific route by using the longest prefix matching.
  95. // If `ip` is invalid IP, or a route is not found, `route` is nil.
  96. func (n *Net) MatchIP(ip net.IP) (route *net.IPNet, value interface{}, err error) {
  97. k, v, err := n.matchIP(ip)
  98. if k != nil {
  99. route = netKeyToIPNet(k)
  100. value = v
  101. }
  102. return
  103. }
  104. func (n *Net) matchIP(ip net.IP) (k []byte, v interface{}, err error) {
  105. var isV4 bool
  106. ip, isV4, err = netValidateIP(ip)
  107. if err != nil {
  108. return
  109. }
  110. var mask net.IPMask
  111. if isV4 {
  112. mask = mask32
  113. } else {
  114. mask = mask128
  115. }
  116. k, v = n.match(netIPNetToKey(ip, mask))
  117. return
  118. }
  119. func (n *Net) match(key []byte) ([]byte, interface{}) {
  120. if n.trie.size > 0 {
  121. if node := lookup(&n.trie.root, key, false); node != nil {
  122. return node.external.key, node.external.value
  123. }
  124. }
  125. return nil, nil
  126. }
  127. func lookup(p *node, key []byte, backtracking bool) *node {
  128. if p.internal != nil {
  129. var direction int
  130. if p.internal.offset == len(key)-1 {
  131. // selecting the larger side when comparing the mask
  132. direction = 1
  133. } else if backtracking {
  134. direction = 0
  135. } else {
  136. direction = p.internal.direction(key)
  137. }
  138. if c := lookup(&p.internal.child[direction], key, backtracking); c != nil {
  139. return c
  140. }
  141. if direction == 1 {
  142. // search other node
  143. return lookup(&p.internal.child[0], key, true)
  144. }
  145. return nil
  146. } else {
  147. nlen := len(p.external.key)
  148. if nlen != len(key) {
  149. return nil
  150. }
  151. // check mask
  152. mask := p.external.key[nlen-1]
  153. if mask > key[nlen-1] {
  154. return nil
  155. }
  156. // compare both keys with mask
  157. div := int(mask >> 3)
  158. for i := 0; i < div; i++ {
  159. if p.external.key[i] != key[i] {
  160. return nil
  161. }
  162. }
  163. if mod := uint(mask & 0x07); mod > 0 {
  164. bit := 8 - mod
  165. if p.external.key[div] != key[div]&(0xff>>bit<<bit) {
  166. return nil
  167. }
  168. }
  169. return p
  170. }
  171. }
  172. // Deletes all routes.
  173. func (n *Net) Clear() {
  174. n.trie.Clear()
  175. }
  176. // Returns number of routes.
  177. func (n *Net) Size() int {
  178. return n.trie.Size()
  179. }
  180. // Create IP routing table
  181. func NewNet() *Net {
  182. return &Net{NewTrie()}
  183. }
  184. func netValidateIP(ip net.IP) (nIP net.IP, isV4 bool, err error) {
  185. if v4 := ip.To4(); v4 != nil {
  186. nIP = v4
  187. isV4 = true
  188. } else if ip.To16() != nil {
  189. nIP = ip
  190. } else {
  191. err = &net.AddrError{Err: "Invalid IP address", Addr: ip.String()}
  192. }
  193. return
  194. }
  195. func netValidateIPNet(r *net.IPNet) (nIP net.IP, isV4 bool, err error) {
  196. if r == nil {
  197. err = &net.AddrError{Err: "IP network is nil"}
  198. return
  199. }
  200. return netValidateIP(r.IP)
  201. }
  202. func netIPNetToKey(ip net.IP, mask net.IPMask) []byte {
  203. // +--------------+------+
  204. // | ip address.. | mask |
  205. // +--------------+------+
  206. ones, _ := mask.Size()
  207. return append(ip, byte(ones))
  208. }
  209. func netKeyToIPNet(k []byte) *net.IPNet {
  210. iplen := len(k) - 1
  211. return &net.IPNet{
  212. IP: net.IP(k[:iplen]),
  213. Mask: net.CIDRMask(int(k[iplen]), iplen*8),
  214. }
  215. }