Skip to content

Commit

Permalink
FEATURE: Add sasl command
Browse files Browse the repository at this point in the history
  • Loading branch information
namsic committed Nov 13, 2024
1 parent 6a52829 commit a2ecba8
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 16 deletions.
194 changes: 194 additions & 0 deletions memcached.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

#if defined(ENABLE_SASL) || defined(ENABLE_ISASL)
#define SASL_ENABLED
#define ASCII_SASL
#endif

#define ZK_CONNECTIONS 1
Expand Down Expand Up @@ -4329,6 +4330,175 @@ static void process_bin_complete_sasl_auth(conn *c)
}
}

#ifdef ASCII_SASL
static void process_sasl_command(conn *c, token_t *tokens, size_t ntokens) {
if (strcmp(tokens[1].value, "mech") == 0) {
init_sasl_conn(c);
const char *result_string = NULL;
unsigned int string_length = 0;
int result = sasl_listmech(c->sasl_conn, NULL,
"", /* What to prepend the string with */
" ", /* What to separate mechanisms with */
"", /* What to append to the string */
&result_string, &string_length,
NULL);
if (result != SASL_OK) {
/* Perhaps there's a better error for this... */
if (settings.verbose) {
mc_logger->log(EXTENSION_LOG_INFO, c,
"%d: Failed to list SASL mechanisms.\n", c->sfd);
}
out_string(c, "SERVER_ERROR internal");
return;
}

char buf[1024];
snprintf(buf, 1024, "VALUE mech 0 %u\r\n%s\r\nEND\r\n", string_length, result_string);
char *response = strdup(buf);
write_and_free(c, response, strlen(response));
return;
}

if (strcmp(tokens[1].value, "auth") == 0) {
char *key = NULL;
uint32_t ksize = 0;
uint32_t vsize = 0;
int read_ntokens = 2;

if (ntokens < 4) {
out_string(c, "CLIENT_ERROR bad command line format");
return;
}

if (ntokens > 4) {
if (tokens[read_ntokens].length > MAX_SASL_MECH_LEN) {
out_string(c, "CLIENT_ERROR bad command line format");
return;
}
key = tokens[read_ntokens].value;
ksize = tokens[read_ntokens].length;
read_ntokens++;
}

if (! safe_strtoul(tokens[read_ntokens].value, &vsize)) {
out_string(c, "CLIENT_ERROR bad command line format");
return;
}

struct sasl_tmp *stmp = calloc(sizeof(struct sasl_tmp) + ksize + vsize + 2, 1);;
if (!stmp) {
out_string(c, "SERVER_ERROR out of memory");
}
if (ksize > 0) {
memcpy(stmp->data, key, ksize);
}
stmp->ksize = ksize;
stmp->vsize = vsize;

c->item = stmp;
c->ritem = stmp->data + stmp->ksize;
c->rlbytes = stmp->vsize + 2;
c->rltotal = 0;
conn_set_state(c, conn_nread);
} else {
out_string(c, "CLIENT_ERROR bad command line format");
return;
}
}

static void process_sasl_complete(conn *c) {
int result;
char buf[1024];
char *response_header;
const char *out;
unsigned int outlen;

assert(c->item);
init_sasl_conn(c);

struct sasl_tmp *stmp = c->item;
uint32_t nkey = stmp->ksize;
uint32_t vlen = stmp->vsize;
const char *challenge = vlen == 0 ? NULL : stmp->data + nkey;

if (nkey > 0) {
char mech[MAX_SASL_MECH_LEN];
memcpy(mech, stmp->data, nkey);
mech[nkey] = 0x00;

result = sasl_server_start(c->sasl_conn, mech, challenge, vlen,
&out, &outlen);
} else {
result = sasl_server_step(c->sasl_conn, challenge, vlen,
&out, &outlen);
}

free(c->item);
c->item = NULL;
c->ritem = NULL;

if (settings.verbose) {
mc_logger->log(EXTENSION_LOG_INFO, c,
"%d: sasl_%s: %d\n", c->sfd, nkey ? "start" : "step", result);
}

switch (result) {
case SASL_OK:
c->authenticated = true;
out_string(c, "END");
auth_data_t data;
get_auth_data(c, &data);
perform_callbacks(ON_AUTH, (const void*)&data, c);
STATS_CMD_NOKEY(c, auth);
break;
case SASL_CONTINUE:
snprintf(buf, 1024, "VALUE sasl 0 %u\r\n", outlen);
response_header = strdup(buf);
c->write_and_free = response_header;

if (add_iov(c, response_header, strlen(response_header)) != 0 ||
add_iov(c, out, outlen) != 0 ||
add_iov(c, "\r\nEND\r\n", 7) != 0) {
out_string(c, "SERVER_ERROR out of memory writing response");
}
conn_set_state(c, conn_mwrite);
c->write_and_go = conn_new_cmd;
break;
default:
if (settings.verbose) {
mc_logger->log(EXTENSION_LOG_INFO, c,
"%d: Unknown sasl response: %d\n", c->sfd, result);
}
out_string(c, "SERVER_ERROR internal");
STATS_ERRORS_NOKEY(c, auth);
}

}

