Skip to content

Commit

Permalink
add rvlookup middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Oct 8, 2024
1 parent c2212e0 commit e3667da
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
1 change: 1 addition & 0 deletions app/router/middleware/all.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ package middleware

import (
_ "github.com/IrineSistiana/mosproxy/app/router/middleware/limit"
_ "github.com/IrineSistiana/mosproxy/app/router/middleware/rvlookup"
)
201 changes: 201 additions & 0 deletions app/router/middleware/rvlookup/rvlookup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package rvlookup

import (
"context"
"errors"
"fmt"
"net/netip"
"strconv"
"time"
"unsafe"

"github.com/IrineSistiana/mosproxy/app/router"
"github.com/IrineSistiana/mosproxy/internal/utils"
"github.com/IrineSistiana/mosproxy/pkg/dnsmsg"
"github.com/maypok86/otter"
"github.com/prometheus/client_golang/prometheus"
)

type Args struct {
CacheSize int `yaml:"cache_size"`
CacheTtl int `yaml:"cache_ttl"`
}

const (
mwName = "rvlookup"
)

func init() {
router.RegMiddleware(mwName, NewRvLookup)
}

type RvLookup struct {
ctx router.PluginCtx
next router.Middleware
args Args

cache otter.Cache[netip.Addr, string]
cacheSize prometheus.GaugeFunc
strCache otter.Cache[string, string] // cache for same domain strings to save memory
}

func NewRvLookup(ctx router.PluginCtx, args map[string]any, next router.Middleware) (router.Middleware, error) {
a := Args{}
err := router.WakeDecode(&a, args, "yaml")
if err != nil {
return nil, fmt.Errorf("invalid args, %w", err)
}

h := &RvLookup{
ctx: ctx,
next: next,
args: a,
}

b, err := otter.NewBuilder[netip.Addr, string](a.CacheSize)
if err != nil {
return nil, err
}
c, err := b.WithTTL(time.Duration(a.CacheTtl) * time.Second).
Cost(func(key netip.Addr, value string) uint32 {
return uint32(unsafe.Sizeof(key)) + uint32(unsafe.Sizeof(value)) + uint32(len(value))
}).
Build()
if err != nil {
return nil, err
}
h.cache = c

sb, err := otter.NewBuilder[string, string](a.CacheSize)
if err != nil {
return nil, err
}
sc, err := sb.WithTTL(time.Duration(a.CacheTtl) * time.Second).
Cost(func(key, value string) uint32 { return uint32(unsafe.Sizeof(value)) + uint32(len(value)) }).
Build()
if err != nil {
return nil, err
}
h.strCache = sc

h.cacheSize = prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "cache_size",
Help: "The number of ptr (ip name pair) that is cached",
}, func() float64 { return float64(c.Size()) })
return h, nil
}

func (h *RvLookup) Handle(ctx context.Context, q *router.QueryCtx) {
qq := q.Question
if qq.Type == dnsmsg.TypePTR && qq.Class == dnsmsg.ClassINET {
rr := h.lookup(qq.Name)
if rr != nil {
resp := dnsmsg.NewMsg()
resp.Questions = append(resp.Questions, q.Question.Copy())
resp.Answers = append(resp.Answers, rr)
q.SetRespFrom(resp, mwName)
return
}
}
h.next.Handle(ctx, q)

if resp := q.Resp(); resp != nil && qq.Class == dnsmsg.ClassINET && (qq.Type == dnsmsg.TypeA || qq.Type == dnsmsg.TypeAAAA) {
for _, rr := range resp.Answers {
switch rr := rr.(type) {
case *dnsmsg.A:
h.saveNameAddr(rr.Name, netip.AddrFrom4(rr.A))
case *dnsmsg.AAAA:
h.saveNameAddr(rr.Name, netip.AddrFrom16(rr.AAAA).Unmap())
}
}
}
}

