Adding quotas based upon type

This commit is contained in:
Frederick F. Kautz IV 2015-04-25 18:01:52 -07:00
parent 5a2fb9741a
commit 3a48f9fe4d
2 changed files with 37 additions and 26 deletions

View File

@ -92,5 +92,5 @@ func HTTPHandler(domain string, driver drivers.Driver) http.Handler {
h := validateHandler(conf, ignoreResourcesHandler(mux)) h := validateHandler(conf, ignoreResourcesHandler(mux))
// quota handler is always last // quota handler is always last
return quota.Handler(h, int64(100*1024*1024)) return quota.BandwidthCap(h, int64(100*1024*1024))
} }

View File

@ -17,8 +17,6 @@
package quota package quota
import ( import (
"encoding/binary"
"log"
"net" "net"
"net/http" "net/http"
"sync" "sync"
@ -34,6 +32,8 @@ type quotaMap struct {
} }
func (q *quotaMap) Add(ip uint32, size uint64) bool { func (q *quotaMap) Add(ip uint32, size uint64) bool {
q.Lock()
defer q.Unlock()
currentMinute := time.Now().Unix() / q.duration currentMinute := time.Now().Unix() / q.duration
expiredQuotas := (time.Now().Unix() / q.duration) - 5 expiredQuotas := (time.Now().Unix() / q.duration) - 5
for time := range q.data { for time := range q.data {
@ -41,19 +41,23 @@ func (q *quotaMap) Add(ip uint32, size uint64) bool {
delete(q.data, time) delete(q.data, time)
} }
} }
log.Println(currentMinute)
if _, ok := q.data[currentMinute]; !ok { if _, ok := q.data[currentMinute]; !ok {
q.data[currentMinute] = make(map[uint32]uint64) q.data[currentMinute] = make(map[uint32]uint64)
} }
currentData, _ := q.data[currentMinute][ip] currentData, _ := q.data[currentMinute][ip]
q.data[currentMinute][ip] = currentData + size proposedDataSize := currentData + size
return false if proposedDataSize > q.limit {
return false
}
q.data[currentMinute][ip] = proposedDataSize
return true
} }
// HttpQuotaHandler // HttpQuotaHandler
type httpQuotaHandler struct { type httpQuotaHandler struct {
handler http.Handler handler http.Handler
quotas *quotaMap quotas *quotaMap
adder func(uint64) uint64
} }
type longIP struct { type longIP struct {
@ -61,37 +65,30 @@ type longIP struct {
} }
// []byte to uint32 representation // []byte to uint32 representation
func (p longIP) IptoUint32() uint32 { func (p longIP) IptoUint32() (result uint32) {
ip := p.To4() ip := p.To4()
if ip == nil { if ip == nil {
return 0 return 0
} }
// golang net.IP is BigEndian q0 := uint32(ip[0]) << 24
return binary.BigEndian.Uint32([]byte(ip)) q1 := uint32(ip[1]) << 16
} q2 := uint32(ip[2]) << 8
q3 := uint32(ip[3])
// any uint32 back to IP representation result = q0 + q1 + q2 + q3
func uint32ToIP(ip uint32) net.IP { return
addr := net.IP{0, 0, 0, 0}
binary.BigEndian.PutUint32(addr, ip)
return addr
} }
// ServeHTTP is an http.Handler ServeHTTP method // ServeHTTP is an http.Handler ServeHTTP method
func (h *httpQuotaHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (h *httpQuotaHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
host, _, _ := net.SplitHostPort(req.RemoteAddr) host, _, _ := net.SplitHostPort(req.RemoteAddr)
log.Println(host) longIP := longIP{net.ParseIP(host)}.IptoUint32()
if h.quotas.Add(longIP, h.adder(uint64(req.ContentLength))) {
longIP := longIP{net.ParseIP(host)} h.handler.ServeHTTP(w, req)
h.quotas.Add(longIP.IptoUint32(), uint64(req.ContentLength)) }
log.Println("quota called")
log.Println(h.quotas)
h.handler.ServeHTTP(w, req)
} }
// Handler implements quotas // BandwidthCap sets a quote based upon bandwidth used
func Handler(h http.Handler, limit int64) http.Handler { func BandwidthCap(h http.Handler, limit int64) http.Handler {
return &httpQuotaHandler{ return &httpQuotaHandler{
handler: h, handler: h,
quotas: &quotaMap{ quotas: &quotaMap{
@ -99,5 +96,19 @@ func Handler(h http.Handler, limit int64) http.Handler {
limit: uint64(limit), limit: uint64(limit),
duration: int64(60), duration: int64(60),
}, },
adder: func(count uint64) uint64 { return count },
}
}
// RequestLimit sets a quota based upon request count
func RequestLimit(h http.Handler, limit int64) http.Handler {
return &httpQuotaHandler{
handler: h,
quotas: &quotaMap{
data: make(map[int64]map[uint32]uint64),
limit: uint64(limit),
duration: int64(60),
},
adder: func(count uint64) uint64 { return 1 },
} }
} }