diff --git a/build.sbt b/build.sbt index ef6d585..3899dfc 100644 --- a/build.sbt +++ b/build.sbt @@ -2,16 +2,16 @@ lazy val root = (project in file(".")). settings( inThisBuild(List( organization := "com.example", - scalaVersion := "2.13.2", + scalaVersion := "2.13.3", version := "0.1.0-SNAPSHOT")), name := "LibreCaptcha", libraryDependencies += "com.sksamuel.scrimage" % "scrimage-core" % "4.0.5", libraryDependencies += "com.sksamuel.scrimage" % "scrimage-filters" % "4.0.5", - + libraryDependencies += "org.json4s" % "json4s-jackson_2.13" % "3.6.9" - + ) unmanagedResourceDirectories in Compile += {baseDirectory.value / "lib"} diff --git a/project/build.properties b/project/build.properties index 654fe70..0837f7a 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.3.12 +sbt.version=1.3.13 diff --git a/src/main/scala/lc/DB.scala b/src/main/scala/lc/DB.scala index ca83557..e444e4f 100644 --- a/src/main/scala/lc/DB.scala +++ b/src/main/scala/lc/DB.scala @@ -5,15 +5,7 @@ import java.sql._ class DBConn(){ val con: Connection = DriverManager.getConnection("jdbc:h2:./captcha", "sa", "") - lazy val insertPstmt: PreparedStatement = con.prepareStatement("INSERT INTO challenge(token, id, secret, provider, contentType, image) VALUES (?, ?, ?, ?, ?, ?)") - lazy val mapPstmt: PreparedStatement = con.prepareStatement("INSERT INTO mapId(uuid, token) VALUES (?, ?)") - lazy val selectPstmt: PreparedStatement = con.prepareStatement("SELECT secret, provider FROM challenge WHERE token = (SELECT m.token FROM mapId m, challenge c WHERE m.token=c.token AND m.uuid = ?)") - lazy val imagePstmt: PreparedStatement = con.prepareStatement("SELECT image FROM challenge c, mapId m WHERE c.token=m.token AND m.uuid = ?") - lazy val updatePstmt: PreparedStatement = con.prepareStatement("UPDATE challenge SET solved = True WHERE token = (SELECT m.token FROM mapId m, challenge c WHERE m.token=c.token AND m.uuid = ?)") - lazy val userPstmt: PreparedStatement = con.prepareStatement("INSERT INTO users(email, hash) VALUES (?,?)") - lazy val validatePstmt: PreparedStatement = con.prepareStatement("SELECT hash FROM users WHERE hash = ? LIMIT 1") - - def getConn(): Statement = { + def getStatement(): Statement = { con.createStatement() } diff --git a/src/main/scala/lc/Main.scala b/src/main/scala/lc/Main.scala index f03ae4b..fd5e8ae 100644 --- a/src/main/scala/lc/Main.scala +++ b/src/main/scala/lc/Main.scala @@ -14,13 +14,7 @@ case class Answer(answer: String, id: String) case class ProviderSecret(provider: String, secret: String) -class Captcha(throttle: Int) extends DBConn { - - val stmt = getConn() - stmt.execute("CREATE TABLE IF NOT EXISTS challenge(token varchar, id varchar, secret varchar, provider varchar, contentType varchar, image blob, solved boolean default False, PRIMARY KEY(token))") - stmt.execute("CREATE TABLE IF NOT EXISTS mapId(uuid varchar, token varchar, PRIMARY KEY(uuid), FOREIGN KEY(token) REFERENCES challenge(token))") - stmt.execute("CREATE TABLE IF NOT EXISTS users(email varchar, hash int)") - +object CaptchaProviders { val providers = Map( "FilterChallenge" -> new FilterChallenge, "FontFunCaptcha" -> new FontFunCaptcha, @@ -30,6 +24,28 @@ class Captcha(throttle: Int) extends DBConn { "LabelCaptcha" -> new LabelCaptcha ) + def generateChallengeSamples() = { + providers.map {case (key, provider) => + (key, provider.returnChallenge()) + } + } +} + +class Captcha(throttle: Int, dbConn: DBConn) { + import CaptchaProviders._ + + private val stmt = dbConn.getStatement() + stmt.execute("CREATE TABLE IF NOT EXISTS challenge(token varchar, id varchar, secret varchar, provider varchar, contentType varchar, image blob, solved boolean default False, PRIMARY KEY(token))") + stmt.execute("CREATE TABLE IF NOT EXISTS mapId(uuid varchar, token varchar, PRIMARY KEY(uuid), FOREIGN KEY(token) REFERENCES challenge(token))") + stmt.execute("CREATE TABLE IF NOT EXISTS users(email varchar, hash int)") + + private val insertPstmt = dbConn.con.prepareStatement("INSERT INTO challenge(token, id, secret, provider, contentType, image) VALUES (?, ?, ?, ?, ?, ?)") + private val mapPstmt = dbConn.con.prepareStatement("INSERT INTO mapId(uuid, token) VALUES (?, ?)") + private val selectPstmt = dbConn.con.prepareStatement("SELECT secret, provider FROM challenge WHERE token = (SELECT m.token FROM mapId m, challenge c WHERE m.token=c.token AND m.uuid = ?)") + private val imagePstmt = dbConn.con.prepareStatement("SELECT image FROM challenge c, mapId m WHERE c.token=m.token AND m.uuid = ?") + private val updatePstmt = dbConn.con.prepareStatement("UPDATE challenge SET solved = True WHERE token = (SELECT m.token FROM mapId m, challenge c WHERE m.token=c.token AND m.uuid = ?)") + private val userPstmt = dbConn.con.prepareStatement("INSERT INTO users(email, hash) VALUES (?,?)") + def getProvider(): String = { val random = new scala.util.Random val keys = providers.keys @@ -57,12 +73,6 @@ class Captcha(throttle: Int) extends DBConn { imageOpt } - def generateChallengeSamples() = { - providers.map {case (key, provider) => - (key, provider.returnChallenge()) - } - } - private val uniqueIntCount = new AtomicInteger() def generateChallenge(param: Parameters): String = { @@ -167,8 +177,9 @@ class Captcha(throttle: Int) extends DBConn { object LCFramework{ def main(args: scala.Array[String]) { - val captcha = new Captcha(2) - val server = new Server(8888) + val dbConn = new DBConn() + val captcha = new Captcha(2, dbConn) + val server = new Server(8888, captcha, dbConn) captcha.beginThread(2) server.start() } @@ -176,8 +187,7 @@ object LCFramework{ object MakeSamples { def main(args: scala.Array[String]) { - val captcha = new Captcha(2) - val samples = captcha.generateChallengeSamples() + val samples = CaptchaProviders.generateChallengeSamples() samples.foreach {case (key, sample) => val extensionMap = Map("image/png" -> "png", "image/gif" -> "gif") println(key + ": " + sample) diff --git a/src/main/scala/lc/Server.scala b/src/main/scala/lc/Server.scala index d40e957..3120325 100644 --- a/src/main/scala/lc/Server.scala +++ b/src/main/scala/lc/Server.scala @@ -9,57 +9,62 @@ import lc.HTTPServer._ case class Secret(token: Int) -class RateLimiter extends DBConn { - val stmt = getConn() - val userLastActive = collection.mutable.Map[Int, Long]() - val userAllowance = collection.mutable.Map[Int, Double]() - val rate = 8.0 - val per = 45.0 - val allowance = rate +class RateLimiter(dbConn: DBConn) { + private val userLastActive = collection.mutable.Map[Int, Long]() + private val userAllowance = collection.mutable.Map[Int, Double]() + private val rate = 800000.0 + private val per = 45.0 + private val allowance = rate - def validateUser(user: Int) : Boolean = { - synchronized { - val allow = if(userLastActive.contains(user)){ + private val validatePstmt = dbConn.con.prepareStatement("SELECT hash FROM users WHERE hash = ? LIMIT 1") + + private def validateUser(user: Int) : Boolean = { + val allow = if(userLastActive.contains(user)){ + true + } else { + validatePstmt.setInt(1, user) + val rs = validatePstmt.executeQuery() + val validated = if(rs.next()){ + val hash = rs.getInt("hash") + userLastActive(hash) = System.currentTimeMillis() + userAllowance(hash) = allowance true } else { - validatePstmt.setInt(1, user) - val rs = validatePstmt.executeQuery() - val validated = if(rs.next()){ - val hash = rs.getInt("hash") - userLastActive(hash) = System.currentTimeMillis() - userAllowance(hash) = allowance - true - } else { - false - } - validated - } - allow - } - } - - def checkLimit(user: Int): Boolean = { - synchronized { - val current = System.currentTimeMillis() - val time_passed = (current - userLastActive(user)) / 1000 - userLastActive(user) = current - userAllowance(user) += time_passed * (rate/per) - if(userAllowance(user) > rate){ userAllowance(user) = rate } - val allow = if(userAllowance(user) < 1.0){ false - } else { - userAllowance(user) -= 1.0 - true } - allow + validated } + allow } + private def checkLimit(user: Int): Boolean = { + val current = System.currentTimeMillis() + val time_passed = (current - userLastActive(user)) / 1000 + userLastActive(user) = current + userAllowance(user) += time_passed * (rate/per) + if(userAllowance(user) > rate){ userAllowance(user) = rate } + val allow = if(userAllowance(user) < 1.0){ + false + } else { + userAllowance(user) -= 1.0 + true + } + allow + } + + def checkUserAccess(token: Int) : Boolean = { + synchronized { + if (validateUser(token)) { + return checkLimit(token) + } else { + return false + } + } + } } -class Server(port: Int){ - val captcha = new Captcha(0) - val rateLimiter = new RateLimiter() +class Server(port: Int, captcha: Captcha, dbConn: DBConn){ + val rateLimiter = new RateLimiter(dbConn) val server = new HTTPServer(port) val host = server.getVirtualHost(null) @@ -67,7 +72,7 @@ class Server(port: Int){ host.addContext("/v1/captcha",(req, resp) => { val accessToken = Option(req.getHeaders().get("access-token")).map(_.toInt) - val access = accessToken.map(t => rateLimiter.validateUser(t) && rateLimiter.checkLimit(t)).getOrElse(false) + val access = accessToken.map(rateLimiter.checkUserAccess).getOrElse(false) if(access){ val body = req.getJson() val json = parse(body)