diff --git a/src/wireguard/dns-malware.go b/src/wireguard/dns-malware.go new file mode 100644 index 0000000..b422908 --- /dev/null +++ b/src/wireguard/dns-malware.go @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2022 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +// #include +import "C" + +import ( + "fmt" + + "github.com/miekg/dns" +) + + +// parseDNSResponse parses the raw DNS response from a byte slice. +func parseDNSResponse(data []byte) *dns.Msg { + tag := cstring("WireGuard/GoBackend/parseDNSResponse") + // Create a new DNS message structure + dnsMsg := new(dns.Msg) + + // Unpack the DNS message from the raw data + err := dnsMsg.Unpack(data) + if err != nil { + C.__android_log_write(C.ANDROID_LOG_DEBUG, tag, cstring(fmt.Sprintf("Failed to unpack DNS message: %v", err))) + return nil + } + + // Check if this is a DNS response (QR bit set) + if dnsMsg.Response { + return dnsMsg + } + + // If not a response, return nil + return nil +} + + +/* + * WasDNSMalwareBlocked checks if the DNS packet contains a "blocked:m" TXT + * record under the "explanation.invalid" domain. If such a record is found, + * it returns true along with the domain being blocked. + * That record is found when our the DDG (malware) DNS server blocks domains that + * serve malware + */ +func WasDNSMalwareBlocked(dnsData []byte) (bool, string) { + dnsResponse := parseDNSResponse(dnsData) + if dnsResponse == nil { + return false, "" + } + + // Look for "explanation.invalid" in the additional section + for _, rr := range dnsResponse.Extra { + if txtRecord, ok := rr.(*dns.TXT); ok { + if txtRecord.Hdr.Name == "explanation.invalid." { + for _, txt := range txtRecord.Txt { + if txt == "blocked:m" { + // If there is a matching TXT record, return the blocked domain + if len(dnsResponse.Question) > 0 { + blockedDomain := dnsResponse.Question[0].Name + return true, blockedDomain + } + } + } + } + } + } + + return false, "" +} \ No newline at end of file diff --git a/src/wireguard/go.mod b/src/wireguard/go.mod index e9797b0..72c53ac 100644 --- a/src/wireguard/go.mod +++ b/src/wireguard/go.mod @@ -3,12 +3,16 @@ module golang.zx2c4.com/wireguard/android go 1.18 require ( - golang.org/x/net v0.10.0 - golang.org/x/sys v0.8.0 + golang.org/x/net v0.27.0 + golang.org/x/sys v0.22.0 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b ) require ( - golang.org/x/crypto v0.9.0 // indirect + github.com/miekg/dns v1.1.62 // indirect + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/tools v0.22.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) diff --git a/src/wireguard/go.sum b/src/wireguard/go.sum index 0b905bf..229c257 100644 --- a/src/wireguard/go.sum +++ b/src/wireguard/go.sum @@ -1,11 +1,25 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= diff --git a/src/wireguard/jni.c b/src/wireguard/jni.c index c90875b..de74f3e 100644 --- a/src/wireguard/jni.c +++ b/src/wireguard/jni.c @@ -18,6 +18,7 @@ extern char *wgVersion(); static JavaVM *JVM = NULL; static jobject GO_BACKEND = NULL; static jmethodID MID_SHOULD_ALLOW = NULL; +static jmethodID MID_RECORD_MALWARE_BLOCK = NULL; static jint SDK = 0; JNIEXPORT jint JNICALL Java_com_wireguard_android_backend_GoBackend_wgTurnOn(JNIEnv *env, jobject goBackend, jstring ifname, @@ -39,6 +40,7 @@ JNIEXPORT jint JNICALL Java_com_wireguard_android_backend_GoBackend_wgTurnOn(JNI jclass clsGoBackend = (*env)->GetObjectClass(env, goBackend); const char *shouldAllowSig = "(ILjava/lang/String;ILjava/lang/String;ILjava/lang/String;I)Z"; MID_SHOULD_ALLOW = jniGetMethodID(env, clsGoBackend, "shouldAllow", shouldAllowSig); + MID_RECORD_MALWARE_BLOCK = jniGetMethodID(env, clsGoBackend, "recordMalwareBlock", "(Ljava/lang/String;)V"); (*env)->DeleteLocalRef(env, clsGoBackend); // Call Go method @@ -137,6 +139,31 @@ int is_pkt_allowed(const uint8_t *buffer, int length) { return 1; } +int record_malware_block(const char *domain) { + if (MID_RECORD_MALWARE_BLOCK == NULL) { + log_print(PLATFORM_LOG_PRIORITY_ERROR, "MID_RECORD_MALWARE_BLOCK method not found"); + return 1; + } + + JNIEnv *env; + jint rs = (*JVM)->AttachCurrentThread(JVM, &env, NULL); + if (rs != JNI_OK) { + log_print(PLATFORM_LOG_PRIORITY_ERROR, "Could not attach to JVM thread"); + return 1; + } + + // Prep call to Kotlin + jstring jdomain = (*env)->NewStringUTF(env, domain); + (*env)->CallVoidMethod(env, GO_BACKEND, MID_RECORD_MALWARE_BLOCK, jdomain); + jniCheckException(env); + + // cleanup + (*env)->DeleteLocalRef(env, jdomain); + + return 0; + +} + JNIEXPORT void JNICALL Java_com_wireguard_android_backend_GoBackend_wgTurnOff(JNIEnv *env, jclass c, jint handle) { wgTurnOff(handle); diff --git a/src/wireguard/tun_android.go b/src/wireguard/tun_android.go index a98805a..b48d7ad 100644 --- a/src/wireguard/tun_android.go +++ b/src/wireguard/tun_android.go @@ -22,8 +22,10 @@ Implementation of the TUN device interface for Android (wraps linux one) // #include +// #include // For C.free and C string functions // extern int is_pkt_allowed(char *buffer, int length); // extern int wg_write_pcap(char *buffer, int length); +// extern int record_malware_block(const char *domain); import "C" import ( @@ -50,10 +52,53 @@ func (tunWrapper *NativeTunWrapper) Name() (string, error) { return tunWrapper.nativeTun.Name() } -func (tunWrapper *NativeTunWrapper) Write(buf [][]byte, offset int) (int, error) { - pktLen, err := tunWrapper.nativeTun.Write(buf, offset) +func (tunWrapper *NativeTunWrapper) Write(bufs [][]byte, offset int) (int, error) { + tag := cstring("WireGuard/GoBackend/Write") -// tag := cstring("WireGuard/GoBackend/Write") + for _, buf := range bufs { + // Check if it's an IPv4 packet + if len(buf) <= offset { + C.__android_log_write(C.ANDROID_LOG_DEBUG, tag, cstring("Skipping invalid packet, too short")) + continue + } + switch buf[offset] >> 4 { + case ipv4.Version: + if len(buf) < ipv4.HeaderLen { + C.__android_log_write(C.ANDROID_LOG_DEBUG, tag, cstring("Skipping bad IPv4 packet")) + continue + } + + // Check if it's a UDP packet + protocol := buf[offset+9] + if protocol == 0x11 { // UDP + // Extract the ports (skip IP and check transport layer headers) + srcPort := (uint16(buf[offset+ipv4.HeaderLen]) << 8) | uint16(buf[offset+ipv4.HeaderLen+1]) + dstPort := (uint16(buf[offset+ipv4.HeaderLen+2]) << 8) | uint16(buf[offset+ipv4.HeaderLen+3]) + + if srcPort == 53 || dstPort == 53 { + // Extract the DNS data (skip IP and UDP headers) + dnsData := buf[offset+ipv4.HeaderLen+8:] + + // Call the helper function to check if the DNS packet should be blocked + shouldBlock, blockedDomain := WasDNSMalwareBlocked(dnsData) + if shouldBlock { + logMessage := "DNS malware was blocked for domain: " + blockedDomain + " due to 'blocked:m' TXT record" + C.__android_log_write(C.ANDROID_LOG_DEBUG, tag, cstring(logMessage)) + + // call back into JVM and let the packet flow normally + cBlockedDomain := C.CString(blockedDomain) + defer C.free(unsafe.Pointer(cBlockedDomain)) + C.record_malware_block(cBlockedDomain) // ignore return code + } + } + } + default: + // Not an IPv4 packet + C.__android_log_write(C.ANDROID_LOG_DEBUG, tag, cstring("Invalid IP")) + } + } + + pktLen, err := tunWrapper.nativeTun.Write(bufs, offset) // PCAP recording // pcap_res := int(C.wg_write_pcap((*C.char)(unsafe.Pointer(&buf[offset])), C.int(pktLen+offset)))