diff --git a/spec/test-utils/webrtc.ts b/spec/test-utils/webrtc.ts index 09849d4f32b..ac9148212b9 100644 --- a/spec/test-utils/webrtc.ts +++ b/spec/test-utils/webrtc.ts @@ -104,12 +104,12 @@ export class MockRTCPeerConnection { private negotiationNeededListener: () => void; public iceCandidateListener?: (e: RTCPeerConnectionIceEvent) => void; public onTrackListener?: (e: RTCTrackEvent) => void; - private needsNegotiation = false; + public needsNegotiation = false; public readyToNegotiate: Promise; private onReadyToNegotiate: () => void; localDescription: RTCSessionDescription; signalingState: RTCSignalingState = "stable"; - public senders: MockRTCRtpSender[] = []; + public transceivers: MockRTCRtpTransceiver[] = []; public static triggerAllNegotiations(): void { for (const inst of this.instances) { @@ -169,12 +169,23 @@ export class MockRTCPeerConnection { } close() { } getStats() { return []; } - addTrack(track: MockMediaStreamTrack): MockRTCRtpSender { + addTransceiver(track: MockMediaStreamTrack): MockRTCRtpTransceiver { this.needsNegotiation = true; this.onReadyToNegotiate(); + const newSender = new MockRTCRtpSender(track); - this.senders.push(newSender); - return newSender; + const newReceiver = new MockRTCRtpReceiver(track); + + const newTransceiver = new MockRTCRtpTransceiver(this); + newTransceiver.sender = newSender as unknown as RTCRtpSender; + newTransceiver.receiver = newReceiver as unknown as RTCRtpReceiver; + + this.transceivers.push(newTransceiver); + + return newTransceiver; + } + addTrack(track: MockMediaStreamTrack): MockRTCRtpSender { + return this.addTransceiver(track).sender as unknown as MockRTCRtpSender; } removeTrack() { @@ -182,9 +193,8 @@ export class MockRTCPeerConnection { this.onReadyToNegotiate(); } - getSenders(): MockRTCRtpSender[] { return this.senders; } - - getTransceivers = jest.fn().mockReturnValue([]); + getTransceivers(): MockRTCRtpTransceiver[] { return this.transceivers; } + getSenders(): MockRTCRtpSender[] { return this.transceivers.map(t => t.sender as unknown as MockRTCRtpSender); } doNegotiation() { if (this.needsNegotiation && this.negotiationNeededListener) { @@ -198,7 +208,23 @@ export class MockRTCRtpSender { constructor(public track: MockMediaStreamTrack) { } replaceTrack(track: MockMediaStreamTrack) { this.track = track; } - setCodecPreferences(prefs: RTCRtpCodecCapability[]): void {} +} + +export class MockRTCRtpReceiver { + constructor(public track: MockMediaStreamTrack) { } +} + +export class MockRTCRtpTransceiver { + constructor(private peerConn: MockRTCPeerConnection) {} + + public sender: RTCRtpSender; + public receiver: RTCRtpReceiver; + + public set direction(_: string) { + this.peerConn.needsNegotiation = true; + } + + setCodecPreferences = jest.fn(); } export class MockMediaStreamTrack { diff --git a/spec/unit/webrtc/call.spec.ts b/spec/unit/webrtc/call.spec.ts index e592cba9b67..df9c2aee09b 100644 --- a/spec/unit/webrtc/call.spec.ts +++ b/spec/unit/webrtc/call.spec.ts @@ -41,7 +41,6 @@ import { installWebRTCMocks, MockRTCPeerConnection, SCREENSHARE_STREAM_ID, - MockRTCRtpSender, } from "../../test-utils/webrtc"; import { CallFeed } from "../../../src/webrtc/callFeed"; import { EventType, IContent, ISendEventResponse, MatrixEvent, Room } from "../../../src"; @@ -370,17 +369,15 @@ describe('Call', function() { ).typed(), ); - const usermediaSenders: Array = (call as any).usermediaSenders; + // XXX: Lots of inspecting the prvate state of the call object here + const transceivers: Map = (call as any).transceivers; expect(call.localUsermediaStream.id).toBe("stream"); expect(call.localUsermediaStream.getAudioTracks()[0].id).toBe("new_audio_track"); expect(call.localUsermediaStream.getVideoTracks()[0].id).toBe("video_track"); - expect(usermediaSenders.find((sender) => { - return sender?.track?.kind === "audio"; - }).track.id).toBe("new_audio_track"); - expect(usermediaSenders.find((sender) => { - return sender?.track?.kind === "video"; - }).track.id).toBe("video_track"); + // call has a function for generating these but we hardcode here to avoid exporting it + expect(transceivers.get("m.usermedia:audio").sender.track.id).toBe("new_audio_track"); + expect(transceivers.get("m.usermedia:video").sender.track.id).toBe("video_track"); }); it("should handle upgrade to video call", async () => { @@ -400,16 +397,13 @@ describe('Call', function() { // setLocalVideoMuted probably? await (call as any).upgradeCall(false, true); - const usermediaSenders: Array = (call as any).usermediaSenders; + // XXX: More inspecting private state of the call object + const transceivers: Map = (call as any).transceivers; expect(call.localUsermediaStream.getAudioTracks()[0].id).toBe("usermedia_audio_track"); expect(call.localUsermediaStream.getVideoTracks()[0].id).toBe("usermedia_video_track"); - expect(usermediaSenders.find((sender) => { - return sender?.track?.kind === "audio"; - }).track.id).toBe("usermedia_audio_track"); - expect(usermediaSenders.find((sender) => { - return sender?.track?.kind === "video"; - }).track.id).toBe("usermedia_video_track"); + expect(transceivers.get("m.usermedia:audio").sender.track.id).toBe("usermedia_audio_track"); + expect(transceivers.get("m.usermedia:video").sender.track.id).toBe("usermedia_video_track"); }); it("should handle SDPStreamMetadata changes", async () => { @@ -479,6 +473,23 @@ describe('Call', function() { }); describe("should deduce the call type correctly", () => { + beforeEach(async () => { + // start an incoming call, but add no feeds + await call.initWithInvite({ + getContent: jest.fn().mockReturnValue({ + version: "1", + call_id: "call_id", + party_id: "remote_party_id", + lifetime: CALL_LIFETIME, + offer: { + sdp: DUMMY_SDP, + }, + }), + getSender: () => "@test:foo", + getLocalAge: () => 1, + } as unknown as MatrixEvent); + }); + it("if no video", async () => { call.getOpponentMember = jest.fn().mockReturnValue({ userId: "@bob:bar.uk" }); @@ -1057,9 +1068,24 @@ describe('Call', function() { }); describe("Screen sharing", () => { + const waitNegotiateFunc = resolve => { + mockSendEvent.mockImplementationOnce(() => { + // Note that the peer connection here is a dummy one and always returns + // dummy SDP, so there's not much point returning the content: the SDP will + // always be the same. + resolve(); + return Promise.resolve({ event_id: "foo" }); + }); + }; + beforeEach(async () => { await startVoiceCall(client, call); + const sendNegotiatePromise = new Promise(waitNegotiateFunc); + + MockRTCPeerConnection.triggerAllNegotiations(); + await sendNegotiatePromise; + await call.onAnswerReceived(makeMockEvent("@test:foo", { "version": 1, "call_id": call.callId, @@ -1090,12 +1116,7 @@ describe('Call', function() { ).toHaveLength(1); mockSendEvent.mockReset(); - const sendNegotiatePromise = new Promise(resolve => { - mockSendEvent.mockImplementationOnce(() => { - resolve(); - return Promise.resolve({ event_id: "foo" }); - }); - }); + const sendNegotiatePromise = new Promise(waitNegotiateFunc); MockRTCPeerConnection.triggerAllNegotiations(); await sendNegotiatePromise; @@ -1130,29 +1151,52 @@ describe('Call', function() { headerExtensions: [], }); - const prom = new Promise(resolve => { - const mockPeerConn = call.peerConn as unknown as MockRTCPeerConnection; - mockPeerConn.addTrack = jest.fn().mockImplementation((track: MockMediaStreamTrack) => { - const mockSender = new MockRTCRtpSender(track); - mockPeerConn.getTransceivers.mockReturnValue([{ - sender: mockSender, - setCodecPreferences: (prefs: RTCRtpCodecCapability[]) => { - expect(prefs).toEqual([ - expect.objectContaining({ mimeType: "video/somethingelse" }), - ]); - - resolve(); - }, - }]); + mockSendEvent.mockReset(); + const sendNegotiatePromise = new Promise(waitNegotiateFunc); - return mockSender; - }); - }); + await call.setScreensharingEnabled(true); + MockRTCPeerConnection.triggerAllNegotiations(); + + await sendNegotiatePromise; + + const mockPeerConn = call.peerConn as unknown as MockRTCPeerConnection; + expect( + mockPeerConn.transceivers[mockPeerConn.transceivers.length - 1].setCodecPreferences, + ).toHaveBeenCalledWith([expect.objectContaining({ mimeType: "video/somethingelse" })]); + }); + + it("re-uses transceiver when screen sharing is re-enabled", async () => { + const mockPeerConn = call.peerConn as unknown as MockRTCPeerConnection; + + // sanity check: we should start with one transciever (user media audio) + expect(mockPeerConn.transceivers.length).toEqual(1); + + const screenshareOnProm1 = new Promise(waitNegotiateFunc); + + await call.setScreensharingEnabled(true); + MockRTCPeerConnection.triggerAllNegotiations(); + + await screenshareOnProm1; + + // we should now have another transciever for the screenshare + expect(mockPeerConn.transceivers.length).toEqual(2); + + const screenshareOffProm = new Promise(waitNegotiateFunc); + await call.setScreensharingEnabled(false); + MockRTCPeerConnection.triggerAllNegotiations(); + await screenshareOffProm; + + // both transceivers should still be there + expect(mockPeerConn.transceivers.length).toEqual(2); + const screenshareOnProm2 = new Promise(waitNegotiateFunc); await call.setScreensharingEnabled(true); MockRTCPeerConnection.triggerAllNegotiations(); + await screenshareOnProm2; - await prom; + // should still be two, ie. another one should not have been created + // when re-enabling the screen share. + expect(mockPeerConn.transceivers.length).toEqual(2); }); }); diff --git a/src/webrtc/call.ts b/src/webrtc/call.ts index ce1cae8cd08..1fb6a4c2293 100644 --- a/src/webrtc/call.ts +++ b/src/webrtc/call.ts @@ -308,6 +308,16 @@ export type CallEventHandlerMap = { [CallEvent.SendVoipEvent]: (event: Record) => void; }; +// The key of the transceiver map (purpose + media type, separated by ':') +type TransceiverKey = string; + +// generates keys for the map of transceivers +// kind is unfortunately a string rather than MediaType as this is the type of +// track.kind +function getTransceiverKey(purpose: SDPStreamMetadataPurpose, kind: TransceiverKey): string { + return purpose + ':' + kind; +} + /** * Construct a new Matrix Call. * @constructor @@ -345,8 +355,10 @@ export class MatrixCall extends TypedEventEmitter = []; - private usermediaSenders: Array = []; - private screensharingSenders: Array = []; + + // our transceivers for each purpose and type of media + private transceivers = new Map(); + private inviteOrAnswerSent = false; private waitForLocalAVStream: boolean; private successor: MatrixCall; @@ -634,6 +646,18 @@ export class MatrixCall extends TypedEventEmitter t.receiver.track == track); + this.transceivers.set(getTransceiverKey(purpose, track.kind), transceiver); + } + } + this.emit(CallEvent.FeedsChanged, this.feeds); logger.info( @@ -675,6 +699,12 @@ export class MatrixCall extends TypedEventEmitter t.receiver.track == track); + this.transceivers.set(getTransceiverKey(purpose, track.kind), transceiver); + } + this.emit(CallEvent.FeedsChanged, this.feeds); logger.info(`Call ${this.callId} pushed remote stream (id="${stream.id}", active="${stream.active}")`); @@ -722,11 +752,6 @@ export class MatrixCall extends TypedEventEmitter { return track.kind === "video"; }); - const sender = this.usermediaSenders.find((sender) => { - return sender.track?.kind === "video"; - }); + + const sender = this.transceivers.get(getTransceiverKey( + SDPStreamMetadataPurpose.Usermedia, "video", + )).sender; + sender.replaceTrack(track); this.pushNewLocalFeed(stream, SDPStreamMetadataPurpose.Screenshare, false); @@ -1183,9 +1243,9 @@ export class MatrixCall extends TypedEventEmitter { return track.kind === "video"; }); - const sender = this.usermediaSenders.find((sender) => { - return sender.track?.kind === "video"; - }); + const sender = this.transceivers.get(getTransceiverKey( + SDPStreamMetadataPurpose.Usermedia, "video", + )).sender; sender.replaceTrack(track); this.client.getMediaHandler().stopScreensharingStream(this.localScreensharingStream); @@ -1219,28 +1279,30 @@ export class MatrixCall extends TypedEventEmitter { - return sender.track?.kind === track.kind; - }); + const tKey = getTransceiverKey(SDPStreamMetadataPurpose.Usermedia, track.kind); - let newSender: RTCRtpSender; + const oldSender = this.transceivers.get(tKey)?.sender; + let added = false; + if (oldSender) { + try { + logger.info( + `Call ${this.callId} `+ + `Replacing track (` + + `id="${track.id}", ` + + `kind="${track.kind}", ` + + `streamId="${stream.id}", ` + + `streamPurpose="${callFeed.purpose}"` + + `) to peer connection`, + ); + await oldSender.replaceTrack(track); + added = true; + } catch (error) { + logger.warn(`replaceTrack failed: adding new transceiver instead`, error); + } + } - try { - logger.info( - `Call ${this.callId} `+ - `Replacing track (` + - `id="${track.id}", ` + - `kind="${track.kind}", ` + - `streamId="${stream.id}", ` + - `streamPurpose="${callFeed.purpose}"` + - `) to peer connection`, - ); - await oldSender.replaceTrack(track); - newSender = oldSender; - } catch (error) { + if (!added) { logger.info( `Call ${this.callId} `+ `Adding track (` + @@ -1250,13 +1312,13 @@ export class MatrixCall extends TypedEventEmitter => { diff --git a/src/webrtc/groupCall.ts b/src/webrtc/groupCall.ts index 350215b9a0f..c20592f52f9 100644 --- a/src/webrtc/groupCall.ts +++ b/src/webrtc/groupCall.ts @@ -607,7 +607,9 @@ export class GroupCall extends TypedEventEmitter< return false; } } else { - await Promise.all(this.calls.map(call => call.removeLocalFeed(call.localScreensharingFeed))); + await Promise.all(this.calls.map(call => { + if (call.localScreensharingFeed) call.removeLocalFeed(call.localScreensharingFeed); + })); this.client.getMediaHandler().stopScreensharingStream(this.localScreenshareFeed.stream); this.removeScreenshareFeed(this.localScreenshareFeed); this.localScreenshareFeed = undefined;