From 3becee9e5d0c9c41baaea2afa9a3531b46c693e7 Mon Sep 17 00:00:00 2001 From: Harshavardhana Date: Tue, 10 Aug 2021 21:20:09 -0700 Subject: [PATCH] use sync map instead of local DNS cache (#12925) also enable PreferGo only resolver for FIPS builds --- internal/http/dial_dnscache.go | 46 +++++++++---------- internal/http/dial_dnscache_test.go | 68 ++++++++++++++--------------- 2 files changed, 54 insertions(+), 60 deletions(-) diff --git a/internal/http/dial_dnscache.go b/internal/http/dial_dnscache.go index 60b7011db..667113cf9 100644 --- a/internal/http/dial_dnscache.go +++ b/internal/http/dial_dnscache.go @@ -78,11 +78,6 @@ func DialContextWithDNSCache(cache *DNSCache, baseDialCtx DialContext) DialConte } } -const ( - // cacheSize is initial size of addr and IP list cache map. - cacheSize = 64 -) - // defaultFreq is default frequency a resolver refreshes DNS cache. var ( defaultFreq = 3 * time.Second @@ -91,13 +86,11 @@ var ( // DNSCache is DNS cache resolver which cache DNS resolve results in memory. type DNSCache struct { - sync.RWMutex - - lookupHostFn func(ctx context.Context, host string) ([]string, error) + resolver *net.Resolver lookupTimeout time.Duration loggerOnce func(ctx context.Context, err error, id interface{}, errKind ...interface{}) - cache map[string][]string + cache sync.Map doneOnce sync.Once doneCh chan struct{} } @@ -114,11 +107,18 @@ func NewDNSCache(freq time.Duration, lookupTimeout time.Duration, loggerOnce fun lookupTimeout = defaultLookupTimeout } + // PreferGo controls whether Go's built-in DNS resolver + // is preferred on platforms where it's available, since + // we do not compile with CGO, FIPS builds are CGO based + // enable this to enforce Go resolver. + defaultResolver := &net.Resolver{ + PreferGo: true, + } + r := &DNSCache{ - lookupHostFn: net.DefaultResolver.LookupHost, + resolver: defaultResolver, lookupTimeout: lookupTimeout, loggerOnce: loggerOnce, - cache: make(map[string][]string, cacheSize), doneCh: make(chan struct{}), } @@ -149,38 +149,32 @@ func NewDNSCache(freq time.Duration, lookupTimeout time.Duration, loggerOnce fun // LookupHost lookups address list from DNS server, persist the results // in-memory cache. `Fetch` is used to obtain the values for a given host. func (r *DNSCache) LookupHost(ctx context.Context, host string) ([]string, error) { - addrs, err := r.lookupHostFn(ctx, host) + addrs, err := r.resolver.LookupHost(ctx, host) if err != nil { return nil, err } - r.Lock() - r.cache[host] = addrs - r.Unlock() - + r.cache.Store(host, addrs) return addrs, nil } // Fetch fetches IP list from the cache. If IP list of the given addr is not in the cache, // then it lookups from DNS server by `Lookup` function. func (r *DNSCache) Fetch(ctx context.Context, host string) ([]string, error) { - r.RLock() - addrs, ok := r.cache[host] - r.RUnlock() + addrs, ok := r.cache.Load(host) if ok { - return addrs, nil + return addrs.([]string), nil } return r.LookupHost(ctx, host) } // Refresh refreshes IP list cache, automatically. func (r *DNSCache) Refresh() { - r.RLock() - hosts := make([]string, 0, len(r.cache)) - for host := range r.cache { - hosts = append(hosts, host) - } - r.RUnlock() + var hosts []string + r.cache.Range(func(k, v interface{}) bool { + hosts = append(hosts, k.(string)) + return true + }) for _, host := range hosts { ctx, cancelF := context.WithTimeout(context.Background(), r.lookupTimeout) diff --git a/internal/http/dial_dnscache_test.go b/internal/http/dial_dnscache_test.go index 00ae0357b..18e6edcf7 100644 --- a/internal/http/dial_dnscache_test.go +++ b/internal/http/dial_dnscache_test.go @@ -23,6 +23,7 @@ import ( "fmt" "math/rand" "net" + "runtime" "testing" "time" ) @@ -42,16 +43,12 @@ func testDNSCache(t *testing.T) *DNSCache { } func TestDialContextWithDNSCache(t *testing.T) { - resolver := &DNSCache{ - cache: map[string][]string{ - "play.min.io": { - "127.0.0.1", - "127.0.0.2", - "127.0.0.3", - }, - }, - } - + resolver := &DNSCache{} + resolver.cache.Store("play.min.io", []string{ + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + }) cases := []struct { permF func(n int) []int dialF DialContext @@ -113,15 +110,12 @@ func TestDialContextWithDNSCacheRand(t *testing.T) { rand.Seed(1) }() - resolver := &DNSCache{ - cache: map[string][]string{ - "play.min.io": { - "127.0.0.1", - "127.0.0.2", - "127.0.0.3", - }, - }, - } + resolver := &DNSCache{} + resolver.cache.Store("play.min.io", []string{ + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + }) count := make(map[string]int) dialF := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -153,33 +147,39 @@ func TestDialContextWithDNSCacheScenario1(t *testing.T) { // Verify if the host lookup function failed to return addresses func TestDialContextWithDNSCacheScenario2(t *testing.T) { + if runtime.GOOS == "windows" { + // Windows doesn't use Dial to connect + // so there is no way this test will work + // as expected. + t.Skip() + } + res := testDNSCache(t) - originalFunc := res.lookupHostFn + originalResolver := res.resolver defer func() { - res.lookupHostFn = originalFunc + res.resolver = originalResolver }() - res.lookupHostFn = func(ctx context.Context, host string) ([]string, error) { - return nil, fmt.Errorf("err") + res.resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, fmt.Errorf("err") + }, } if _, err := DialContextWithDNSCache(res, nil)(context.Background(), "tcp", "min.io:443"); err == nil { - t.Fatalf("exect to fail") + t.Fatalf("expect to fail") } } // Verify we always return the first error from net.Dial failure func TestDialContextWithDNSCacheScenario3(t *testing.T) { - resolver := &DNSCache{ - cache: map[string][]string{ - "min.io": { - "1.1.1.1", - "2.2.2.2", - "3.3.3.3", - }, - }, - } - + resolver := &DNSCache{} + resolver.cache.Store("min.io", []string{ + "1.1.1.1", + "2.2.2.2", + "3.3.3.3", + }) origFunc := randPerm randPerm = func(n int) []int { return []int{0, 1, 2}