func (h *RvLookup) lookup(n dnsmsg.Name) *dnsmsg.NAMEResource {
addr, err := parsePtr(n)
if err != nil {
return nil
}
addr = addr.Unmap()

raw, ok := h.cache.Get(addr)
if ok {
ptr := dnsmsg.NewNAME()
ptr.Name.CopyFrom(n)
ptr.Class = dnsmsg.ClassINET
ptr.Type = dnsmsg.TypePTR
ptr.TTL = 10
err := dnsmsg.ParseNameRaw(&ptr.NameData, raw)
if err != nil {
h.ctx.Logger.Err(err).Msg("internel err: invalid raw name")
return nil
}
return ptr
}
return nil
}

var (
errInvalidPtrZone = errors.New("invalid ptr zone")
errInvalidLabelLen = errors.New("invalid label len")
)

func parsePtr(n dnsmsg.Name) (netip.Addr, error) {
labels := n.Labels()
if len(labels) < 2 {
return netip.Addr{}, errInvalidLabelLen
}

if string(labels[len(labels)-1]) != "arpa" {
return netip.Addr{}, errInvalidPtrZone
}

switch string(labels[len(labels)-2]) {
case "in-addr":
return parseIpv4(labels[:len(labels)-2])
case "ip6":
return parseIpv16(labels[:len(labels)-2])
default:
return netip.Addr{}, errInvalidPtrZone
}
}

func parseIpv4(labels [][]byte) (netip.Addr, error) {
if len(labels) != 4 {
return netip.Addr{}, errInvalidLabelLen
}
var buf [4]byte
for i := 0; i < 4; i++ {
label := labels[3-i]
u, err := strconv.ParseUint(utils.Bytes2StrUnsafe(label), 10, 8)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid label %s: %s", label, err)
}
buf[i] = byte(u)
}
return netip.AddrFrom4(buf), nil
}

func parseIpv16(labels [][]byte) (netip.Addr, error) {
if len(labels) != 32 {
return netip.Addr{}, errInvalidLabelLen
}
var buf [16]byte
for i := 0; i < 32; i++ {
label := labels[31-i]
u, err := strconv.ParseUint(utils.Bytes2StrUnsafe(label), 16, 4)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid label %s: %s", label, err)
}
buf[(i / 2)] += byte(u) << ((1 - i%2) * 4)
}
return netip.AddrFrom16(buf), nil
}

func (h *RvLookup) saveNameAddr(name dnsmsg.Name, addr netip.Addr) {
s, ok := h.strCache.Get(utils.Bytes2StrUnsafe(name.Data()))
if !ok { // concurrent set is possible, but it's ok
s = string(name.Data())
}
h.cache.Set(addr, s)
}
46 changes: 46 additions & 0 deletions app/router/middleware/rvlookup/rvlookup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package rvlookup

import (
"net/netip"
"testing"

"github.com/IrineSistiana/mosproxy/pkg/dnsmsg"
"github.com/stretchr/testify/require"
)

func Test_parsePtr(t *testing.T) {
r := require.New(t)
testFn := func(n string, want string) {
var name dnsmsg.Name
err := name.Parse(n)
r.NoError(err)
wantAddr, err := netip.ParseAddr(want)
r.NoError(err)

addr, err := parsePtr(name)
r.NoError(err)
r.Truef(addr == wantAddr, "got=%s, want=%s", addr, want)
}

testFn("0.0.0.0.in-addr.arpa", "0.0.0.0")
testFn("4.4.8.8.in-addr.arpa", "8.8.4.4")
testFn("b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa", "2001:db8::567:89ab")

testErrFn := func(n string) {
var name dnsmsg.Name
err := name.Parse(n)
r.NoError(err)
addr, err := parsePtr(name)
r.Error(err)
r.False(addr.IsValid())
}

testErrFn("x.x.x")
testErrFn("0.0.0.0.in-addr.arpa.xxx") // invalid domain
testErrFn("0.0.0.0.in-addrxxx.arpa") // invalid domain
testErrFn("0.0.0.in-addr.arpa") // 3 labels
testErrFn("0.0.0.0.0.in-addr.arpa") // 5 labels
testErrFn("b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6_xxx.arpa") // invalid domain
testErrFn("9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa") // 30 labels
testErrFn("0.0.b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa") // 34 labels
}

0 comments on commit e3667da

Please sign in to comment.