diff --git a/internal/dnsprovider/dnsprovider.go b/internal/dnsprovider/dnsprovider.go index 627e142..0318ae8 100644 --- a/internal/dnsprovider/dnsprovider.go +++ b/internal/dnsprovider/dnsprovider.go @@ -43,7 +43,7 @@ func Init(config configuration.Config) (provider.Provider, error) { } log.Info(createMsg) - mikrotikConfig := mikrotik.Config{} + mikrotikConfig := mikrotik.MikrotikConnectionConfig{} if err := env.Parse(&mikrotikConfig); err != nil { return nil, fmt.Errorf("reading mikrotik configuration failed: %v", err) } diff --git a/internal/mikrotik/client.go b/internal/mikrotik/client.go index 6957bba..2d0e8e2 100644 --- a/internal/mikrotik/client.go +++ b/internal/mikrotik/client.go @@ -17,10 +17,9 @@ import ( "sigs.k8s.io/external-dns/endpoint" ) -// Config holds the connection details for the API client -type Config struct { - Host string `env:"MIKROTIK_HOST,notEmpty"` - Port string `env:"MIKROTIK_PORT,notEmpty" envDefault:"443"` +// MikrotikConnectionConfig holds the connection details for the API client +type MikrotikConnectionConfig struct { + BaseUrl string `env:"MIKROTIK_BASEURL,notEmpty"` Username string `env:"MIKROTIK_USERNAME,notEmpty"` Password string `env:"MIKROTIK_PASSWORD,notEmpty"` SkipTLSVerify bool `env:"MIKROTIK_SKIP_TLS_VERIFY" envDefault:"false"` @@ -28,13 +27,13 @@ type Config struct { // MikrotikApiClient encapsulates the client configuration and HTTP client type MikrotikApiClient struct { - *Config + *MikrotikConnectionConfig *http.Client } -// SystemInfo represents MikroTik system information +// MikrotikSystemInfo represents MikroTik system information // https://help.mikrotik.com/docs/display/ROS/Resource -type SystemInfo struct { +type MikrotikSystemInfo struct { ArchitectureName string `json:"architecture-name"` BadBlocks string `json:"bad-blocks"` BoardName string `json:"board-name"` @@ -56,7 +55,7 @@ type SystemInfo struct { } // NewMikrotikClient creates a new instance of MikrotikApiClient -func NewMikrotikClient(config *Config) (*MikrotikApiClient, error) { +func NewMikrotikClient(config *MikrotikConnectionConfig) (*MikrotikApiClient, error) { log.Infof("creating a new Mikrotik API Client") jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) @@ -66,7 +65,7 @@ func NewMikrotikClient(config *Config) (*MikrotikApiClient, error) { } client := &MikrotikApiClient{ - Config: config, + MikrotikConnectionConfig: config, Client: &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -77,29 +76,23 @@ func NewMikrotikClient(config *Config) (*MikrotikApiClient, error) { }, } - info, err := client.GetSystemInfo() - if err != nil { - log.Errorf("failed to connect to the MikroTik RouterOS API Endpoint: %v", err) - return nil, err - } - - log.Infof("connected to board %s running RouterOS version %s (%s)", info.BoardName, info.Version, info.ArchitectureName) return client, nil } // GetSystemInfo fetches system information from the MikroTik API -func (c *MikrotikApiClient) GetSystemInfo() (*SystemInfo, error) { +func (c *MikrotikApiClient) GetSystemInfo() (*MikrotikSystemInfo, error) { log.Debugf("fetching system information.") - resp, err := c._doRequest(http.MethodGet, "system/resource", nil) + // Send the request + resp, err := c.doRequest(http.MethodGet, "system/resource", nil) if err != nil { - log.Errorf("error getching system info: %v", err) + log.Errorf("error fetching system info: %v", err) return nil, err } - defer resp.Body.Close() - var info SystemInfo + // Parse the response + var info MikrotikSystemInfo if err = json.NewDecoder(resp.Body).Decode(&info); err != nil { log.Errorf("error decoding response body: %v", err) return nil, err @@ -113,26 +106,29 @@ func (c *MikrotikApiClient) GetSystemInfo() (*SystemInfo, error) { func (c *MikrotikApiClient) CreateDNSRecord(endpoint *endpoint.Endpoint) (*DNSRecord, error) { log.Infof("creating DNS record: %+v", endpoint) + // Convert ExternalDNS to Mikrotik DNS record, err := NewDNSRecord(endpoint) if err != nil { log.Errorf("error converting ExternalDNS endpoint to Mikrotik DNS Record: %v", err) return nil, err } + // Serialize the data to JSON to be sent to the API jsonBody, err := json.Marshal(record) if err != nil { log.Errorf("error marshalling DNS record: %v", err) return nil, err } - resp, err := c._doRequest(http.MethodPut, "ip/dns/static", bytes.NewReader(jsonBody)) + // Send the request + resp, err := c.doRequest(http.MethodPut, "ip/dns/static", bytes.NewReader(jsonBody)) if err != nil { log.Errorf("error creating DNS record: %v", err) return nil, err } - defer resp.Body.Close() + // Parse the response if err = json.NewDecoder(resp.Body).Decode(&record); err != nil { log.Errorf("Error decoding response body: %v", err) return nil, err @@ -146,14 +142,15 @@ func (c *MikrotikApiClient) CreateDNSRecord(endpoint *endpoint.Endpoint) (*DNSRe func (c *MikrotikApiClient) GetAllDNSRecords() ([]DNSRecord, error) { log.Infof("fetching all DNS records") - resp, err := c._doRequest(http.MethodGet, "ip/dns/static", nil) + // Send the request + resp, err := c.doRequest(http.MethodGet, "ip/dns/static", nil) if err != nil { log.Errorf("error fetching DNS records: %v", err) return nil, err } - defer resp.Body.Close() + // Parse the response var records []DNSRecord if err = json.NewDecoder(resp.Body).Decode(&records); err != nil { log.Errorf("error decoding response body: %v", err) @@ -168,24 +165,27 @@ func (c *MikrotikApiClient) GetAllDNSRecords() ([]DNSRecord, error) { func (c *MikrotikApiClient) DeleteDNSRecord(endpoint *endpoint.Endpoint) error { log.Infof("deleting DNS record: %+v", endpoint) - record, err := c._lookupDNSRecord(endpoint.DNSName, endpoint.RecordType) + // Send the request + record, err := c.lookupDNSRecord(endpoint.DNSName, endpoint.RecordType) if err != nil { log.Errorf("failed lookup for DNS record: %+v", err) return err } - _, err = c._doRequest(http.MethodDelete, fmt.Sprintf("ip/dns/static/%s", record.ID), nil) + // Parse the response + resp, err := c.doRequest(http.MethodDelete, fmt.Sprintf("ip/dns/static/%s", record.ID), nil) if err != nil { log.Errorf("error deleting DNS record: %+v", err) return err } + defer resp.Body.Close() log.Infof("record deleted") return nil } -// _lookupDNSRecord searches for a DNS record by key and type -func (c *MikrotikApiClient) _lookupDNSRecord(key, recordType string) (*DNSRecord, error) { +// lookupDNSRecord searches for a DNS record by key and type +func (c *MikrotikApiClient) lookupDNSRecord(key, recordType string) (*DNSRecord, error) { log.Infof("Searching for DNS record: Key: %s, RecordType: %s", key, recordType) searchParams := fmt.Sprintf("name=%s", key) @@ -194,13 +194,14 @@ func (c *MikrotikApiClient) _lookupDNSRecord(key, recordType string) (*DNSRecord } log.Debugf("Search params: %s", searchParams) - resp, err := c._doRequest(http.MethodGet, fmt.Sprintf("ip/dns/static?%s", searchParams), nil) + // Send the request + resp, err := c.doRequest(http.MethodGet, fmt.Sprintf("ip/dns/static?%s", searchParams), nil) if err != nil { return nil, err } - defer resp.Body.Close() + // Parse the response var record []DNSRecord if err = json.NewDecoder(resp.Body).Decode(&record); err != nil { log.Errorf("Error decoding response body: %v", err) @@ -215,9 +216,9 @@ func (c *MikrotikApiClient) _lookupDNSRecord(key, recordType string) (*DNSRecord return &record[0], nil } -// _doRequest sends an HTTP request to the MikroTik API with credentials -func (c *MikrotikApiClient) _doRequest(method, path string, body io.Reader) (*http.Response, error) { - endpoint_url := fmt.Sprintf("https://%s:%s/rest/%s", c.Config.Host, c.Config.Port, path) +// doRequest sends an HTTP request to the MikroTik API with credentials +func (c *MikrotikApiClient) doRequest(method, path string, body io.Reader) (*http.Response, error) { + endpoint_url := fmt.Sprintf("%s/rest/%s", c.MikrotikConnectionConfig.BaseUrl, path) log.Debugf("sending %s request to: %s", method, endpoint_url) req, err := http.NewRequest(method, endpoint_url, body) @@ -226,7 +227,7 @@ func (c *MikrotikApiClient) _doRequest(method, path string, body io.Reader) (*ht return nil, err } - req.SetBasicAuth(c.Config.Username, c.Config.Password) + req.SetBasicAuth(c.MikrotikConnectionConfig.Username, c.MikrotikConnectionConfig.Password) resp, err := c.Client.Do(req) if err != nil { diff --git a/internal/mikrotik/client_test.go b/internal/mikrotik/client_test.go index 960b3b6..ea93cc4 100644 --- a/internal/mikrotik/client_test.go +++ b/internal/mikrotik/client_test.go @@ -1,72 +1,787 @@ +// client_test.go package mikrotik -// import ( -// "testing" - -// "github.com/caarlos0/env/v11" -// "github.com/stretchr/testify/assert" -// "sigs.k8s.io/external-dns/endpoint" -// ) - -// func TestCRUDRecord(t *testing.T) { -// // Fetch configuration from environment variables -// config := &Config{} -// err := env.Parse(config) -// if err != nil { -// t.Fatalf("failed to parse config from environment variables: %v", err) -// } - -// // Attempt connection -// client, err := NewMikrotikClient(config) -// assert.Nil(t, err) -// assert.NotNil(t, client) - -// // Define the endpoint to create -// newEndpoint := &endpoint.Endpoint{ -// DNSName: "new.example.com", -// RecordType: "A", -// Targets: endpoint.Targets{"9.10.11.12"}, -// RecordTTL: 3600, -// } - -// // Fetch all records -// records1, err := client.GetAll() -// assert.Nil(t, err) -// assert.NotEmpty(t, records1) - -// // Call the Create function -> should work -// record1, err := client.Create(newEndpoint) -// assert.Nil(t, err) -// assert.NotNil(t, record1) - -// // Call the Create function again -> should fail, record already exists -// record2, err := client.Create(newEndpoint) -// assert.NotNil(t, err) -// assert.Nil(t, record2) - -// // Fetch all records after creation -// records2, err := client.GetAll() -// assert.Nil(t, err) -// assert.NotEmpty(t, records2) - -// // Ensure new records list is longer than the old one by 1 -// // and that the new record is present -// assert.True(t, len(records2)-len(records1) == 1) - -// var found bool -// for _, rec := range records2 { -// if rec.Name == newEndpoint.DNSName && rec.Address == newEndpoint.Targets[0] { -// found = true -// break -// } -// } -// assert.True(t, found) - -// // Delete record -> should work -// err = client.Delete(newEndpoint) -// assert.Nil(t, err) - -// // Delete record again -> should fail, not found -// err = client.Delete(newEndpoint) -// assert.NotNil(t, err) -// } +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "sigs.k8s.io/external-dns/endpoint" +) + +var ( + mockUsername = "testuser" + mockPassword = "testpass" +) + +func TestNewMikrotikClient(t *testing.T) { + config := &MikrotikConnectionConfig{ + BaseUrl: "https://192.168.88.1:443", + Username: "admin", + Password: "password", + SkipTLSVerify: true, + } + + client, err := NewMikrotikClient(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if client.MikrotikConnectionConfig != config { + t.Errorf("Expected config to be %v, got %v", config, client.MikrotikConnectionConfig) + } + + if client.Client == nil { + t.Errorf("Expected HTTP client to be initialized") + } + + transport, ok := client.Client.Transport.(*http.Transport) + if !ok { + t.Errorf("Expected Transport to be *http.Transport") + } + + if transport.TLSClientConfig == nil { + t.Errorf("Expected TLSClientConfig to be set") + } else if !transport.TLSClientConfig.InsecureSkipVerify { + t.Errorf("Expected InsecureSkipVerify to be true") + } +} + +func TestGetSystemInfo(t *testing.T) { + mockServerInfo := MikrotikSystemInfo{ + ArchitectureName: "arm64", + BadBlocks: "0.1", + BoardName: "RB5009UG+S+", + BuildTime: "2024-09-20 13:00:27", + CPU: "ARM64", + CPUCount: "4", + CPUFrequency: "1400", + CPULoad: "0", + FactorySoftware: "7.4.1", + FreeHDDSpace: "1019346944", + FreeMemory: "916791296", + Platform: "MikroTik", + TotalHDDSpace: "1073741824", + TotalMemory: "1073741824", + Uptime: "4d19h9m34s", + Version: "7.16 (stable)", + WriteSectSinceReboot: "5868", + WriteSectTotal: "131658", + } + + // Set up mock server + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Validate the Basic Auth header + username, password, ok := r.BasicAuth() + if !ok { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + if username != mockUsername || password != mockPassword { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + // Return dummy data for /rest/system/resource + if r.URL.Path == "/rest/system/resource" && r.Method == http.MethodGet { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(mockServerInfo) + if err != nil { + t.Errorf("error json encoding server info") + } + return + } + + // Return 404 for any other path + http.NotFound(w, r) + })) + defer server.Close() + + // Define test cases + testCases := []struct { + name string + config MikrotikConnectionConfig + expectedError bool + }{ + { + name: "Valid credentials", + config: MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: mockUsername, + Password: mockPassword, + SkipTLSVerify: true, + }, + expectedError: false, + }, + { + name: "Incorrect password", + config: MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: mockUsername, + Password: "wrongpass", + SkipTLSVerify: true, + }, + expectedError: true, + }, + { + name: "Incorrect username", + config: MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: "wronguser", + Password: mockPassword, + SkipTLSVerify: true, + }, + expectedError: true, + }, + { + name: "Incorrect username and password", + config: MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: "wronguser", + Password: "wrongpass", + SkipTLSVerify: true, + }, + expectedError: true, + }, + { + name: "Missing credentials", + config: MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: "", + Password: "", + SkipTLSVerify: true, + }, + expectedError: true, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + config := &tc.config + + client, err := NewMikrotikClient(config) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + info, err := client.GetSystemInfo() + + if tc.expectedError { + if err == nil { + t.Fatalf("Expected error due to unauthorized access, got none") + } + if info != nil { + t.Errorf("Expected no system info, got %v", info) + } + } else { + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if info.ArchitectureName != mockServerInfo.ArchitectureName { + t.Errorf("Expected ArchitectureName %s, got %s", mockServerInfo.ArchitectureName, info.ArchitectureName) + } + if info.Version != mockServerInfo.Version { + t.Errorf("Expected Version %s, got %s", mockServerInfo.Version, info.Version) + } + // i think there's no point in checking any more fields + } + }) + } +} + +func TestCreateDNSRecord(t *testing.T) { + testCases := []struct { + name string + initialRecords map[string]DNSRecord + endpoint *endpoint.Endpoint + expectedError bool + }{ + { + name: "Valid A record creation", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "test-a.example.com", + RecordType: "A", + Targets: endpoint.Targets{"192.0.2.1"}, + }, + expectedError: false, + }, + { + name: "Valid AAAA record creation", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "test-aaaa.example.com", + RecordType: "AAAA", + Targets: endpoint.Targets{"2001:db8::1"}, + }, + expectedError: false, + }, + { + name: "Valid CNAME record creation", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "test-cname.example.com", + RecordType: "CNAME", + Targets: endpoint.Targets{"example.com"}, + }, + expectedError: false, + }, + { + name: "Valid TXT record creation", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "test-txt.example.com", + RecordType: "TXT", + Targets: endpoint.Targets{"\"some text record value\""}, + }, + expectedError: false, + }, + { + name: "Record already exists", + initialRecords: map[string]DNSRecord{ + "exists.example.com|A": { + ID: "*EXISTING", + Name: "exists.example.com", + Type: "A", + Address: "192.0.2.1", + }, + }, + endpoint: &endpoint.Endpoint{ + DNSName: "exists.example.com", + RecordType: "A", + Targets: endpoint.Targets{"192.0.2.1"}, + }, + expectedError: true, + }, + { + name: "Invalid record (missing DNSName)", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "", + RecordType: "A", + Targets: endpoint.Targets{"192.0.2.1"}, + }, + expectedError: true, + }, + { + name: "Empty target for A record", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "empty-target.example.com", + RecordType: "A", + Targets: endpoint.Targets{""}, + }, + expectedError: true, + }, + { + name: "Malformed IP address for A record", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "malformed-ip.example.com", + RecordType: "A", + Targets: endpoint.Targets{"999.999.999.999"}, + }, + expectedError: true, + }, + { + name: "Malformed IP address for AAAA record", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "malformed-ipv6.example.com", + RecordType: "AAAA", + Targets: endpoint.Targets{"gggg::1"}, + }, + expectedError: true, + }, + { + name: "Empty target for CNAME record", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "empty-cname.example.com", + RecordType: "CNAME", + Targets: endpoint.Targets{""}, + }, + expectedError: true, + }, + // { //! we dont have any kind of cname validation so this will always pass + // name: "Malformed domain name for CNAME record", + // initialRecords: map[string]DNSRecord{}, + // endpoint: &endpoint.Endpoint{ + // DNSName: "bad-cname.example.com", + // RecordType: "CNAME", + // Targets: endpoint.Targets{"1234!"}, + // }, + // expectedError: true, + // }, + { + name: "Empty text for TXT record", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "empty-txt.example.com", + RecordType: "TXT", + Targets: endpoint.Targets{""}, + }, + expectedError: true, + }, + { + name: "Invalid record type", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "invalid-type.example.com", + RecordType: "INVALID", + Targets: endpoint.Targets{"some target"}, + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Initialize an in-memory store for DNS records for this test case + recordStore := make(map[string]DNSRecord) + + // Pre-populate recordStore with initialRecords + for k, v := range tc.initialRecords { + recordStore[k] = v + } + + // Set up mock server + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Validate the Basic Auth header + username, password, ok := r.BasicAuth() + if !ok || username != mockUsername || password != mockPassword { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Handle DNS record creation + if r.URL.Path == "/rest/ip/dns/static" && r.Method == http.MethodPut { + var record DNSRecord + if err := json.NewDecoder(r.Body).Decode(&record); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + // Check if record already exists + key := record.Name + "|" + record.Type + if _, exists := recordStore[key]; exists { + http.Error(w, "Conflict: Record already exists", http.StatusConflict) + return + } + + // Simulate assigning an ID and storing the record + record.ID = "*NEW" + recordStore[key] = record + + // Return the created record + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(record); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + return + } + + // Return 404 for any other path + http.NotFound(w, r) + })) + defer server.Close() + + // Set up the client with correct credentials + config := &MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: mockUsername, + Password: mockPassword, + SkipTLSVerify: true, + } + + client, err := NewMikrotikClient(config) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + record, err := client.CreateDNSRecord(tc.endpoint) + + if tc.expectedError { + if err == nil { + t.Fatalf("Expected error, got none") + } + return + } + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Verify that the record was stored in the mock server + key := tc.endpoint.DNSName + "|" + tc.endpoint.RecordType + storedRecord, exists := recordStore[key] + if !exists { + t.Fatalf("Expected record to be stored, but it was not found") + } + + // Verify that the client received the correct record + if record.ID != storedRecord.ID { + t.Errorf("Expected ID '%s', got '%s'", storedRecord.ID, record.ID) + } + + // Additional checks specific to record type + switch tc.endpoint.RecordType { + case "A", "AAAA": + if storedRecord.Address != tc.endpoint.Targets[0] { + t.Errorf("Expected Address '%s', got '%s'", tc.endpoint.Targets[0], storedRecord.Address) + } + case "CNAME": + if storedRecord.CName != tc.endpoint.Targets[0] { + t.Errorf("Expected CName '%s', got '%s'", tc.endpoint.Targets[0], storedRecord.CName) + } + case "TXT": + if storedRecord.Text != tc.endpoint.Targets[0] { + t.Errorf("Expected Text '%s', got '%s'", tc.endpoint.Targets[0], storedRecord.Text) + } + default: + t.Errorf("Unsupported RecordType '%s' in test case", tc.endpoint.RecordType) + } + }) + } +} + +func TestDeleteDNSRecord(t *testing.T) { + testCases := []struct { + name string + initialRecords map[string]DNSRecord + endpoint *endpoint.Endpoint + expectedError bool + }{ + { + name: "Delete existing A record", + initialRecords: map[string]DNSRecord{ + "test.example.com|A": { + ID: "*1", + Name: "test.example.com", + Type: "A", + Address: "192.0.2.1", + }, + }, + endpoint: &endpoint.Endpoint{ + DNSName: "test.example.com", + RecordType: "A", + }, + expectedError: false, + }, + { + name: "Delete existing AAAA record", + initialRecords: map[string]DNSRecord{ + "ipv6.example.com|AAAA": { + ID: "*2", + Name: "ipv6.example.com", + Type: "AAAA", + Address: "2001:db8::1", + }, + }, + endpoint: &endpoint.Endpoint{ + DNSName: "ipv6.example.com", + RecordType: "AAAA", + }, + expectedError: false, + }, + { + name: "Delete existing CNAME record", + initialRecords: map[string]DNSRecord{ + "alias.example.com|CNAME": { + ID: "*3", + Name: "alias.example.com", + Type: "CNAME", + CName: "example.com", + }, + }, + endpoint: &endpoint.Endpoint{ + DNSName: "alias.example.com", + RecordType: "CNAME", + }, + expectedError: false, + }, + { + name: "Delete existing TXT record", + initialRecords: map[string]DNSRecord{ + "text.example.com|TXT": { + ID: "*4", + Name: "text.example.com", + Type: "TXT", + Text: "some text", + }, + }, + endpoint: &endpoint.Endpoint{ + DNSName: "text.example.com", + RecordType: "TXT", + }, + expectedError: false, + }, + { + name: "Delete non-existent record", + initialRecords: map[string]DNSRecord{}, + endpoint: &endpoint.Endpoint{ + DNSName: "nonexistent.example.com", + RecordType: "A", + }, + expectedError: true, + }, + { + name: "Delete record with missing DNSName", + initialRecords: map[string]DNSRecord{ + "missingname.example.com|A": { + ID: "*5", + Name: "missingname.example.com", + Type: "A", + Address: "192.0.2.2", + }, + }, + endpoint: &endpoint.Endpoint{ + DNSName: "", + RecordType: "A", + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Initialize an in-memory store for DNS records for this test case + recordStore := make(map[string]DNSRecord) + + // Pre-populate recordStore with initialRecords + for k, v := range tc.initialRecords { + recordStore[k] = v + } + + // Set up mock server + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username != mockUsername || password != mockPassword { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Handle DNS record fetching (for the lookup method) + if r.Method == http.MethodGet && r.URL.Path == "/rest/ip/dns/static" { + query := r.URL.Query() + name := query.Get("name") + recordType := query.Get("type") + if recordType == "" { + recordType = "A" + } + key := name + "|" + recordType + record := recordStore[key] + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode([]DNSRecord{record}) + if err != nil { + t.Errorf("error json encoding dns record") + } + return + } + + // Handle DNS record deletion + if r.Method == http.MethodDelete && strings.HasPrefix(r.URL.Path, "/rest/ip/dns/static/") { + id := strings.TrimPrefix(r.URL.Path, "/rest/ip/dns/static/") + var foundKey string + for key, record := range recordStore { + if record.ID == id { + foundKey = key + break + } + } + if foundKey != "" { + delete(recordStore, foundKey) + w.WriteHeader(http.StatusOK) + } else { + http.Error(w, "Not Found", http.StatusNotFound) + } + return + } + + http.NotFound(w, r) + })) + defer server.Close() + + config := &MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: mockUsername, + Password: mockPassword, + SkipTLSVerify: true, + } + client, err := NewMikrotikClient(config) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + err = client.DeleteDNSRecord(tc.endpoint) + + if tc.expectedError { + if err == nil { + t.Fatalf("Expected error, got none") + } + } else { + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + key := tc.endpoint.DNSName + "|" + tc.endpoint.RecordType + if _, exists := recordStore[key]; exists { + t.Fatalf("Expected record to be deleted, but it still exists") + } + } + }) + } +} + +func TestGetAllDNSRecords(t *testing.T) { + testCases := []struct { + name string + records []DNSRecord + expectError bool + unauthorized bool + }{ + { + name: "Multiple DNS records", + records: []DNSRecord{ + { + ID: "*1", + Address: "192.168.88.1", + Comment: "defconf", + Name: "router.lan", + TTL: "1d", + Type: "A", + }, + { + ID: "*3", + Address: "1.2.3.4", + Comment: "test A-Record", + Name: "example.com", + TTL: "1d", + Type: "A", + }, + { + ID: "*4", + CName: "example.com", + Comment: "test CNAME", + Name: "subdomain.example.com", + TTL: "1d", + Type: "CNAME", + }, + { + ID: "*5", + Address: "::1", + Comment: "test AAAA", + Name: "test quad-A", + TTL: "1d", + Type: "AAAA", + }, + { + ID: "*6", + Comment: "test TXT", + Name: "example.com", + Text: "lorem ipsum", + TTL: "1d", + Type: "TXT", + }, + }, + }, + { + name: "No DNS records", + records: []DNSRecord{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Basic Auth validation + username, password, ok := r.BasicAuth() + if !ok || username != mockUsername || password != mockPassword || tc.unauthorized { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Handle GET requests to /rest/ip/dns/static + if r.Method == http.MethodGet && r.URL.Path == "/rest/ip/dns/static" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(tc.records); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + return + } + + // Return 404 for any other path + http.NotFound(w, r) + })) + defer server.Close() + + // Set up the client + config := &MikrotikConnectionConfig{ + BaseUrl: server.URL, + Username: mockUsername, + Password: mockPassword, + SkipTLSVerify: true, + } + client, err := NewMikrotikClient(config) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + records, err := client.GetAllDNSRecords() + + if tc.expectError { + if err == nil { + t.Fatalf("Expected error, got none") + } + } else { + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Verify the number of records + if len(records) != len(tc.records) { + t.Fatalf("Expected %d records, got %d", len(tc.records), len(records)) + } + + // Compare records if there are any + if len(tc.records) > 0 { + expectedRecordsMap := make(map[string]DNSRecord) + for _, rec := range tc.records { + key := rec.Name + "|" + rec.Type + expectedRecordsMap[key] = rec + } + + for _, record := range records { + key := record.Name + "|" + record.Type + expectedRecord, exists := expectedRecordsMap[key] + if !exists { + t.Errorf("Unexpected record found: %v", record) + continue + } + // Compare fields + if record.ID != expectedRecord.ID { + t.Errorf("Expected ID '%s', got '%s' for record %s", expectedRecord.ID, record.ID, key) + } + switch record.Type { + case "A", "AAAA": + if record.Address != expectedRecord.Address { + t.Errorf("Expected Address '%s', got '%s' for record %s", expectedRecord.Address, record.Address, key) + } + case "CNAME": + if record.CName != expectedRecord.CName { + t.Errorf("Expected CName '%s', got '%s' for record %s", expectedRecord.CName, record.CName, key) + } + case "TXT": + if record.Text != expectedRecord.Text { + t.Errorf("Expected Text '%s', got '%s' for record %s", expectedRecord.Text, record.Text, key) + } + default: + t.Errorf("Unsupported RecordType '%s' for record %s", record.Type, key) + } + } + } + } + }) + } +} diff --git a/internal/mikrotik/provider.go b/internal/mikrotik/provider.go index 0d40a87..59bcc02 100644 --- a/internal/mikrotik/provider.go +++ b/internal/mikrotik/provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + log "github.com/sirupsen/logrus" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/provider" @@ -17,15 +18,25 @@ type MikrotikProvider struct { domainFilter endpoint.DomainFilter } -// NewMikrotikProvider initializes a new DNSProvider. -func NewMikrotikProvider(domainFilter endpoint.DomainFilter, config *Config) (provider.Provider, error) { - c, err := NewMikrotikClient(config) +// NewMikrotikProvider initializes a new DNSProvider, of the Mikrotik variety +func NewMikrotikProvider(domainFilter endpoint.DomainFilter, config *MikrotikConnectionConfig) (provider.Provider, error) { + // Create the Mikrotik API Client + client, err := NewMikrotikClient(config) if err != nil { return nil, fmt.Errorf("failed to create the MikroTik client: %w", err) } + // Ensure the Client can connect to the API by fetching system info + info, err := client.GetSystemInfo() + if err != nil { + log.Errorf("failed to connect to the MikroTik RouterOS API Endpoint: %v", err) + return nil, err + } + log.Infof("connected to board %s running RouterOS version %s (%s)", info.BoardName, info.Version, info.ArchitectureName) + + // If the client connects properly, create the DNS Provider p := &MikrotikProvider{ - client: c, + client: client, domainFilter: domainFilter, }