use sync map instead of local DNS cache (#12925)

also enable PreferGo only resolver for FIPS builds
This commit is contained in:
Harshavardhana 2021-08-10 21:20:09 -07:00 committed by GitHub
parent 40a2fa8e81
commit 3becee9e5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 60 deletions

View File

@ -78,11 +78,6 @@ func DialContextWithDNSCache(cache *DNSCache, baseDialCtx DialContext) DialConte
} }
} }
const (
// cacheSize is initial size of addr and IP list cache map.
cacheSize = 64
)
// defaultFreq is default frequency a resolver refreshes DNS cache. // defaultFreq is default frequency a resolver refreshes DNS cache.
var ( var (
defaultFreq = 3 * time.Second defaultFreq = 3 * time.Second
@ -91,13 +86,11 @@ var (
// DNSCache is DNS cache resolver which cache DNS resolve results in memory. // DNSCache is DNS cache resolver which cache DNS resolve results in memory.
type DNSCache struct { type DNSCache struct {
sync.RWMutex resolver *net.Resolver
lookupHostFn func(ctx context.Context, host string) ([]string, error)
lookupTimeout time.Duration lookupTimeout time.Duration
loggerOnce func(ctx context.Context, err error, id interface{}, errKind ...interface{}) loggerOnce func(ctx context.Context, err error, id interface{}, errKind ...interface{})
cache map[string][]string cache sync.Map
doneOnce sync.Once doneOnce sync.Once
doneCh chan struct{} doneCh chan struct{}
} }
@ -114,11 +107,18 @@ func NewDNSCache(freq time.Duration, lookupTimeout time.Duration, loggerOnce fun
lookupTimeout = defaultLookupTimeout lookupTimeout = defaultLookupTimeout
} }
// PreferGo controls whether Go's built-in DNS resolver
// is preferred on platforms where it's available, since
// we do not compile with CGO, FIPS builds are CGO based
// enable this to enforce Go resolver.
defaultResolver := &net.Resolver{
PreferGo: true,
}
r := &DNSCache{ r := &DNSCache{
lookupHostFn: net.DefaultResolver.LookupHost, resolver: defaultResolver,
lookupTimeout: lookupTimeout, lookupTimeout: lookupTimeout,
loggerOnce: loggerOnce, loggerOnce: loggerOnce,
cache: make(map[string][]string, cacheSize),
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
} }
@ -149,38 +149,32 @@ func NewDNSCache(freq time.Duration, lookupTimeout time.Duration, loggerOnce fun
// LookupHost lookups address list from DNS server, persist the results // LookupHost lookups address list from DNS server, persist the results
// in-memory cache. `Fetch` is used to obtain the values for a given host. // in-memory cache. `Fetch` is used to obtain the values for a given host.
func (r *DNSCache) LookupHost(ctx context.Context, host string) ([]string, error) { func (r *DNSCache) LookupHost(ctx context.Context, host string) ([]string, error) {
addrs, err := r.lookupHostFn(ctx, host) addrs, err := r.resolver.LookupHost(ctx, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r.Lock() r.cache.Store(host, addrs)
r.cache[host] = addrs
r.Unlock()
return addrs, nil return addrs, nil
} }
// Fetch fetches IP list from the cache. If IP list of the given addr is not in the cache, // Fetch fetches IP list from the cache. If IP list of the given addr is not in the cache,
// then it lookups from DNS server by `Lookup` function. // then it lookups from DNS server by `Lookup` function.
func (r *DNSCache) Fetch(ctx context.Context, host string) ([]string, error) { func (r *DNSCache) Fetch(ctx context.Context, host string) ([]string, error) {
r.RLock() addrs, ok := r.cache.Load(host)
addrs, ok := r.cache[host]
r.RUnlock()
if ok { if ok {
return addrs, nil return addrs.([]string), nil
} }
return r.LookupHost(ctx, host) return r.LookupHost(ctx, host)
} }
// Refresh refreshes IP list cache, automatically. // Refresh refreshes IP list cache, automatically.
func (r *DNSCache) Refresh() { func (r *DNSCache) Refresh() {
r.RLock() var hosts []string
hosts := make([]string, 0, len(r.cache)) r.cache.Range(func(k, v interface{}) bool {
for host := range r.cache { hosts = append(hosts, k.(string))
hosts = append(hosts, host) return true
} })
r.RUnlock()
for _, host := range hosts { for _, host := range hosts {
ctx, cancelF := context.WithTimeout(context.Background(), r.lookupTimeout) ctx, cancelF := context.WithTimeout(context.Background(), r.lookupTimeout)

View File

@ -23,6 +23,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"runtime"
"testing" "testing"
"time" "time"
) )
@ -42,16 +43,12 @@ func testDNSCache(t *testing.T) *DNSCache {
} }
func TestDialContextWithDNSCache(t *testing.T) { func TestDialContextWithDNSCache(t *testing.T) {
resolver := &DNSCache{ resolver := &DNSCache{}
cache: map[string][]string{ resolver.cache.Store("play.min.io", []string{
"play.min.io": { "127.0.0.1",
"127.0.0.1", "127.0.0.2",
"127.0.0.2", "127.0.0.3",
"127.0.0.3", })
},
},
}
cases := []struct { cases := []struct {
permF func(n int) []int permF func(n int) []int
dialF DialContext dialF DialContext
@ -113,15 +110,12 @@ func TestDialContextWithDNSCacheRand(t *testing.T) {
rand.Seed(1) rand.Seed(1)
}() }()
resolver := &DNSCache{ resolver := &DNSCache{}
cache: map[string][]string{ resolver.cache.Store("play.min.io", []string{
"play.min.io": { "127.0.0.1",
"127.0.0.1", "127.0.0.2",
"127.0.0.2", "127.0.0.3",
"127.0.0.3", })
},
},
}
count := make(map[string]int) count := make(map[string]int)
dialF := func(ctx context.Context, network, addr string) (net.Conn, error) { dialF := func(ctx context.Context, network, addr string) (net.Conn, error) {
@ -153,33 +147,39 @@ func TestDialContextWithDNSCacheScenario1(t *testing.T) {
// Verify if the host lookup function failed to return addresses // Verify if the host lookup function failed to return addresses
func TestDialContextWithDNSCacheScenario2(t *testing.T) { func TestDialContextWithDNSCacheScenario2(t *testing.T) {
if runtime.GOOS == "windows" {
// Windows doesn't use Dial to connect
// so there is no way this test will work
// as expected.
t.Skip()
}
res := testDNSCache(t) res := testDNSCache(t)
originalFunc := res.lookupHostFn originalResolver := res.resolver
defer func() { defer func() {
res.lookupHostFn = originalFunc res.resolver = originalResolver
}() }()
res.lookupHostFn = func(ctx context.Context, host string) ([]string, error) { res.resolver = &net.Resolver{
return nil, fmt.Errorf("err") PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, fmt.Errorf("err")
},
} }
if _, err := DialContextWithDNSCache(res, nil)(context.Background(), "tcp", "min.io:443"); err == nil { if _, err := DialContextWithDNSCache(res, nil)(context.Background(), "tcp", "min.io:443"); err == nil {
t.Fatalf("exect to fail") t.Fatalf("expect to fail")
} }
} }
// Verify we always return the first error from net.Dial failure // Verify we always return the first error from net.Dial failure
func TestDialContextWithDNSCacheScenario3(t *testing.T) { func TestDialContextWithDNSCacheScenario3(t *testing.T) {
resolver := &DNSCache{ resolver := &DNSCache{}
cache: map[string][]string{ resolver.cache.Store("min.io", []string{
"min.io": { "1.1.1.1",
"1.1.1.1", "2.2.2.2",
"2.2.2.2", "3.3.3.3",
"3.3.3.3", })
},
},
}
origFunc := randPerm origFunc := randPerm
randPerm = func(n int) []int { randPerm = func(n int) []int {
return []int{0, 1, 2} return []int{0, 1, 2}