Unverified Commit a5896c23 authored by Tariq Ibrahim's avatar Tariq Ibrahim
Browse files

remove context.TODO()s in external-dns

parent f400ded4
......@@ -103,8 +103,7 @@ type Controller struct {
}
// RunOnce runs a single iteration of a reconciliation loop.
func (c *Controller) RunOnce() error {
ctx := context.Background()
func (c *Controller) RunOnce(ctx context.Context) error {
records, err := c.Registry.Records(ctx)
if err != nil {
registryErrorsTotal.Inc()
......@@ -141,11 +140,11 @@ func (c *Controller) RunOnce() error {
}
// Run runs RunOnce in a loop with a delay until stopChan receives a value.
func (c *Controller) Run(stopChan <-chan struct{}) {
func (c *Controller) Run(ctx context.Context, stopChan <-chan struct{}) {
ticker := time.NewTicker(c.Interval)
defer ticker.Stop()
for {
err := c.RunOnce()
err := c.RunOnce(ctx)
if err != nil {
log.Error(err)
}
......
......@@ -146,7 +146,7 @@ func TestRunOnce(t *testing.T) {
Policy: &plan.SyncPolicy{},
}
assert.NoError(t, ctrl.RunOnce())
assert.NoError(t, ctrl.RunOnce(context.Background()))
// Validate that the mock source was called.
source.AssertExpectations(t)
......
......@@ -17,6 +17,7 @@ limitations under the License.
package main
import (
"context"
"net/http"
"os"
"os/signal"
......@@ -60,6 +61,8 @@ func main() {
}
log.SetLevel(ll)
ctx := context.Background()
stopChan := make(chan struct{}, 1)
go serveMetrics(cfg.MetricsAddress)
......@@ -144,9 +147,9 @@ func main() {
case "rcodezero":
p, err = provider.NewRcodeZeroProvider(domainFilter, cfg.DryRun, cfg.RcodezeroTXTEncrypt)
case "google":
p, err = provider.NewGoogleProvider(cfg.GoogleProject, domainFilter, zoneIDFilter, cfg.GoogleBatchChangeSize, cfg.GoogleBatchChangeInterval, cfg.DryRun)
p, err = provider.NewGoogleProvider(ctx, cfg.GoogleProject, domainFilter, zoneIDFilter, cfg.GoogleBatchChangeSize, cfg.GoogleBatchChangeInterval, cfg.DryRun)
case "digitalocean":
p, err = provider.NewDigitalOceanProvider(domainFilter, cfg.DryRun)
p, err = provider.NewDigitalOceanProvider(ctx, domainFilter, cfg.DryRun)
case "linode":
p, err = provider.NewLinodeProvider(domainFilter, cfg.DryRun, externaldns.Version)
case "dnsimple":
......@@ -197,6 +200,7 @@ func main() {
p, err = provider.NewDesignateProvider(domainFilter, cfg.DryRun)
case "pdns":
p, err = provider.NewPDNSProvider(
ctx,
provider.PDNSConfig{
DomainFilter: domainFilter,
DryRun: cfg.DryRun,
......@@ -266,14 +270,14 @@ func main() {
}
if cfg.Once {
err := ctrl.RunOnce()
err := ctrl.RunOnce(ctx)
if err != nil {
log.Fatal(err)
}
os.Exit(0)
}
ctrl.Run(stopChan)
ctrl.Run(ctx, stopChan)
}
func handleSigterm(stopChan chan struct{}) {
......
......@@ -146,9 +146,8 @@ func NewCloudFlareProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter,
}
// Zones returns the list of hosted zones.
func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) {
func (p *CloudFlareProvider) Zones(ctx context.Context) ([]cloudflare.Zone, error) {
result := []cloudflare.Zone{}
ctx := context.TODO()
p.PaginationOptions.Page = 1
for {
......@@ -177,7 +176,7 @@ func (p *CloudFlareProvider) Zones() ([]cloudflare.Zone, error) {
// Records returns the list of records.
func (p *CloudFlareProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return nil, err
}
......@@ -208,17 +207,17 @@ func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Cha
combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareUpdate, changes.UpdateNew, proxiedByDefault)...)
combinedChanges = append(combinedChanges, newCloudFlareChanges(cloudFlareDelete, changes.Delete, proxiedByDefault)...)
return p.submitChanges(combinedChanges)
return p.submitChanges(ctx, combinedChanges)
}
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *CloudFlareProvider) submitChanges(changes []*cloudFlareChange) error {
func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloudFlareChange) error {
// return early if there is nothing to change
if len(changes) == 0 {
return nil
}
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return err
}
......
......@@ -477,7 +477,7 @@ func TestCloudFlareZones(t *testing.T) {
zoneIDFilter: NewZoneIDFilter([]string{""}),
}
zones, err := provider.Zones()
zones, err := provider.Zones(context.Background())
if err != nil {
t.Fatal(err)
}
......
......@@ -57,12 +57,12 @@ type DigitalOceanChange struct {
}
// NewDigitalOceanProvider initializes a new DigitalOcean DNS based Provider.
func NewDigitalOceanProvider(domainFilter DomainFilter, dryRun bool) (*DigitalOceanProvider, error) {
func NewDigitalOceanProvider(ctx context.Context, domainFilter DomainFilter, dryRun bool) (*DigitalOceanProvider, error) {
token, ok := os.LookupEnv("DO_TOKEN")
if !ok {
return nil, fmt.Errorf("No token found")
}
oauthClient := oauth2.NewClient(context.TODO(), oauth2.StaticTokenSource(&oauth2.Token{
oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: token,
}))
client := godo.NewClient(oauthClient)
......@@ -76,10 +76,10 @@ func NewDigitalOceanProvider(domainFilter DomainFilter, dryRun bool) (*DigitalOc
}
// Zones returns the list of hosted zones.
func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) {
func (p *DigitalOceanProvider) Zones(ctx context.Context) ([]godo.Domain, error) {
result := []godo.Domain{}
zones, err := p.fetchZones()
zones, err := p.fetchZones(ctx)
if err != nil {
return nil, err
}
......@@ -95,13 +95,13 @@ func (p *DigitalOceanProvider) Zones() ([]godo.Domain, error) {
// Records returns the list of records in a given zone.
func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return nil, err
}
endpoints := []*endpoint.Endpoint{}
for _, zone := range zones {
records, err := p.fetchRecords(zone.Name)
records, err := p.fetchRecords(ctx, zone.Name)
if err != nil {
return nil, err
}
......@@ -124,11 +124,11 @@ func (p *DigitalOceanProvider) Records(ctx context.Context) ([]*endpoint.Endpoin
return endpoints, nil
}
func (p *DigitalOceanProvider) fetchRecords(zoneName string) ([]godo.DomainRecord, error) {
func (p *DigitalOceanProvider) fetchRecords(ctx context.Context, zoneName string) ([]godo.DomainRecord, error) {
allRecords := []godo.DomainRecord{}
listOptions := &godo.ListOptions{}
for {
records, resp, err := p.Client.Records(context.TODO(), zoneName, listOptions)
records, resp, err := p.Client.Records(ctx, zoneName, listOptions)
if err != nil {
return nil, err
}
......@@ -149,11 +149,11 @@ func (p *DigitalOceanProvider) fetchRecords(zoneName string) ([]godo.DomainRecor
return allRecords, nil
}
func (p *DigitalOceanProvider) fetchZones() ([]godo.Domain, error) {
func (p *DigitalOceanProvider) fetchZones(ctx context.Context) ([]godo.Domain, error) {
allZones := []godo.Domain{}
listOptions := &godo.ListOptions{}
for {
zones, resp, err := p.Client.List(context.TODO(), listOptions)
zones, resp, err := p.Client.List(ctx, listOptions)
if err != nil {
return nil, err
}
......@@ -175,13 +175,13 @@ func (p *DigitalOceanProvider) fetchZones() ([]godo.Domain, error) {
}
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) error {
func (p *DigitalOceanProvider) submitChanges(ctx context.Context, changes []*DigitalOceanChange) error {
// return early if there is nothing to change
if len(changes) == 0 {
return nil
}
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return err
}
......@@ -189,7 +189,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
// separate into per-zone change sets to be passed to the API.
changesByZone := digitalOceanChangesByZone(zones, changes)
for zoneName, changes := range changesByZone {
records, err := p.fetchRecords(zoneName)
records, err := p.fetchRecords(ctx, zoneName)
if err != nil {
log.Errorf("Failed to list records in the zone: %s", zoneName)
continue
......@@ -225,7 +225,7 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
switch change.Action {
case DigitalOceanCreate:
_, _, err = p.Client.CreateRecord(context.TODO(), zoneName,
_, _, err = p.Client.CreateRecord(ctx, zoneName,
&godo.DomainRecordEditRequest{
Data: change.ResourceRecordSet.Data,
Name: change.ResourceRecordSet.Name,
......@@ -237,13 +237,13 @@ func (p *DigitalOceanProvider) submitChanges(changes []*DigitalOceanChange) erro
}
case DigitalOceanDelete:
recordID := p.getRecordID(records, change.ResourceRecordSet)
_, err = p.Client.DeleteRecord(context.TODO(), zoneName, recordID)
_, err = p.Client.DeleteRecord(ctx, zoneName, recordID)
if err != nil {
return err
}
case DigitalOceanUpdate:
recordID := p.getRecordID(records, change.ResourceRecordSet)
_, _, err = p.Client.EditRecord(context.TODO(), zoneName, recordID,
_, _, err = p.Client.EditRecord(ctx, zoneName, recordID,
&godo.DomainRecordEditRequest{
Data: change.ResourceRecordSet.Data,
Name: change.ResourceRecordSet.Name,
......@@ -267,7 +267,7 @@ func (p *DigitalOceanProvider) ApplyChanges(ctx context.Context, changes *plan.C
combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanUpdate, changes.UpdateNew)...)
combinedChanges = append(combinedChanges, newDigitalOceanChanges(DigitalOceanDelete, changes.Delete)...)
return p.submitChanges(combinedChanges)
return p.submitChanges(ctx, combinedChanges)
}
// newDigitalOceanChanges returns a collection of Changes based on the given records and action.
......
......@@ -413,7 +413,7 @@ func TestDigitalOceanZones(t *testing.T) {
domainFilter: NewDomainFilter([]string{"com"}),
}
zones, err := provider.Zones()
zones, err := provider.Zones(context.Background())
if err != nil {
t.Fatal(err)
}
......@@ -445,12 +445,12 @@ func TestDigitalOceanApplyChanges(t *testing.T) {
func TestNewDigitalOceanProvider(t *testing.T) {
_ = os.Setenv("DO_TOKEN", "xxxxxxxxxxxxxxxxx")
_, err := NewDigitalOceanProvider(NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
_, err := NewDigitalOceanProvider(context.Background(), NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
if err != nil {
t.Errorf("should not fail, %s", err)
}
_ = os.Unsetenv("DO_TOKEN")
_, err = NewDigitalOceanProvider(NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
_, err = NewDigitalOceanProvider(context.Background(), NewDomainFilter([]string{"ext-dns-test.zalando.to."}), true)
if err == nil {
t.Errorf("expected to fail")
}
......@@ -494,7 +494,7 @@ func TestDigitalOceanRecord(t *testing.T) {
Client: &mockDigitalOceanClient{},
}
records, err := provider.fetchRecords("example.com")
records, err := provider.fetchRecords(context.Background(), "example.com")
if err != nil {
t.Fatal(err)
}
......
......@@ -175,13 +175,13 @@ func (ep *ExoscaleProvider) ApplyChanges(ctx context.Context, changes *plan.Chan
func (ep *ExoscaleProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
endpoints := make([]*endpoint.Endpoint, 0)
domains, err := ep.client.GetDomains(context.TODO())
domains, err := ep.client.GetDomains(ctx)
if err != nil {
return nil, err
}
for _, d := range domains {
record, err := ep.client.GetRecords(context.TODO(), d.Name)
record, err := ep.client.GetRecords(ctx, d.Name)
if err != nil {
return nil, err
}
......
......@@ -116,11 +116,13 @@ type GoogleProvider struct {
managedZonesClient managedZonesServiceInterface
// A client for managing change sets
changesClient changesServiceInterface
// The context parameter to be passed for gcloud API calls.
ctx context.Context
}
// NewGoogleProvider initializes a new Google CloudDNS based Provider.
func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, batchChangeSize int, batchChangeInterval time.Duration, dryRun bool) (*GoogleProvider, error) {
gcloud, err := google.DefaultClient(context.TODO(), dns.NdevClouddnsReadwriteScope)
func NewGoogleProvider(ctx context.Context, project string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, batchChangeSize int, batchChangeInterval time.Duration, dryRun bool) (*GoogleProvider, error) {
gcloud, err := google.DefaultClient(ctx, dns.NdevClouddnsReadwriteScope)
if err != nil {
return nil, err
}
......@@ -132,7 +134,7 @@ func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter Z
},
})
dnsClient, err := dns.NewService(context.TODO(), option.WithHTTPClient(gcloud))
dnsClient, err := dns.NewService(ctx, option.WithHTTPClient(gcloud))
if err != nil {
return nil, err
}
......@@ -155,13 +157,14 @@ func NewGoogleProvider(project string, domainFilter DomainFilter, zoneIDFilter Z
resourceRecordSetsClient: resourceRecordSetsService{dnsClient.ResourceRecordSets},
managedZonesClient: managedZonesService{dnsClient.ManagedZones},
changesClient: changesService{dnsClient.Changes},
ctx: ctx,
}
return provider, nil
}
// Zones returns the list of hosted zones.
func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) {
func (p *GoogleProvider) Zones(ctx context.Context) (map[string]*dns.ManagedZone, error) {
zones := make(map[string]*dns.ManagedZone)
f := func(resp *dns.ManagedZonesListResponse) error {
......@@ -178,7 +181,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) {
}
log.Debugf("Matching zones against domain filters: %v", p.domainFilter.filters)
if err := p.managedZonesClient.List(p.project).Pages(context.TODO(), f); err != nil {
if err := p.managedZonesClient.List(p.project).Pages(ctx, f); err != nil {
return nil, err
}
......@@ -199,7 +202,7 @@ func (p *GoogleProvider) Zones() (map[string]*dns.ManagedZone, error) {
// Records returns the list of records in all relevant zones.
func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return nil, err
}
......@@ -230,7 +233,7 @@ func (p *GoogleProvider) CreateRecords(endpoints []*endpoint.Endpoint) error {
change.Additions = append(change.Additions, p.newFilteredRecords(endpoints)...)
return p.submitChange(change)
return p.submitChange(p.ctx, change)
}
// UpdateRecords updates a given set of old records to a new set of records in a given hosted zone.
......@@ -240,7 +243,7 @@ func (p *GoogleProvider) UpdateRecords(records, oldRecords []*endpoint.Endpoint)
change.Additions = append(change.Additions, p.newFilteredRecords(records)...)
change.Deletions = append(change.Deletions, p.newFilteredRecords(oldRecords)...)
return p.submitChange(change)
return p.submitChange(p.ctx, change)
}
// DeleteRecords deletes a given set of DNS records in a given zone.
......@@ -249,7 +252,7 @@ func (p *GoogleProvider) DeleteRecords(endpoints []*endpoint.Endpoint) error {
change.Deletions = append(change.Deletions, p.newFilteredRecords(endpoints)...)
return p.submitChange(change)
return p.submitChange(p.ctx, change)
}
// ApplyChanges applies a given set of changes in a given zone.
......@@ -263,7 +266,7 @@ func (p *GoogleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
change.Deletions = append(change.Deletions, p.newFilteredRecords(changes.Delete)...)
return p.submitChange(change)
return p.submitChange(ctx, change)
}
// newFilteredRecords returns a collection of RecordSets based on the given endpoints and domainFilter.
......@@ -280,13 +283,13 @@ func (p *GoogleProvider) newFilteredRecords(endpoints []*endpoint.Endpoint) []*d
}
// submitChange takes a zone and a Change and sends it to Google.
func (p *GoogleProvider) submitChange(change *dns.Change) error {
func (p *GoogleProvider) submitChange(ctx context.Context, change *dns.Change) error {
if len(change.Additions) == 0 && len(change.Deletions) == 0 {
log.Info("All records are already up to date")
return nil
}
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return err
}
......
......@@ -194,7 +194,7 @@ func hasTrailingDot(target string) bool {
func TestGoogleZonesIDFilter(t *testing.T) {
provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"10002"}), false, []*endpoint.Endpoint{})
zones, err := provider.Zones()
zones, err := provider.Zones(context.Background())
require.NoError(t, err)
validateZones(t, zones, map[string]*dns.ManagedZone{
......@@ -205,7 +205,7 @@ func TestGoogleZonesIDFilter(t *testing.T) {
func TestGoogleZonesNameFilter(t *testing.T) {
provider := newGoogleProviderZoneOverlap(t, NewDomainFilter([]string{"cluster.local."}), NewZoneIDFilter([]string{"internal-2"}), false, []*endpoint.Endpoint{})
zones, err := provider.Zones()
zones, err := provider.Zones(context.Background())
require.NoError(t, err)
validateZones(t, zones, map[string]*dns.ManagedZone{
......@@ -216,7 +216,7 @@ func TestGoogleZonesNameFilter(t *testing.T) {
func TestGoogleZones(t *testing.T) {
provider := newGoogleProvider(t, NewDomainFilter([]string{"ext-dns-test-2.gcp.zalan.do."}), NewZoneIDFilter([]string{""}), false, []*endpoint.Endpoint{})
zones, err := provider.Zones()
zones, err := provider.Zones(context.Background())
require.NoError(t, err)
validateZones(t, zones, map[string]*dns.ManagedZone{
......@@ -777,7 +777,7 @@ func setupGoogleRecords(t *testing.T, provider *GoogleProvider, endpoints []*end
func clearGoogleRecords(t *testing.T, provider *GoogleProvider, zone string) {
recordSets := []*dns.ResourceRecordSet{}
require.NoError(t, provider.resourceRecordSetsClient.List(provider.project, zone).Pages(context.TODO(), func(resp *dns.ResourceRecordSetsListResponse) error {
require.NoError(t, provider.resourceRecordSetsClient.List(provider.project, zone).Pages(context.Background(), func(resp *dns.ResourceRecordSetsListResponse) error {
for _, r := range resp.Rrsets {
switch r.Type {
case endpoint.RecordTypeA, endpoint.RecordTypeCNAME:
......
......@@ -101,8 +101,8 @@ func NewLinodeProvider(domainFilter DomainFilter, dryRun bool, appVersion string
}
// Zones returns the list of hosted zones.
func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) {
zones, err := p.fetchZones()
func (p *LinodeProvider) Zones(ctx context.Context) ([]*linodego.Domain, error) {
zones, err := p.fetchZones(ctx)
if err != nil {
return nil, err
}
......@@ -112,7 +112,7 @@ func (p *LinodeProvider) Zones() ([]*linodego.Domain, error) {
// Records returns the list of records in a given zone.
func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
zones, err := p.Zones()
zones, err := p.Zones(ctx)
if err != nil {
return nil, err
}
......@@ -120,7 +120,7 @@ func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err
var endpoints []*endpoint.Endpoint
for _, zone := range zones {
records, err := p.fetchRecords(zone.ID)
records, err := p.fetchRecords(ctx, zone.ID)
if err != nil {
return nil, err
}
......@@ -143,8 +143,8 @@ func (p *LinodeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err
return endpoints, nil
}
func (p *LinodeProvider) fetchRecords(domainID int) ([]*linodego.DomainRecord, error) {
records, err := p.Client.ListDomainRecords(context.TODO(), domainID, nil)
func (p *LinodeProvider) fetchRecords(ctx context.Context, domainID int) ([]*linodego.DomainRecord, error) {
records, err := p.Client.ListDomainRecords(ctx, domainID, nil)
if err != nil {
return nil, err
}
......@@ -152,10 +152,10 @@ func (p *LinodeProvider) fetchRecords(domainID int) ([]*linodego.DomainRecord, e
return records, nil
}
func (p *LinodeProvider) fetchZones() ([]*linodego.Domain, error) {
func (p *LinodeProvider) fetchZones(ctx context.Context) ([]*linodego.Domain, error) {
var zones []*linodego.Domain
allZones, err := p.Client.ListDomains(context.TODO(), linodego.NewListOptions(0, ""))
allZones, err := p.Client.ListDomains(ctx, linodego.NewListOptions(0, ""))
if err != nil {
return nil, err
......@@ -173,7 +173,7 @@ func (p *LinodeProvider) fetchZones() ([]*linodego.Domain, error) {
}
// submitChanges takes a zone and a collection of Changes and sends them as a single transaction.
func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
func (p *LinodeProvider) submitChanges(ctx context.Context, changes LinodeChanges) error {
for _, change := range changes.Creates {
logFields := log.Fields{
"record": change.Options.Name,
......@@ -187,7 +187,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
if p.DryRun {
log.WithFields(logFields).Info("Would create record.")
} else if _, err := p.Client.CreateDomainRecord(context.TODO(), change.Domain.ID, change.Options); err != nil {
} else if _, err := p.Client.CreateDomainRecord(ctx, change.Domain.ID, change.Options); err != nil {
log.WithFields(logFields).Errorf(
"Failed to Create record: %v",
err,
......@@ -208,7 +208,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
if p.DryRun {
log.WithFields(logFields).Info("Would delete record.")
} else if err := p.Client.DeleteDomainRecord(context.TODO(), change.Domain.ID, change.DomainRecord.ID); err != nil {
} else if err := p.Client.DeleteDomainRecord(ctx, change.Domain.ID, change.DomainRecord.ID); err != nil {
log.WithFields(logFields).Errorf(
"Failed to Delete record: %v",
err,
......@@ -229,7 +229,7 @@ func (p *LinodeProvider) submitChanges(changes LinodeChanges) error {
if p.DryRun {
log.WithFields(logFields).Info("Would update record.")
} else if _, err := p.Client.UpdateDomainRecord(context.TODO(), change.Domain.ID, change.DomainRecord.ID, change.Options); err != nil {
} else if _, err := p.Client.UpdateDomainRecord(ctx, change.Domain.ID, change.DomainRecord.ID, change.Options); err != nil {
log.WithFields(logFields).Errorf(
"Failed to Update record: %v",
err,
......@@ -259,7 +259,7 @@ func getPriority() *int {
func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
recordsByZoneID := make(map[string][]*linodego.DomainRecord)
zones, err := p.fetchZones()
zones, err := p.fetchZones(ctx)
if err != nil {
return err
......@@ -276,7 +276,7 @@ func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
// Fetch records for each zone
for _, zone := range zones {
records, err := p.fetchRecords(zone.ID)
records, err := p.fetchRecords(ctx, zone.ID)
if err != nil {
return err
......@@ -484,7 +484,7 @@ func (p *LinodeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
}
}
return p.submitChanges(LinodeChanges{
return p.submitChanges(ctx, LinodeChanges{
Creates: linodeCreates,
Deletes: linodeDeletes,
Updates: linodeUpdates,
......
......@@ -160,7 +160,7 @@ func TestLinodeStripRecordName(t *testing.T) {
}))
}
func TestLinodeFetchZonesNoFiilters(t *testing.T) {
func TestLinodeFetchZonesNoFilters(t *testing.T) {
mockDomainClient := MockDomainClient{}
provider := &LinodeProvider{
......@@ -176,7 +176,7 @@ func TestLinodeFetchZonesNoFiilters(t *testing.T) {
).Return(createZones(), nil).Once()
expected := createZones()
actual, err := provider.fetchZones()
actual, err := provider.fetchZones(context.Background())
require.NoError(t, err)
mockDomainClient.AssertExpectations(t)
......@@ -202,7 +202,7 @@ func TestLinodeFetchZonesWithFilter(t *testing.T) {
{ID: 1, Domain: "foo.com"},
{ID: 3, Domain: "baz.com"},
}
actual, err := provider.fetchZones()
actual, err := provider.fetchZones(context.Background())
require.NoError(t, err)
mockDomainClient.AssertExpectations(t)
......
......@@ -225,7 +225,7 @@ type PDNSProvider struct {
}
// NewPDNSProvider initializes a new PowerDNS based Provider.
func NewPDNSProvider(config PDNSConfig) (*PDNSProvider, error) {
func NewPDNSProvider(ctx context.Context, config PDNSConfig) (*PDNSProvider, error) {
// Do some input validation
......@@ -252,7 +252,7 @@ func NewPDNSProvider(config PDNSConfig) (*PDNSProvider, error) {
provider := &PDNSProvider{
client: &PDNSAPIClient{
dryRun: config.DryRun,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: config.APIKey}),
authCtx: context.WithValue(ctx, pgo.ContextAPIKey, pgo.APIKey{Key: config.APIKey}),
client: pgo.NewAPIClient(pdnsClientConfig),
domainFilter: config.DomainFilter,
},
......
......@@ -495,21 +495,21 @@ var (
DomainFilterEmptyClient = &PDNSAPIClient{
dryRun: false,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
client: pgo.NewAPIClient(pgo.NewConfiguration()),
domainFilter: DomainFilterListEmpty,
}
DomainFilterSingleClient = &PDNSAPIClient{
dryRun: false,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
client: pgo.NewAPIClient(pgo.NewConfiguration()),
domainFilter: DomainFilterListSingle,
}
DomainFilterMultipleClient = &PDNSAPIClient{
dryRun: false,
authCtx: context.WithValue(context.TODO(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
authCtx: context.WithValue(context.Background(), pgo.ContextAPIKey, pgo.APIKey{Key: "TEST-API-KEY"}),
client: pgo.NewAPIClient(pgo.NewConfiguration()),
domainFilter: DomainFilterListMultiple,
}
......@@ -639,124 +639,148 @@ type NewPDNSProviderTestSuite struct {
func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreate() {
_, err := NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
DomainFilter: NewDomainFilter([]string{""}),
})
_, err := NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
DomainFilter: NewDomainFilter([]string{""}),
})
assert.Error(suite.T(), err, "--pdns-api-key should be specified")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{"example.com", "example.org"}),
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{"example.com", "example.org"}),
})
assert.Nil(suite.T(), err, "--domain-filter should raise no error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
DryRun: true,
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
DryRun: true,
})
assert.Error(suite.T(), err, "--dry-run should raise an error")
// This is our "regular" code path, no error should be thrown
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
})
assert.Nil(suite.T(), err, "Regular case should raise no error")
}
func (suite *NewPDNSProviderTestSuite) TestPDNSProviderCreateTLS() {
_, err := NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
})
_, err := NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
})
assert.Nil(suite.T(), err, "Omitted TLS Config case should raise no error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: false,
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: false,
},
})
assert.Nil(suite.T(), err, "Disabled TLS Config should raise no error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: false,
CAFilePath: "/path/to/ca.crt",
ClientCertFilePath: "/path/to/cert.pem",
ClientCertKeyFilePath: "/path/to/cert-key.pem",
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: false,
CAFilePath: "/path/to/ca.crt",
ClientCertFilePath: "/path/to/cert.pem",
ClientCertKeyFilePath: "/path/to/cert-key.pem",
},
})
assert.Nil(suite.T(), err, "Disabled TLS Config with additional flags should raise no error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
},
})
assert.Error(suite.T(), err, "Enabled TLS Config without --tls-ca should raise an error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
},
})
assert.Nil(suite.T(), err, "Enabled TLS Config with --tls-ca should raise no error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
ClientCertFilePath: "../internal/testresources/client-cert.pem",
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
ClientCertFilePath: "../internal/testresources/client-cert.pem",
},
})
assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert only should raise an error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem",
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem",
},
})
assert.Error(suite.T(), err, "Enabled TLS Config with --tls-client-cert-key only should raise an error")
_, err = NewPDNSProvider(PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
ClientCertFilePath: "../internal/testresources/client-cert.pem",
ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem",
},
})
_, err = NewPDNSProvider(
context.Background(),
PDNSConfig{
Server: "http://localhost:8081",
APIKey: "foo",
DomainFilter: NewDomainFilter([]string{""}),
TLSConfig: TLSConfig{
TLSEnabled: true,
CAFilePath: "../internal/testresources/ca.pem",
ClientCertFilePath: "../internal/testresources/client-cert.pem",
ClientCertKeyFilePath: "../internal/testresources/client-cert-key.pem",
},
})
assert.Nil(suite.T(), err, "Enabled TLS Config with all flags should raise no error")
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment