diff --git a/v1/subscriber/subscriber.go b/v1/subscriber/subscriber.go index 744b1b8..2096aa1 100644 --- a/v1/subscriber/subscriber.go +++ b/v1/subscriber/subscriber.go @@ -328,19 +328,32 @@ func (p *API) DeleteSubscription(clientID uuid.UUID, subscriptionID string) erro return nil } -// DeleteAllSubscriptions delete all subscriptionOne information -func (p *API) DeleteAllSubscriptions(clientID uuid.UUID) error { - if subStore, ok := p.SubscriberStore.Get(clientID); ok { - if err := deleteAllFromFile(fmt.Sprintf("%s/%s", p.storeFilePath, fmt.Sprintf("%s.json", clientID))); err != nil { - return err +// DeleteAllSubscriptionsForClient delete all subscriptions for the client +func (p *API) DeleteAllSubscriptionsForClient(clientID uuid.UUID) (int, error) { + var err error + var numSubDeleted, numSubToDelete int + if sub, ok := p.SubscriberStore.Get(clientID); ok { + numSubToDelete = len(sub.SubStore.Store) + if err = p.DeleteClient(clientID); err != nil { + return 0, err } - subStore.SubStore = &store.PubSubStore{ - RWMutex: sync.RWMutex{}, - Store: map[string]*pubsub.PubSub{}, + numSubDeleted += numSubToDelete + } + return numSubDeleted, nil +} + +// DeleteAllSubscriptions delete all subscriptions in store +func (p *API) DeleteAllSubscriptions() (int, error) { + var err error + var numSubDeleted, numSubToDelete int + for clientID, subs := range p.SubscriberStore.Store { + numSubToDelete = len(subs.SubStore.Store) + if err = p.DeleteClient(clientID); err != nil { + return numSubDeleted, err } - p.SubscriberStore.Set(clientID, subStore) + numSubDeleted += numSubToDelete } - return nil + return numSubDeleted, nil } // DeleteClient delete all subscriptionOne information diff --git a/v1/subscriber/subscriber_test.go b/v1/subscriber/subscriber_test.go index 1f08efc..30bd7f3 100644 --- a/v1/subscriber/subscriber_test.go +++ b/v1/subscriber/subscriber_test.go @@ -129,7 +129,7 @@ func TestAPI_DeleteAllSubscriptions(t *testing.T) { assert.Nil(t, e) assert.NotEmpty(t, s.ClientID) assert.NotNil(t, s.SubStore.Store) - e = globalInstance.DeleteAllSubscriptions(clientID) + _, e = globalInstance.DeleteAllSubscriptionsForClient(clientID) assert.Nil(t, e) b, e := globalInstance.GetSubscriptionsFromFile(clientID) assert.Nil(t, e) @@ -217,7 +217,7 @@ func Test_Concurrency(t *testing.T) { } func clean() { - _ = globalInstance.DeleteAllSubscriptions(clientID) + globalInstance.DeleteAllSubscriptionsForClient(clientID) //nolint } func TestTeardown(*testing.T) {