Skip to content

Commit

Permalink
Tests pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Sep 3, 2024
1 parent 4b2d25d commit b6f2118
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 35 deletions.
3 changes: 3 additions & 0 deletions edb/server/_rust_native/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ def get_thread_channel():
channel = create_to_python_channel()
_channels[loop] = channel
return channel

def init_async():
pass
28 changes: 14 additions & 14 deletions edb/server/pgconnparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import platform
import ssl as ssl_module
import warnings
# from edb.server._rust_native.pgrust import parse_dsn as parse_dsn_native
from edb.server._rust_native.module._pg_rust import parse_dsn as parse_dsn_native

class SSLMode(enum.IntEnum):
disable = 0
Expand Down Expand Up @@ -100,7 +100,7 @@ def parse_dsn(
ConnectionParameters,
]:
try:
parsed, ssl_paths = parse_dsn_native(getpass.getuser(),
parsed = parse_dsn_native(getpass.getuser(),
str(get_pg_home_directory()),
dsn)
except Exception as e:
Expand All @@ -119,18 +119,18 @@ def parse_dsn(
if sslmode < SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
else:
if ssl_paths['rootcert']:
ssl.load_verify_locations(ssl_paths['rootcert'])
if ssl_config['rootcert']:
ssl.load_verify_locations(ssl_config['rootcert'])
ssl.verify_mode = ssl_module.CERT_REQUIRED
else:
if sslmode == SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
if ssl_paths['crl']:
ssl.load_verify_locations(ssl_paths['crl'])
if ssl_config['crl']:
ssl.load_verify_locations(ssl_config['crl'])
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
if ssl_paths['key'] and ssl_paths['cert']:
ssl.load_cert_chain(ssl_paths['cert'],
ssl_paths['key'],
if ssl_config['key'] and ssl_config['cert']:
ssl.load_cert_chain(ssl_config['cert'],
ssl_config['key'],
ssl_config['password'] or '')
if ssl_config['max_protocol_version']:
ssl.maximum_version = _parse_tls_version(
Expand All @@ -145,21 +145,21 @@ def parse_dsn(

# Extract hosts from the dict
addrs: List[Tuple[str, int]] = []
for host in parsed['hosts']:
for host, port in parsed['hosts']:
if 'Hostname' in host:
host, port = host['Hostname']
host = host['Hostname']
addrs.append((host, port))
elif 'IP' in host:
ip, port, scope = host['IP']
ip, scope = host['IP']
# Reconstruct the scope ID
if scope:
ip = f'{ip}%{scope}'
addrs.append((ip, port))
elif 'Path' in host:
path, port = host['Path']
path = host['Path']
addrs.append((path, port))
elif 'Abstract' in host:
path, port = host['Abstract']
path = host['Abstract']
addrs.append((path, port))

# Database/user/password/connect_timeout
Expand Down
35 changes: 24 additions & 11 deletions edb/server/pgrust/src/connection/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,7 @@ impl ToString for SslMode {
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SslVersion {
Tls1,
Tls1_1,
Expand All @@ -421,22 +420,36 @@ pub enum SslVersion {
impl ToString for SslVersion {
fn to_string(&self) -> String {
match self {
SslVersion::Tls1 => "tls_1".to_string(),
SslVersion::Tls1_1 => "tls_1.1".to_string(),
SslVersion::Tls1_2 => "tls_1.2".to_string(),
SslVersion::Tls1_3 => "tls_1.3".to_string(),
SslVersion::Tls1 => "TLSv1".to_string(),
SslVersion::Tls1_1 => "TLSv1.1".to_string(),
SslVersion::Tls1_2 => "TLSv1.2".to_string(),
SslVersion::Tls1_3 => "TLSv1.3".to_string(),
}
}
}

impl serde::Serialize for SslVersion {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(match self {
SslVersion::Tls1 => "TLSv1",
SslVersion::Tls1_1 => "TLSv1.1",
SslVersion::Tls1_2 => "TLSv1.2",
SslVersion::Tls1_3 => "TLSv1.3",
})
}
}

impl<'a> TryFrom<Cow<'a, str>> for SslVersion {
type Error = ParseError;
fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
Ok(match value.as_ref() {
"tls_1" => SslVersion::Tls1,
"tls_1.1" => SslVersion::Tls1_1,
"tls_1.2" => SslVersion::Tls1_2,
"tls_1.3" => SslVersion::Tls1_3,
Ok(match value.to_lowercase().as_ref() {
"tls_1" | "tlsv1" => SslVersion::Tls1,
"tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
"tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
"tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
_ => return Err(ParseError::InvalidTLSVersion(value.to_string())),
})
}
Expand Down
21 changes: 11 additions & 10 deletions tests/test_backend_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ class TestConnectParams(tb.TestCase):
{
'name': 'dsn_combines_env_multi_host',
'env': {
'PGHOST': 'host1:1111,host2:2222',
'PGHOST': 'host1,host2',
'PGPORT': '1111,2222',
'PGUSER': 'foo',
},
'dsn': 'postgresql:///db',
Expand All @@ -574,7 +575,7 @@ class TestConnectParams(tb.TestCase):
'env': {
'PGUSER': 'foo',
},
'dsn': 'postgresql:///db?host=host1:1111,host2:2222',
'dsn': 'postgresql:///db?host=host1,host2&port=1111,2222',
'result': ([('host1', 1111), ('host2', 2222)], {
'database': 'db',
'user': 'foo',
Expand All @@ -598,11 +599,11 @@ class TestConnectParams(tb.TestCase):
'dsn': 'postgresql://me:[email protected]:888/'
'db?param=sss&param=123&host=testhost&user=testuser'
'&port=2222&database=testdb&sslmode=require',
'result': ([('127.0.0.1', 888)], {
'result': ([('testhost', 2222)], {
'server_settings': {'param': '123'},
'user': 'me',
'user': 'testuser',
'password': 'ask',
'database': 'db',
'database': 'testdb',
'ssl': True,
'sslmode': SSLMode.require})
},
Expand All @@ -613,11 +614,11 @@ class TestConnectParams(tb.TestCase):
'db?param=sss&param=123&host=testhost&user=testuser'
'&port=2222&database=testdb&sslmode=verify_full'
'&aa=bb',
'result': ([('127.0.0.1', 888)], {
'result': ([('testhost', 2222)], {
'server_settings': {'aa': 'bb', 'param': '123'},
'user': 'me',
'user': 'testuser',
'password': 'ask',
'database': 'db',
'database': 'testdb',
'sslmode': SSLMode.verify_full,
'ssl': True})
},
Expand Down Expand Up @@ -738,8 +739,8 @@ class TestConnectParams(tb.TestCase):
},
{
'name': 'dsn_only_cloudsql_unix_and_tcp',
'dsn': 'postgres:///db?host=127.0.0.1:5432,/cloudsql/'
'project:region:instance-name,localhost:5433&user=spam',
'dsn': 'postgres:///db?host=127.0.0.1,/cloudsql/'
'project:region:instance-name,localhost&port=5432,,5433&user=spam',
'result': (
[
('127.0.0.1', 5432),
Expand Down

0 comments on commit b6f2118

Please sign in to comment.