Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted quality of life tweaks #760

Merged
merged 10 commits into from
May 30, 2024
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ target_sources(${QUOTIENT_LIB_NAME} PUBLIC FILE_SET HEADERS BASE_DIRS .
Quotient/networksettings.h
Quotient/converters.h
Quotient/util.h
Quotient/ranges_extras.h
Quotient/eventitem.h
Quotient/accountregistry.h
Quotient/mxcreply.h
Expand Down
42 changes: 26 additions & 16 deletions Quotient/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ QFuture<void> Connection::Private::ensureHomeserver(const QString& userId,
return result;
}

void Connection::logout()
QFuture<void> Connection::logout()
{
// If there's an ongoing sync job, stop it (this also suspends sync loop)
const auto wasSyncing = bool(d->syncJob);
Expand All @@ -368,9 +368,11 @@ void Connection::logout()
}

d->logoutJob = callApi<LogoutJob>();
emit stateChanged(); // isLoggedIn() == false from now
Q_ASSERT(!isLoggedIn()); // Because d->logoutJob is running
emit stateChanged();

connect(d->logoutJob, &LogoutJob::finished, this, [this, wasSyncing] {
QFutureInterface<void> p;
connect(d->logoutJob.get(), &BaseJob::finished, this, [this, wasSyncing, p]() mutable {
if (d->logoutJob->status().good()
|| d->logoutJob->error() == BaseJob::Unauthorised
|| d->logoutJob->error() == BaseJob::ContentAccessError) {
Expand All @@ -381,11 +383,15 @@ void Connection::logout()
emit loggedOut();
deleteLater();
} else { // logout() somehow didn't proceed - restore the session state
Q_ASSERT(isLoggedIn());
emit stateChanged();
if (wasSyncing)
syncLoopIteration(); // Resume sync loop (or a single sync)
p.cancel();
}
p.reportFinished();
});
return p.future();
}

void Connection::sync(int timeout)
Expand Down Expand Up @@ -608,8 +614,14 @@ void Connection::Private::consumePresenceData(Events&& presenceData)

void Connection::Private::consumeToDeviceEvents(Events&& toDeviceEvents)
{
if (encryptionData)
encryptionData->consumeToDeviceEvents(std::move(toDeviceEvents));
if (toDeviceEvents.empty())
return;

qCDebug(E2EE) << "Consuming" << toDeviceEvents.size() << "to-device events";
for (auto&& tdEvt : std::move(toDeviceEvents)) {
if (encryptionData)
encryptionData->consumeToDeviceEvent(std::move(tdEvt));
}
}

void Connection::stopSync()
Expand Down Expand Up @@ -769,13 +781,14 @@ JobHandle<CreateRoomJob> Connection::createRoom(
alias, name, topic, invites, invite3pids, roomVersion,
creationContent, initialState, presetName, isDirect)
.then(this, [this, invites, isDirect](auto* j) {
if (auto* room = provideRoom(j->roomId(), JoinState::Join)) {
emit createdRoom(room);
if (isDirect)
for (const auto& i : invites)
addToDirectChats(room, i);
} else
Q_ASSERT_X(false, "Connection::createRoom", "Failed to create a room");
auto* room = provideRoom(j->roomId(), JoinState::Join);
if (ALARM_X(!room, "Failed to create a room"))
return;

emit createdRoom(room);
if (isDirect)
for (const auto& i : invites)
addToDirectChats(room, i);
});
}

Expand All @@ -787,11 +800,8 @@ void Connection::requestDirectChat(const QString& userId)
QFuture<Room*> Connection::getDirectChat(const QString& otherUserId)
{
auto* u = user(otherUserId);
if (u == nullptr) {
qCCritical(MAIN) << "Connection::getDirectChat: Couldn't get a user object for"
<< otherUserId;
if (ALARM_X(!u, u"Couldn't get a user object for" % otherUserId))
return {};
}

// There can be more than one DC; find the first valid (existing and
// not left), and delete inexistent (forgotten?) ones along the way.
Expand Down
2 changes: 1 addition & 1 deletion Quotient/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ public Q_SLOTS:
//! Explicitly request capabilities from the server
void reloadCapabilities();

void logout();
QFuture<void> logout();

void sync(int timeout = -1);
void syncLoop(int timeout = 30000);
Expand Down
113 changes: 50 additions & 63 deletions Quotient/connectionencryptiondata_p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,30 +309,24 @@ void ConnectionEncryptionData::loadOutdatedUserDevices()
});
}

