diff --git a/src/main/scala/lc/Main.scala b/src/main/scala/lc/Main.scala index 18e4753..e37c75e 100644 --- a/src/main/scala/lc/Main.scala +++ b/src/main/scala/lc/Main.scala @@ -19,18 +19,32 @@ import java.util.concurrent._ import java.util.UUID import scala.Array -class Captcha(throttle: Int) { +class DBConn(){ val con: Connection = DriverManager.getConnection("jdbc:h2:./captcha", "sa", "") - val stmt: Statement = con.createStatement() - stmt.execute("CREATE TABLE IF NOT EXISTS challenge(token varchar, id varchar, secret varchar, provider varchar, contentType varchar, image blob, solved boolean default False, PRIMARY KEY(token))") - stmt.execute("CREATE TABLE IF NOT EXISTS mapId(uuid varchar, token varchar, PRIMARY KEY(uuid), FOREIGN KEY(token) REFERENCES challenge(token))") - stmt.execute("CREATE TABLE IF NOT EXISTS users(email varchar, hash int)") + val insertPstmt: PreparedStatement = con.prepareStatement("INSERT INTO challenge(token, id, secret, provider, contentType, image) VALUES (?, ?, ?, ?, ?, ?)") val mapPstmt: PreparedStatement = con.prepareStatement("INSERT INTO mapId(uuid, token) VALUES (?, ?)") val selectPstmt: PreparedStatement = con.prepareStatement("SELECT secret, provider FROM challenge WHERE token = ?") val imagePstmt: PreparedStatement = con.prepareStatement("SELECT image FROM challenge c, mapId m WHERE c.token=m.token AND m.uuid = ?") val updatePstmt: PreparedStatement = con.prepareStatement("UPDATE challenge SET solved = True WHERE token = (SELECT m.token FROM mapId m, challenge c WHERE m.token=c.token AND m.uuid = ?)") val userPstmt: PreparedStatement = con.prepareStatement("INSERT INTO users(email, hash) VALUES (?,?)") + val validatePstmt: PreparedStatement = con.prepareStatement("SELECT hash FROM users WHERE hash = ? LIMIT 1") + + def getConn(): Statement = { + con.createStatement() + } + + def closeConnection(): Unit = { + con.close() + } +} + +class Captcha(throttle: Int) extends DBConn { + + val stmt = getConn() + stmt.execute("CREATE TABLE IF NOT EXISTS challenge(token varchar, id varchar, secret varchar, provider varchar, contentType varchar, image blob, solved boolean default False, PRIMARY KEY(token))") + stmt.execute("CREATE TABLE IF NOT EXISTS mapId(uuid varchar, token varchar, PRIMARY KEY(uuid), FOREIGN KEY(token) REFERENCES challenge(token))") + stmt.execute("CREATE TABLE IF NOT EXISTS users(email varchar, hash int)") val providers = Map("FilterChallenge" -> new FilterChallenge, "FontFunCaptcha" -> new FontFunCaptcha, @@ -151,10 +165,6 @@ class Captcha(throttle: Int) { println(s"${token}\t\t${id}\t\t${secret}\t\t${solved}") } } - - def closeConnection(): Unit = { - con.close() - } } case class Size(height: Int, width: Int) @@ -163,18 +173,73 @@ case class Id(id: String) case class Answer(answer: String, id: String) case class Secret(token: Int) +class RateLimiter extends DBConn { + val stmt = getConn() + val userLastActive = collection.mutable.Map[Int, Long]() + val userAllowance = collection.mutable.Map[Int, Double]() + val rate = 2.0 + val per = 45.0 + val allowance = rate + + def validateUser(user: Int) : Boolean = { + synchronized { + val allow = if(userLastActive.contains(user)){ + true + } else { + validatePstmt.setInt(1, user) + val rs = validatePstmt.executeQuery() + val validated = if(rs.next()){ + val hash = rs.getInt("hash") + userLastActive(hash) = System.currentTimeMillis() + userAllowance(hash) = allowance + true + } else { + false + } + validated + } + allow + } + } + + def checkLimit(user: Int): Boolean = { + synchronized { + val current = System.currentTimeMillis() + val time_passed = (current - userLastActive(user)) / 1000000000 + userLastActive(user) = current + userAllowance(user) += time_passed * (rate/per) + if(userAllowance(user) > rate){ userAllowance(user) = rate } + val allow = if(userAllowance(user) < 1.0){ + false + } else { + userAllowance(user) -= 1.0 + true + } + allow + } + } + +} + class Server(port: Int){ val captcha = new Captcha(0) + val rateLimiter = new RateLimiter() val server = new HTTPServer(port) val host = server.getVirtualHost(null) implicit val formats = DefaultFormats host.addContext("/v1/captcha",(req, resp) => { - val body = req.getJson() - val json = parse(body) - val param = json.extract[Parameters] - val id = captcha.getChallenge(param) + val accessToken = Option(req.getHeaders().get("access-token")).map(_.toInt) + val access = accessToken.map(t => rateLimiter.validateUser(t) && rateLimiter.checkLimit(t)).getOrElse(false) + val id = if(access){ + val body = req.getJson() + val json = parse(body) + val param = json.extract[Parameters] + captcha.getChallenge(param) + } else { + "Not a valid user or rate limit reached!" + } resp.getHeaders().add("Content-Type","application/json") resp.send(200, write(id)) 0