From 62b3a098bdbe71ecd33194677e58801195b73682 Mon Sep 17 00:00:00 2001 From: hrj Date: Sun, 5 Jul 2020 00:50:05 +0530 Subject: [PATCH] synchronise access to database statements --- src/main/scala/lc/Main.scala | 64 ++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/src/main/scala/lc/Main.scala b/src/main/scala/lc/Main.scala index 61eaa87..8496f2f 100644 --- a/src/main/scala/lc/Main.scala +++ b/src/main/scala/lc/Main.scala @@ -5,12 +5,15 @@ import java.io.ByteArrayInputStream import java.util.concurrent._ import java.util.UUID import java.sql.{Blob, ResultSet} +import java.util.concurrent.atomic.AtomicInteger case class Size(height: Int, width: Int) case class Parameters(level: String, media: String, input_type: String, size: Option[Size]) case class Id(id: String) case class Answer(answer: String, id: String) +case class ProviderSecret(provider: String, secret: String) + class Captcha(throttle: Int) extends DBConn { val stmt = getConn() @@ -55,20 +58,25 @@ class Captcha(throttle: Int) extends DBConn { } } + private val uniqueIntCount = new AtomicInteger() + def generateChallenge(param: Parameters): String = { //TODO: eval params to choose a provider val providerMap = getProvider() val provider = providers(providerMap) val challenge = provider.returnChallenge() val blob = new ByteArrayInputStream(challenge.content) - val token = scala.util.Random.nextInt(10000).toString - insertPstmt.setString(1, token) - insertPstmt.setString(2, provider.getId) - insertPstmt.setString(3, challenge.secret) - insertPstmt.setString(4, providerMap) - insertPstmt.setString(5, challenge.contentType) - insertPstmt.setBlob(6, blob) - insertPstmt.executeUpdate() + // val token = scala.util.Random.nextInt(100000).toString + val token = uniqueIntCount.incrementAndGet().toString + insertPstmt.synchronized { + insertPstmt.setString(1, token) + insertPstmt.setString(2, provider.getId) + insertPstmt.setString(3, challenge.secret) + insertPstmt.setString(4, providerMap) + insertPstmt.setString(5, challenge.contentType) + insertPstmt.setBlob(6, blob) + insertPstmt.executeUpdate() + } token } @@ -91,34 +99,42 @@ class Captcha(throttle: Int) extends DBConn { } def getChallenge(param: Parameters): Id = { - val rs = stmt.executeQuery("SELECT token FROM challenge WHERE solved=FALSE ORDER BY RAND() LIMIT 1") - val id = if(rs.next()){ - rs.getString("token") - } else { - generateChallenge(param) + val idOpt = stmt.synchronized { + val rs = stmt.executeQuery("SELECT token FROM challenge WHERE solved=FALSE ORDER BY RAND() LIMIT 1") + if(rs.next()) { + Some(rs.getString("token")) + } else { + None + } } + val id = idOpt.getOrElse(generateChallenge(param)) val uuid = getUUID(id) Id(uuid) } def getUUID(id: String): String = { val uuid = UUID.randomUUID().toString - mapPstmt.setString(1,uuid) - mapPstmt.setString(2,id) - mapPstmt.executeUpdate() + mapPstmt.synchronized { + mapPstmt.setString(1,uuid) + mapPstmt.setString(2,id) + mapPstmt.executeUpdate() + } uuid } def checkAnswer(answer: Answer): Boolean = { - selectPstmt.setString(1, answer.id) - val rs: ResultSet = selectPstmt.executeQuery() - if (rs.first()) { - val secret = rs.getString("secret") - val provider = rs.getString("provider") - providers(provider).checkAnswer(secret, answer.answer) - } else { - false + val psOpt:Option[ProviderSecret] = selectPstmt.synchronized { + selectPstmt.setString(1, answer.id) + val rs: ResultSet = selectPstmt.executeQuery() + if (rs.first()) { + val secret = rs.getString("secret") + val provider = rs.getString("provider") + Some(ProviderSecret(provider, secret)) + } else { + None + } } + psOpt.map(ps => providers(ps.provider).checkAnswer(ps.secret, answer.answer)).getOrElse(false) } def getHash(email: String): Int = {