Skip to content

Commit

Permalink
Refine geodata decoder (#322)
Browse files Browse the repository at this point in the history
* Refactor: geodata loader
* Feat: add membench & benchmark for geodata decoder
* Feat: log router memory usage when initialization
  • Loading branch information
Loyalsoldier authored May 6, 2021
1 parent 9c437b9 commit a9ce6d4
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 43 deletions.
47 changes: 42 additions & 5 deletions common/geodata/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@ import (
"io/fs"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/common/geodata"
)

const (
geoipURL = "https://raw.githubusercontent.com/v2fly/geoip/release/geoip.dat"
geositeURL = "https://raw.githubusercontent.com/v2fly/domain-list-community/release/dlc.dat"
)

func init() {
const (
geoipURL = "https://raw.githubusercontent.com/v2fly/geoip/release/geoip.dat"
geositeURL = "https://raw.githubusercontent.com/v2fly/domain-list-community/release/dlc.dat"
)

wd, err := os.Getwd()
common.Must(err)

Expand Down Expand Up @@ -66,3 +67,39 @@ func TestDecodeGeoSite(t *testing.T) {
t.Errorf("failed to load geosite:test, expected: %v, got: %v", expected, result)
}
}

func BenchmarkLoadGeoIP(b *testing.B) {
m1 := runtime.MemStats{}
m2 := runtime.MemStats{}

loader := geodata.GetGeodataLoader()

runtime.ReadMemStats(&m1)
cn, _ := loader.LoadGeoIP("cn")
private, _ := loader.LoadGeoIP("private")
runtime.KeepAlive(cn)
runtime.KeepAlive(private)
runtime.ReadMemStats(&m2)

b.ReportMetric(float64(m2.Alloc-m1.Alloc)/1024, "KiB(GeoIP-Alloc)")
b.ReportMetric(float64(m2.TotalAlloc-m1.TotalAlloc)/1024/1024, "MiB(GeoIP-TotalAlloc)")
}

func BenchmarkLoadGeoSite(b *testing.B) {
m3 := runtime.MemStats{}
m4 := runtime.MemStats{}

loader := geodata.GetGeodataLoader()

runtime.ReadMemStats(&m3)
cn, _ := loader.LoadGeoSite("cn")
notcn, _ := loader.LoadGeoSite("geolocation-!cn")
private, _ := loader.LoadGeoSite("private")
runtime.KeepAlive(cn)
runtime.KeepAlive(notcn)
runtime.KeepAlive(private)
runtime.ReadMemStats(&m4)

b.ReportMetric(float64(m4.Alloc-m3.Alloc)/1024/1024, "MiB(GeoSite-Alloc)")
b.ReportMetric(float64(m4.TotalAlloc-m3.TotalAlloc)/1024/1024, "MiB(GeoSite-TotalAlloc)")
}
36 changes: 0 additions & 36 deletions common/geodata/load.go

This file was deleted.

52 changes: 52 additions & 0 deletions common/geodata/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package geodata

import (
"runtime"

v2router "github.com/v2fly/v2ray-core/v4/app/router"
)

type geodataLoader interface {
LoadIP(filename, country string) ([]*v2router.CIDR, error)
LoadSite(filename, list string) ([]*v2router.Domain, error)
LoadGeoIP(country string) ([]*v2router.CIDR, error)
LoadGeoSite(list string) ([]*v2router.Domain, error)
}

func GetGeodataLoader() geodataLoader {
return &geodataCache{
make(map[string]*v2router.GeoIP),
make(map[string]*v2router.GeoSite),
}
}

type geodataCache struct {
GeoIPCache
GeoSiteCache
}

func (g *geodataCache) LoadIP(filename, country string) ([]*v2router.CIDR, error) {
geoip, err := g.GeoIPCache.Unmarshal(filename, country)
if err != nil {
return nil, err
}
runtime.GC()
return geoip.Cidr, nil
}

func (g *geodataCache) LoadSite(filename, list string) ([]*v2router.Domain, error) {
geosite, err := g.GeoSiteCache.Unmarshal(filename, list)
if err != nil {
return nil, err
}
runtime.GC()
return geosite.Domain, nil
}

func (g *geodataCache) LoadGeoIP(country string) ([]*v2router.CIDR, error) {
return g.LoadIP("geoip.dat", country)
}

func (g *geodataCache) LoadGeoSite(list string) ([]*v2router.Domain, error) {
return g.LoadSite("geosite.dat", list)
}
26 changes: 24 additions & 2 deletions tunnel/router/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"regexp"
"runtime"
"strconv"
"strings"

Expand Down Expand Up @@ -264,6 +265,11 @@ func loadCode(cfg *Config, prefix string) []codeInfo {
}

func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
m1 := runtime.MemStats{}
m2 := runtime.MemStats{}
m3 := runtime.MemStats{}
m4 := runtime.MemStats{}

cfg := config.FromContext(ctx, Name).(*Config)
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
Expand Down Expand Up @@ -304,10 +310,14 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
return nil, common.NewError("unknown strategy: " + cfg.Router.DomainStrategy)
}

runtime.ReadMemStats(&m1)

geodataLoader := geodata.GetGeodataLoader()

ipCode := loadCode(cfg, "geoip:")
for _, c := range ipCode {
code := c.code
cidrs, err := geodata.LoadGeoIP(code)
cidrs, err := geodataLoader.LoadGeoIP(code)
if err != nil {
log.Error(err)
} else {
Expand All @@ -316,6 +326,8 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
}
}

runtime.ReadMemStats(&m2)

siteCode := loadCode(cfg, "geosite:")
for _, c := range siteCode {
code := c.code
Expand All @@ -334,7 +346,7 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
continue
}

domainList, err := geodata.LoadGeoSite(code)
domainList, err := geodataLoader.LoadGeoSite(code)
if err != nil {
log.Error(err)
} else {
Expand All @@ -360,6 +372,8 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
}
}

runtime.ReadMemStats(&m3)

domainInfo := loadCode(cfg, "domain:")
for _, info := range domainInfo {
client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{
Expand Down Expand Up @@ -433,5 +447,13 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
}

log.Info("router client created")

runtime.ReadMemStats(&m4)

log.Debugf("GeoIP rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m2.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m2.TotalAlloc-m1.TotalAlloc))
log.Debugf("GeoSite rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m3.Alloc-m2.Alloc), common.HumanFriendlyTraffic(m3.TotalAlloc-m2.TotalAlloc))
log.Debugf("Plaintext rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m3.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m3.TotalAlloc))
log.Debugf("Total(router) -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m1.TotalAlloc))

return client, nil
}

0 comments on commit a9ce6d4

Please sign in to comment.