Skip to content

Commit

Permalink
fix: add missing Host header on server discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcosSpessatto committed Dec 12, 2024
1 parent 127bc3e commit 66934aa
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 29 deletions.
47 changes: 29 additions & 18 deletions packages/homeserver/src/helpers/server-discovery/discovery.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -341,22 +341,26 @@ describe('#getWellKnownCachedAddress()', () => {
});

describe('#resolveHostAddressByServerName()', () => {
const localHomeServerName = 'rc1';
const localHomeServerNameWithPort = 'rc1:443';

afterEach(() => {
wellKnownCache.clear();
});

it('should resolve IP literal addresses directly', async () => {
const result = await resolveHostAddressByServerName('192.168.1.1');
expect(result).toBe('192.168.1.1:8448');
const { address, headers } = await resolveHostAddressByServerName('192.168.1.1', localHomeServerName);
expect(address).toBe('192.168.1.1:8448');
expect(headers).toEqual({ Host: localHomeServerNameWithPort });
});

it('should resolve addresses with explicit ports directly', async () => {
mockResolver.resolveAny.mockResolvedValueOnce([
{ type: 'AAAA', address: '2001:0db8:85a3:0000:0000:8a2e:0370:7334', ttl: 300 }
]);
const result = await resolveHostAddressByServerName('example.com:8080');
expect(result).toBe('[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080');
const { address, headers } = await resolveHostAddressByServerName('example.com:8080', localHomeServerName);
expect(address).toBe('[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080');
expect(headers).toEqual({ Host: localHomeServerNameWithPort });
});

it('should return cached address if available and valid', async () => {
Expand All @@ -371,8 +375,9 @@ describe('#resolveHostAddressByServerName()', () => {
};
wellKnownCache.set(serverName, cachedData);

const result = await resolveHostAddressByServerName(serverName);
expect(result).toBe('[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8448');
const { address, headers } = await resolveHostAddressByServerName(serverName, localHomeServerName);
expect(address).toBe('[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8448');
expect(headers).toEqual({ Host: 'cached.example.com:8448' });
});

it('should resolve using well-known address if not cached', async () => {
Expand All @@ -388,8 +393,9 @@ describe('#resolveHostAddressByServerName()', () => {
};
global.fetch = jest.fn().mockResolvedValueOnce(mockResponse);

const result = await resolveHostAddressByServerName('example.com');
expect(result).toBe('[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8448');
const { address, headers } = await resolveHostAddressByServerName('example.com', localHomeServerName);
expect(address).toBe('[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8448');
expect(headers).toEqual({ Host: 'example.com:8448' });
});

it('should fallback to SRV records if well-known address is not available', async () => {
Expand All @@ -399,8 +405,9 @@ describe('#resolveHostAddressByServerName()', () => {
{ type: 'A', address: '192.168.1.1', ttl: 300 }
]);

const result = await resolveHostAddressByServerName('example.com');
expect(result).toBe('192.168.1.1:8448');
const { address, headers } = await resolveHostAddressByServerName('example.com', localHomeServerName);
expect(address).toBe('192.168.1.1:8448');
expect(headers).toEqual({ Host: 'example.com' });
});

it('should return the provided address with the default port when the request did not throw but ok is false', async () => {
Expand All @@ -416,8 +423,9 @@ describe('#resolveHostAddressByServerName()', () => {
};
global.fetch = jest.fn().mockResolvedValueOnce(mockResponse);

const result = await resolveHostAddressByServerName('example.com');
expect(result).toBe('example.com:8448');
const { address, headers } = await resolveHostAddressByServerName('example.com', localHomeServerName);
expect(address).toBe('example.com:8448');
expect(headers).toEqual({ Host: 'example.com' });
});

it('should return the provided address with the default port when the request did not throw but json() threw', async () => {
Expand All @@ -433,24 +441,27 @@ describe('#resolveHostAddressByServerName()', () => {
};
global.fetch = jest.fn().mockResolvedValueOnce(mockResponse);

const result = await resolveHostAddressByServerName('example.com');
expect(result).toBe('example.com:8448');
const { address, headers } = await resolveHostAddressByServerName('example.com', localHomeServerName);
expect(address).toBe('example.com:8448');
expect(headers).toEqual({ Host: 'example.com' });
});

it('should fallback to default port if no SRV, CNAME, AAAA, not A records are found', async () => {
global.fetch = jest.fn().mockRejectedValueOnce(new Error('Fetch error'));
mockResolver.resolveSrv = jest.fn().mockResolvedValue([]);
mockResolver.resolveAny.mockResolvedValueOnce([]);

const result = await resolveHostAddressByServerName('example.com');
expect(result).toBe('example.com:8448');
const { address, headers } = await resolveHostAddressByServerName('example.com', localHomeServerName);
expect(address).toBe('example.com:8448');
expect(headers).toEqual({ Host: 'example.com' });
});

it('should handle errors gracefully and return address with default port', async () => {
global.fetch = jest.fn().mockRejectedValueOnce(new Error('Fetch error'));
mockResolver.resolveSrv = jest.fn().mockRejectedValue(new Error('DNS resolution error'));

const result = await resolveHostAddressByServerName('example.com');
expect(result).toBe('example.com:8448');
const { address, headers } = await resolveHostAddressByServerName('example.com', localHomeServerName);
expect(address).toBe('example.com:8448');
expect(headers).toEqual({ Host: 'example.com' });
});
});
25 changes: 18 additions & 7 deletions packages/homeserver/src/helpers/server-discovery/discovery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,34 @@ const getAddressFromWellKnownData = async (serverName: string): Promise<string>
return address;
}

export const resolveHostAddressByServerName = async (serverName: string): Promise<string> => {
const defaultOwnServerAddress = (ownServerName: string): string => {
return `${ownServerName}:443`;
}

export const resolveHostAddressByServerName = async (serverName: string, ownServerName: string): Promise<{ address: string; headers: { Host: string } }> => {
try {
if (isIpLiteral(serverName)) {
return await resolveWhenServerNameIsIpAddress(serverName);
const address = await resolveWhenServerNameIsIpAddress(serverName);
return { address, headers: { Host: defaultOwnServerAddress(ownServerName) } };
}

if (addressHasExplicitPort(serverName)) {
return await resolveWhenServerNameIsAddressWithPort(serverName);
const address = await resolveWhenServerNameIsAddressWithPort(serverName);
return { address, headers: { Host: defaultOwnServerAddress(ownServerName) } };
}

const address = await getAddressFromWellKnownData(serverName);
const rawAddress = await getAddressFromWellKnownData(serverName);
const address = await resolveFollowingWellKnownRules(rawAddress);

return await resolveFollowingWellKnownRules(address);
return { address, headers: { Host: rawAddress } };
} catch (error) {
if (error instanceof Error && error.message === 'No address found') {
return resolveUsingSRVRecordsOrFallbackToOtherRecords(serverName).catch(() => addressWithDefaultPort(serverName));
const address = await resolveUsingSRVRecordsOrFallbackToOtherRecords(serverName).catch(() => addressWithDefaultPort(serverName));

return { address, headers: { Host: serverName } };
}
return addressWithDefaultPort(serverName);
const address = await addressWithDefaultPort(serverName);

return { address, headers: { Host: address } };
}
}
15 changes: 11 additions & 4 deletions packages/homeserver/src/makeRequest.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { HomeServerRoutes } from "./app";
import { authorizationHeaders, computeHash } from "./authentication";
import { cache } from "./cache";
import { resolveHostAddressByServerName } from "./helpers/server-discovery/discovery";
import { extractURIfromURL } from "./helpers/url";
import type { SigningKey } from "./keys";
Expand Down Expand Up @@ -40,7 +39,8 @@ export const makeSignedRequest = async <
signingName: string;
queryString?: string;
}) => {
const url = new URL(`https://${await resolveHostAddressByServerName(domain)}${uri}`);
const { address, headers } = await resolveHostAddressByServerName(domain, signingName);
const url = new URL(`https://${address}${uri}`);
if (queryString) {
url.search = queryString;
}
Expand All @@ -67,6 +67,7 @@ export const makeSignedRequest = async <
...(queryString && { search: queryString }),
headers: {
Authorization: auth,
...headers,
},
});

Expand All @@ -84,16 +85,19 @@ export const makeRequest = async <
domain,
uri,
body,
signingName,
options = {},
queryString,
}: (B extends Record<string, unknown> ? { body: B } : { body?: never }) & {
method: M;
domain: string;
uri: U;
signingName: string;
options?: Record<string, any>;
queryString?: string;
}) => {
const url = new URL(`https://${await resolveHostAddressByServerName(domain)}${uri}`);
const { address, headers } = await resolveHostAddressByServerName(domain, signingName);
const url = new URL(`https://${address}${uri}`);
if (queryString) {
url.search = queryString;
}
Expand All @@ -103,6 +107,7 @@ export const makeRequest = async <
...(body && { body: JSON.stringify(body) }),
method,
...(queryString && { search: queryString }),
headers,
});

return response.json() as Promise<
Expand Down Expand Up @@ -142,7 +147,8 @@ export const makeUnsignedRequest = async <
options.body,
);

const url = new URL(`https://${await resolveHostAddressByServerName(domain)}${uri}`);
const { address, headers } = await resolveHostAddressByServerName(domain, signingName);
const url = new URL(`https://${address}${uri}`);
if (queryString) {
url.search = queryString;
}
Expand All @@ -152,6 +158,7 @@ export const makeUnsignedRequest = async <
method,
headers: {
Authorization: auth,
...headers,
},
});

Expand Down
1 change: 1 addition & 0 deletions packages/homeserver/src/plugins/validateHeaderSignature.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export const validateHeaderSignature = async ({
method: "GET",
domain: origin.origin,
uri: "/_matrix/key/v2/server",
signingName: context.config.name,
});
if (result.valid_until_ts < Date.now()) {
throw new Error("Expired remote public key");
Expand Down

0 comments on commit 66934aa

Please sign in to comment.