Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into tx-config
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Feb 4, 2025
2 parents 956e89c + 79d6c86 commit bda5cb3
Show file tree
Hide file tree
Showing 59 changed files with 3,057 additions and 1,814 deletions.
398 changes: 190 additions & 208 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion edb/buildmeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
# The merge conflict there is a nice reminder that you probably need
# to write a patch in edb/pgsql/patches.py, and then you should preserve
# the old value.
EDGEDB_CATALOG_VERSION = 2024_01_29_00_00
EDGEDB_CATALOG_VERSION = 2024_02_04_00_00
EDGEDB_MAJOR_VERSION = 7


Expand Down
12 changes: 12 additions & 0 deletions edb/common/ast/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def generic_visit(self, node):
changes = {}

for field, old_value in base.iter_fields(node, include_meta=False):
field_spec = node._fields[field]
if self.skip_hidden and field_spec.hidden:
continue
if field in self.extra_skips:
continue

old_value = getattr(node, field, None)

if typeutils.is_container(old_value):
Expand All @@ -79,6 +85,12 @@ def generic_visit(self, node):

else:
for field, old_value in base.iter_fields(node, include_meta=False):
field_spec = node._fields[field]
if self.skip_hidden and field_spec.hidden:
continue
if field in self.extra_skips:
continue

old_value = getattr(node, field, None)

