Skip to content

Commit

Permalink
Add TURN server support
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrivoruchko committed Aug 22, 2024
1 parent 9c6546f commit 26e9862
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.json.JSONArray
import org.json.JSONException
import org.json.JSONObject
import org.webrtc.IceCandidate
import org.webrtc.PeerConnection.IceServer
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit

Expand Down Expand Up @@ -67,7 +68,7 @@ internal class SocketSignaling(
fun onSocketDisconnected(reason: String)
fun onStreamCreated(streamId: StreamId)
fun onStreamRemoved()
fun onClientJoin(clientId: ClientId)
fun onClientJoin(clientId: ClientId, iceServers: List<IceServer>)
fun onClientLeave(clientId: ClientId)
fun onClientNotFound(clientId: ClientId, reason: String)
fun onClientAnswer(clientId: ClientId, answer: Answer)
Expand Down Expand Up @@ -100,6 +101,7 @@ internal class SocketSignaling(
const val STATUS = "status"
const val OK = "OK"
const val STREAM_ID = "streamId"
const val ICE_SERVERS = "iceServers"
const val PASSWORD_HASH = "passwordHash"
const val CLIENT_ID = "clientId"
const val OFFER = "offer"
Expand Down Expand Up @@ -232,7 +234,7 @@ internal class SocketSignaling(
payload.sendErrorAck(Payload.ERROR_EMPTY_OR_BAD_DATA)
} else if (passwordVerifier.isValid(payload.clientId, payload.passwordHash)) {
payload.sendOkAck()
eventListener.onClientJoin(payload.clientId)
eventListener.onClientJoin(payload.clientId, payload.iceServers)
} else {
XLog.w(getLog("onStreamCreated", "[${Event.STREAM_JOIN}] Wrong stream password"))
payload.sendErrorAck(Payload.ERROR_WRONG_STREAM_PASSWORD)
Expand Down Expand Up @@ -489,6 +491,24 @@ internal class SocketSignaling(
json?.optString(Payload.PASSWORD_HASH) ?: ""
}

val iceServers: List<IceServer> by lazy(LazyThreadSafetyMode.NONE) {
json?.optJSONArray(Payload.ICE_SERVERS)?.let { iceServersArray ->
(0 until iceServersArray.length()).mapNotNull { i ->
val iceServerJson = iceServersArray.optJSONObject(i) ?: return@mapNotNull null
val urls = iceServerJson.optString("urls").ifBlank { null } ?: return@mapNotNull null
val username = iceServerJson.optString("username").ifBlank { null }
val credential = iceServerJson.optString("credential").ifBlank { null }

IceServer.builder(urls).apply {
if (username != null && credential != null) {
setUsername(username)
setPassword(credential)
}
}.createIceServer()
}
} ?: emptyList()
}

val answer: Answer by lazy(LazyThreadSafetyMode.NONE) {
json?.optString(Payload.ANSWER)?.let { Answer(it) } ?: Answer("")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import java.util.zip.CRC32

internal class WebRtcClient(
internal val clientId: ClientId,
iceServers: List<IceServer>,
private val factory: PeerConnectionFactory,
private val videoCodecs: List<RtpCapabilities.CodecCapability>,
private val audioCodecs: List<RtpCapabilities.CodecCapability>,
Expand All @@ -49,6 +50,9 @@ internal class WebRtcClient(
private enum class State { CREATED, PENDING_OFFER, PENDING_OFFER_ACCEPT, OFFER_ACCEPTED }

private val id: String = "${clientId.value}#$publicId"
private val rtcConfig = RTCConfiguration(iceServers.ifEmpty { defaultIceServers }).apply {
tcpCandidatePolicy = PeerConnection.TcpCandidatePolicy.DISABLED
}
private val pendingIceCandidates: MutableList<IceCandidate> = Collections.synchronizedList(mutableListOf())

private val state: AtomicReference<State> = AtomicReference(State.CREATED)
Expand Down Expand Up @@ -89,7 +93,7 @@ internal class WebRtcClient(
}
if (it.mediaType == MediaStreamTrack.MediaType.MEDIA_TYPE_AUDIO) it.setCodecPreferences(audioCodecs)
}
//setBitrate(200_000, 4_000_000, 8_000_000) doesn't work
// TODO setBitrate(200_000, 2_000_000, 4_000_000)
}

mediaStreamId = mediaStream.id
Expand Down Expand Up @@ -248,17 +252,33 @@ internal class WebRtcClient(
}

// Signaling thread
private fun onCandidatePairChanged(event: CandidatePairChangeEvent) {
val msg = "Client: $id, MediaStream: $mediaStreamId, State: $state"
private fun onCandidatePairChanged() {
XLog.d(this@WebRtcClient.getLog("onCandidatePairChanged", "Client: $id, MediaStream: $mediaStreamId, State: $state"))

peerConnection?.getStats { report ->
val transport = report.statsMap.filter { it.value.type == "transport" }.values.firstOrNull() ?: return@getStats
val selectedCandidatePairId = transport.members.get("selectedCandidatePairId") as String? ?: return@getStats

val selectedCandidatePair = report.statsMap[selectedCandidatePairId] ?: return@getStats
val localCandidateId = selectedCandidatePair.members["localCandidateId"] as String? ?: return@getStats
val remoteCandidateId = selectedCandidatePair.members["remoteCandidateId"] as String? ?: return@getStats
val localCandidate = report.statsMap[localCandidateId] ?: return@getStats
val remoteCandidate = report.statsMap[remoteCandidateId] ?: return@getStats

val localNetworkType = localCandidate.members["networkType"] as String? ?: ""
val localCandidateType = (localCandidate.members["candidateType"] as String? ?: "").let { type ->
when {
type.equals("host", ignoreCase = true) -> "HOST"
type.equals("srflx", ignoreCase = true) -> "STUN"
type.equals("prflx", ignoreCase = true) -> "STUN"
type.equals("relay", ignoreCase = true) -> "TURN"
else -> null
}
}
val remoteIP = remoteCandidate.members["ip"] as String? ?: ""

if (state.get() == State.OFFER_ACCEPTED) {
XLog.d(this@WebRtcClient.getLog("onCandidatePairChanged", msg))
clientAddress.set(event.runCatching { remote.sdp.split(' ', limit = 6).drop(4).first() }
.map { if (regexIPv4.matches(it) || regexIPv6Standard.matches(it) || regexIPv6Compressed.matches(it)) it else "-" }
.getOrElse { "-" })
clientAddress.set("${localNetworkType.uppercase()}${localCandidateType?.let { " [$it]" }}\n$remoteIP")
eventListener.onClientAddress(clientId)
} else {
XLog.d(this@WebRtcClient.getLog("onCandidatePairChanged", "Ignoring"), IllegalStateException("onCandidatePairChanged: $msg"))
}
}

Expand All @@ -279,7 +299,7 @@ internal class WebRtcClient(
private class WebRTCPeerConnectionObserver(
private val clientId: ClientId,
private val onHostCandidate: (IceCandidate) -> Unit,
private val onCandidatePairChanged: (CandidatePairChangeEvent) -> Unit,
private val onCandidatePairChanged: () -> Unit,
private val onPeerDisconnected: () -> Unit
) : PeerConnection.Observer {

Expand Down Expand Up @@ -324,7 +344,7 @@ internal class WebRtcClient(

override fun onSelectedCandidatePairChanged(event: CandidatePairChangeEvent) {
XLog.v(getLog("onSelectedCandidatePairChanged", "Client: $clientId"))
onCandidatePairChanged(event)
onCandidatePairChanged()
}

override fun onAddStream(mediaStream: MediaStream?) {
Expand Down Expand Up @@ -366,26 +386,13 @@ internal class WebRtcClient(

private companion object {
@JvmStatic
private val iceServers = listOf(
IceServer.builder("stun:stun.l.google.com:19302").createIceServer(),
IceServer.builder("stun:stun1.l.google.com:19302").createIceServer(),
IceServer.builder("stun:stun2.l.google.com:19302").createIceServer(),
IceServer.builder("stun:stun3.l.google.com:19302").createIceServer(),
IceServer.builder("stun:stun4.l.google.com:19302").createIceServer()
)

@JvmStatic
private val rtcConfig = RTCConfiguration(iceServers.asSequence().shuffled().take(2).toList()).apply {
tcpCandidatePolicy = PeerConnection.TcpCandidatePolicy.DISABLED
}

@JvmStatic
private val regexIPv4 = "^(([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])(\\.(?!\$)|\$)){4}\$".toRegex()

@JvmStatic
private val regexIPv6Standard = "^(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\$".toRegex()

@JvmStatic
private val regexIPv6Compressed = "^((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)::((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)\$".toRegex()
private val defaultIceServers
get() = sequenceOf(
"stun:stun.l.google.com:19302",
"stun:stun1.l.google.com:19302",
"stun:stun2.l.google.com:19302",
"stun:stun3.l.google.com:19302",
"stun:stun4.l.google.com:19302",
).shuffled().take(2).map { IceServer.builder(it).createIceServer() }.toList()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.koin.core.annotation.InjectedParam
import org.koin.core.annotation.Scope
import org.koin.core.annotation.Scoped
import org.webrtc.IceCandidate
import org.webrtc.PeerConnection.IceServer
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import kotlin.math.pow
Expand Down Expand Up @@ -105,7 +106,7 @@ internal class WebRtcStreamingService(
data class OpenSocket(val token: PlayIntegrityToken) : InternalEvent(Priority.RECOVER_IGNORE)
data object StreamCreate : InternalEvent(Priority.RECOVER_IGNORE)
data class StreamCreated(val streamId: StreamId) : InternalEvent(Priority.RECOVER_IGNORE)
data class ClientJoin(val clientId: ClientId) : InternalEvent(Priority.RECOVER_IGNORE)
data class ClientJoin(val clientId: ClientId, val iceServers: List<IceServer>) : InternalEvent(Priority.RECOVER_IGNORE)
data class SocketSignalingError(val error: SocketSignaling.Error) : InternalEvent(Priority.RECOVER_IGNORE)

data object StartStream : InternalEvent(Priority.STOP_IGNORE)
Expand Down Expand Up @@ -157,9 +158,9 @@ internal class WebRtcStreamingService(
sendEvent(InternalEvent.StreamCreate)
}

override fun onClientJoin(clientId: ClientId) {
XLog.v(this@WebRtcStreamingService.getLog("SocketSignaling.onClientJoin", "$clientId"))
sendEvent(InternalEvent.ClientJoin(clientId))
override fun onClientJoin(clientId: ClientId, iceServers: List<IceServer>) {
XLog.v(this@WebRtcStreamingService.getLog("SocketSignaling.onClientJoin", "$clientId IceServers: ${iceServers.size}"))
sendEvent(InternalEvent.ClientJoin(clientId, iceServers))
}

override fun onClientLeave(clientId: ClientId) {
Expand Down Expand Up @@ -211,11 +212,13 @@ internal class WebRtcStreamingService(
}

override fun onError(clientId: ClientId, cause: Throwable) {
if (cause.message?.startsWith("onPeerDisconnected") == true)
XLog.e(this@WebRtcStreamingService.getLog("WebRTCClient.onError", "Client: $clientId: ${cause.message}"))
else
if (cause.message?.startsWith("onPeerDisconnected") == true) {
XLog.w(this@WebRtcStreamingService.getLog("WebRTCClient.onError", "Client: $clientId: ${cause.message}"))
sendEvent(WebRtcEvent.RemoveClient(clientId, false, "onError:${cause.message}"))
} else {
XLog.e(this@WebRtcStreamingService.getLog("WebRTCClient.onError", "Client: $clientId"), cause)
sendEvent(WebRtcEvent.RemoveClient(clientId, true, "onError:${cause.message}"))
sendEvent(WebRtcEvent.RemoveClient(clientId, true, "onError:${cause.message}"))
}
}
}

Expand Down Expand Up @@ -543,8 +546,11 @@ internal class WebRtcStreamingService(

val prj = projection!!
clients[event.clientId]?.stop()
clients[event.clientId] =
WebRtcClient(event.clientId, prj.peerConnectionFactory, prj.videoCodecs, prj.audioCodecs, webRtcClientEventListener)
clients[event.clientId] = WebRtcClient(
event.clientId, event.iceServers,
prj.peerConnectionFactory, prj.videoCodecs, prj.audioCodecs,
webRtcClientEventListener
)

if (isStreaming()) {
clients[event.clientId]?.start(prj.localMediaSteam!!)
Expand All @@ -558,8 +564,7 @@ internal class WebRtcStreamingService(
return
}

clients[event.clientId]?.stop()
clients.remove(event.clientId)
clients.remove(event.clientId)?.stop()
if (event.notifyServer)
requireNotNull(signaling) { "signaling==null" }
.sendRemoveClients(listOf(event.clientId), "RemoveClient:${event.reason}")
Expand Down

0 comments on commit 26e9862

Please sign in to comment.