Some more improvements to connection limit

This commit is contained in:
Harshavardhana 2015-04-29 18:01:49 -07:00
parent 58491d22fc
commit fd2203b1b7
2 changed files with 15 additions and 9 deletions

View File

@ -32,20 +32,24 @@ type connLimit struct {
limit int limit int
} }
func (c *connLimit) IsLimitExceeded(ip uint32) bool {
if c.connections[ip] >= c.limit {
return true
}
return false
}
func (c *connLimit) GetUsed(ip uint32) int { func (c *connLimit) GetUsed(ip uint32) int {
return c.connections[ip] return c.connections[ip]
} }
func (c *connLimit) TestAndAdd(ip uint32) bool { func (c *connLimit) Add(ip uint32) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
count, _ := c.connections[ip] count := c.connections[ip]
if count >= c.limit {
return false
}
count = count + 1 count = count + 1
c.connections[ip] = count c.connections[ip] = count
return true return
} }
func (c *connLimit) Remove(ip uint32) { func (c *connLimit) Remove(ip uint32) {
@ -64,11 +68,13 @@ func (c *connLimit) Remove(ip uint32) {
func (c *connLimit) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (c *connLimit) ServeHTTP(w http.ResponseWriter, req *http.Request) {
host, _, _ := net.SplitHostPort(req.RemoteAddr) host, _, _ := net.SplitHostPort(req.RemoteAddr)
longIP := longIP{net.ParseIP(host)}.IptoUint32() longIP := longIP{net.ParseIP(host)}.IptoUint32()
if !c.TestAndAdd(longIP) { if c.IsLimitExceeded(longIP) {
hosts, _ := net.LookupAddr(uint32ToIP(longIP).String()) hosts, _ := net.LookupAddr(uint32ToIP(longIP).String())
log.Debug.Printf("Offending Host: %s, ConnectionsUSED: %d\n", hosts, c.GetUsed(longIP)) log.Debug.Printf("Connection limit reached - Host: %s, Total Connections: %d\n", hosts, c.GetUsed(longIP))
writeErrorResponse(w, req, ConnectionLimitExceeded, req.URL.Path) writeErrorResponse(w, req, ConnectionLimitExceeded, req.URL.Path)
return
} }
c.Add(longIP)
defer c.Remove(longIP) defer c.Remove(longIP)
c.handler.ServeHTTP(w, req) c.handler.ServeHTTP(w, req)
} }

View File

@ -58,11 +58,11 @@ const (
func writeErrorResponse(w http.ResponseWriter, req *http.Request, errorType int, resource string) { func writeErrorResponse(w http.ResponseWriter, req *http.Request, errorType int, resource string) {
error := getErrorCode(errorType) error := getErrorCode(errorType)
errorResponse := getErrorResponse(error, resource) errorResponse := getErrorResponse(error, resource)
encodedErrorResponse := encodeErrorResponse(errorResponse)
// set headers // set headers
writeErrorHeaders(w) writeErrorHeaders(w)
w.WriteHeader(error.HTTPStatusCode) w.WriteHeader(error.HTTPStatusCode)
// write body // write body
encodedErrorResponse := encodeErrorResponse(errorResponse)
w.Write(encodedErrorResponse) w.Write(encodedErrorResponse)
} }