diff --git a/database/postgres/pgdatabase.go b/database/postgres/pgdatabase.go index 5bdf6af6..2ac6f539 100644 --- a/database/postgres/pgdatabase.go +++ b/database/postgres/pgdatabase.go @@ -45,10 +45,11 @@ func NewPgDatabase(ctx context.Context, logger hclog.Logger, dsn string, sd sche var _ execution.Storage = (*PgDatabase)(nil) // Insert inserts all resources to given table, table and resources are assumed from same table. -func (p PgDatabase) Insert(ctx context.Context, t *schema.Table, resources schema.Resources) error { +func (p PgDatabase) Insert(ctx context.Context, t *schema.Table, resources schema.Resources, shouldCascade bool, cascadeDeleteFilters map[string]interface{}) error { if len(resources) == 0 { return nil } + // It is safe to assume that all resources have the same columns cols := quoteColumns(resources.ColumnNames()) psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) @@ -76,10 +77,25 @@ func (p PgDatabase) Insert(ctx context.Context, t *schema.Table, resources schem if err != nil { return diag.NewBaseError(err, diag.DATABASE, diag.WithResourceName(t.Name), diag.WithSummary("bad insert SQL statement created"), diag.WithDetails("SQL statement %q is invalid", s)) } - _, err = p.pool.Exec(ctx, s, args...) + + err = p.pool.BeginTxFunc(ctx, pgx.TxOptions{ + IsoLevel: pgx.ReadCommitted, + AccessMode: pgx.ReadWrite, + DeferrableMode: pgx.Deferrable, + }, func(tx pgx.Tx) error { + if shouldCascade { + if err := deleteResourceByCQId(ctx, tx, resources, cascadeDeleteFilters); err != nil { + return err + } + } + + _, err := tx.Exec(ctx, s, args...) + return err + }) if err == nil { return nil } + if pgErr, ok := err.(*pgconn.PgError); ok { // This should rarely occur, but if it occurs we want to print the SQL to debug it further if pgerrcode.IsSyntaxErrororAccessRuleViolation(pgErr.Code) { @@ -104,16 +120,7 @@ func (p PgDatabase) CopyFrom(ctx context.Context, resources schema.Resources, sh DeferrableMode: pgx.Deferrable, }, func(tx pgx.Tx) error { if shouldCascade { - q := goqu.Dialect("postgres").Delete(resources.TableName()).Where(goqu.Ex{"cq_id": resources.GetIds()}) - for k, v := range cascadeDeleteFilters { - q = q.Where(goqu.Ex{k: goqu.Op{"eq": v}}) - } - sql, args, err := q.Prepared(true).ToSQL() - if err != nil { - return err - } - _, err = tx.Exec(ctx, sql, args...) - if err != nil { + if err := deleteResourceByCQId(ctx, tx, resources, cascadeDeleteFilters); err != nil { return err } } @@ -221,3 +228,16 @@ func quoteColumns(columns []string) []string { } return ret } + +func deleteResourceByCQId(ctx context.Context, tx pgx.Tx, resources schema.Resources, cascadeDeleteFilters map[string]interface{}) error { + q := goqu.Dialect("postgres").Delete(resources.TableName()).Where(goqu.Ex{"cq_id": resources.GetIds()}) + for k, v := range cascadeDeleteFilters { + q = q.Where(goqu.Ex{k: goqu.Op{"eq": v}}) + } + sql, args, err := q.Prepared(true).ToSQL() + if err != nil { + return err + } + _, err = tx.Exec(ctx, sql, args...) + return err +} diff --git a/provider/execution/execution.go b/provider/execution/execution.go index 55387832..e12a9f64 100644 --- a/provider/execution/execution.go +++ b/provider/execution/execution.go @@ -345,7 +345,7 @@ func (e TableExecutor) saveToStorage(ctx context.Context, resources schema.Resou e.Logger.Warn("failed copy-from to db", "error", err) // fallback insert, copy from sometimes does problems, so we fall back with bulk insert - err = e.Db.Insert(ctx, e.Table, resources) + err = e.Db.Insert(ctx, e.Table, resources, shouldCascade, e.extraFields) if err == nil { return resources, nil } @@ -355,7 +355,7 @@ func (e TableExecutor) saveToStorage(ctx context.Context, resources schema.Resou // Try to insert resource by resource if partial fetch is enabled and an error occurred partialFetchResources := make(schema.Resources, 0) for id := range resources { - if err := e.Db.Insert(ctx, e.Table, schema.Resources{resources[id]}); err != nil { + if err := e.Db.Insert(ctx, e.Table, schema.Resources{resources[id]}, shouldCascade, e.extraFields); err != nil { e.Logger.Error("failed to insert resource into db", "error", err, "resource_keys", resources[id].PrimaryKeyValues()) diags = diags.Add(ClassifyError(err, diag.WithType(diag.DATABASE))) continue diff --git a/provider/execution/mocks_test.go b/provider/execution/mocks_test.go index 1f35f662..d6ba8fae 100644 --- a/provider/execution/mocks_test.go +++ b/provider/execution/mocks_test.go @@ -67,12 +67,12 @@ func (_m *DatabaseMock) Exec(ctx context.Context, query string, args ...interfac } // Insert provides a mock function with given fields: ctx, t, instance -func (_m *DatabaseMock) Insert(ctx context.Context, t *schema.Table, instance schema.Resources) error { +func (_m *DatabaseMock) Insert(ctx context.Context, t *schema.Table, instance schema.Resources, shouldCascade bool, cascadeDeleteFilters map[string]interface{}) error { ret := _m.Called(ctx, t, instance) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *schema.Table, schema.Resources) error); ok { - r0 = rf(ctx, t, instance) + if rf, ok := ret.Get(0).(func(context.Context, *schema.Table, schema.Resources, bool, map[string]interface{}) error); ok { + r0 = rf(ctx, t, instance, shouldCascade, cascadeDeleteFilters) } else { r0 = ret.Error(0) } diff --git a/provider/execution/storage.go b/provider/execution/storage.go index 47beb586..064699d7 100644 --- a/provider/execution/storage.go +++ b/provider/execution/storage.go @@ -13,10 +13,10 @@ import ( type Storage interface { QueryExecer Copier - Insert(ctx context.Context, t *schema.Table, instance schema.Resources) error + Insert(ctx context.Context, t *schema.Table, instance schema.Resources, shouldCascade bool, cascadeDeleteFilters map[string]interface{}) error Delete(ctx context.Context, t *schema.Table, kvFilters []interface{}) error RemoveStaleData(ctx context.Context, t *schema.Table, executionStart time.Time, kvFilters []interface{}) error - CopyFrom(ctx context.Context, resources schema.Resources, shouldCascade bool, CascadeDeleteFilters map[string]interface{}) error + CopyFrom(ctx context.Context, resources schema.Resources, shouldCascade bool, cascadeDeleteFilters map[string]interface{}) error Close() Dialect() schema.Dialect } diff --git a/provider/execution/storage_test.go b/provider/execution/storage_test.go index 140100f2..95cf0e3b 100644 --- a/provider/execution/storage_test.go +++ b/provider/execution/storage_test.go @@ -21,7 +21,7 @@ func (f noopStorage) Exec(ctx context.Context, query string, args ...interface{} return nil } -func (f noopStorage) Insert(ctx context.Context, t *schema.Table, instance schema.Resources) error { +func (f noopStorage) Insert(ctx context.Context, t *schema.Table, instance schema.Resources, shouldCascade bool, cascadeDeleteFilters map[string]interface{}) error { return nil } @@ -33,7 +33,7 @@ func (f noopStorage) RemoveStaleData(ctx context.Context, t *schema.Table, execu return nil } -func (f noopStorage) CopyFrom(ctx context.Context, resources schema.Resources, shouldCascade bool, CascadeDeleteFilters map[string]interface{}) error { +func (f noopStorage) CopyFrom(ctx context.Context, resources schema.Resources, shouldCascade bool, cascadeDeleteFilters map[string]interface{}) error { return nil } diff --git a/provider/schema/mock/mock_storage.go b/provider/schema/mock/mock_storage.go index f6c80ad0..447194d8 100644 --- a/provider/schema/mock/mock_storage.go +++ b/provider/schema/mock/mock_storage.go @@ -112,17 +112,17 @@ func (mr *MockStorageMockRecorder) Exec(arg0, arg1 interface{}, arg2 ...interfac } // Insert mocks base method. -func (m *MockStorage) Insert(arg0 context.Context, arg1 *schema.Table, arg2 schema.Resources) error { +func (m *MockStorage) Insert(arg0 context.Context, arg1 *schema.Table, arg2 schema.Resources, arg3 bool, arg4 map[string]interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Insert", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Insert", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // Insert indicates an expected call of Insert. -func (mr *MockStorageMockRecorder) Insert(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) Insert(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockStorage)(nil).Insert), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockStorage)(nil).Insert), arg0, arg1, arg2, arg3, arg4) } // Query mocks base method.