Skip to content

Commit 3f1504b

Browse files
authored
Support gateways without public IPs on AWS (#1224)
1 parent 3965d6c commit 3f1504b

File tree

7 files changed

+88
-33
lines changed

7 files changed

+88
-33
lines changed

src/dstack/_internal/core/backends/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
BackendType.LAMBDA,
1616
BackendType.TENSORDOCK,
1717
]
18+
BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT = [BackendType.AWS]

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ def create_gateway(
204204
]
205205
if settings.DSTACK_VERSION is not None:
206206
tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION})
207+
vpc_id, subnet_id = get_vpc_id_subnet_id_or_error(
208+
ec2_client=ec2_client,
209+
config=self.config,
210+
region=configuration.region,
211+
allocate_public_ip=configuration.public_ip,
212+
)
207213
response = ec2.create_instances(
208214
**aws_resources.create_instances_struct(
209215
disk_size=10,
@@ -215,17 +221,24 @@ def create_gateway(
215221
security_group_id=aws_resources.create_gateway_security_group(
216222
ec2_client=ec2_client,
217223
project_id=configuration.project_name,
224+
vpc_id=vpc_id,
218225
),
219226
spot=False,
227+
subnet_id=subnet_id,
228+
allocate_public_ip=configuration.public_ip,
220229
)
221230
)
222231
instance = response[0]
223232
instance.wait_until_running()
224233
instance.reload() # populate instance.public_ip_address
234+
if configuration.public_ip:
235+
ip_address = instance.public_ip_address
236+
else:
237+
ip_address = instance.private_ip_address
225238
return LaunchedGatewayInfo(
226239
instance_id=instance.instance_id,
227240
region=configuration.region,
228-
ip_address=instance.public_ip_address,
241+
ip_address=ip_address,
229242
)
230243

231244

src/dstack/_internal/core/backends/aws/resources.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,31 @@ def get_gateway_image_id(ec2_client: botocore.client.BaseClient) -> str:
171171
return image["ImageId"]
172172

173173

174-
def create_gateway_security_group(ec2_client: botocore.client.BaseClient, project_id: str) -> str:
174+
def create_gateway_security_group(
175+
ec2_client: botocore.client.BaseClient,
176+
project_id: str,
177+
vpc_id: Optional[str],
178+
) -> str:
175179
security_group_name = "dstack_gw_sg_" + project_id.replace("-", "_").lower()
176-
177-
response = ec2_client.describe_security_groups(
178-
Filters=[
180+
describe_security_groups_filters = [
181+
{
182+
"Name": "group-name",
183+
"Values": [security_group_name],
184+
},
185+
]
186+
if vpc_id is not None:
187+
describe_security_groups_filters.append(
179188
{
180-
"Name": "group-name",
181-
"Values": [security_group_name],
182-
},
183-
],
184-
)
189+
"Name": "vpc-id",
190+
"Values": [vpc_id],
191+
}
192+
)
193+
response = ec2_client.describe_security_groups(Filters=describe_security_groups_filters)
185194
if response.get("SecurityGroups"):
186195
return response["SecurityGroups"][0]["GroupId"]
187-
196+
create_security_group_kwargs = {}
197+
if vpc_id is not None:
198+
create_security_group_kwargs["VpcId"] = vpc_id
188199
security_group = ec2_client.create_security_group(
189200
Description="Generated by dstack",
190201
GroupName=security_group_name,
@@ -198,6 +209,7 @@ def create_gateway_security_group(ec2_client: botocore.client.BaseClient, projec
198209
],
199210
},
200211
],
212+
**create_security_group_kwargs,
201213
)
202214
group_id = security_group["GroupId"]
203215

src/dstack/_internal/core/models/gateways.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class GatewayConfiguration(CoreModel):
1717
domain: Annotated[
1818
Optional[str], Field(description="The gateway domain, e.g. `*.example.com`")
1919
] = None
20-
# public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True
20+
public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True
2121

2222

