Skip to content

Commit

Permalink
Fix bugs in CloudFlare update logic, fixes #1514
Browse files Browse the repository at this point in the history
  • Loading branch information
sheerun committed Apr 22, 2020
1 parent cbb84be commit c004731
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 93 deletions.
139 changes: 76 additions & 63 deletions provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"fmt"
"os"
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -111,8 +110,8 @@ type CloudFlareProvider struct {

// cloudFlareChange differentiates between ChangActions
type cloudFlareChange struct {
Action string
ResourceRecordSet []cloudflare.DNSRecord
Action string
ResourceRecord cloudflare.DNSRecord
}

// NewCloudFlareProvider initializes a new CloudFlare DNS based Provider.
Expand Down Expand Up @@ -199,15 +198,39 @@ func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint,

// ApplyChanges applies a given set of changes in a given zone.
func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
proxiedByDefault := p.proxiedByDefault
cloudflareChanges := []*cloudFlareChange{}

combinedChanges := make([]*cloudFlareChange, 0, len(changes.Create)+len(changes.UpdateNew)+len(changes.Delete))
for _, endpoint := range changes.Create {
for _, target := range endpoint.Targets {
cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareCreate, endpoint, target))
}
}

for i, desired := range changes.UpdateNew {
current := changes.UpdateOld[i]

add, remove, leave := difference(current.Targets, desired.Targets)

for _, a := range add {
cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareCreate, desired, a))
}

for _, a := range leave {
cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareUpdate, desired, a))
}

combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareCreate, changes.Create, proxiedByDefault)...)
combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareUpdate, changes.UpdateNew, proxiedByDefault)...)
combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareDelete, changes.Delete, proxiedByDefault)...)
for _, a := range remove {
cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareDelete, current, a))
}
}

for _, endpoint := range changes.Delete {
for _, target := range endpoint.Targets {
cloudflareChanges = append(cloudflareChanges, p.newCloudFlareChange(cloudFlareDelete, endpoint, target))
}
}

return p.submitChanges(ctx, combinedChanges)
return p.submitChanges(ctx, cloudflareChanges)
}

// submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
Expand All @@ -231,12 +254,11 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud
}
for _, change := range changes {
logFields := log.Fields{
"record": change.ResourceRecordSet[0].Name,
"type": change.ResourceRecordSet[0].Type,
"ttl": change.ResourceRecordSet[0].TTL,
"targets": len(change.ResourceRecordSet),
"action": change.Action,
"zone": zoneID,
"record": change.ResourceRecord.Name,
"type": change.ResourceRecord.Type,
"ttl": change.ResourceRecord.TTL,
"action": change.Action,
"zone": zoneID,
}

log.WithFields(logFields).Info("Changing record.")
Expand All @@ -245,24 +267,30 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud
continue
}

recordIDs := p.getRecordIDs(records, change.ResourceRecordSet[0])

// to simplify bookkeeping for multiple records, an update is executed as delete+create
if change.Action == cloudFlareDelete || change.Action == cloudFlareUpdate {
for _, recordID := range recordIDs {
err := p.Client.DeleteDNSRecord(zoneID, recordID)
if err != nil {
log.WithFields(logFields).Errorf("failed to delete record: %v", err)
}
if change.Action == cloudFlareUpdate {
recordID := p.getRecordID(records, change.ResourceRecord)
if recordID == "" {
log.WithFields(logFields).Errorf("failed to find previous record: %v", change.ResourceRecord)
continue
}
}

if change.Action == cloudFlareCreate || change.Action == cloudFlareUpdate {
for _, record := range change.ResourceRecordSet {
_, err := p.Client.CreateDNSRecord(zoneID, record)
if err != nil {
log.WithFields(logFields).Errorf("failed to create record: %v", err)
}
err := p.Client.UpdateDNSRecord(zoneID, recordID, change.ResourceRecord)
if err != nil {
log.WithFields(logFields).Errorf("failed to delete record: %v", err)
}
} else if change.Action == cloudFlareDelete {
recordID := p.getRecordID(records, change.ResourceRecord)
if recordID == "" {
log.WithFields(logFields).Errorf("failed to find previous record: %v", change.ResourceRecord)
continue
}
err := p.Client.DeleteDNSRecord(zoneID, recordID)
if err != nil {
log.WithFields(logFields).Errorf("failed to delete record: %v", err)
}
} else if change.Action == cloudFlareCreate {
_, err := p.Client.CreateDNSRecord(zoneID, change.ResourceRecord)
if err != nil {
log.WithFields(logFields).Errorf("failed to create record: %v", err)
}
}
}
Expand All @@ -281,9 +309,9 @@ func (p *CloudFlareProvider) changesByZone(zones []cloudflare.Zone, changeSet []
}

for _, c := range changeSet {
zoneID, _ := zoneNameIDMapper.FindZone(c.ResourceRecordSet[0].Name)
zoneID, _ := zoneNameIDMapper.FindZone(c.ResourceRecord.Name)
if zoneID == "" {
log.Debugf("Skipping record %s because no hosted zone matching record DNS Name was detected", c.ResourceRecordSet[0].Name)
log.Debugf("Skipping record %s because no hosted zone matching record DNS Name was detected", c.ResourceRecord.Name)
continue
}
changes[zoneID] = append(changes[zoneID], c)
Expand All @@ -292,51 +320,36 @@ func (p *CloudFlareProvider) changesByZone(zones []cloudflare.Zone, changeSet []
return changes
}

func (p *CloudFlareProvider) getRecordIDs(records []cloudflare.DNSRecord, record cloudflare.DNSRecord) []string {
recordIDs := make([]string, 0)
func (p *CloudFlareProvider) getRecordID(records []cloudflare.DNSRecord, record cloudflare.DNSRecord) string {
for _, zoneRecord := range records {
if zoneRecord.Name == record.Name && zoneRecord.Type == record.Type {
recordIDs = append(recordIDs, zoneRecord.ID)
if zoneRecord.Name == record.Name && zoneRecord.Type == record.Type && zoneRecord.Content == record.Content {
return zoneRecord.ID
}
}
sort.Strings(recordIDs)
return recordIDs
}

// newCloudFlareChanges returns a collection of Changes based on the given records and action.
func newCloudFlareChanges(action string, endpoints []*endpoint.Endpoint, proxiedByDefault bool) []*cloudFlareChange {
changes := make([]*cloudFlareChange, 0, len(endpoints))

for _, endpoint := range endpoints {
changes = append(changes, newCloudFlareChange(action, endpoint, proxiedByDefault))
}

return changes
return ""
}

func newCloudFlareChange(action string, endpoint *endpoint.Endpoint, proxiedByDefault bool) *cloudFlareChange {
func (p *CloudFlareProvider) newCloudFlareChange(action string, endpoint *endpoint.Endpoint, target string) *cloudFlareChange {
ttl := defaultCloudFlareRecordTTL
proxied := shouldBeProxied(endpoint, proxiedByDefault)
proxied := shouldBeProxied(endpoint, p.proxiedByDefault)

if endpoint.RecordTTL.IsConfigured() {
ttl = int(endpoint.RecordTTL)
}

resourceRecordSet := make([]cloudflare.DNSRecord, len(endpoint.Targets))
if len(endpoint.Targets) > 1 {
log.Errorf("Updates should have just one target")
}

for i := range endpoint.Targets {
resourceRecordSet[i] = cloudflare.DNSRecord{
return &cloudFlareChange{
Action: action,
ResourceRecord: cloudflare.DNSRecord{
Name: endpoint.DNSName,
TTL: ttl,
Proxied: proxied,
Type: endpoint.RecordType,
Content: endpoint.Targets[i],
}
}

return &cloudFlareChange{
Action: action,
ResourceRecordSet: resourceRecordSet,
Content: target,
},
}
}

Expand Down
87 changes: 57 additions & 30 deletions provider/cloudflare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ var ExampleDomain = []cloudflare.DNSRecord{
Content: "1.2.3.4",
Proxied: false,
},
{
ID: "2345678901",
ZoneID: "001",
Name: "foobar.bar.com",
Type: endpoint.RecordTypeA,
TTL: 120,
Content: "3.4.5.6",
Proxied: false,
},
{
ID: "1231231233",
ZoneID: "002",
Expand Down Expand Up @@ -655,29 +664,46 @@ func TestCloudflareGetRecordID(t *testing.T) {
p := &CloudFlareProvider{}
records := []cloudflare.DNSRecord{
{
Name: "foo.com",
Type: endpoint.RecordTypeCNAME,
ID: "1",
Name: "foo.com",
Type: endpoint.RecordTypeCNAME,
Content: "foobar",
ID: "1",
},
{
Name: "bar.de",
Type: endpoint.RecordTypeA,
ID: "2",
},
}

assert.Len(t, p.getRecordIDs(records, cloudflare.DNSRecord{
Name: "foo.com",
Type: endpoint.RecordTypeA,
}), 0)
assert.Len(t, p.getRecordIDs(records, cloudflare.DNSRecord{
Name: "bar.de",
Type: endpoint.RecordTypeA,
}), 1)
assert.Equal(t, "2", p.getRecordIDs(records, cloudflare.DNSRecord{
Name: "bar.de",
Type: endpoint.RecordTypeA,
})[0])
Name: "bar.de",
Type: endpoint.RecordTypeA,
Content: "1.2.3.4",
ID: "2",
},
}

assert.Equal(t, "", p.getRecordID(records, cloudflare.DNSRecord{
Name: "foo.com",
Type: endpoint.RecordTypeA,
Content: "foobar",
}))

assert.Equal(t, "", p.getRecordID(records, cloudflare.DNSRecord{
Name: "foo.com",
Type: endpoint.RecordTypeCNAME,
Content: "fizfuz",
}))

assert.Equal(t, "1", p.getRecordID(records, cloudflare.DNSRecord{
Name: "foo.com",
Type: endpoint.RecordTypeCNAME,
Content: "foobar",
}))
assert.Equal(t, "", p.getRecordID(records, cloudflare.DNSRecord{
Name: "bar.de",
Type: endpoint.RecordTypeA,
Content: "2.3.4.5",
}))
assert.Equal(t, "2", p.getRecordID(records, cloudflare.DNSRecord{
Name: "bar.de",
Type: endpoint.RecordTypeA,
Content: "1.2.3.4",
}))
}

func TestCloudflareGroupByNameAndType(t *testing.T) {
Expand Down Expand Up @@ -947,32 +973,33 @@ func TestCloudflareComplexUpdate(t *testing.T) {
}

td.CmpDeeply(t, client.Actions, []MockAction{
MockAction{
Name: "Delete",
ZoneId: "001",
RecordId: "1234567890",
},
MockAction{
Name: "Create",
ZoneId: "001",
RecordData: cloudflare.DNSRecord{
Name: "foobar.bar.com",
Type: "A",
Content: "1.2.3.4",
Content: "2.3.4.5",
TTL: 1,
Proxied: true,
},
},
MockAction{
Name: "Create",
ZoneId: "001",
Name: "Update",
ZoneId: "001",
RecordId: "1234567890",
RecordData: cloudflare.DNSRecord{
Name: "foobar.bar.com",
Type: "A",
Content: "2.3.4.5",
Content: "1.2.3.4",
TTL: 1,
Proxied: true,
},
},
MockAction{
Name: "Delete",
ZoneId: "001",
RecordId: "2345678901",
},
})
}
22 changes: 22 additions & 0 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,25 @@ func ensureTrailingDot(hostname string) string {

return strings.TrimSuffix(hostname, ".") + "."
}

// Tells which entries need to be respectively
// added, removed, or left untouched for "current" to be transformed to "desired"
func difference(current, desired []string) (add []string, remove []string, leave []string) {
index := make(map[string]struct{}, len(current))
for _, x := range current {
index[x] = struct{}{}
}
for _, x := range desired {
if _, found := index[x]; found {
leave = append(leave, x)
delete(index, x)
} else {
add = append(add, x)
delete(index, x)
}
}
for x := range index {
remove = append(remove, x)
}
return
}
11 changes: 11 additions & 0 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"testing"

log "github.com/sirupsen/logrus"

"github.com/stretchr/testify/assert"
)

func TestMain(m *testing.M) {
Expand All @@ -44,3 +46,12 @@ func TestEnsureTrailingDot(t *testing.T) {
}
}
}

func TestDifference(t *testing.T) {
current := []string{"foo", "bar"}
desired := []string{"bar", "baz"}
add, remove, leave := difference(current, desired)
assert.Equal(t, add, []string{"baz"})
assert.Equal(t, remove, []string{"foo"})
assert.Equal(t, leave, []string{"bar"})
}

0 comments on commit c004731

Please sign in to comment.