Skip to content

Commit

Permalink
feature #270: updated event repository to use the event repository to…
Browse files Browse the repository at this point in the history
… get the gorm db connection directly from the container
  • Loading branch information
IshikaGopie committed Jul 24, 2023
1 parent c0ba376 commit cfcfa61
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 27 deletions.
2 changes: 1 addition & 1 deletion controllers/rest/global_initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func DefaultEventStore(ctxt context.Context, tapi Container, swagger *openapi3.S
//if there is a projection then add the event handler as a subscriber to the event store
if api.gormConnection != nil {
var defaultEventStore model.EventRepository
defaultEventStore, err = model.NewBasicEventRepository(api.gormConnection, api.EchoInstance().Logger, false, "", "")
defaultEventStore, err = model.NewBasicEventRepository(api.gormConnection, api.EchoInstance().Logger, false, "", "", tapi)
err = defaultEventStore.Migrate(ctxt)
api.RegisterEventStore("Default", defaultEventStore)
}
Expand Down
50 changes: 28 additions & 22 deletions model/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ import (
)

type EventRepositoryGorm struct {
DB *gorm.DB
gormDB *gorm.DB
//DB *gorm.DB
//gormDB *gorm.DB
eventDispatcher DefaultEventDisptacher
logger Log
unitOfWork bool
AccountID string
ApplicationID string
GroupID string
UserID string
Container Container
}

type GormEvent struct {
Expand Down Expand Up @@ -66,7 +67,7 @@ func (e *EventRepositoryGorm) Persist(ctxt context.Context, entity AggregateInte
savePointID := "s" + ksuid.New().String() //NOTE the save point can't start with a number
e.logger.Infof("persisting %d events with save point %s", len(entities), savePointID)
if e.unitOfWork {
e.DB.SavePoint(savePointID)
e.DB().SavePoint(savePointID)
}

for _, entity := range entities {
Expand Down Expand Up @@ -95,7 +96,7 @@ func (e *EventRepositoryGorm) Persist(ctxt context.Context, entity AggregateInte
}
if e.unitOfWork {
e.logger.Debugf("rolling back saving events to %s", savePointID)
e.DB.RollbackTo(savePointID)
e.DB().RollbackTo(savePointID)
}

return event.GetErrors()[0]
Expand All @@ -107,7 +108,7 @@ func (e *EventRepositoryGorm) Persist(ctxt context.Context, entity AggregateInte
}
gormEvents = append(gormEvents, gormEvent)
}
db := e.DB.CreateInBatches(gormEvents, 2000)
db := e.DB().CreateInBatches(gormEvents, 2000)
if db.Error != nil {
return db.Error
}
Expand All @@ -124,7 +125,7 @@ func (e *EventRepositoryGorm) Persist(ctxt context.Context, entity AggregateInte
//GetByAggregate get events for a root aggregate
func (e *EventRepositoryGorm) GetByAggregate(ID string) ([]*Event, error) {
var events []GormEvent
result := e.DB.Order("sequence_no asc").Where("root_id = ?", ID).Find(&events)
result := e.DB().Order("sequence_no asc").Where("root_id = ?", ID).Find(&events)
if result.Error != nil {
return nil, result.Error
}
Expand Down Expand Up @@ -156,7 +157,7 @@ func (e *EventRepositoryGorm) GetByAggregate(ID string) ([]*Event, error) {
//events should now be retrieved by root id,entity type and entity id. Use GetByEntityAndAggregate instead
func (e *EventRepositoryGorm) GetByAggregateAndType(ID string, entityType string) ([]*Event, error) {
var events []GormEvent
result := e.DB.Order("sequence_no asc").Where("entity_id = ? AND entity_type = ?", ID, entityType).Find(&events)
result := e.DB().Order("sequence_no asc").Where("entity_id = ? AND entity_type = ?", ID, entityType).Find(&events)
if result.Error != nil {
return nil, result.Error
}
Expand Down Expand Up @@ -184,7 +185,7 @@ func (e *EventRepositoryGorm) GetByAggregateAndType(ID string, entityType string

func (e *EventRepositoryGorm) GetByEntityAndAggregate(EntityID string, Type string, RootID string) ([]*Event, error) {
var events []GormEvent
result := e.DB.Order("sequence_no asc").Where("entity_id = ? AND entity_type = ? AND root_id = ?", EntityID, Type, RootID).Find(&events)
result := e.DB().Order("sequence_no asc").Where("entity_id = ? AND entity_type = ? AND root_id = ?", EntityID, Type, RootID).Find(&events)
if result.Error != nil {
return nil, result.Error
}
Expand Down Expand Up @@ -213,7 +214,7 @@ func (e *EventRepositoryGorm) GetByEntityAndAggregate(EntityID string, Type stri
//GetAggregateSequenceNumber gets the latest sequence number for the aggregate entity
func (e *EventRepositoryGorm) GetAggregateSequenceNumber(ID string) (int64, error) {
var event GormEvent
result := e.DB.Order("sequence_no desc").Where("root_id = ?", ID).Find(&event)
result := e.DB().Order("sequence_no desc").Where("root_id = ?", ID).Find(&event)
if result.Error != nil {
return 0, result.Error
}
Expand All @@ -222,7 +223,7 @@ func (e *EventRepositoryGorm) GetAggregateSequenceNumber(ID string) (int64, erro

func (e *EventRepositoryGorm) GetByAggregateAndSequenceRange(ID string, start int64, end int64) ([]*Event, error) {
var events []GormEvent
result := e.DB.Order("sequence_no asc").Where("entity_id = ? AND sequence_no >=? AND sequence_no <= ?", ID, start, end).Find(&events)
result := e.DB().Order("sequence_no asc").Where("entity_id = ? AND sequence_no >=? AND sequence_no <= ?", ID, start, end).Find(&events)
if result.Error != nil {
return nil, result.Error
}
Expand Down Expand Up @@ -262,7 +263,7 @@ func (e *EventRepositoryGorm) Migrate(ctx context.Context) error {
if err != nil {
return err
}
err = e.DB.AutoMigrate(&event)
err = e.DB().AutoMigrate(&event)
if err != nil {
return err
}
Expand All @@ -271,24 +272,24 @@ func (e *EventRepositoryGorm) Migrate(ctx context.Context) error {
}

func (e *EventRepositoryGorm) Flush() error {
err := e.DB.Commit().Error
e.DB = e.gormDB.Begin()
err := e.DB().Commit().Error
e.DB().Begin()
return err
}

func (e *EventRepositoryGorm) Remove(entities []Entity) error {

savePointID := "s" + ksuid.New().String() //NOTE the save point can't start with a number
e.logger.Infof("persisting %d events with save point %s", len(entities), savePointID)
e.DB.SavePoint(savePointID)
e.DB().SavePoint(savePointID)
for _, event := range entities {
gormEvent, err := NewGormEvent(event.(*Event))
if err != nil {
return err
}
db := e.DB.Delete(gormEvent)
db := e.DB().Delete(gormEvent)
if db.Error != nil {
e.DB.RollbackTo(savePointID)
e.DB().RollbackTo(savePointID)
return db.Error
}
}
Expand All @@ -315,14 +316,14 @@ func (e *EventRepositoryGorm) ReplayEvents(ctxt context.Context, date time.Time,
var events []GormEvent

if date.IsZero() {
result := e.DB.Table("gorm_events").Order("created_at asc").Find(&events)
result := e.DB().Table("gorm_events").Order("created_at asc").Find(&events)
if result.Error != nil {
e.logger.Errorf("got error pulling events '%s'", result.Error)
errors = append(errors, result.Error)
return 0, 0, 0, errors
}
} else {
result := e.DB.Table("gorm_events").Where("created_at = ?", date).Find(&events)
result := e.DB().Table("gorm_events").Where("created_at = ?", date).Find(&events)
if result.Error != nil {
e.logger.Errorf("got error pulling events '%s'", result.Error)
errors = append(errors, result.Error)
Expand Down Expand Up @@ -368,10 +369,15 @@ func (e *EventRepositoryGorm) ReplayEvents(ctxt context.Context, date time.Time,
return totalEvents, successfulEvents, failedEvents, errors
}

func NewBasicEventRepository(gormDB *gorm.DB, logger Log, useUnitOfWork bool, accountID string, applicationID string) (EventRepository, error) {
func (e *EventRepositoryGorm) DB() *gorm.DB {
t, _ := e.Container.GetGormDBConnection("default")
return t
}

func NewBasicEventRepository(gormDB *gorm.DB, logger Log, useUnitOfWork bool, accountID string, applicationID string, api Container) (EventRepository, error) {
if useUnitOfWork {
transaction := gormDB.Begin()
return &EventRepositoryGorm{DB: transaction, gormDB: gormDB, logger: logger, unitOfWork: useUnitOfWork, AccountID: accountID, ApplicationID: applicationID}, nil
//transaction := gormDB.Begin()
return &EventRepositoryGorm{logger: logger, unitOfWork: useUnitOfWork, AccountID: accountID, ApplicationID: applicationID, Container: api}, nil
}
return &EventRepositoryGorm{DB: gormDB, logger: logger, AccountID: accountID, ApplicationID: applicationID}, nil
return &EventRepositoryGorm{logger: logger, AccountID: accountID, ApplicationID: applicationID, Container: api}, nil
}
8 changes: 4 additions & 4 deletions model/repositories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func TestEventRepository_ReplayEvents(t *testing.T) {
eventRepo.Persist(newContext, entity3)

t.Run("replay events - drop tables", func(t *testing.T) {
if eventRepo.DB.Migrator().HasTable("Blog") {
err = eventRepo.DB.Migrator().DropTable("Blog")
if eventRepo.DB().Migrator().HasTable("Blog") {
err = eventRepo.DB().Migrator().DropTable("Blog")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -123,12 +123,12 @@ func TestEventRepository_ReplayEvents(t *testing.T) {
t.Run("replay events - remove rows", func(t *testing.T) {
output := map[string]interface{}{}

searchResult := eventRepo.DB.Table("Blog").Where("weos_id = ?", "12345").Delete(&output)
searchResult := eventRepo.DB().Table("Blog").Where("weos_id = ?", "12345").Delete(&output)
if searchResult.Error != nil {
t.Fatal(searchResult.Error)
}

searchResult = eventRepo.DB.Table("Blog").Where("weos_id = ?", "123456").Delete(&output)
searchResult = eventRepo.DB().Table("Blog").Where("weos_id = ?", "123456").Delete(&output)
if searchResult.Error != nil {
t.Fatal(searchResult.Error)
}
Expand Down

0 comments on commit cfcfa61

Please sign in to comment.