Skip to content

Commit

Permalink
fast import: full e2e with pageserver and mock s3
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoBjorn committed Jan 31, 2025
1 parent 12a061c commit 04ab6b3
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test_runner/fixtures/fast_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
pg_distrib_dir: Path,
pg_version: PgVersion,
workdir: Path,
cleanup: bool = True,
):
if extra_env is None:
env_vars = {}
Expand All @@ -47,6 +48,7 @@ def __init__(
if not workdir.exists():
raise Exception(f"Working directory '{workdir}' does not exist")
self.workdir = workdir
self.cleanup = cleanup

def run(
self,
Expand Down Expand Up @@ -81,7 +83,7 @@ def __enter__(self):
return self

def __exit__(self, *args):
if self.workdir.exists():
if self.workdir.exists() and self.cleanup:
shutil.rmtree(self.workdir)


Expand All @@ -92,8 +94,8 @@ def fast_import(
neon_binpath: Path,
pg_distrib_dir: Path,
) -> Iterator[FastImport]:
workdir = Path(tempfile.mkdtemp())
with FastImport(None, neon_binpath, pg_distrib_dir, pg_version, workdir) as fi:
workdir = Path(tempfile.mkdtemp(dir=test_output_dir, prefix="fast_import_"))
with FastImport(None, neon_binpath, pg_distrib_dir, pg_version, workdir, cleanup=False) as fi:
yield fi

if fi.cmd is None:
Expand Down
220 changes: 220 additions & 0 deletions test_runner/regress/test_import_pgdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class RelBlockSize(Enum):


@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
@pytest.mark.timeout(10 * 60)
def test_pgdata_import_smoke(
vanilla_pg: VanillaPostgres,
neon_env_builder: NeonEnvBuilder,
Expand Down Expand Up @@ -340,6 +341,225 @@ def validate_vanilla_equivalence(ep):
br_initdb_endpoint.safe_psql("select * from othertable")


@pytest.mark.timeout(10 * 60)
def test_fast_import_with_pageserver_ingest(
test_output_dir,
vanilla_pg: VanillaPostgres,
port_distributor: PortDistributor,
fast_import: FastImport,
pg_distrib_dir: Path,
pg_version: PgVersion,
mock_s3_server: MockS3Server,
mock_kms: KMSClient,
mock_s3_client: S3Client,
neon_env_builder: NeonEnvBuilder,
make_httpserver: HTTPServer,
):
# Prepare KMS and S3
key_response = mock_kms.create_key(
Description="Test key",
KeyUsage="ENCRYPT_DECRYPT",
Origin="AWS_KMS",
)
key_id = key_response["KeyMetadata"]["KeyId"]

def encrypt(x: str) -> EncryptResponseTypeDef:
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)

# Start source postgres and ingest data
vanilla_pg.start()
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")

# Setup pageserver and fake cplane for import progress
def handler(request: Request) -> Response:
log.info(f"control plane request: {request.json}")
return Response(json.dumps({}), status=200)

cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler)

neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
env = neon_env_builder.init_start()

env.pageserver.patch_config_toml_nonrecursive(
{
"import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api",
# because import_pgdata code uses this endpoint, not the one in common remote storage config
# TODO: maybe use common remote_storage config in pageserver?
"import_pgdata_aws_endpoint_url": env.s3_mock_server.endpoint(),
}
)
env.pageserver.stop()
env.pageserver.start()

# Encrypt connstrings and put spec into S3
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
spec = {
"encryption_secret": {"KMS": {"key_id": key_id}},
"source_connstring_ciphertext_base64": base64.b64encode(
source_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
"project_id": "someproject",
"branch_id": "somebranch",
}

bucket = "test-bucket"
key_prefix = "test-prefix"
mock_s3_client.create_bucket(Bucket=bucket)
mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec))

# Create timeline with import_pgdata
tenant_id = TenantId.generate()
env.storage_controller.tenant_create(tenant_id)

timeline_id = TimelineId.generate()
log.info("starting import")
start = time.monotonic()

idempotency = ImportPgdataIdemptencyKey.random()
log.info(f"idempotency key {idempotency}")
# TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop
# and check for 429

