synchronise access to database statements

This commit is contained in:
hrj 2020-07-05 00:50:05 +05:30
parent 0b1e902326
commit 62b3a098bd

View File

@ -5,12 +5,15 @@ import java.io.ByteArrayInputStream
import java.util.concurrent._ import java.util.concurrent._
import java.util.UUID import java.util.UUID
import java.sql.{Blob, ResultSet} import java.sql.{Blob, ResultSet}
import java.util.concurrent.atomic.AtomicInteger
case class Size(height: Int, width: Int) case class Size(height: Int, width: Int)
case class Parameters(level: String, media: String, input_type: String, size: Option[Size]) case class Parameters(level: String, media: String, input_type: String, size: Option[Size])
case class Id(id: String) case class Id(id: String)
case class Answer(answer: String, id: String) case class Answer(answer: String, id: String)
case class ProviderSecret(provider: String, secret: String)
class Captcha(throttle: Int) extends DBConn { class Captcha(throttle: Int) extends DBConn {
val stmt = getConn() val stmt = getConn()
@ -55,13 +58,17 @@ class Captcha(throttle: Int) extends DBConn {
} }
} }
private val uniqueIntCount = new AtomicInteger()
def generateChallenge(param: Parameters): String = { def generateChallenge(param: Parameters): String = {
//TODO: eval params to choose a provider //TODO: eval params to choose a provider
val providerMap = getProvider() val providerMap = getProvider()
val provider = providers(providerMap) val provider = providers(providerMap)
val challenge = provider.returnChallenge() val challenge = provider.returnChallenge()
val blob = new ByteArrayInputStream(challenge.content) val blob = new ByteArrayInputStream(challenge.content)
val token = scala.util.Random.nextInt(10000).toString // val token = scala.util.Random.nextInt(100000).toString
val token = uniqueIntCount.incrementAndGet().toString
insertPstmt.synchronized {
insertPstmt.setString(1, token) insertPstmt.setString(1, token)
insertPstmt.setString(2, provider.getId) insertPstmt.setString(2, provider.getId)
insertPstmt.setString(3, challenge.secret) insertPstmt.setString(3, challenge.secret)
@ -69,6 +76,7 @@ class Captcha(throttle: Int) extends DBConn {
insertPstmt.setString(5, challenge.contentType) insertPstmt.setString(5, challenge.contentType)
insertPstmt.setBlob(6, blob) insertPstmt.setBlob(6, blob)
insertPstmt.executeUpdate() insertPstmt.executeUpdate()
}
token token
} }
@ -91,35 +99,43 @@ class Captcha(throttle: Int) extends DBConn {
} }
def getChallenge(param: Parameters): Id = { def getChallenge(param: Parameters): Id = {
val idOpt = stmt.synchronized {
val rs = stmt.executeQuery("SELECT token FROM challenge WHERE solved=FALSE ORDER BY RAND() LIMIT 1") val rs = stmt.executeQuery("SELECT token FROM challenge WHERE solved=FALSE ORDER BY RAND() LIMIT 1")
val id = if(rs.next()){ if(rs.next()) {
rs.getString("token") Some(rs.getString("token"))
} else { } else {
generateChallenge(param) None
} }
}
val id = idOpt.getOrElse(generateChallenge(param))
val uuid = getUUID(id) val uuid = getUUID(id)
Id(uuid) Id(uuid)
} }
def getUUID(id: String): String = { def getUUID(id: String): String = {
val uuid = UUID.randomUUID().toString val uuid = UUID.randomUUID().toString
mapPstmt.synchronized {
mapPstmt.setString(1,uuid) mapPstmt.setString(1,uuid)
mapPstmt.setString(2,id) mapPstmt.setString(2,id)
mapPstmt.executeUpdate() mapPstmt.executeUpdate()
}
uuid uuid
} }
def checkAnswer(answer: Answer): Boolean = { def checkAnswer(answer: Answer): Boolean = {
val psOpt:Option[ProviderSecret] = selectPstmt.synchronized {
selectPstmt.setString(1, answer.id) selectPstmt.setString(1, answer.id)
val rs: ResultSet = selectPstmt.executeQuery() val rs: ResultSet = selectPstmt.executeQuery()
if (rs.first()) { if (rs.first()) {
val secret = rs.getString("secret") val secret = rs.getString("secret")
val provider = rs.getString("provider") val provider = rs.getString("provider")
providers(provider).checkAnswer(secret, answer.answer) Some(ProviderSecret(provider, secret))
} else { } else {
false None
} }
} }
psOpt.map(ps => providers(ps.provider).checkAnswer(ps.secret, answer.answer)).getOrElse(false)
}
def getHash(email: String): Int = { def getHash(email: String): Int = {
val secret = "" val secret = ""