Skip to content

Commit

Permalink
Merge remote-tracking branch 'Tim203/feat/nethernet' into playground/…
Browse files Browse the repository at this point in the history
…nethernet
  • Loading branch information
rtm516 committed Aug 19, 2024
2 parents aa357b0 + d26cdd3 commit 95c27f4
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,44 @@

import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;

public class CustomDatagramTransport implements DatagramTransport {
private final ComponentSocket socket;
private final DatagramSocket socket;
private final Component component;
private final int receiveLimit = 1500; // Typically, a standard MTU size
private final int sendLimit = 1500; // Typically, a standard MTU size
private final int maxMessageSize = 262144; // vanilla

public CustomDatagramTransport(Component component) {
this.socket = component.getComponentSocket();
this.socket = component.getSocket();
this.component = component;
}

@Override
public int getReceiveLimit() {
return receiveLimit;
return maxMessageSize;
}

@Override
public int getSendLimit() {
return sendLimit;
return maxMessageSize;
}

@Override
public int receive(byte[] buf, int off, int len, int waitMillis) throws IOException {
System.out.println("receive! " + new String(buf, off, len));
DatagramPacket packet = new DatagramPacket(buf, off, len);
socket.receive(packet);
return packet.getLength();
}

@Override
public void send(byte[] buf, int off, int len) throws IOException {
DatagramPacket packet = new DatagramPacket(buf, off, len, component.getDefaultCandidate().getTransportAddress());
socket.send(packet);
System.out.println("send! " + new String(buf, off, len));
socket.send(new DatagramPacket(buf, off, len, component.getDefaultCandidate().getTransportAddress()));
}

@Override
public void close() throws IOException {
public void close() {
socket.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import java.io.IOException;
import java.math.BigInteger;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.security.KeyPairGenerator;
import java.security.SecureRandom;
import java.security.Security;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -28,8 +29,9 @@
import javax.sdp.MediaDescription;

import org.bouncycastle.asn1.x509.X509Name;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.tls.AlertDescription;
import org.bouncycastle.tls.Certificate;
import org.bouncycastle.tls.CertificateRequest;
import org.bouncycastle.tls.DTLSClientProtocol;
import org.bouncycastle.tls.DefaultTlsClient;
Expand All @@ -38,7 +40,12 @@
import org.bouncycastle.tls.TlsCredentials;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsServerCertificate;
import org.bouncycastle.tls.crypto.TlsCertificate;
import org.bouncycastle.tls.crypto.TlsCryptoParameters;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaDefaultTlsCredentialedSigner;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCertificate;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCryptoProvider;
import org.bouncycastle.util.encoders.Hex;
import org.bouncycastle.x509.X509V3CertificateGenerator;
import org.ice4j.Transport;
import org.ice4j.TransportAddress;
Expand All @@ -59,10 +66,6 @@
* Handle the connection and authentication with the RTA websocket
*/
public class RtcWebsocketClient extends WebSocketClient {
static {
Security.addProvider(new BouncyCastleProvider());
}

private final Logger logger;

private RTCConfiguration rtcConfig;
Expand Down Expand Up @@ -132,33 +135,32 @@ private void handleDataAction(BigInteger from, String message) {
if ("CONNECTREQUEST".equals(type)) {
handleConnectRequest(from, sessionId, content);
} else if ("CANDIDATEADD".equals(type)) {
handleCandidateAdd(sessionId, content);
try {
handleCandidateAdd(sessionId, content);
} catch (UnknownHostException e) {
throw new RuntimeException(e);
}
}
}

private void handleConnectRequest(BigInteger from, String sessionId, String message) {
// agent.startCandidateTrickle(iceCandidates -> {
// iceCandidates
// });

try {
var factory = new NistSdpFactory();

var offer = factory.createSessionDescription(message);

String userFragment = "";
String password = "";
String fingerprint;
var stream = agent.createMediaStream("application");
String fingerprint = null;
for (Object mediaDescription : offer.getMediaDescriptions(false)) {
var description = (MediaDescription) mediaDescription;
for (Object descriptionAttribute : description.getAttributes(false)) {
var attribute = (Attribute) descriptionAttribute;
switch (attribute.getName()) {
case "ice-ufrag":
userFragment = attribute.getValue();
stream.setRemoteUfrag(attribute.getValue());
break;
case "ice-pwd":
password = attribute.getValue();
stream.setRemotePassword(attribute.getValue());
break;
case "fragment":
fingerprint = attribute.getValue();
Expand All @@ -167,8 +169,9 @@ private void handleConnectRequest(BigInteger from, String sessionId, String mess
}
}

component.getParentStream().setRemoteUfrag(userFragment);
component.getParentStream().setRemotePassword(password);
agent.startConnectivityEstablishment();

component = agent.createComponent(stream, 5000, 5000, 6000);

var keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
Expand All @@ -182,27 +185,34 @@ private void handleConnectRequest(BigInteger from, String sessionId, String mess
certGen.setSubjectDN(new X509Name("CN=Test Certificate"));
certGen.setPublicKey(keyPair.getPublic());
certGen.setSignatureAlgorithm("SHA256WithRSA");
var cert = certGen.generate(keyPair.getPrivate());

var crypto = new JcaTlsCryptoProvider().create(SecureRandom.getInstanceStrong());
var bcCert = new Certificate(new TlsCertificate[]{new JcaTlsCertificate(crypto, cert)});

String finalFingerprint = fingerprint;
var client = new DefaultTlsClient(crypto) {
@Override
public TlsAuthentication getAuthentication() throws IOException {
return new TlsAuthentication() {
@Override
public void notifyServerCertificate(TlsServerCertificate serverCertificate) throws IOException {
if (serverCertificate == null || serverCertificate.getCertificate() == null) {
throw new TlsFatalAlert(AlertDescription.handshake_failure);
if (serverCertificate == null || serverCertificate.getCertificate() == null || serverCertificate.getCertificate().isEmpty()) {
System.out.println("invalid cert: " + serverCertificate);
throw new TlsFatalAlert(AlertDescription.bad_certificate);
}
// var status =
var cert = serverCertificate.getCertificate().getCertificateAt(0).getEncoded();
var fp = fingerprintFor(cert);

// System.out.println("status type: " + serverCertificate.);
if (!fp.equals(finalFingerprint)) {
System.out.println("fingerprint does not match! expected " + finalFingerprint + " got " + fp);
throw new TlsFatalAlert(AlertDescription.bad_certificate);
}
}

@Override
public TlsCredentials getClientCredentials(CertificateRequest certificateRequest) throws IOException {
// return new JcaDefaultTlsCredentialedSigner();
return null;
public TlsCredentials getClientCredentials(CertificateRequest certificateRequest) {
return new JcaDefaultTlsCredentialedSigner(new TlsCryptoParameters(context), crypto, keyPair.getPrivate(), bcCert, null);
}
};
}
Expand All @@ -226,11 +236,7 @@ protected ProtocolVersion[] getSupportedVersions() {
});

var answer = factory.createSessionDescription();
long answerSessionId = new Random().nextLong();
while (answerSessionId < 0) {
answerSessionId = new Random().nextLong();
}
answer.setOrigin(factory.createOrigin("-", answerSessionId, 2L, "IN", "IP4", "127.0.0.1"));
answer.setOrigin(factory.createOrigin("-", Math.abs(new Random().nextLong()), 2L, "IN", "IP4", "127.0.0.1"));

var attributes = new Vector<>();
attributes.add(factory.createAttribute("group", "BUNDLE 0"));
Expand All @@ -243,6 +249,11 @@ protected ProtocolVersion[] getSupportedVersions() {
media.setAttribute("ice-ufrag", agent.getLocalUfrag());
media.setAttribute("ice-pwd", agent.getLocalPassword());
media.setAttribute("ice-options", "trickle");
media.setAttribute("fingerprint", "sha-256 " + fingerprintFor(cert.getEncoded()));
media.setAttribute("setup", "active");
media.setAttribute("mid", "0");
media.setAttribute("sctp-port", "5000");
media.setAttribute("max-message-size", "262144");
answer.setMediaDescriptions(new Vector<>(Collections.of(media)));

var json = Constants.GSON.toJson(new WsToMessage(
Expand All @@ -251,7 +262,6 @@ protected ProtocolVersion[] getSupportedVersions() {
System.out.println(json);
send(json);


component.getLocalCandidates().forEach(candidate -> {
var jsonAdd = Constants.GSON.toJson(new WsToMessage(
1, from, "CANDIDATEADD " + sessionId + " " + candidate.toString() + " generation 0 ufrag " + agent.getLocalUfrag() + " network-id " + candidate.getFoundation()
Expand All @@ -260,39 +270,73 @@ protected ProtocolVersion[] getSupportedVersions() {
send(jsonAdd);
});

new Thread(() -> {
agent.addStateChangeListener(evt -> {
System.out.println(evt + " " + evt.getPropertyName());
});

stream.addPairChangeListener(evt -> {
System.out.println("pair change! " + evt);
try {
Thread.sleep(2_500);
agent.startConnectivityEstablishment();
} catch (Exception e) {
e.printStackTrace();
new DTLSClientProtocol().connect(client, new CustomDatagramTransport(component));
} catch (IOException e) {
throw new RuntimeException(e);
}
}).start();
});

// new Thread(() -> {
// try {
// Thread.sleep(1000);
// component.getRemoteCandidates().forEach(remoteCandidate -> {
// System.out.println("remote candidate: " + remoteCandidate);
// });
// agent.startConnectivityEstablishment();
// } catch (InterruptedException e) {
// throw new RuntimeException(e);
// }
// }).start();

// } catch (SdpException | FileNotFoundException | CertificateException | NoSuchAlgorithmException e) {
} catch (Exception e) {
throw new RuntimeException(e);
}

System.out.println("LETS GOOOO CONNECTREQUEST!!!!!");
// var session = pendingSession;
// pendingSession = null;
//
// activeSessions.put(sessionId, session);
// session.receiveOffer(from, sessionId, message);
}

private void handleCandidateAdd(String sessionId, String message) {
private String fingerprintFor(byte[] input) {
var digest = new SHA256Digest();
digest.update(input, 0, input.length);
var result = new byte[digest.getDigestSize()];
digest.doFinal(result, 0);

var hexBytes = Hex.encode(result);
String hex = new String(hexBytes, StandardCharsets.US_ASCII).toUpperCase();

var fp = new StringBuilder();
int i = 0;
fp.append(hex, i, i + 2);
while ((i += 2) < hex.length())
{
fp.append(':');
fp.append(hex, i, i + 2);
}
return fp.toString();
}

private void handleCandidateAdd(String sessionId, String message) throws UnknownHostException {
// agent.candidate
// activeSessions.get(sessionId).addCandidate(message);

RemoteCandidate remoteCandidate = parseCandidate(message, component.getParentStream());
component.addRemoteCandidate(remoteCandidate);
component.addUpdateRemoteCandidates(parseCandidate(message, component.getParentStream()));
// component.updateRemoteCandidates();
}

public static RemoteCandidate parseCandidate(String value,
IceMediaStream stream)
{


public static RemoteCandidate parseCandidate(String value, IceMediaStream stream) {
StringTokenizer tokenizer = new StringTokenizer(value);

//XXX add exception handling.
Expand All @@ -303,8 +347,7 @@ public static RemoteCandidate parseCandidate(String value,
String address = tokenizer.nextToken();
int port = Integer.parseInt(tokenizer.nextToken());

TransportAddress transAddr
= new TransportAddress(address, port, transport);
TransportAddress transAddr = new TransportAddress(address, port, transport);

tokenizer.nextToken(); //skip the "typ" String
CandidateType type = CandidateType.parse(tokenizer.nextToken());
Expand All @@ -329,17 +372,13 @@ public static RemoteCandidate parseCandidate(String value,
tokenizer.nextToken(); // skip the rport element
int relatedPort = Integer.parseInt(tokenizer.nextToken());

TransportAddress raddr = new TransportAddress(
val, relatedPort, Transport.UDP);
TransportAddress raddr = new TransportAddress(val, relatedPort, Transport.UDP);

relatedCandidate = component.findRemoteCandidate(raddr);
}
}

RemoteCandidate cand = new RemoteCandidate(transAddr, component, type,
foundation, priority, relatedCandidate, ufrag);

return cand;
return new RemoteCandidate(transAddr, component, type, foundation, priority, relatedCandidate, ufrag);
}

private void initialize(JsonObject message) {
Expand All @@ -358,16 +397,7 @@ private void initialize(JsonObject message) {
// pendingSession = new PeerSession(this, rtcConfig);

agent = new Agent();
agent.addStateChangeListener(evt -> {
logger.info("ICE Agent state changed: " + evt);
});

try {
IceMediaStream stream = agent.createMediaStream("rtcmedia");
component = agent.createComponent(stream, 5000, 5000, 6000);
} catch (IOException e) {
throw new RuntimeException(e);
}
// agent.setTrickling(true);

for (JsonElement authServerElement : turnAuthServers) {
var authServer = authServerElement.getAsJsonObject();
Expand All @@ -379,17 +409,16 @@ private void initialize(JsonObject message) {
String host = parts[1];
int port = Integer.parseInt(parts[2]);

agent.addCandidateHarvester(switch (type) {
case "stun":
yield new StunCandidateHarvester(new TransportAddress(host, port, Transport.UDP));
case "turn":
yield new TurnCandidateHarvester(
if ("stun".equals(type)) {
agent.addCandidateHarvester(new StunCandidateHarvester(new TransportAddress(host, port, Transport.UDP)));
} else if ("turn".equals(type)) {
agent.addCandidateHarvester(new TurnCandidateHarvester(
new TransportAddress(host, port, Transport.UDP),
new LongTermCredential(username, password)
);
default:
throw new IllegalStateException("Unexpected value: " + type);
});
));
} else {
throw new IllegalStateException("Unexpected value: " + type);
}
});
}
}
Expand Down

0 comments on commit 95c27f4

Please sign in to comment.