import_branch_name = "imported"
env.storage_controller.timeline_create(
tenant_id,
{
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {
"AwsS3": {
"region": env.s3_mock_server.region(),
"bucket": bucket,
"key": key_prefix,
}
},
},
},
)
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)

# Run fast_import
if fast_import.extra_env is None:
fast_import.extra_env = {}
fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
fast_import.extra_env["AWS_REGION"] = mock_s3_server.region()
fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug"
pg_port = port_distributor.get_port()
fast_import.run(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}")
vanilla_pg.stop()

def validate_vanilla_equivalence(ep):
res = ep.safe_psql("SELECT count(*), sum(a) FROM foo;", dbname="neondb")
assert res[0] == (10, 55), f"got result: {res}"

# Sanity check that data in pgdata is expected:
pgbin = PgBin(test_output_dir, fast_import.pg_distrib_dir, fast_import.pg_version)
with VanillaPostgres(
fast_import.workdir / "pgdata", pgbin, pg_port, False
) as new_pgdata_vanilla_pg:
new_pgdata_vanilla_pg.start()

# database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres
conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb")
validate_vanilla_equivalence(conn)

# Poll pageserver statuses in s3
while True:
locations = env.storage_controller.locate(tenant_id)
active_count = 0
for location in locations:
shard_id = TenantShardId.parse(location["shard_id"])
ps = env.get_pageserver(location["node_id"])
try:
detail = ps.http_client().timeline_detail(shard_id, timeline_id)
log.info(f"timeline {tenant_id}/{timeline_id} detail: {detail}")
state = detail["state"]
log.info(f"shard {shard_id} state: {state}")
if state == "Active":
active_count += 1
except PageserverApiException as e:
if e.status_code == 404:
log.info("not found, import is in progress")
continue
elif e.status_code == 429:
log.info("import is in progress")
continue
else:
raise

if state == "Active":
key = f"{key_prefix}/status/shard-{shard_id.shard_index}"
shard_status_file_contents = (
mock_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8")
)
shard_status = json.loads(shard_status_file_contents)
assert shard_status["done"] is True

if active_count == len(locations):
log.info("all shards are active")
break
time.sleep(0.5)

import_duration = time.monotonic() - start
log.info(f"import complete; duration={import_duration:.2f}s")

ep = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)

# check that data is there
validate_vanilla_equivalence(ep)

# check that we can do basic ops

ep.safe_psql("create table othertable(values text)", dbname="neondb")
rw_lsn = Lsn(ep.safe_psql_scalar("select pg_current_wal_flush_lsn()"))
ep.stop()

# ... at the tip
_ = env.create_branch(
new_branch_name="br-tip",
ancestor_branch_name=import_branch_name,
tenant_id=tenant_id,
ancestor_start_lsn=rw_lsn,
)
br_tip_endpoint = env.endpoints.create_start(
branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id
)
validate_vanilla_equivalence(br_tip_endpoint)
br_tip_endpoint.safe_psql("select * from othertable", dbname="neondb")
br_tip_endpoint.stop()

# ... at the initdb lsn
locations = env.storage_controller.locate(tenant_id)
[shard_zero] = [
loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0
]
shard_zero_ps = env.get_pageserver(shard_zero["node_id"])
shard_zero_timeline_info = shard_zero_ps.http_client().timeline_detail(
shard_zero["shard_id"], timeline_id
)
initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"])
_ = env.create_branch(
new_branch_name="br-initdb",
ancestor_branch_name=import_branch_name,
tenant_id=tenant_id,
ancestor_start_lsn=initdb_lsn,
)
br_initdb_endpoint = env.endpoints.create_start(
branch_name="br-initdb", endpoint_id="br-initdb-ro", tenant_id=tenant_id
)
validate_vanilla_equivalence(br_initdb_endpoint)
with pytest.raises(psycopg2.errors.UndefinedTable):
br_initdb_endpoint.safe_psql("select * from othertable", dbname="neondb")
br_initdb_endpoint.stop()

env.pageserver.stop(immediate=True)


def test_fast_import_binary(
test_output_dir,
vanilla_pg: VanillaPostgres,
Expand Down

0 comments on commit 04ab6b3

Please sign in to comment.