mullvad-best-server/main.go
Thorsten Schubert 086106da30
All checks were successful
/ build (push) Successful in 44s
Add verbose output
2024-05-08 20:18:43 +02:00

341 lines
8.9 KiB
Go

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"runtime"
"sort"
"strings"
"sync"
"time"
probe "github.com/prometheus-community/pro-bing"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
type EndpointResponse struct {
Wireguard Wireguard `json:"wireguard"`
}
type Wireguard struct {
Relays []Relay `json:"relays"`
}
type Relay struct {
Hostname string `json:"hostname"`
Location string `json:"location"`
Provider string `json:"provider"`
Ipv4AddrIn string `json:"ipv4_addr_in"`
PublicKey string `json:"public_key"`
Ipv6AddrIn string `json:"ipv6_addr_in"`
Weight int `json:"weight"`
Active bool `json:"active"`
Owned bool `json:"owned"`
Stboot bool `json:"stboot"`
IncludeInCountry bool `json:"include_in_country"`
SameIP bool `json:"same_ip"`
}
type Config struct {
Country string
City string
Provider string
Interval time.Duration
Rounds int
Threads int
Timeout time.Duration
MinRtt bool
Warmup bool
Verbose bool
}
type Measurement struct {
Relay *Relay
Latency time.Duration
}
func main() {
var (
apiTimeFlag = flag.Duration("api-timeout", 10*time.Second, "API timeout")
countryFlag = flag.String("country", "", "Relay country code (ISO 3166 ALPHA-2), e.g. 'de' for Germany")
cityFlag = flag.String("city", "", "City the relay is located in), e.g. 'London'")
intervalFlag = flag.Duration("interval", 200*time.Millisecond, "Packet interval")
jsonFlag = flag.Bool("json", false, "Output result as JSON")
logFlag = flag.String("log", "warn", "Log level. Allowed values: trace, debug, info, warn, error, fatal, panic")
numPktFlag = flag.Int("packets", 16, "Number of packets per relay")
providerFlag = flag.String("provider", "", "Relay provider, e.g. 31173 for mullvad owned relays")
rttFlag = flag.String("rtt", "avg", "Minimum (min) or Average (avg) time for RTT latency")
threadsFlag = flag.Int("threads", runtime.NumCPU(), "Number of relays processed concurrently (batch)")
timeoutFlag = flag.Duration("timeout", 200*time.Millisecond, "Ping timeout")
warmupFlag = flag.Bool("warmup", true, "Warmup phase")
verboseFlag = flag.Bool("verbose", false, "More detailed output")
)
flag.Parse()
isMinRtt := func(s *string) bool {
buf := strings.ToLower(*s)
switch buf {
case "min":
return true
case "avg":
return false
default:
log.Fatal().Msg("rtt arguments must be either min or avg")
return false
}
}(rttFlag)
config := &Config{
Country: *countryFlag,
City: *cityFlag,
Interval: *intervalFlag,
MinRtt: isMinRtt,
Provider: *providerFlag,
Rounds: *numPktFlag,
Threads: *threadsFlag,
Timeout: *timeoutFlag,
Warmup: *warmupFlag,
Verbose: *verboseFlag,
}
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339})
level, err := zerolog.ParseLevel(*logFlag)
if err != nil {
log.Fatal().Err(err).Msg("Unable to set log level")
}
zerolog.SetGlobalLevel(level)
// Produces Relays
relays := queryRelaysAPI(*apiTimeFlag)
log.Debug().Int("numRelays", len(relays)).Msg("Total number of relays")
relays = filterRelays(relays, config)
log.Debug().Int("numFiltered", len(relays)).Msg("Number of filtered relays")
// Consumes Relays concurrently
latency := findFastestRelay(relays, config)
if latency.Relay == nil {
log.Fatal().Str("country", *countryFlag).Msg("No relay for provided country code found")
panic("unreachable")
}
log.Info().Interface("relay", latency.Relay).Dur("latency",
latency.Latency).Msg("Lowest latency relay found")
if *jsonFlag {
relayJSON, err := json.Marshal(latency.Relay)
if err != nil {
log.Fatal().Err(err).Msg("Could not marshal relay information to JSON")
}
fmt.Println(string(relayJSON))
} else {
fmt.Println(latency.Relay.Hostname)
}
}
func queryRelaysAPI(timeout time.Duration) []Relay {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var Body io.ReadCloser
req, _ := http.NewRequestWithContext(ctx, http.MethodGet,
"https://api.mullvad.net/app/v1/relays", Body)
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatal().Err(err).Msg("Could not retrieve relays from endpoint")
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatal().Err(err)
}
var endpoint EndpointResponse
err = json.Unmarshal(respBody, &endpoint)
if err != nil {
log.Fatal().Err(err).Msg("Could not parse JSON")
}
return endpoint.Wireguard.Relays
}
func filterRelays(relays []Relay, conf *Config) []Relay {
log.Trace().Int("filtered", len(relays)).Msg("Filtered relays")
// filter in-place
filtered := relays[:0]
countryPrefix := conf.Country + "-"
citySuffix := "-" + conf.City
for _, relay := range relays {
if relay.Active && (len(conf.Country) == 0 || strings.HasPrefix(relay.Location, countryPrefix)) &&
(len(conf.Provider) == 0 || relay.Provider == conf.Provider) &&
(len(conf.City) == 0 || strings.HasSuffix(relay.Location, citySuffix)) {
filtered = append(filtered, relay)
log.Trace().Interface("appending", relay).Msg("Appending relay")
}
}
return filtered
}
// build ARP cache on the remote
func warmup(relay Relay) {
pinger, err := probe.NewPinger(relay.Ipv4AddrIn)
pinger.Timeout = 1 * time.Second
pinger.Count = 1
pinger.RecordRtts = false
if err != nil {
log.Debug().Err(err).Send()
return
}
err = pinger.Run()
if err != nil {
log.Debug().Err(err).Send()
}
}
func measureLatencyWorker(wg *sync.WaitGroup, conf *Config, in <-chan Relay, out chan<- Measurement, wid int) {
defer wg.Done()
traceWid := func() *zerolog.Event {
return log.Trace().Int("wid", wid)
}
traceWid().Int("input", len(in)).Int("output", len(out)).Msg("measureLatencyWorker entry")
for relay := range in {
traceWid().Interface("relay", relay).Msg("Relay received")
if conf.Warmup {
warmup(relay)
}
pinger, err := probe.NewPinger(relay.Ipv4AddrIn)
// Timeout is the total timeout value for all packets and interval
// https://github.com/prometheus-community/pro-bing/issues/19
pinger.Timeout = time.Duration(int64(conf.Rounds) * (int64(conf.Interval) + int64(conf.Timeout)))
traceWid().Dur("timeout", pinger.Timeout).Msg("Timeout (ms)")
pinger.Count = conf.Rounds
pinger.Interval = conf.Interval
pinger.RecordRtts = false
if err != nil {
log.Debug().Err(err).Send()
return
}
var duration time.Duration
pinger.OnFinish = func(stats *probe.Statistics) {
log.Debug().Interface("stats", stats).Send()
duration = func() time.Duration {
if stats.PacketsRecv < stats.PacketsSent {
log.Debug().Int("wid", wid).Interface("relay", relay).Msg("Packets lost")
return 0
} else if conf.MinRtt {
return stats.MinRtt
} else {
return stats.AvgRtt
}
}()
}
pinger.OnRecv = func(pkt *probe.Packet) {
traceWid().Interface("packet", pkt).Send()
}
err = pinger.Run()
if err != nil {
log.Debug().Err(err).Send()
return
}
if duration != 0 {
out <- Measurement{&relay, duration}
}
}
traceWid().Msg("measureLatencyWorker exit")
}
func findFastestRelay(relays []Relay, conf *Config) *Measurement {
log.Trace().Int("numRelays", len(relays)).Msg("Num relays")
var wg sync.WaitGroup
inRelay := make(chan Relay, conf.Threads)
outMeasurement := make(chan Measurement, conf.Threads)
// Pass Relays to worker threads
go func() {
for _, relay := range relays {
inRelay <- relay
}
close(inRelay)
}()
done := make(chan struct{})
measurement := &Measurement{nil, 60 * time.Second}
// Process results
go func() {
results := make(map[string]time.Duration)
for res := range outMeasurement {
log.Debug().Interface("relay", res.Relay).Dur("latency",
res.Latency).Send()
if measurement.Latency > res.Latency {
log.Info().Interface("lowest", res.Relay).Msg("New lowest latency")
measurement.Relay = res.Relay
measurement.Latency = res.Latency
}
if conf.Verbose {
results[res.Relay.Hostname] = res.Latency
}
}
if conf.Verbose {
// Sort results by Duration by creating a separate slice and sorting keys by value first
keys := make([]string, 0, len(results))
for key := range results {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool { return results[keys[i]] > results[keys[j]] })
// Print map with sorted
for _, key := range keys {
fmt.Printf("%s, %s\n", key, results[key])
}
fmt.Printf("%s\n", strings.Repeat("=", len(measurement.Relay.Hostname)))
}
// Unblock main thread with empty struct for signaling
done <- struct{}{}
}()
// Spawn n threads for producing measurements
// consuming Relays until none are left
wg.Add(conf.Threads)
for i := 0; i < conf.Threads; i++ {
go measureLatencyWorker(&wg, conf, inRelay, outMeasurement, i+1)
}
// And wait for all threads to exit
wg.Wait()
close(outMeasurement)
// Block on measurement processing
<-done
// Fastest measurement
return measurement
}