Commit abfc612b authored by Daniel Wolf's avatar Daniel Wolf
Browse files

Added TLS socket pooling and reuse

Closes #213
parent cc0b918e
...@@ -124,7 +124,7 @@ dependencies { ...@@ -124,7 +124,7 @@ dependencies {
implementation 'com.frostnerd.utilskt:preferences:1.5.17' // https://git.frostnerd.com/AndroidUtils/preferenceskt implementation 'com.frostnerd.utilskt:preferences:1.5.17' // https://git.frostnerd.com/AndroidUtils/preferenceskt
implementation 'com.frostnerd.utilskt:navigationdraweractivity:1.3.29' // https://git.frostnerd.com/AndroidUtils/navigationdraweractivity implementation 'com.frostnerd.utilskt:navigationdraweractivity:1.3.29' // https://git.frostnerd.com/AndroidUtils/navigationdraweractivity
implementation 'com.frostnerd.utilskt:encrypteddnstunnelproxy:1.5.181' // https://git.frostnerd.com/AndroidUtils/encrypteddnstunnelproxy implementation 'com.frostnerd.utilskt:encrypteddnstunnelproxy:1.5.182' // https://git.frostnerd.com/AndroidUtils/encrypteddnstunnelproxy
implementation 'com.frostnerd.utilskt:general:1.0.19' // https://git.frostnerd.com/AndroidUtils/generalkt implementation 'com.frostnerd.utilskt:general:1.0.19' // https://git.frostnerd.com/AndroidUtils/generalkt
implementation 'com.frostnerd.utilskt:adapters:1.1.6' // https://git.frostnerd.com/AndroidUtils/Adapters implementation 'com.frostnerd.utilskt:adapters:1.1.6' // https://git.frostnerd.com/AndroidUtils/Adapters
......
package com.frostnerd.smokescreen.util.speedtest package com.frostnerd.smokescreen.util.speedtest
import androidx.annotation.IntRange import androidx.annotation.IntRange
import cn.danielw.fop.ObjectFactory
import cn.danielw.fop.ObjectPool
import cn.danielw.fop.PoolConfig
import cn.danielw.fop.Poolable
import com.frostnerd.dnstunnelproxy.DnsServerInformation import com.frostnerd.dnstunnelproxy.DnsServerInformation
import com.frostnerd.dnstunnelproxy.UpstreamAddress import com.frostnerd.dnstunnelproxy.UpstreamAddress
import com.frostnerd.encrypteddnstunnelproxy.HttpsDnsServerInformation import com.frostnerd.encrypteddnstunnelproxy.HttpsDnsServerInformation
...@@ -11,12 +15,14 @@ import okhttp3.OkHttpClient ...@@ -11,12 +15,14 @@ import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
import okhttp3.RequestBody.Companion.toRequestBody import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.Response import okhttp3.Response
import okhttp3.internal.closeQuietly
import org.minidns.dnsmessage.DnsMessage import org.minidns.dnsmessage.DnsMessage
import org.minidns.dnsmessage.Question import org.minidns.dnsmessage.Question
import org.minidns.record.Record import org.minidns.record.Record
import java.io.DataInputStream import java.io.DataInputStream
import java.io.DataOutputStream import java.io.DataOutputStream
import java.net.* import java.net.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import javax.net.ssl.SSLSocketFactory import javax.net.ssl.SSLSocketFactory
import kotlin.random.Random import kotlin.random.Random
...@@ -56,6 +62,24 @@ class DnsSpeedTest(val server: DnsServerInformation<*>, ...@@ -56,6 +62,24 @@ class DnsSpeedTest(val server: DnsServerInformation<*>,
it.urlCreator.address it.urlCreator.address
}) })
} }
private val connectionPool = ConcurrentHashMap<TLSUpstreamAddress, ObjectPool<Socket>>()
private lateinit var poolConfig:PoolConfig
private val poolFactory = object: ObjectFactory<Socket> {
private val sslSocketFactory = SSLSocketFactory.getDefault()
override fun validate(t: Socket): Boolean {
return !t.isConnected
}
override fun destroy(t: Socket) {
t.closeQuietly()
}
override fun create(): Socket {
return sslSocketFactory.createSocket()
}
}
companion object { companion object {
val testDomains = listOf("google.com", "frostnerd.com", "amazon.com", "youtube.com", "github.com", val testDomains = listOf("google.com", "frostnerd.com", "amazon.com", "youtube.com", "github.com",
"stackoverflow.com", "stackexchange.com", "spotify.com", "material.io", "reddit.com", "android.com") "stackoverflow.com", "stackexchange.com", "spotify.com", "material.io", "reddit.com", "android.com")
...@@ -67,6 +91,13 @@ class DnsSpeedTest(val server: DnsServerInformation<*>, ...@@ -67,6 +91,13 @@ class DnsSpeedTest(val server: DnsServerInformation<*>,
*/ */
fun runTest(@IntRange(from = 1) passes: Int): Int? { fun runTest(@IntRange(from = 1) passes: Int): Int? {
var ttl = 0 var ttl = 0
poolConfig = PoolConfig().apply {
this.maxSize = 2
this.minSize = 1
this.partitionSize = 1
this.maxIdleMilliseconds = 60*1000*5
}
for (i in 0 until passes) { for (i in 0 until passes) {
if (server is HttpsDnsServerInformation) { if (server is HttpsDnsServerInformation) {
server.serverConfigurations.values.forEach { server.serverConfigurations.values.forEach {
...@@ -129,18 +160,30 @@ class DnsSpeedTest(val server: DnsServerInformation<*>, ...@@ -129,18 +160,30 @@ class DnsSpeedTest(val server: DnsServerInformation<*>,
} }
} }
private fun obtainTlsSocket(address: TLSUpstreamAddress): Poolable<Socket>? {
return try {
connectionPool.getOrPut(address) {
ObjectPool(poolConfig, poolFactory)
}.borrowObject()
} catch (e: RuntimeException) {
null
}
}
private fun testTls(address: TLSUpstreamAddress): Int? { private fun testTls(address: TLSUpstreamAddress): Int? {
val addr = val addr =
address.addressCreator.resolveOrGetResultOrNull(retryIfError = true, runResolveNow = true) ?: run { address.addressCreator.resolveOrGetResultOrNull(retryIfError = true, runResolveNow = true) ?: run {
log("DoT test failed once for ${server.name}: Address failed to resolve ($address)") log("DoT test failed once for ${server.name}: Address failed to resolve ($address)")
return null return null
} }
var socket: Socket? = null var socketPooled: Poolable<Socket>? = null
var socket:Socket? = null
try { try {
socket = SSLSocketFactory.getDefault().createSocket() socketPooled = obtainTlsSocket(address)
socket = socketPooled?.`object` ?: SSLSocketFactory.getDefault().createSocket()
val msg = createTestDnsPacket() val msg = createTestDnsPacket()
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
socket.connect(InetSocketAddress(addr[0], address.port), connectTimeout) socket!!.connect(InetSocketAddress(addr[0], address.port), connectTimeout)
socket.soTimeout = readTimeout socket.soTimeout = readTimeout
val data: ByteArray = msg.toArray() val data: ByteArray = msg.toArray()
val outputStream = DataOutputStream(socket.getOutputStream()) val outputStream = DataOutputStream(socket.getOutputStream())
...@@ -155,7 +198,6 @@ class DnsSpeedTest(val server: DnsServerInformation<*>, ...@@ -155,7 +198,6 @@ class DnsSpeedTest(val server: DnsServerInformation<*>,
inStream.read(readData) inStream.read(readData)
val time = (System.currentTimeMillis() - start).toInt() val time = (System.currentTimeMillis() - start).toInt()
socket.close()
socket = null socket = null
if(!testResponse(DnsMessage(readData))) { if(!testResponse(DnsMessage(readData))) {
log("DoT test failed once for ${server.name}: Testing the response for valid dns message failed") log("DoT test failed once for ${server.name}: Testing the response for valid dns message failed")
...@@ -167,6 +209,7 @@ class DnsSpeedTest(val server: DnsServerInformation<*>, ...@@ -167,6 +209,7 @@ class DnsSpeedTest(val server: DnsServerInformation<*>,
return null return null
} finally { } finally {
socket?.close() socket?.close()
socketPooled?.returnObject()
} }
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment