diff --git a/src/main/scala/lc/Server.scala b/src/main/scala/lc/Server.scala index d75e45d..2d838ab 100644 --- a/src/main/scala/lc/Server.scala +++ b/src/main/scala/lc/Server.scala @@ -10,50 +10,55 @@ import lc.HTTPServer._ case class Secret(token: Int) class RateLimiter extends DBConn { - val userLastActive = collection.mutable.Map[Int, Long]() - val userAllowance = collection.mutable.Map[Int, Double]() - val rate = 8.0 - val per = 45.0 - val allowance = rate + private val userLastActive = collection.mutable.Map[Int, Long]() + private val userAllowance = collection.mutable.Map[Int, Double]() + private val rate = 800000.0 + private val per = 45.0 + private val allowance = rate - def validateUser(user: Int) : Boolean = { - synchronized { - val allow = if(userLastActive.contains(user)){ + private def validateUser(user: Int) : Boolean = { + val allow = if(userLastActive.contains(user)){ + true + } else { + validatePstmt.setInt(1, user) + val rs = validatePstmt.executeQuery() + val validated = if(rs.next()){ + val hash = rs.getInt("hash") + userLastActive(hash) = System.currentTimeMillis() + userAllowance(hash) = allowance true } else { - validatePstmt.setInt(1, user) - val rs = validatePstmt.executeQuery() - val validated = if(rs.next()){ - val hash = rs.getInt("hash") - userLastActive(hash) = System.currentTimeMillis() - userAllowance(hash) = allowance - true - } else { - false - } - validated - } - allow - } - } - - def checkLimit(user: Int): Boolean = { - synchronized { - val current = System.currentTimeMillis() - val time_passed = (current - userLastActive(user)) / 1000 - userLastActive(user) = current - userAllowance(user) += time_passed * (rate/per) - if(userAllowance(user) > rate){ userAllowance(user) = rate } - val allow = if(userAllowance(user) < 1.0){ false - } else { - userAllowance(user) -= 1.0 - true } - allow + validated } + allow } + private def checkLimit(user: Int): Boolean = { + val current = System.currentTimeMillis() + val time_passed = (current - userLastActive(user)) / 1000 + userLastActive(user) = current + userAllowance(user) += time_passed * (rate/per) + if(userAllowance(user) > rate){ userAllowance(user) = rate } + val allow = if(userAllowance(user) < 1.0){ + false + } else { + userAllowance(user) -= 1.0 + true + } + allow + } + + def checkUserAccess(token: Int) : Boolean = { + synchronized { + if (validateUser(token)) { + return checkLimit(token) + } else { + return false + } + } + } } class Server(port: Int){ @@ -66,7 +71,7 @@ class Server(port: Int){ host.addContext("/v1/captcha",(req, resp) => { val accessToken = Option(req.getHeaders().get("access-token")).map(_.toInt) - val access = accessToken.map(t => rateLimiter.validateUser(t) && rateLimiter.checkLimit(t)).getOrElse(false) + val access = accessToken.map(rateLimiter.checkUserAccess).getOrElse(false) if(access){ val body = req.getJson() val json = parse(body)