void ConnectionEncryptionData::consumeToDeviceEvents(Events&& toDeviceEvents)
void ConnectionEncryptionData::consumeToDeviceEvent(EventPtr toDeviceEvent)
{
if (!toDeviceEvents.empty()) {
qCDebug(E2EE) << "Consuming" << toDeviceEvents.size()
<< "to-device events";
for (auto&& tdEvt : std::move(toDeviceEvents)) {
if (processIfVerificationEvent(*tdEvt, false))
continue;
if (auto&& event = eventCast<EncryptedEvent>(std::move(tdEvt))) {
if (event->algorithm() != OlmV1Curve25519AesSha2AlgoKey) {
qCDebug(E2EE) << "Unsupported algorithm" << event->id()
<< "for event" << event->algorithm();
continue;
}
if (isKnownCurveKey(event->senderId(), event->senderKey())) {
handleEncryptedToDeviceEvent(*event);
continue;
}
trackedUsers += event->senderId();
outdatedUsers += event->senderId();
encryptionUpdateRequired = true;
pendingEncryptedEvents.push_back(std::move(event));
}
if (processIfVerificationEvent(*toDeviceEvent, false))
return;
if (auto&& event = eventCast<EncryptedEvent>(std::move(toDeviceEvent))) {
if (event->algorithm() != OlmV1Curve25519AesSha2AlgoKey) {
qCDebug(E2EE) << "Unsupported algorithm" << event->id()
<< "for event" << event->algorithm();
return;
}
if (isKnownCurveKey(event->senderId(), event->senderKey())) {
handleEncryptedToDeviceEvent(*event);
return;
}
trackedUsers += event->senderId();
outdatedUsers += event->senderId();
encryptionUpdateRequired = true;
pendingEncryptedEvents.push_back(std::move(event));
}
}

Expand Down Expand Up @@ -362,8 +356,15 @@ bool ConnectionEncryptionData::processIfVerificationEvent(const Event& evt,
false);
}

void ConnectionEncryptionData::handleEncryptedToDeviceEvent(
const EncryptedEvent& event)
class SecretSendEvent : public Event {
public:
using Event::Event;
QUO_EVENT(SecretSendEvent, "m.secret.send")
QUO_CONTENT_GETTER(QString, requestId)
QUO_CONTENT_GETTER(QString, secret)
};

void ConnectionEncryptionData::handleEncryptedToDeviceEvent(const EncryptedEvent& event)
{
const auto [decryptedEvent, olmSessionId] = sessionDecryptMessage(event);
if (!decryptedEvent) {
Expand All @@ -374,27 +375,22 @@ void ConnectionEncryptionData::handleEncryptedToDeviceEvent(

if (processIfVerificationEvent(*decryptedEvent, true))
return;
switchOnType(
*decryptedEvent,
[this, &event,
olmSessionId = olmSessionId](const RoomKeyEvent& roomKeyEvent) {
decryptedEvent->switchOnType(
[this, &event, olmSessionId](const RoomKeyEvent& roomKeyEvent) {
if (auto* detectedRoom = q->room(roomKeyEvent.roomId())) {
detectedRoom->handleRoomKeyEvent(roomKeyEvent, event.senderId(),
olmSessionId, event.senderKey().toLatin1());
detectedRoom->handleRoomKeyEvent(roomKeyEvent, event.senderId(), olmSessionId,
event.senderKey().toLatin1());
} else {
qCDebug(E2EE)
<< "Encrypted event room id" << roomKeyEvent.roomId()
<< "is not found at the connection" << q->objectName();
}
},
[this](const Event& evt) {
//TODO create an event subclass for this
if (evt.matrixType() == "m.secret.send"_ls) {
emit q->secretReceived(evt.contentPart<QString>("request_id"_ls), evt.contentPart<QString>("secret"_ls));
return;
}
qCWarning(E2EE) << "Skipping encrypted to_device event, type"
<< evt.matrixType();
[this](const SecretSendEvent& sse) {
emit q->secretReceived(sse.requestId(), sse.secret());
},
[](const Event& evt) {
qCWarning(E2EE) << "Skipping encrypted to_device event, type" << evt.matrixType();
});
}

Expand Down Expand Up @@ -830,29 +826,22 @@ std::pair<EventPtr, QByteArray> ConnectionEncryptionData::sessionDecryptMessage(
query.bindValue(":curveKey"_ls, encryptedEvent.senderKey());
database.execute(query);
if (!query.next()) {
qCWarning(E2EE) << "Unknown device while trying to recover from "
"broken olm session";
qCWarning(E2EE) << "Unknown device while trying to recover from broken olm session";
return {};
}
auto senderId = encryptedEvent.senderId();
auto deviceId = query.value("deviceId"_ls).toString();
QHash<QString, QHash<QString, QString>> hash{
{ encryptedEvent.senderId(),
{ { deviceId, "signed_curve25519"_ls } } }
};
auto job = q->callApi<ClaimKeysJob>(hash);
QObject::connect(
job, &BaseJob::finished, q, [this, deviceId, job, senderId] {
if (triedDevices.contains({ senderId, deviceId })) {
return;
}
triedDevices += { senderId, deviceId };
qDebug(E2EE)
<< "Sending dummy event to" << senderId << deviceId;
createOlmSession(senderId, deviceId,
job->oneTimeKeys()[senderId][deviceId]);
q->sendToDevice(senderId, deviceId, DummyEvent(), true);
});
QHash<QString, QHash<QString, QString>> hash{ { encryptedEvent.senderId(),
{ { deviceId, "signed_curve25519"_ls } } } };
q->callApi<ClaimKeysJob>(hash).then(q, [this, deviceId, senderId](const auto* job) {
if (triedDevices.contains({ senderId, deviceId })) {
return;
}
triedDevices += { senderId, deviceId };
qDebug(E2EE) << "Sending dummy event to" << senderId << deviceId;
createOlmSession(senderId, deviceId, job->oneTimeKeys()[senderId][deviceId]);
q->sendToDevice(senderId, deviceId, DummyEvent(), true);
});
return {};
}

Expand All @@ -861,8 +850,8 @@ std::pair<EventPtr, QByteArray> ConnectionEncryptionData::sessionDecryptMessage(

if (auto sender = decryptedEvent->fullJson()[SenderKey].toString();
sender != encryptedEvent.senderId()) {
qWarning(E2EE) << "Found user" << sender << "instead of sender"
<< encryptedEvent.senderId() << "in Olm plaintext";
qWarning(E2EE) << "Found user" << sender << "instead of sender" << encryptedEvent.senderId()
<< "in Olm plaintext";
return {};
}

Expand All @@ -872,12 +861,10 @@ std::pair<EventPtr, QByteArray> ConnectionEncryptionData::sessionDecryptMessage(
query.bindValue(":curveKey"_ls, senderKey);
database.execute(query);
if (!query.next()) {
qWarning(E2EE) << "Received olm message from unknown device"
<< senderKey;
qWarning(E2EE) << "Received olm message from unknown device" << senderKey;
return {};
}
if (auto edKey =
decryptedEvent->fullJson()["keys"_ls][Ed25519Key].toString();
if (auto edKey = decryptedEvent->fullJson()["keys"_ls][Ed25519Key].toString();
edKey.isEmpty() || query.value("edKey"_ls).toString() != edKey) //
{
qDebug(E2EE) << "Received olm message with invalid ed key";
Expand Down
5 changes: 2 additions & 3 deletions Quotient/connectionencryptiondata_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace _impl {

void onSyncSuccess(SyncData &syncResponse);
void loadOutdatedUserDevices();
void consumeToDeviceEvents(Events&& toDeviceEvents);
void consumeToDeviceEvent(EventPtr toDeviceEvent);
void encryptionUpdate(const QList<QString>& forUsers);

bool createOlmSession(const QString& targetUserId,
Expand All @@ -68,8 +68,7 @@ namespace _impl {
std::pair<QByteArray, QByteArray> sessionDecryptMessage(
const QJsonObject& personalCipherObject,
const QByteArray& senderKey);
std::pair<EventPtr, QByteArray> sessionDecryptMessage(
const EncryptedEvent& encryptedEvent);
std::pair<EventPtr, QByteArray> sessionDecryptMessage(const EncryptedEvent& encryptedEvent);

QJsonObject assembleEncryptedContent(
QJsonObject payloadJson, const QString& targetUserId,
Expand Down
46 changes: 29 additions & 17 deletions Quotient/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,26 @@
class QVariant;

namespace Quotient {

inline void editSubobject(QJsonObject& json, auto key, std::invocable<QJsonObject&> auto visitor)
{
auto subObject = json.take(key).toObject();
visitor(subObject);
json.insert(key, subObject);
}

inline void replaceSubvalue(QJsonObject& json, auto topLevelKey, auto subKey, QJsonValue subValue)
{
editSubobject(json, topLevelKey, [subKey, subValue](QJsonObject& innerJson) {
innerJson.insert(subKey, subValue);
});
}

template <typename T>
struct JsonObjectConverter {
// To be implemented in specialisations
static void dumpTo(QJsonObject&, const T&) = delete;
static void fillFrom(const QJsonObject&, T&) = delete;
};
struct JsonObjectConverter;
// Specialisations should implement either or both of:
//static void dumpTo(QJsonObject&, const T&); // For toJson() and fillJson() to work
//static void fillFrom(const QJsonObject&, T&); // For fromJson() and fillFromJson() to work

template <typename PodT, typename JsonT>
PodT fromJson(const JsonT&);
Expand Down Expand Up @@ -133,9 +147,8 @@ inline void fillFromJson(const QJsonValue& jv, T& pod)
}

namespace _impl {
void warnUnknownEnumValue(const QString& stringValue,
const char* enumTypeName);
void reportEnumOutOfBounds(uint32_t v, const char* enumTypeName);
QUOTIENT_API void warnUnknownEnumValue(const QString& stringValue, const char* enumTypeName);
QUOTIENT_API void reportEnumOutOfBounds(uint32_t v, const char* enumTypeName);
}

//! \brief Facility string-to-enum converter
Expand All @@ -147,12 +160,11 @@ namespace _impl {
//! matching respective enum values, 0-based.
//! \sa enumToJsonString
template <typename EnumT, typename EnumStringValuesT>
EnumT enumFromJsonString(const QString& s, const EnumStringValuesT& enumValues,
EnumT defaultValue)
inline EnumT enumFromJsonString(const QString& s, const EnumStringValuesT& enumValues,
EnumT defaultValue)
{
static_assert(std::is_unsigned_v<std::underlying_type_t<EnumT>>);
if (const auto it = std::find(cbegin(enumValues), cend(enumValues), s);
it != cend(enumValues))
if (const auto it = std::ranges::find(enumValues, s); it != cend(enumValues))
return static_cast<EnumT>(it - cbegin(enumValues));

if (!s.isEmpty())
Expand All @@ -170,7 +182,7 @@ EnumT enumFromJsonString(const QString& s, const EnumStringValuesT& enumValues,
//! }</tt> (mind the gap at value 0, in particular).
//! \sa enumFromJsonString
template <typename EnumT, typename EnumStringValuesT>
QString enumToJsonString(EnumT v, const EnumStringValuesT& enumValues)
inline QString enumToJsonString(EnumT v, const EnumStringValuesT& enumValues)
{
static_assert(std::is_unsigned_v<std::underlying_type_t<EnumT>>);
if (v < size(enumValues))
Expand All @@ -193,8 +205,8 @@ QString enumToJsonString(EnumT v, const EnumStringValuesT& enumValues)
//! \note This function does not support flag combinations.
//! \sa QUO_DECLARE_FLAGS, QUO_DECLARE_FLAGS_NS
template <typename FlagT, typename FlagStringValuesT>
FlagT flagFromJsonString(const QString& s, const FlagStringValuesT& flagValues,
FlagT defaultValue = FlagT(0U))
inline FlagT flagFromJsonString(const QString& s, const FlagStringValuesT& flagValues,
FlagT defaultValue = FlagT(0U))
{
// Enums based on signed integers don't make much sense for flag types
static_assert(std::is_unsigned_v<std::underlying_type_t<FlagT>>);
Expand All @@ -207,7 +219,7 @@ FlagT flagFromJsonString(const QString& s, const FlagStringValuesT& flagValues,
}

template <typename FlagT, typename FlagStringValuesT>
QString flagToJsonString(FlagT v, const FlagStringValuesT& flagValues)
inline QString flagToJsonString(FlagT v, const FlagStringValuesT& flagValues)
{
static_assert(std::is_unsigned_v<std::underlying_type_t<FlagT>>);
if (const auto offset = std::countr_zero(std::to_underlying(v)); offset < ssize(flagValues))
Expand All @@ -220,7 +232,7 @@ QString flagToJsonString(FlagT v, const FlagStringValuesT& flagValues)

// Specialisations

template<>
template <>
inline bool fromJson(const QJsonValue& jv) { return jv.toBool(); }

template <>
Expand Down
Loading