static bool authenticated_ascii(conn *c, token_t *tokens, size_t ntokens)
{

if (c->authenticated) {
return true;
} else if (ntokens < 2) {
return false;
} else if (strcmp(tokens[0].value, "sasl") == 0) {
return true;
} else if (strcmp(tokens[0].value, "version") == 0) {
return true;
} else if (strcmp(tokens[0].value, "quit") == 0) {
return true;
}

if (settings.verbose > 1) {
mc_logger->log(EXTENSION_LOG_DEBUG, c,
"%d: authenticated_ascii() in cmd %s is false\n",
c->sfd, tokens[0].value);
}
return false;
}
#endif

static bool authenticated(conn *c)
{
bool rv = false;
Expand Down Expand Up @@ -7618,6 +7788,11 @@ static bool ascii_response_handler(const void *cookie,

static void complete_nread_ascii(conn *c)
{
#ifdef ASCII_SASL
if (!c->authenticated) {
process_sasl_complete(c);
} else
#endif
if (c->ascii_cmd != NULL) {
if (!c->ascii_cmd->execute(c->ascii_cmd->cookie, c, 0, NULL,
ascii_response_handler)) {
Expand Down Expand Up @@ -13106,6 +13281,13 @@ static void process_command_ascii(conn *c, char *command, int cmdlen)

ntokens = tokenize_command(command, cmdlen, tokens, MAX_TOKENS);

#ifdef ASCII_SASL
if (settings.require_sasl && !authenticated_ascii(c, tokens, ntokens)) {
out_string(c, "CLIENT_ERROR unauthenticated");
return;
}
#endif

if ((ntokens >= 3) && (strcmp(tokens[COMMAND_TOKEN].value, "get") == 0))
{
process_get_command(c, tokens, ntokens, false);
Expand Down Expand Up @@ -13249,6 +13431,12 @@ static void process_command_ascii(conn *c, char *command, int cmdlen)
{
process_shutdown_command(c, tokens, ntokens);
}
#ifdef ASCII_SASL
else if ((ntokens >= 3) && (strcmp(tokens[COMMAND_TOKEN].value, "sasl") == 0))
{
process_sasl_command(c, tokens, ntokens);
}
#endif
else /* no matching command */
{
if (settings.extensions.ascii != NULL) {
Expand Down Expand Up @@ -15077,7 +15265,9 @@ int main (int argc, char **argv)
int cache_memory_limit = 0;
int sticky_memory_limit = 0;

#ifndef ASCII_SASL
bool protocol_specified = false;
#endif
bool tcp_specified = false;
bool udp_specified = false;

Expand Down Expand Up @@ -15290,7 +15480,9 @@ int main (int argc, char **argv)
settings.backlog = atoi(optarg);
break;
case 'B':
#ifndef ASCII_SASL
protocol_specified = true;
#endif
if (strcmp(optarg, "auto") == 0) {
settings.binding_protocol = negotiating_prot;
} else if (strcmp(optarg, "binary") == 0) {
Expand Down Expand Up @@ -15415,6 +15607,7 @@ int main (int argc, char **argv)
}
}

#ifndef ASCII_SASL
if (settings.require_sasl) {
if (!protocol_specified) {
settings.binding_protocol = binary_prot;
Expand All @@ -15431,6 +15624,7 @@ int main (int argc, char **argv)
}
}
}
#endif

if (udp_specified && settings.udpport != 0 && !tcp_specified) {
settings.port = settings.udpport;
Expand Down
17 changes: 1 addition & 16 deletions t/binary-sasl.t.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if (supports_sasl()) {
plan skip_all => "The binary 'saslpasswd' is missing from your system";
}
else {
plan tests => 33;
plan tests => 30;
}
} else {
plan tests => 1;
Expand All @@ -42,21 +42,6 @@ if (supports_sasl()) {
exit 0;
}

eval {
my $server = get_memcached($engine, "-S -B auto");
};
ok($@, "SASL shouldn't be used with protocol auto negotiate");

eval {
my $server = get_memcached($engine, "-S -B ascii");
};
ok($@, "SASL isn't implemented in the ascii protocol");

eval {
my $server = get_memcached($engine, "-S -B binary -B ascii");
};
ok($@, "SASL isn't implemented in the ascii protocol");

# Based almost 100% off testClient.py which is:
# Copyright (c) 2007 Dustin Sallings <[email protected]>

Expand Down

0 comments on commit a2ecba8

Please sign in to comment.