Skip to content

Commit

Permalink
test with on fast import with local (moto_server) s3 & kms
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoBjorn committed Jan 20, 2025
1 parent 8cd8164 commit 01604b3
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 4 deletions.
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Jinja2 = "^3.1.5"
types-requests = "^2.31.0.0"
types-psycopg2 = "^2.9.21.20241019"
boto3 = "^1.34.11"
boto3-stubs = {extras = ["s3"], version = "^1.26.16"}
boto3-stubs = {extras = ["s3", "kms"], version = "^1.26.16"}
moto = {extras = ["server"], version = "^5.0.6"}
backoff = "^2.2.1"
pytest-lazy-fixture = "^0.6.3"
Expand Down
27 changes: 27 additions & 0 deletions test_runner/fixtures/neon_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import asyncpg
import backoff
import boto3
import httpx
import psycopg2
import psycopg2.sql
Expand All @@ -36,6 +37,8 @@
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from jwcrypto import jwk
from mypy_boto3_kms import KMSClient
from mypy_boto3_s3 import S3Client

# Type-related stuff
from psycopg2.extensions import connection as PgConnection
Expand Down Expand Up @@ -198,6 +201,30 @@ def mock_s3_server(port_distributor: PortDistributor) -> Iterator[MockS3Server]:
mock_s3_server.kill()


@pytest.fixture(scope="session")
def mock_kms(mock_s3_server: MockS3Server) -> Iterator[KMSClient]:
yield boto3.client(
"kms",
endpoint_url=mock_s3_server.endpoint(),
region_name=mock_s3_server.region(),
aws_access_key_id=mock_s3_server.access_key(),
aws_secret_access_key=mock_s3_server.secret_key(),
aws_session_token=mock_s3_server.session_token(),
)


@pytest.fixture(scope="session")
def mock_s3_client(mock_s3_server: MockS3Server) -> Iterator[S3Client]:
yield boto3.client(
"s3",
endpoint_url=mock_s3_server.endpoint(),
region_name=mock_s3_server.region(),
aws_access_key_id=mock_s3_server.access_key(),
aws_secret_access_key=mock_s3_server.secret_key(),
aws_session_token=mock_s3_server.session_token(),
)


class PgProtocol:
"""Reusable connection logic"""

Expand Down
74 changes: 72 additions & 2 deletions test_runner/regress/test_import_pgdata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import json
import re
import time
Expand All @@ -17,8 +18,11 @@
)
from fixtures.pg_version import PgVersion
from fixtures.port_distributor import PortDistributor
from fixtures.remote_storage import RemoteStorageKind
from fixtures.remote_storage import MockS3Server, RemoteStorageKind
from fixtures.utils import run_only_on_postgres
from mypy_boto3_kms import KMSClient
from mypy_boto3_kms.type_defs import EncryptResponseTypeDef
from mypy_boto3_s3 import S3Client
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
Expand Down Expand Up @@ -369,7 +373,73 @@ def test_fast_import_restore_to_connstring(
)
vanilla_pg.stop()

# database name and user are hardcoded in fast_import binary, and they are different from normal vanilla postgres
res = restore_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
log.info(f"Result: {res}")
assert res[0][0] == 10


def test_fast_import_restore_to_connstring_from_s3_spec(
test_output_dir,
vanilla_pg: VanillaPostgres,
port_distributor: PortDistributor,
fast_import: FastImport,
pg_distrib_dir: Path,
pg_version: PgVersion,
mock_s3_server: MockS3Server,
mock_kms: KMSClient,
mock_s3_client: S3Client,
):
# Prepare KMS and S3
key_response = mock_kms.create_key(
Description="Test key",
KeyUsage="ENCRYPT_DECRYPT",
Origin="AWS_KMS",
)
key_id = key_response["KeyMetadata"]["KeyId"]

def encrypt(x: str) -> EncryptResponseTypeDef:
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)

# Start source postgres and ingest data
vanilla_pg.start()
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")

# Start target postgres
pgdatadir = test_output_dir / "restore-pgdata"
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
port = port_distributor.get_port()
with VanillaPostgres(pgdatadir, pg_bin, port) as restore_vanilla_pg:
restore_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
restore_vanilla_pg.start()

# Encrypt connstrings and put spec into S3
source_connstring_encrypted = encrypt(vanilla_pg.connstr())
restore_connstring_encrypted = encrypt(restore_vanilla_pg.connstr())
spec = {
"encryption_secret": {"KMS": {"key_id": key_id}},
"source_connstring_ciphertext_base64": base64.b64encode(
source_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
"restore_connstring_ciphertext_base64": base64.b64encode(
restore_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
}

mock_s3_client.create_bucket(Bucket="test-bucket")
mock_s3_client.put_object(
Bucket="test-bucket", Key="test-prefix/spec.json", Body=json.dumps(spec)
)

# Run fast_import
fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key()
fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key()
fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token()
fast_import.extra_env["AWS_REGION"] = mock_s3_server.region()
fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint()
fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug"
fast_import.run(s3prefix="s3://test-bucket/test-prefix")
vanilla_pg.stop()

res = restore_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
log.info(f"Result: {res}")
assert res[0][0] == 10
Expand Down

0 comments on commit 01604b3

Please sign in to comment.