if typeutils.is_container(old_value):
Expand Down
4 changes: 4 additions & 0 deletions edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ def compile_FunctionCall(

class ArgumentInliner(ast.NodeTransformer):

# Don't look through hidden nodes, they may contain references to nodes
# which should not be modified. For example, irast.Stmt.parent_stmt.
skip_hidden = True

mapped_args: dict[irast.PathId, irast.PathId]
inlined_arg_keys: list[int | str]

Expand Down
6 changes: 3 additions & 3 deletions edb/lib/std/20-genericfuncs.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ std::assert_single(
CREATE ANNOTATION std::description :=
"Check that the input set contains at most one element, raise
CardinalityViolationError otherwise.";
SET volatility := 'Stable';
SET volatility := 'Immutable';
SET preserves_optionality := true;
USING SQL EXPRESSION;
};
Expand All @@ -49,7 +49,7 @@ std::assert_exists(
CREATE ANNOTATION std::description :=
"Check that the input set contains at least one element, raise
CardinalityViolationError otherwise.";
SET volatility := 'Stable';
SET volatility := 'Immutable';
SET preserves_upper_cardinality := true;
USING SQL EXPRESSION;
};
Expand All @@ -67,7 +67,7 @@ std::assert_distinct(
CREATE ANNOTATION std::description :=
"Check that the input set is a proper set, i.e. all elements
are unique";
SET volatility := 'Stable';
SET volatility := 'Immutable';
SET preserves_optionality := true;
SET preserves_upper_cardinality := true;
USING SQL EXPRESSION;
Expand Down
20 changes: 20 additions & 0 deletions edb/pgsql/metaschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8209,6 +8209,26 @@ async def generate_support_functions(
return trampoline_functions(cmds)


async def regenerate_config_support_functions(
conn: PGConnection,
config_spec: edbconfig.Spec,
) -> None:
# Regenerate functions dependent on config spec.
commands = dbops.CommandGroup()

funcs = [
ApplySessionConfigFunction(config_spec),
PostgresJsonConfigValueToFrontendConfigValueFunction(config_spec),
]

cmds = [dbops.CreateFunction(func, or_replace=True) for func in funcs]
commands.add_commands(cmds)

block = dbops.PLTopBlock()
commands.generate(block)
await _execute_block(conn, block)


async def generate_more_support_functions(
conn: PGConnection,
compiler: edbcompiler.Compiler,
Expand Down
1 change: 1 addition & 0 deletions edb/pgsql/resolver/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def eval_TypeCast(
'pg_last_wal_replay_lsn',
'pg_current_wal_flush_lsn',
'pg_relation_is_publishable',
'pg_show_all_settings',
}
)

Expand Down
7 changes: 5 additions & 2 deletions edb/schema/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,13 +1940,16 @@ def _create_begin(

if check_default_type:
default_type = ir_default.stype
if not default_type.assignment_castable_to(p_type, schema):
if not default_type.assignment_castable_to(
p_type, ir_default.schema
):
raise errors.InvalidFunctionDefinitionError(
f'cannot create the `{signature}` function: '
f'invalid declaration of parameter '
f'{p.get_displayname(schema)!r}: '
f'unexpected type of the default expression: '
f'{default_type.get_displayname(schema)}, expected '
f'{default_type.get_displayname(ir_default.schema)}, '
f'expected '
f'{p_type.get_displayname(schema)}',
span=self.span)

Expand Down
2 changes: 1 addition & 1 deletion edb/schema/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def canonicalize_attributes(
span = self.get_attribute_span(field)
raise errors.SchemaDefinitionError(
f'{vname} expression for {pol_name} is of invalid type: '
f'{expr_type.get_displayname(schema)}, '
f'{expr_type.get_displayname(expression.irast.schema)}, '
f'expected {target.get_displayname(schema)}',
span=self.span,
)
Expand Down
4 changes: 2 additions & 2 deletions edb/schema/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def canonicalize_attributes(
raise errors.SchemaDefinitionError(
f'{vname} expression for {trig_name} is of invalid '
f'type: '
f'{expr_type.get_displayname(schema)}, '
f'expected {target.get_displayname(schema)}',
f'{expr_type.get_displayname(expression.irast.schema)}'
f', expected {target.get_displayname(schema)}',
span=span,
)

Expand Down
4 changes: 3 additions & 1 deletion edb/server/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,8 +1781,10 @@ async def _init_stdlib(
await conn.sql_execute(testmode_sql.encode("utf-8"))
trampolines.extend(new_trampolines)
# _testmode includes extra config settings, so make sure
# those are picked up.
# those are picked up...
config_spec = config.load_spec_from_schema(stdlib.stdschema)
# ...and that config functions dependent on it are regenerated
await metaschema.regenerate_config_support_functions(conn, config_spec)

logger.info('Finalizing database setup...')

Expand Down
21 changes: 4 additions & 17 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,23 +3002,10 @@ def get_cases_by_shard(cases, selected_shard, total_shards, verbosity, stats):
return _merge_results(cases)


def find_available_port(max_value=None) -> int:
if max_value is None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("localhost", 0))
return sock.getsockname()[1]
elif max_value > 1024:
port = max_value
while port > 1024:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("localhost", port))
return port
except IOError:
port -= 1
raise RuntimeError("cannot find an available port")
else:
raise ValueError("max_value must be greater than 1024")
def find_available_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("localhost", 0))
return sock.getsockname()[1]


def _needs_factoring(weakly):
Expand Down
5 changes: 4 additions & 1 deletion rust/auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ pub enum AuthType {
ScramSha256,
}

#[derive(Debug, Clone)]
#[derive(derive_more::Debug, Clone)]
pub enum CredentialData {
/// A credential that always succeeds, regardless of input password. Due to
/// the design of SCRAM-SHA-256, this cannot be used with that auth type.
Trust,
/// A credential that always fails, regardless of the input password.
Deny,
/// A plain-text password.
#[debug("Plain(...)")]
Plain(String),
/// A stored MD5 hash + salt.
#[debug("Md5(...)")]
Md5(md5::StoredHash),
/// A stored SCRAM-SHA-256 key.
#[debug("Scram(...)")]
Scram(scram::StoredKey),
}

Expand Down
34 changes: 23 additions & 11 deletions rust/gel-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ description = "A library for streaming data between clients and servers."
[features]
# rustls or openssl imply tokio, and tokio is the only stream we support
# at this time.
default = ["tokio", "rustls"]
default = ["tokio"]
client = []
server = []
tokio = ["dep:tokio"]
rustls = ["tokio", "dep:rustls", "dep:rustls-tokio-stream", "dep:rustls-platform-verifier", "dep:webpki"]
openssl = ["tokio", "dep:openssl", "dep:tokio-openssl", "dep:foreign-types", "dep:openssl-sys"]
Expand All @@ -17,29 +19,39 @@ __manual_tests = []
[dependencies]
derive_more = { version = "1", features = ["full"] }
thiserror = "2"
rustls-pki-types = "1"
futures = "0.3"

tokio = { version = "1", optional = true, features = ["full"] }
# Given that this library may be used in multiple contexts, we want to limit the
# features we enable by default.

rustls-pki-types = { version = "1", default-features = false, features = ["std"] }

# feature = "tokio"
tokio = { version = "1", optional = true, default-features = false, features = ["net", "rt"] }
hickory-resolver = { version = "0.24.2", optional = true, default-features = false, features = ["tokio-runtime", "system-config"] }

# feature = "rustls"
rustls = { version = "0.23", optional = true, default-features = false, features = ["ring", "logging", "std", "tls12"] }
openssl = { version = "0.10.55", optional = true }
tokio-openssl = { version = "0.6.5", optional = true }
hickory-resolver = { version = "0.24.2", optional = true }
rustls-tokio-stream = { version = "0.3.0", optional = true }
rustls-tokio-stream = { version = "0.5.0", optional = true }
rustls-platform-verifier = { version = "0.5.0", optional = true }
webpki = { version = "0.22", optional = true }

# feature = "openssl"
openssl = { version = "0.10.55", optional = true, default-features = false }
tokio-openssl = { version = "0.6.5", optional = true, default-features = false }
# Get these from openssl
foreign-types = { version = "*", optional = true }
openssl-sys = { version = "*", optional = true }
foreign-types = { version = "*", optional = true, default-features = false }
openssl-sys = { version = "*", optional = true, default-features = false }

[dev-dependencies]
# Run tests with all features enabled
gel-stream = { workspace = true, features = ["client", "server", "tokio", "rustls", "openssl"] }

tokio = { version = "1", features = ["full"] }
tempfile = "3"
ntest = "0.9.3"
rustls-pemfile = "2"

rstest = "0.24.0"
rustls-tokio-stream = "0.3.0"

[lints]
workspace = true
Expand Down
31 changes: 21 additions & 10 deletions rust/gel-stream/src/client/connection.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
use std::marker::PhantomData;
use std::net::SocketAddr;

use super::stream::UpgradableStream;
use super::target::{MaybeResolvedTarget, ResolvedTarget};
use super::tokio_stream::Resolver;
use super::{ConnectionError, Ssl, Target, TlsInit};
use crate::common::tokio_stream::{Resolver, TokioStream};
use crate::{ConnectionError, Ssl, StreamUpgrade, TlsDriver, UpgradableStream};
use crate::{MaybeResolvedTarget, ResolvedTarget, Target};

type Connection = UpgradableStream<super::Stream, Option<super::Ssl>>;
type Connection<S, D> = UpgradableStream<S, D>;

/// A connector can be used to connect multiple times to the same target.
pub struct Connector {
#[allow(private_bounds)]
pub struct Connector<D: TlsDriver = Ssl> {
target: Target,
resolver: Resolver,
driver: PhantomData<D>,
}

impl Connector {
impl Connector<Ssl> {
pub fn new(target: Target) -> Result<Self, std::io::Error> {
Self::new_explicit(target)
}
}

#[allow(private_bounds)]
impl<D: TlsDriver> Connector<D> {
pub fn new_explicit(target: Target) -> Result<Self, std::io::Error> {
Ok(Self {
target,
resolver: Resolver::new()?,
driver: PhantomData,
})
}

pub async fn connect(&self) -> Result<Connection, ConnectionError> {
pub async fn connect(&self) -> Result<Connection<TokioStream, D>, ConnectionError> {
let stream = match self.target.maybe_resolved() {
MaybeResolvedTarget::Resolved(target) => target.connect().await?,
MaybeResolvedTarget::Unresolved(host, port, _) => {
Expand All @@ -36,13 +46,14 @@ impl Connector {
};

if let Some(ssl) = self.target.maybe_ssl() {
let mut stm = UpgradableStream::new(stream, Some(Ssl::init(ssl, self.target.name())?));
let ssl = D::init_client(ssl, self.target.name())?;
let mut stm = UpgradableStream::new_client(stream, Some(ssl));
if !self.target.is_starttls() {
stm.secure_upgrade().await?;
}
Ok(stm)
} else {
Ok(UpgradableStream::new(stream, None))
Ok(UpgradableStream::new_client(stream, None))
}
}
}
Loading

0 comments on commit bda5cb3

Please sign in to comment.