Skip to content

Commit

Permalink
Experimental GATT refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
shermp committed Sep 27, 2024
1 parent e39644e commit 8b6d16a
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 264 deletions.
287 changes: 105 additions & 182 deletions src/asha_bt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ void ScanResult::reset()
service_found = false;
ha = HA();
report = AdvertisingReport();
services.clear();
services_it = services.end();
}

static void handle_bt_audio_pending_worker([[maybe_unused]] async_context_t *context,
Expand All @@ -166,14 +164,69 @@ static void handle_bt_audio_pending_worker([[maybe_unused]] async_context_t *con
AudioBuffer::Volume vol = audio_buff.get_volume();
bool pcm_is_streaming = audio_buff.pcm_streaming.Load();
for (auto& ha : ha_mgr.hearing_aids) {
ha.avail_credits = l2cap_cbm_available_credits(ha.cid);
if (ha.state == L2Connected) {
switch(ha.state) {
case ServicesDiscovered:
ha.state = DiscoverASHAChars;
ha.discover_chars();
break;
case ASHACharsDiscovered:
ha.state = ReadROP;
ha.read_char();
break;
case ROPRead:
ha.state = ReadPSM;
ha.read_char();
break;
case PSMRead:
ha.state = DiscoverGAPChars;
ha.discover_chars();
break;
case GAPCharsDiscovered:
ha.state = DiscoverDISChars;
ha.discover_chars();
break;
case DISCharsDiscovered:
ha.state = ReadDeviceName;
ha.read_char();
break;
case DeviceNameRead:
ha.state = ReadManufacturerName;
ha.read_char();
break;
case ManufacturerNameRead:
ha.state = ReadModelNum;
ha.read_char();
break;
case ModelNumRead:
ha.state = ReadFWVers;
ha.read_char();
break;
case FWVersRead:
ha.state = SubscribeASPNotification;
ha.subscribe_to_asp_notification();
break;
case ASPNotificationSubscribed:
if (ha_mgr.set_complete()) {
runtime_settings.set_full_set_paired(true);
led_mgr.set_led(LEDManager::State::On);
scan_state = ScanState::Complete;
}
ha.state = L2Connecting;
ha.create_l2cap_channel();
break;
case L2Connected:
ha.avail_credits = l2cap_cbm_available_credits(ha.cid);
/* Ensure sufficient credits are available to (re)start
audio streaming */
if (pcm_is_streaming && ha.avail_credits >= 8) {
ha.write_acp_start();
}
} else if (ha.is_streaming_audio()) {
break;
case AudioPacketReady:
case AudioPacketSending:
case AudioPacketSent:
case ASPStartOk:
ha.avail_credits = l2cap_cbm_available_credits(ha.cid);
if (!pcm_is_streaming) {
LOG_INFO("%s: USB audio no longer streaming. Stopping ASHA stream", ha.side_str);
ha.write_acp_stop();
Expand All @@ -187,6 +240,9 @@ static void handle_bt_audio_pending_worker([[maybe_unused]] async_context_t *con
ha.set_write_index(write_index);
ha.set_volume(vol);
ha.send_audio_packet();
break;
default:
break;
}
}
}
Expand Down Expand Up @@ -298,6 +354,9 @@ extern "C" void bt_main()
async_context_add_when_pending_worker(ctx, &logging_pending_worker);
logging_ctx = ctx;

// gatt_worker.do_work = handle_gatt_worker;
// gatt_ctx = ctx;

led_mgr.set_ctx(ctx);

if (runtime_settings.hci_dump_enabled) {
Expand Down Expand Up @@ -385,22 +444,14 @@ static void set_data_length()
static void discover_services()
{
if (scan_state != ScanState::ServiceDiscovery) return;
HA cached = ha_mgr.get_from_cache(curr_scan.ha.addr);
if (cached) {
LOG_INFO("Hearing aid found in cache. Skipping discovery");
cached.conn_handle = curr_scan.ha.conn_handle;
curr_scan.ha = cached;
scan_state = ScanState::Finalizing;
finalise_curr_discovery();
} else {
LOG_INFO("Device paired. Discovering ASHA service");
auto res = gatt_client_discover_primary_services(&scan_gatt_event_handler, curr_scan.ha.conn_handle);
//auto res = gatt_client_discover_primary_services_by_uuid128(&scan_gatt_event_handler, curr_scan.ha.conn_handle, AshaUUID::service);
if (res != ERROR_CODE_SUCCESS) {
LOG_ERROR("Could not register service query: %d", static_cast<int>(res));
scan_state = ScanState::Disconnecting;
gap_disconnect(curr_scan.ha.conn_handle);
}

LOG_INFO("Device paired. Discovering ASHA service");
auto res = gatt_client_discover_primary_services(&scan_gatt_event_handler, curr_scan.ha.conn_handle);
//auto res = gatt_client_discover_primary_services_by_uuid128(&scan_gatt_event_handler, curr_scan.ha.conn_handle, AshaUUID::service);
if (res != ERROR_CODE_SUCCESS) {
LOG_ERROR("Could not register service query: %d", static_cast<int>(res));
scan_state = ScanState::Disconnecting;
gap_disconnect(curr_scan.ha.conn_handle);
}
}

Expand Down Expand Up @@ -678,24 +729,6 @@ static void l2cap_cbm_event_handler (uint8_t packet_type,
}
}

#define GATT_QUERY_ASSERT(cmd, msg) { \
auto res = (cmd); \
if (res != ERROR_CODE_SUCCESS) { \
LOG_ERROR(msg ": 0x%02x", static_cast<unsigned int>(res)); \
scan_state = ScanState::Disconnecting; \
gap_disconnect(curr_scan.ha.conn_handle); \
return; \
}}

#define GATT_COMPLETE_ASSERT(pkt, msg) { \
if (gatt_event_query_complete_get_att_status((pkt)) != ATT_ERROR_SUCCESS) { \
LOG_ERROR("%s: 0x%02x", (msg), static_cast<unsigned int>(gatt_event_query_complete_get_att_status((pkt)))); \
scan_state = ScanState::Disconnecting; \
gap_disconnect(curr_scan.ha.conn_handle); \
return; \
} \
}

/* Handler for reading GATT service and characteristic values
and subscribing to the AudioStatusPoint characteristic notification
during the connection process.
Expand All @@ -716,153 +749,36 @@ static void scan_gatt_event_handler ([[maybe_unused]] uint8_t packet_type,
if (service.uuid16 == AshaUUID::service16 || uuid_eq(service.uuid128, AshaUUID::service)) {
memcpy(&curr_scan.ha.asha_service.service, &service, sizeof(service));
curr_scan.service_found = true;
curr_scan.services.push_back(&curr_scan.ha.asha_service.service);
LOG_INFO("ASHA service found");
} else if (service.uuid16 == GapUUID::service16) {
LOG_INFO("GAP service found");
memcpy(&curr_scan.ha.gap_service.service, &service, sizeof(service));
curr_scan.services.push_back(&curr_scan.ha.gap_service.service);
} else if (service.uuid16 == DisUUID::service16) {
LOG_INFO("DIS service found");
memcpy(&curr_scan.ha.dis_service.service, &service, sizeof(service));
}
break;
case GATT_EVENT_QUERY_COMPLETE:
{
GATT_COMPLETE_ASSERT(packet, "ATT error discovering services");
if (gatt_event_query_complete_get_att_status(packet) != ATT_ERROR_SUCCESS) {
LOG_ERROR("Service discovery ATT error: 0x%02x", static_cast<unsigned int>(gatt_event_query_complete_get_att_status(packet)));
scan_state = ScanState::Disconnecting;
gap_disconnect(curr_scan.ha.conn_handle);
return;
}
// Older hearing aids may support MFI but not ASHA
if (!curr_scan.service_found || curr_scan.services.empty()) {
if (!curr_scan.service_found) {
LOG_INFO("ASHA service not found. Continuing scanning");
scan_state = ScanState::Disconnecting;
gap_disconnect(curr_scan.ha.conn_handle);
break;
}
curr_scan.services_it = curr_scan.services.begin();
scan_state = ScanState::CharDiscovery;
// Service found. Discover characteristics
LOG_INFO("Discovering characteristics for found services");
GATT_QUERY_ASSERT(gatt_client_discover_characteristics_for_service(
&scan_gatt_event_handler, curr_scan.ha.conn_handle, *curr_scan.services_it),
"Could not register GATT characteristics query");
break;
}
}
break;
}
case ScanState::CharDiscovery:
{
gatt_client_characteristic_t characteristic;
switch (hci_event_packet_get_type(packet)) {
case GATT_EVENT_CHARACTERISTIC_QUERY_RESULT:
gatt_event_characteristic_query_result_get_characteristic(packet, &characteristic);
if (characteristic.uuid16 == GapUUID::deviceName16) {
LOG_INFO("Got Device Name Characteristic");
curr_scan.ha.gap_service.device_name = characteristic;
} else if (uuid_eq(characteristic.uuid128, AshaUUID::readOnlyProps)) {
LOG_INFO("Got ROP Characteristic");
curr_scan.ha.asha_service.rop = characteristic;
} else if (uuid_eq(characteristic.uuid128, AshaUUID::audioControlPoint)) {
LOG_INFO("Got ACP Characteristic");
curr_scan.ha.asha_service.acp = characteristic;
} else if (uuid_eq(characteristic.uuid128, AshaUUID::audioStatus)) {
LOG_INFO("Got AUS Characteristic");
curr_scan.ha.asha_service.asp = characteristic;
} else if (uuid_eq(characteristic.uuid128, AshaUUID::volume)) {
LOG_INFO("Got VOL Characteristic");
curr_scan.ha.asha_service.vol = characteristic;
} else if (uuid_eq(characteristic.uuid128, AshaUUID::psm)) {
LOG_INFO("Got PSM Characteristic");
curr_scan.ha.asha_service.psm = characteristic;
}
// LOG_INFO("Characteristic handles: Start: 0x%04hx Value: 0x%04hx End: 0x%04hx",
// characteristic.start_handle, characteristic.value_handle, characteristic.end_handle);
break;
case GATT_EVENT_QUERY_COMPLETE:
GATT_COMPLETE_ASSERT(packet, "ATT error discovering characteristics");
LOG_INFO("Characteristic discovery complete");
curr_scan.services_it++;
if (curr_scan.services_it != curr_scan.services.end()) {
LOG_INFO("Continue discovering more characteristics");

GATT_QUERY_ASSERT(gatt_client_discover_characteristics_for_service(
&scan_gatt_event_handler, curr_scan.ha.conn_handle, *curr_scan.services_it
), "Could not register GATT characteristics query");
return;
} else {
LOG_INFO("Characteristic discovery complete");
}
scan_state = ScanState::ReadDeviceName;
GATT_QUERY_ASSERT(gatt_client_read_value_of_characteristic(
&scan_gatt_event_handler,
curr_scan.ha.conn_handle,
&curr_scan.ha.gap_service.device_name
), "Could not register read of device name");
break;
}
break;
}
case ScanState::ReadDeviceName:
{
switch (hci_event_packet_get_type(packet)) {
case GATT_EVENT_CHARACTERISTIC_VALUE_QUERY_RESULT:
{
LOG_INFO("Getting Device Name value");
size_t len = (size_t)gatt_event_characteristic_value_query_result_get_value_length(packet);
LOG_INFO("Name length: %u", len);
curr_scan.ha.name.clear();
curr_scan.ha.name.append((const char*)gatt_event_characteristic_value_query_result_get_value(packet), len);
break;
}
case GATT_EVENT_QUERY_COMPLETE:
GATT_COMPLETE_ASSERT(packet, "ATT error reading device name");
LOG_INFO("Completed value read of Device Name");
// Start reading the Read Only Properties characteristic
scan_state = ScanState::ReadROP;
GATT_QUERY_ASSERT(gatt_client_read_value_of_characteristic(
&scan_gatt_event_handler,
curr_scan.ha.conn_handle,
&curr_scan.ha.asha_service.rop
), "Could not register read of ROP");
break;
}
break;
}
case ScanState::ReadROP:
{
switch (hci_event_packet_get_type(packet)) {
case GATT_EVENT_CHARACTERISTIC_VALUE_QUERY_RESULT:
LOG_INFO("Getting ReadOnlyProperties value");
curr_scan.ha.rop.read(gatt_event_characteristic_value_query_result_get_value(packet));
//curr_scan.device.read_only_props.dump_values();
break;
case GATT_EVENT_QUERY_COMPLETE:
GATT_COMPLETE_ASSERT(packet, "ATT error reading ROP");
LOG_INFO("Completed value read of ReadOnlyProperties");
/* Next get the PSM value */
scan_state = ScanState::ReadPSM;
GATT_QUERY_ASSERT(gatt_client_read_value_of_characteristic(
&scan_gatt_event_handler,
curr_scan.ha.conn_handle,
&curr_scan.ha.asha_service.psm
), "Could not register read of PSM");
break;
}
break;
}
case ScanState::ReadPSM:
{
switch (hci_event_packet_get_type(packet)) {
case GATT_EVENT_CHARACTERISTIC_VALUE_QUERY_RESULT:
LOG_INFO("Getting PSM value");
curr_scan.ha.psm = gatt_event_characteristic_value_query_result_get_value(packet)[0];
LOG_INFO("PSM: %d", static_cast<int>(curr_scan.ha.psm));
break;
case GATT_EVENT_QUERY_COMPLETE:
GATT_COMPLETE_ASSERT(packet, "ATT error reading PSM");
LOG_INFO("Completed value read of PSM");
curr_scan.ha.rop.print_values();
scan_state = ScanState::Finalizing;
curr_scan.ha.state = HA::State::ServicesDiscovered;
finalise_curr_discovery();
break;
}
break;
default:
break;
};
}
default:
break;
Expand All @@ -875,6 +791,22 @@ static void connected_gatt_event_handler([[maybe_unused]] uint8_t packet_type,
[[maybe_unused]] uint16_t size)
{
switch (hci_event_packet_get_type(packet)) {
case GATT_EVENT_CHARACTERISTIC_QUERY_RESULT:
{
HA& ha = ha_mgr.get_by_conn_handle(gatt_event_characteristic_query_result_get_handle(packet));
if (ha) {
ha.on_char_discovered(packet);
}
break;
}
case GATT_EVENT_CHARACTERISTIC_VALUE_QUERY_RESULT:
{
HA& ha = ha_mgr.get_by_conn_handle(gatt_event_characteristic_value_query_result_get_handle(packet));
if (ha) {
ha.on_read_char_value(packet);
}
break;
}
case GATT_EVENT_NOTIFICATION:
{
HA& ha = ha_mgr.get_by_conn_handle(gatt_event_notification_get_handle(packet));
Expand All @@ -891,7 +823,7 @@ static void connected_gatt_event_handler([[maybe_unused]] uint8_t packet_type,
{
HA& ha = ha_mgr.get_by_conn_handle(gatt_event_query_complete_get_handle(packet));
if (ha) {
ha.on_gatt_event_query_complete(gatt_event_query_complete_get_att_status(packet));
ha.on_gatt_event_query_complete(packet);
}
break;
}
Expand All @@ -911,7 +843,6 @@ static void finalise_curr_discovery()
scan_state = ScanState::Complete;
return;
}
curr_scan.ha.state = HA::State::GATTConnected;
curr_scan.ha.l2cap_packet_handler = &l2cap_cbm_event_handler;
curr_scan.ha.gatt_packet_handler = &connected_gatt_event_handler;
auto& ha = ha_mgr.add(curr_scan.ha);
Expand All @@ -921,16 +852,8 @@ static void finalise_curr_discovery()
scan_state = ScanState::Scan;
return;
}
ha.on_gatt_connected();
if (ha_mgr.set_complete()) {
LOG_INFO("Connected to all aid(s) in set.");
runtime_settings.set_full_set_paired(true);
led_mgr.set_led(LEDManager::State::On);
scan_state = ScanState::Complete;
} else {
led_mgr.set_led_pattern(one_connected);
scan_state = ScanState::Scan;
}
led_mgr.set_led_pattern(one_connected);
scan_state = ScanState::Scan;
}

static void delete_paired_devices()
Expand Down
Loading

0 comments on commit 8b6d16a

Please sign in to comment.