2323
class GatewayComputeConfiguration(CoreModel):

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import dstack._internal.server.services.jobs as jobs_services
1313
import dstack._internal.utils.random_names as random_names
14+
from dstack._internal.core.backends import BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
1415
from dstack._internal.core.backends.base.compute import (
1516
Compute,
1617
get_dstack_gateway_wheel,
@@ -87,17 +88,27 @@ async def get_project_default_gateway(
8788

8889

8990
async def create_gateway_compute(
91+
project_name: str,
9092
backend_compute: Compute,
91-
configuration: GatewayComputeConfiguration,
93+
configuration: GatewayConfiguration,
9294
backend_id: Optional[uuid.UUID] = None,
9395
) -> GatewayComputeModel:
9496
private_bytes, public_bytes = generate_rsa_key_pair_bytes()
9597
gateway_ssh_private_key = private_bytes.decode()
9698
gateway_ssh_public_key = public_bytes.decode()
9799

100+
compute_configuration = GatewayComputeConfiguration(
101+
project_name=project_name,
102+
instance_name=configuration.name,
103+
backend=configuration.backend,
104+
region=configuration.region,
105+
public_ip=configuration.public_ip,
106+
ssh_key_pub=gateway_ssh_public_key,
107+
)
108+
98109
info = await run_async(
99110
backend_compute.create_gateway,
100-
configuration,
111+
compute_configuration,
101112
)
102113

103114
return GatewayComputeModel(
@@ -122,6 +133,15 @@ async def create_gateway(
122133
else:
123134
raise ResourceNotExistsError()
124135

136+
if (
137+
not configuration.public_ip
138+
and configuration.backend not in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
139+
):
140+
raise GatewayError(
141+
f"Private gateways are not supported for {configuration.backend.value} backend. "
142+
f"Supported backends: {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}."
143+
)
144+
125145
if configuration.name is None:
126146
configuration.name = await generate_gateway_name(session=session, project=project)
127147

@@ -139,19 +159,11 @@ async def create_gateway(
139159
if project.default_gateway is None or configuration.default:
140160
await set_default_gateway(session=session, project=project, name=configuration.name)
141161

142-
compute_configuration = GatewayComputeConfiguration(
143-
project_name=project.name,
144-
instance_name=gateway.name,
145-
backend=configuration.backend,
146-
region=configuration.region,
147-
public_ip=True,
148-
ssh_key_pub=project.name,
149-
)
150-
151162
try:
152163
gateway.gateway_compute = await create_gateway_compute(
153164
backend_compute=backend.compute(),
154-
configuration=compute_configuration,
165+
project_name=project.name,
166+
configuration=configuration,
155167
backend_id=backend_model.id,
156168
)
157169
session.add(gateway)
@@ -321,13 +333,6 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
321333
async def register_service(session: AsyncSession, run_model: RunModel):
322334
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
323335

324-
service_https = run_spec.configuration.https
325-
service_protocol = "https" if service_https else "http"
326-
327-
# Currently, gateway endpoint is always https
328-
gateway_https = True
329-
gateway_protocol = "https" if gateway_https else "http"
330-
331336
# TODO(egor-s): allow to configure gateway name
332337
gateway_name: Optional[str] = None
333338
if gateway_name is None:
@@ -343,6 +348,21 @@ async def register_service(session: AsyncSession, run_model: RunModel):
343348
if gateway.gateway_compute is None:
344349
raise ServerClientError("Gateway has no instance associated with it")
345350

351+
service_https = run_spec.configuration.https
352+
service_protocol = "https" if service_https else "http"
353+
354+
gateway_configuration = None
355+
if gateway.configuration is not None:
356+
gateway_configuration = GatewayConfiguration.__response__.parse_raw(gateway.configuration)
357+
if service_https and not gateway_configuration.public_ip:
358+
raise ServerClientError("Cannot run HTTPS service on gateway without public IP")
359+
360+
gateway_https = True
361+
if gateway_configuration is not None:
362+
# Currently, https is always False for private gateways
363+
gateway_https = gateway_configuration.public_ip
364+
gateway_protocol = "https" if gateway_https else "http"
365+
346366
wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
347367
if wildcard_domain is None:
348368
raise ServerClientError("Domain is required for gateway")

src/dstack/_internal/server/services/gateways/connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import aiorwlock
77

8+
from dstack._internal.core.services.ssh.ports import PortsLock
89
from dstack._internal.server.services.gateways.client import (
910
GATEWAY_MANAGEMENT_PORT,
1011
GatewayClient,
@@ -29,9 +30,10 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
2930
self._lock = aiorwlock.RWLock()
3031
self.stats: Dict[str, Dict[int, Stat]] = {}
3132
self.ip_address = ip_address
32-
33+
self.ports_lock = PortsLock(restrictions={server_port: 0}).acquire()
34+
local_port = self.ports_lock.dict()[server_port]
3335
args = ["-L", "{temp_dir}/gateway:localhost:%d" % GATEWAY_MANAGEMENT_PORT]
34-
args += ["-R", f"localhost:8001:localhost:{server_port}"]
36+
args += ["-R", f"localhost:{local_port}:localhost:{server_port}"]
3537
self.tunnel = AsyncSSHTunnel(
3638
f"ubuntu@{ip_address}",
3739
id_rsa,

src/tests/_internal/server/routers/test_gateways.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def test_list(self, test_db, session: AsyncSession):
8080
"region": gateway.region,
8181
"domain": gateway.wildcard_domain,
8282
"default": False,
83+
"public_ip": True,
8384
},
8485
}
8586
]
@@ -124,6 +125,7 @@ async def test_get(self, test_db, session: AsyncSession):
124125
"region": gateway.region,
125126
"domain": gateway.wildcard_domain,
126127
"default": False,
128+
"public_ip": True,
127129
},
128130
}
129131

@@ -203,6 +205,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession):
203205
"region": "us",
204206
"domain": None,
205207
"default": True,
208+
"public_ip": True,
206209
},
207210
}
208211

@@ -257,6 +260,7 @@ async def test_create_gateway_without_name(self, test_db, session: AsyncSession)
257260
"region": "us",
258261
"domain": None,
259262
"default": True,
263+
"public_ip": True,
260264
},
261265
}
262266

@@ -391,6 +395,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession):
391395
"region": gateway.region,
392396
"domain": gateway.wildcard_domain,
393397
"default": True,
398+
"public_ip": True,
394399
},
395400
}
396401

@@ -498,6 +503,7 @@ def get_backend(_, backend_type):
498503
"region": gateway_gcp.region,
499504
"domain": gateway_gcp.wildcard_domain,
500505
"default": False,
506+
"public_ip": True,
501507
},
502508
}
503509
]
@@ -557,6 +563,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession):
557563
"region": gateway.region,
558564
"domain": "test.com",
559565
"default": False,
566+
"public_ip": True,
560567
},
561568
}
562569

0 commit comments

Comments
 (0)