Replication - Part 2

That perhaps the real me died a long time ago and I'm a replicant made with a cyborg body and computer brain.

Ghost in the Shell (1995)

At the end of the last chapter we left the system with a working handshake, but there is still some work to do to achieve replication. In this chapter we will make the server propagate write commands to all connected replicas, so that clients can correctly read data from them as well.

Midway through the chapter we take a short detour to refactor the connection functions. They currently assume a concrete TcpStream, which makes them difficult to test. We turn them into generic functions bounded by the right async traits, and we fix a handful of edge cases that the existing implementation doesn't cover.

In the second half of the chapter we will learn how the Redis master checks that its replicas are in sync and implement the system. The final feature is WAIT, the command a client uses to block until its writes have reached a given number of replicas. To implement that, we will once again use the powerful actor model, and learn how to manage timeouts in async tasks.

Step 8.1 - Single-replica propagation#

The first step to achieve replication is to make sure the master propagates "write" commands to its replicas. This might look a bit daunting, but let's have a look at the state of the system:

  • Master and replicas contain the same logic, which means the latter can already receive and process Redis commands out of the box.
  • Replicas open a connection with the master, which they use to perform the handshake.

Given these two elements, propagating commands is truly just a matter of sending to replicas the same binary RESP message received by the master. As we already split the logic of each command in separate files, we just need to choose which ones will be sent to replicas and add relevant code there.

To send commands to connected replicas, however, the master has to know how to reach each one of them. The first change to the server is then to add a list of objects that represent the open channels to replicas. When a replica sends a message to the master during the handshake, each request contains the field sender that represents exactly what we need.

Let's add the list to the server

src/server.rs
use crate::commands::{echo, get, info, ping, psync, replconf, set};
use crate::connection::{stream_send_receive_resp, ConnectionMessage};
use crate::replication::ReplicationConfig;
use crate::request::Request;
use crate::server_result::{ServerError, ServerResult, ServerValue};
use crate::server_result::{ServerError, ServerMessage, ServerResult, ServerValue};
use crate::storage::Storage;
use crate::RESP;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::{io::AsyncWriteExt, net::TcpStream};

...

pub struct Server {
    pub info: ServerInfo,
    pub storage: Option<Storage>,
    pub replication: ReplicationConfig,
    pub replica_senders: Vec<mpsc::Sender<ServerMessage>>,
}

and since the handshake terminates with a PSYNC, we can add the replica to the pool at the end of that command

src/commands/psync.rs
pub async fn command(server: &mut Server, request: &Request, _command: &Vec<String>) {
    // Reset the master replica offset.
    server.replication.repl_offset = 0;

    // Create the FULLRESYNC message.
    let resp = ServerValue::RESP(RESP::SimpleString(format!(
        "FULLRESYNC {} {}",
        server.replication.replid.clone(),
        server.replication.repl_offset.to_string()
    )));

    request.data(resp).await;

    // Generate the RDB data for the server.
    let rdb = server.generate_rdb();

    // Calculate the length of the RDB data.
    let rdb_len = RESP::RDBPrefix(rdb.len());

    // Send the RDB length.
    request.data(ServerValue::RESP(rdb_len)).await;

    // Send the RDB data.
    request.data(ServerValue::Binary(rdb)).await;

    // Add the replica to the pool.
    server.replica_senders.push(request.sender.clone());
}

Now the server can terminate any write command with a loop on replica_senders to propagate the command itself.

Before we implement that, however, we need a way to capture the binary form of an incoming client message. The master could in theory transform decoded RESP commands back into their binary form, but storing the original binary code looks like the simplest way to propagate incoming commands exactly as they are, reducing the risk of bugs in the decoding/encoding process.

First, let's add to a request the original binary payload

src/request.rs
#[derive(Debug)]
pub struct Request {
    pub value: RESP,
    pub sender: mpsc::Sender<ServerMessage>,
    pub binary: Vec<u8>,
}

The request is created in the function handle_connection, so we need to change the call to add the new argument.

src/connection.rs
pub async fn handle_connection(

    loop {

                        // Process the bytes in the buffer according to
                        // the content and extract the request. Update the index.
                        let resp = match bytes_to_resp(&buffer[..size].to_vec(), &mut index) {
                            Ok(v) => v,
                            Err(e) => {
                                eprintln!("Error: {}", e);
                                return;
                            }
                        };

                        // Create the request.
                        let request = Request {
                            value: resp,
                            sender: connection_sender.clone(),
                            binary: buffer[..size].to_vec(),
                        };

                        // Send the request to the server.
                        match server_sender.send(ConnectionMessage::Request(request)).await {
                            Ok(()) => {},
                            Err(e) => {
                                eprintln!("Error sending request: {}", e);
                                return;
                            }
                        }

For now, we can store all the bytes we received using buffer[..size], since the client sends one command at a time. Later, the challenge will introduce multiple commands sent together and we will need to make some adjustments.

This change to the structure Request makes it necessary to change all the tests that contain the creation of a Request.

src/commands/echo.rs
mod tests {

    async fn test_command() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;
src/commands/get.rs
mod tests {

    async fn test_command() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;


    async fn test_storage_not_initialised() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;


    async fn test_wrong_syntax_missing_key() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;
src/commands/info.rs
mod tests {

    async fn test_command_master() {

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;


    async fn test_command_replica() {

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;
src/commands/ping.rs
mod tests {

    async fn test_command_ping() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;


    async fn test_command_ping_uppercase() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;
src/commands/psync.rs
mod tests {

    async fn test_command() {

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
        };

        server.replication.replid = String::from("some_repl_id");
src/commands/replconf.rs
mod tests {

    async fn test_command_replconf_listening_port() {

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;


    async fn test_command_replconf_capa() {
    
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;
src/commands/set.rs
mod tests {

    async fn test_command() {

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;


    async fn test_storage_not_initialised() {

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;


    async fn test_wrong_syntax_missing_key() {

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
        };

        command(&mut server, &request, &cmd).await;
src/server.rs
mod tests {

    async fn test_process_request_ping() {
    
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Array(vec![RESP::BulkString(String::from("PING"))]),
            sender: connection_sender,
            binary: Vec::new(),
        };

        let storage = Storage::new();


    async fn test_process_request_echo() {
    
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Array(vec![
                RESP::BulkString(String::from("ECHO")),
                RESP::BulkString(String::from("42")),
            ]),
            sender: connection_sender,
            binary: Vec::new(),
        };

        let storage = Storage::new();


    async fn test_process_request_not_array() {
    
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::BulkString(String::from("PING")),
            sender: connection_sender,
            binary: Vec::new(),
        };

        let storage = Storage::new();


    async fn test_process_request_not_bulkstrings() {
    
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Array(vec![RESP::SimpleString(String::from("PING"))]),
            sender: connection_sender,
            binary: Vec::new(),
        };

        let storage = Storage::new();

We have now everything we need to propagate write commands to replicas. As SET is the only "write" command we have, we can implement the propagation directly there.

src/commands/set.rs
use crate::request::Request;
use crate::resp::RESP;
use crate::server::Server;
use crate::server_result::{ServerError, ServerValue};
use crate::server_result::{ServerError, ServerMessage, ServerValue};
use crate::set::parse_set_arguments;

pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // Set the value of the key in the storage.
    if let Err(_) = storage.set(key, value, args) {
        request
            .error(ServerError::CommandInternalError(command.join(" ")))
            .await;
        return;
    };

    request
        .data(ServerValue::RESP(RESP::SimpleString(String::from("OK"))))
        .await;

    // Forward the given request to all replicas known to the server.
    for replica_sender in server.replica_senders.iter() {
        let _ = replica_sender
            .send(ServerMessage::Data(ServerValue::Binary(
                request.binary.clone(),
            )))
            .await;
    }

As the challenge doesn't require any error management for this operation, we can safely ignore the output of the propagation.

CodeCrafters

Replication Stage 11: Single-replica propagation

The code we wrote in this section passes Replication - Stage 11 of the CodeCrafters challenge.

Step 8.2 - Multi-replica command propagation#

The code we wrote in the previous section passes the next stage of the challenge without modifications. The reason is that we implemented a solution that naturally handles multiple replicas through a vector of queues and a loop.

While the code is already passing the test, it is worth trying to improve it. At the moment, our code contains the implementation of a single "write" command, that is SET, but we are likely to implement new ones in the future (e.g. DEL, that is mentioned in the challenge instructions for the previous step).

This means that it's worth isolating the code that propagates incoming commands to other replicas, moving it from the implementation of SET to a higher level of abstraction, which is the function process_request.

src/commands/set.rs
use crate::request::Request;
use crate::resp::RESP;
use crate::server::Server;
use crate::server_result::{ServerError, ServerMessage, ServerValue};
use crate::server_result::{ServerError, ServerValue};
use crate::set::parse_set_arguments;

pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // Set the value of the key in the storage.
    if let Err(_) = storage.set(key, value, args) {
        request
            .error(ServerError::CommandInternalError(command.join(" ")))
            .await;
        return;
    };

    request
        .data(ServerValue::RESP(RESP::SimpleString(String::from("OK"))))
        .await;

    // Forward the given request to all replicas known to the server.
    for replica_sender in server.replica_senders.iter() {
        let _ = replica_sender
            .send(ServerMessage::Data(ServerValue::Binary(
                request.binary.clone(),
            )))
            .await;
    }
src/server.rs
// Forward the given request to all replicas known to the server.
pub async fn send_request_to_replicas(request: &Request, server: &Server) {
    for replica_sender in server.replica_senders.iter() {
        let _ = replica_sender
            .send(ServerMessage::Data(ServerValue::Binary(
                request.binary.clone(),
            )))
            .await;
    }
}

pub async fn process_request(request: Request, server: &mut Server) {

    match command_name.as_str() {

        "replconf" => {
            replconf::command(server, &request, &command).await;
        }
        "set" => {
            set::command(server, &request, &command).await;

            // Forward the request to all replicas.
            send_request_to_replicas(&request, server).await;
        }
        _ => {
            request
                .error(ServerError::CommandNotAvailable(command[0].clone()))
                .await;
        }

CodeCrafters

Replication Stage 12: Multi Replica Command Propagation

The code we wrote in this section passes Replication - Stage 12 of the CodeCrafters challenge.

Step 8.3 - Read the RDB file#

In the next stage, tests will ensure that a replica applies the commands propagated by the master. Before we do that, however, we need to add to the replica some code to parse the RDB file sent by the master at the end of the handshake process. I recommend activating the test for the step "Command Processing" whose output will help you to understand what we will do.

We don't need to implement any processing of the content of the RDB file, but we need to read it, otherwise we won't be able to process the commands that the master will send after it.

As we saw in the previous chapter, the server sends a FULLRESYNC followed by the RDB. We are going to receive something like

+FULLRESYNC <repl_id> 0\r\n$<length>\r\n<contents>

which is clearly not a standard RESP message. This means that we need to use some of the lower-level functions we developed for RESP in order to process it.

It is worth having a look at the logs of the new test to have a practical example of what the server is going to send. The highlighted lines contain the bytes we are going to receive

Test logs
[stage-113] [handshake] master: Sent "FULLRESYNC 75cd7bc10c49047e0d163660f3b90625b1af31dc 0"
[stage-113] [handshake] master: Sent bytes: "+FULLRESYNC 75cd7bc10c49047e0d163660f3b90625b1af31dc 0\r\n"
[stage-113] [handshake] Sending RDB file...
[stage-113] [handshake] master: Sent bytes:
"$88\r\nREDIS0011\xfa\tredis-ver\x057.2.0\xfa\nredis-bits\xc0
 \xfa\x05ctime\xc2m\b\xbce\xfa\bused-mem°\xc4\x10\x00\xfa\baof
 -base\xc0\x00\xff\xf0n;\xfe\xc0\xffZ\xa2"
[stage-113] [handshake] Sent RDB file.
[your_program] Error: Server error: Data received from stream is incorrect.

The error Data received from stream is incorrect. comes from process_request. Our code believes that the handshake is over and is treating the new bytes as a Redis command.

The task can be split into the following steps:

  1. Read the FULLRESYNC message.
  2. Read the RDB length and the separator \r\n.
  3. Read the correct amount of bytes (RDB length).

Step 8.3.1 - Read the FULLRESYNC message#

The main problem we have here is that we cannot use high-level functions we developed earlier such as stream_read_resp. The master will send FULLRESYNC immediately followed by the RDB, and since the latter is invalid RESP the whole message becomes invalid.

This means that we must read one "line" at a time, processing incoming bytes until we find the separator \r\n. This is the function that performs that task

src/connection.rs
// Read bytes from the stream until the \r\n is reached (including those bytes).
pub async fn stream_read_line(
    stream: &mut TcpStream,
    buffer: &mut [u8],
) -> ConnectionResult<usize> {
    let mut byte = [0; 1];
    let mut index = 0;

    // Keep reading every byte until we issue a break.
    loop {
        // Read one byte.
        stream
            .read_exact(&mut byte)
            .await
            .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

        // Store the byte in the buffer.
        buffer[index] = byte[0];

        // Check if we reached the end of the line.
        if buffer[index] == b'\n' && buffer[index - 1] == b'\r' {
            break;
        }

        index += 1;
    }

    Ok(index)
}

Using it, we can now add some code to the function handshake that parses the FULLRESYNC part of the message

src/server.rs
use crate::commands::{echo, get, info, ping, psync, replconf, set};
use crate::connection::{stream_send_receive_resp, ConnectionMessage};
use crate::connection::{stream_read_line, stream_send_receive_resp, ConnectionMessage};
use crate::replication::ReplicationConfig;
use crate::request::Request;
use crate::resp::bytes_to_resp;
use crate::server_result::{ServerError, ServerMessage, ServerResult, ServerValue};
use crate::storage::Storage;
use crate::RESP;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::{io::AsyncWriteExt, net::TcpStream};

...

pub async fn handshake(stream: &mut TcpStream, info: &ServerInfo) -> ServerResult {

    // Send the PSYNC command.
    stream
        .write_all(psync.to_string().as_bytes())
        .await
        .map_err(|e| {
            ServerError::HandshakeFailed(format!(
                "Sending {} - Cannot write to stream: {}",
                replconf.to_string(),
                e.to_string()
            ))
        })?;

    /////////////////////////////////////
    // The response will be
    // * The FULLRESYNC: `+FULLRESYNC ID 0\r\n`
    // * The RDB: `$LEN\r\nBINARY`

    let mut index: usize = 0;

    // Read the FULLRESYNC line.
    if let Some(_) = stream_read_line(stream, &mut buffer).await.err() {
        return Err(ServerError::HandshakeFailed(String::from(
            "PSYNC failed, no FULLRESYNC",
        )));
    }

    // Convert bytes into RESP.
    let response = bytes_to_resp(&buffer, &mut index).ok();
    match response {
        Some(RESP::SimpleString(_)) => {}
        _ => {
            return Err(ServerError::HandshakeFailed(String::from(
                "PSYNC failed, wrong FULLRESYNC response",
            )));
        }
    };

    Ok(ServerValue::None)

At the end of this step the test logs will contain the following output

Test logs
[stage-113] [handshake] master: Sent "FULLRESYNC 75cd7bc10c49047e0d163660f3b90625b1af31dc 0"
[stage-113] [handshake] master: Sent bytes: "+FULLRESYNC 75cd7bc10c49047e0d163660f3b90625b1af31dc 0\r\n"
[stage-113] [handshake] Sending RDB file...
[stage-113] [handshake] master: Sent bytes:
"$88\r\nREDIS0011\xfa\tredis-ver\x057.2.0\xfa\nredis-bits\xc0
 \xfa\x05ctime\xc2m\b\xbce\xfa\bused-mem°\xc4\x10\x00\xfa\baof
 -base\xc0\x00\xff\xf0n;\xfe\xc0\xffZ\xa2"
[stage-113] [handshake] Sent RDB file.
[your_program] Error: Cannot convert from UTF-8

Now the error is Cannot convert from UTF-8 that comes from parse_bulk_string. Our code happily digested the command FULLRESYNC and moved on, but since the RDB starts with a $ it is mistakenly treated as a bulk string. The RDB, however, doesn't contain only text, but binary data that is definitely not in UTF-8 format.

Untested functions

The new function stream_read_line is pretty fragile, and there is at least one pretty evident issue in the logic: if the stream sends \n as the first byte, the code would try to read outside the buffer to find \r. We should add tests, but as we said in a past chapter, adding mocks to Rust is, while not impossible, not as seamless as in other languages.

Fortunately, here we can introduce the concept of generic functions with trait bounds to make the code testable. This change requires some work, however, so it will be done later. For now, we will keep the function stream_read_line as it is.

Step 8.3.2 - Read the RDB length and the separator#

To know how many bytes are contained in the RDB file we can just read the length using the two functions resp_remove_type and resp_extract_length, exactly what we do in parse_bulk_string.

src/server.rs
use crate::commands::{echo, get, info, ping, psync, replconf, set};
use crate::connection::{stream_read_line, stream_send_receive_resp, ConnectionMessage};
use crate::replication::ReplicationConfig;
use crate::request::Request;
use crate::resp::bytes_to_resp;
use crate::resp::{bytes_to_resp, resp_extract_length, resp_remove_type};
use crate::server_result::{ServerError, ServerMessage, ServerResult, ServerValue};
use crate::storage::Storage;
use crate::RESP;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::{io::AsyncWriteExt, net::TcpStream};

...

pub async fn handshake(stream: &mut TcpStream, info: &ServerInfo) -> ServerResult {

    // Read the FULLRESYNC line.
    if let Some(_) = stream_read_line(stream, &mut buffer).await.err() {
        return Err(ServerError::HandshakeFailed(String::from(
            "PSYNC failed, no FULLRESYNC",
        )));
    }

    // Convert bytes into RESP.
    let response = bytes_to_resp(&buffer, &mut index).ok();
    match response {
        Some(RESP::SimpleString(_)) => {}
        _ => {
            return Err(ServerError::HandshakeFailed(String::from(
                "PSYNC failed, wrong FULLRESYNC response",
            )));
        }
    };

    // Read the RDB length.
    if let Some(_) = stream_read_line(stream, &mut buffer).await.err() {
        return Err(ServerError::HandshakeFailed(String::from(
            "PSYNC failed, cannot read RDB length",
        )));
    }

    index = 0;

    // Remove the dollar sign.
    if let Some(_) = resp_remove_type('$', &buffer, &mut index).err() {
        return Err(ServerError::HandshakeFailed(String::from(
            "PSYNC failed, RDB doesn't start with $",
        )));
    }

    // Convert bytes into RDB length.
    let length = resp_extract_length(&buffer, &mut index).unwrap();

    Ok(ServerValue::None)

With these changes, the error of the test becomes

Test logs
[stage-113] [handshake] master: Sent "FULLRESYNC 75cd7bc10c49047e0d163660f3b90625b1af31dc 0"
[stage-113] [handshake] master: Sent bytes: "+FULLRESYNC 75cd7bc10c49047e0d163660f3b90625b1af31dc 0\r\n"
[stage-113] [handshake] Sending RDB file...
[stage-113] [handshake] master: Sent bytes:
"$88\r\nREDIS0011\xfa\tredis-ver\x057.2.0\xfa\nredis-bits\xc0
 \xfa\x05ctime\xc2m\b\xbce\xfa\bused-mem°\xc4\x10\x00\xfa\baof
 -base\xc0\x00\xff\xf0n;\xfe\xc0\xffZ\xa2"
[stage-113] [handshake] Sent RDB file.
[stage-113] [propagation] master: > SET foo 123
[stage-113] [propagation] master: Sent bytes: "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\n123\r\n"
[stage-113] [propagation] master: > SET bar 456
[your_program] Error: Unknown format for RESP string

The message Unknown format for RESP string comes from the function bytes_to_resp that is called in handle_connection. Our code is completely confused now, as the data we received starts with REDIS0011 which is not RESP at all. Fortunately, at this point we have almost everything we need to read all the RDB bytes.

Step 8.3.3 - Read the correct amount of bytes#

The last piece of code we need in order to read the RDB is a function that reads a given amount of bytes from a stream. This can be done with a simple loop and the function AsyncReadExt::read_exact [docs]

src/connection.rs
// Read the given amount of bytes from the stream
pub async fn stream_read_data_length(
    stream: &mut TcpStream,
    buffer: &mut [u8],
    length: usize,
) -> ConnectionResult<usize> {
    let mut byte = [0; 1];
    let mut index = 0;

    // Keep reading every byte until we reach the given amount.
    while index < length {
        // Read one byte.
        stream
            .read_exact(&mut byte)
            .await
            .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

        // Store the byte in the buffer.
        buffer[index] = byte[0];

        index += 1;
    }

    Ok(index)
}

which allows us to finally extract the RDB from the stream during the handshake process.

src/server.rs
use crate::commands::{echo, get, info, ping, psync, replconf, set};
use crate::connection::{stream_read_line, stream_send_receive_resp, ConnectionMessage};
use crate::connection::{
    stream_read_data_length, stream_read_line, stream_send_receive_resp, ConnectionMessage,
};
use crate::replication::ReplicationConfig;
use crate::request::Request;
use crate::resp::{bytes_to_resp, resp_extract_length, resp_remove_type};
use crate::server_result::{ServerError, ServerMessage, ServerResult, ServerValue};
use crate::storage::Storage;
use crate::RESP;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::{io::AsyncWriteExt, net::TcpStream};

...


pub async fn handshake(stream: &mut TcpStream, info: &ServerInfo) -> ServerResult {

    // Remove the dollar sign.
    if let Some(_) = resp_remove_type('$', &buffer, &mut index).err() {
        return Err(ServerError::HandshakeFailed(String::from(
            "PSYNC failed, RDB doesn't start with $",
        )));
    }

    // Convert bytes into RDB length.
    let length = resp_extract_length(&buffer, &mut index).unwrap();

    // Read the RDB data.
    if let Some(_) = stream_read_data_length(stream, &mut buffer, length as usize)
        .await
        .err()
    {
        return Err(ServerError::HandshakeFailed(String::from(
            "PSYNC failed, cannot read RDB",
        )));
    }

    Ok(ServerValue::None)

In this specific implementation we are going to ignore the content of the RDB, so there is no function to parse it.

The test for this stage, however, doesn't stop after the RDB is sent. So, even with this change it still fails because the server is not handling multiple commands in the same TCP segment. We will fix that in the next step.

Step 8.4 - Command processing#

Now that the replica has received the RDB it is ready to accept Redis commands sent from the master or from clients. In principle, what we have done so far is enough to provide a working implementation, as the server propagates commands and replicas are ready to read them.

The challenge warns us that "it is not guaranteed that propagated commands will be sent one at a time. One TCP segment might contain bytes for multiple commands."

This is exactly what happens in this case. The test log shows that the master is propagating three commands

[stage-113] [propagation] master: > SET foo 123
[stage-113] [propagation] master: Sent bytes: "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\n123\r\n"
[stage-113] [propagation] master: > SET bar 456
[stage-113] [propagation] master: Sent bytes: "*3\r\n$3\r\nSET\r\n$3\r\nbar\r\n$3\r\n456\r\n"
[stage-113] [propagation] master: > SET baz 789
[stage-113] [propagation] master: Sent bytes: "*3\r\n$3\r\nSET\r\n$3\r\nbaz\r\n$3\r\n789\r\n"

but you might see our code actually receiving some or all of them, seemingly at random.

Making sure that our server (master or replica) correctly handles multiple commands is pretty simple. We just need to drop the assumption that the buffer read from the TCP stream contains only one command and loop over all the commands.

src/connection.rs
pub async fn handle_connection(

    loop {

                    // If the stream returned some data,
                    // process the request.
                    Ok(size) if size != 0 => {
                        // Initialise the index to start at the
                        // beginning of the buffer.
                        let mut index: usize = 0;

                        // Process the bytes in the buffer according to
                        // the content and extract the request. Update the index.
                        let resp = match bytes_to_resp(&buffer[..size].to_vec(), &mut index) {
                            Ok(v) => v,
                            Err(e) => {
                                eprintln!("Error: {}", e);
                                return;
                            }
                        };

                        // Create the request.
                        let request = Request {
                            value: resp,
                            sender: connection_sender.clone(),
                            binary: buffer[..size].to_vec(),
                        };

                        // Send the request to the server.
                        match server_sender.send(ConnectionMessage::Request(request)).await {
                            Ok(()) => {},
                            Err(e) => {
                                eprintln!("Error sending request: {}", e);
                                return;
                            }
                        }
                        // Keep reading until we processed the full
                        // amount of data read from the stream.
                        while index < size {
                            let start: usize = index;

                            // Process the bytes in the buffer according to
                            // the content and extract the request. Update the index.
                            let resp = match bytes_to_resp(&buffer[..size].to_vec(), &mut index) {
                                Ok(v) => v,
                                Err(e) => {
                                    eprintln!("Error: {}", e);
                                    return;
                                }
                            };

                            // Create the request.
                            let request = Request {
                                value: resp,
                                sender: connection_sender.clone(),
                                binary: buffer[start..index].to_vec(),
                            };

                            // Send the request to the server.
                            match server_sender.send(ConnectionMessage::Request(request)).await {
                                Ok(()) => {},
                                Err(e) => {
                                    eprintln!("Error sending request: {}", e);
                                    return;
                                }
                            }
                        }
                    }

CodeCrafters

Replication Stage 13: Command Processing

The code we wrote in this section passes Replication - Stage 13 of the CodeCrafters challenge.

Step 8.5 - Refactoring connection functions#

As discussed before, the stream functions in src/connection.rs are untested, and we already spotted potential bugs in the current implementation. In this step, we are going to make them testable by introducing the powerful concept of generic functions with trait bounds.

Let's take a closer look at the function stream_read_line and at the potential issues.

src/connection.rs
// Read bytes from the stream until the \r\n is reached (including those bytes).
pub async fn stream_read_line(
    stream: &mut TcpStream,
    buffer: &mut [u8],
) -> ConnectionResult<usize> {
    let mut byte = [0; 1];
    let mut index = 0;

    // Keep reading every byte until we issue a break.
    loop {
        // Read one byte.
        stream
            .read_exact(&mut byte)
            .await
            .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

        // Store the byte in the buffer.
        buffer[index] = byte[0];

        // Check if we reached the end of the line.
        if buffer[index] == b'\n' && buffer[index - 1] == b'\r' { 1
            break;
        }

        index += 1; 2
    }

    Ok(index)
}

The check 1 assumes that, when we see a \n, there is always a previous byte to look at. But on the very first iteration of the loop index is 0, and since index is a usize, the expression index - 1 would panic at runtime. If the stream ever delivers \n as its very first byte, the process crashes. Nothing in the type system catches it, and we have no tests to exercise it.

We also have a second problem in the same function: we never check that index stays within the bounds of buffer 2. If the terminator \r\n simply never arrives, the loop will happily keep reading bytes from the network and writing past the end of the caller's buffer, causing another panic. This is precisely the kind of defensive check that binary_extract_line in src/resp.rs already has, and that we should be porting over.

We can clearly try to address these problems directly, but it would be great to have tests that verify the function's behaviour.

The function stream_read_line receives a &mut TcpStream, and spinning up a real TCP socket in a unit test is not a great idea. What we really want is to decouple the function from TcpStream so that a test can feed it bytes from an in-memory source.

Generic functions and trait bounds#

A generic function is a function that is not tied to a specific concrete type. Instead, it declares one or more type parameters, and the compiler generates a specialised version for every concrete type the function is called with. This process is called monomorphisation and is described in detail in the chapter Generic Data Types of the Rust book.

A bare type parameter, by itself, would be too permissive: inside the function body the compiler has no idea what operations are available on a value of type T, so it would refuse to let us do anything useful with it. Trait bounds solve this by constraining the type parameter to types that implement one or more traits. The Rust book chapter Traits: Defining Shared Behavior covers the syntax and semantics in depth. In short, a bound like T: Display means "the caller may pick any concrete type for T, as long as that type implements Display", and the function body can then call every method that Display provides.

There are two common ways to express "this function works with anything that can be read from". One is to take a trait object, &mut dyn AsyncRead, which uses dynamic dispatch, and the other is to take a generic parameter with a trait bound, <R: AsyncRead>, which uses static dispatch. Both solve our problem, but in this context the static dispatch is the best solution. It produces slightly faster code and it composes more naturally with the rest of the async ecosystem.

Picking the right bounds#

We want stream_read_line to accept "anything that can be asynchronously read from", so the obvious candidate is AsyncRead [docs] trait. This is the async counterpart of the standard library's std::io::Read: a type that implements AsyncRead exposes a poll_read method, which attempts to fill a buffer with bytes and may return Poll::Pending if the data is not yet available. We don't call poll_read directly; instead we use the convenience methods provided by the AsyncReadExt extension trait, exactly as we already do with read_exact.

AsyncRead alone is not quite enough. The convenience methods on AsyncReadExt, including read_exact, are defined with a Self: Unpin bound so they can accept an ordinary &mut self instead of forcing the caller to pin the reader first. Without + Unpin on our generic R, the compiler would refuse to call read_exact on it, and we'd have to pin R manually at every call.

The combined bound R: AsyncRead + Unpin is therefore the most permissive signature that still lets us do useful work inside the function body. It captures the intent, "I need something I can asynchronously read from and that I can move around freely", and nothing more.

Step 8.5.1 - Make the function generic#

The actual diff is small. We add AsyncRead to the list of imports from tokio::io — we need the trait itself in scope for the bound, even though we'll keep using AsyncReadExt for the read_exact method — and we replace the concrete &mut TcpStream parameter with a generic &mut R.

src/connection.rs
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
    net::{TcpListener, TcpStream},
    select,
    sync::mpsc,
};

...

// Read bytes from the stream until the \r\n is reached (including those bytes).
pub async fn stream_read_line(
    stream: &mut TcpStream,
pub async fn stream_read_line<R: AsyncRead + Unpin>(
    stream: &mut R,
    buffer: &mut [u8],
) -> ConnectionResult<usize> {

The single real caller of stream_read_line lives in src/server.rs and passes a &mut TcpStream. Since TcpStream implements both AsyncRead and Unpin, the call site compiles unchanged: the compiler infers R = TcpStream and monomorphises a specialised copy of the function for that type, which is effectively identical to the function we had before. From the point of view of the production code path, nothing has moved.

What has moved is the set of types that are allowed to call the function. Any type that satisfies AsyncRead + Unpin now works, and Tokio provides such an implementation for humble byte slices. This is the payoff: a plain &[u8] is an AsyncRead, which means a unit test can feed bytes to stream_read_line without ever touching a socket, without a mocking library, and without adding a single line of infrastructure code. The Tokio documentation for AsyncRead [docs] lists every implementor; it's worth a glance to appreciate how much testing surface this single abstraction unlocks.

We can now add a first test.

src/connection.rs
#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    // Test that the function stream_read_line
    // correctly detects the \r\n, writes the bytes
    // (terminator included) into the buffer, and
    // returns the index of the final \n.
    async fn test_stream_read_line() {
        let mut stream: &[u8] = b"OK\r\n";
        let mut buffer = [0u8; 16];

        let index = stream_read_line(&mut stream, &mut buffer).await.unwrap();

        assert_eq!(&buffer[..=index], b"OK\r\n");
        assert_eq!(index, 3);
    }
}

Step 8.5.2 - Add tests and fix bugs#

Now that stream_read_line is generic and we have a first passing test, we can turn our attention to the actual bug we flagged: the expression buffer[index - 1] on the very first iteration of the loop.

Let's first write a test that exposes the bug. The data we "send" to the function is the three-byte sequence \n\r\n, where the initial \n is part of the line's content, not its terminator, and the function should happily return once it sees the real \r\n at the end.

src/connection.rs
mod tests {

    #[tokio::test]
    // Test that the function stream_read_line does
    // not panic when the stream sends \n as the very
    // first byte. A lone \n is not preceded by a \r
    // and must not be interpreted as the terminator.
    async fn test_stream_read_line_leading_newline() {
        let mut stream: &[u8] = b"\n\r\n";
        let mut buffer = [0u8; 16];

        let index = stream_read_line(&mut stream, &mut buffer).await.unwrap();

        assert_eq!(&buffer[..=index], b"\n\r\n");
        assert_eq!(index, 2);
    }

As expected, the test panics at runtime.

failures:

---- connection::tests::test_stream_read_line_leading_newline stdout ----

thread 'connection::tests::test_stream_read_line_leading_newline' (3619422) panicked at src/connection.rs:263:45:
attempt to subtract with overflow

The fix is to stop reading from the buffer at all when we need to check the previous byte. Rather than going back into buffer[index - 1], we keep the previous byte in a dedicated local variable and update it at the end of each iteration. This is exactly the same technique that binary_extract_line already uses with its previous_elem variable.

src/connection.rs
pub async fn stream_read_line<R: AsyncRead + Unpin>(
    stream: &mut R,
    buffer: &mut [u8],
) -> ConnectionResult<usize> {
    let mut byte = [0; 1];
    let mut index = 0;

    // Track the byte we read in the previous iteration
    // so that we can detect the terminator \r\n without
    // ever indexing the buffer backwards. We initialise
    // it to 0, which is neither \r nor \n.
    let mut previous_byte: u8 = 0;

    // Keep reading every byte until we issue a break.
    loop {
        // Read one byte.
        stream
            .read_exact(&mut byte)
            .await
            .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

        // Store the byte in the buffer.
        buffer[index] = byte[0];

        // Check if we reached the end of the line.
        if buffer[index] == b'\n' && buffer[index - 1] == b'\r' {
        if byte[0] == b'\n' && previous_byte == b'\r' {
            break;
        }

        previous_byte = byte[0];
        index += 1;
    }

    Ok(index)
}

The initial value 0 for previous_byte looks like a natural choice for a u8. Now the whole test suite passes.

Step 8.5.3 - Guard against short buffers#

There is still a potential bug in the current version of the function. Every iteration of the loop writes into buffer[index] and then increments index, but nothing checks that index is still a valid offset into buffer. If the terminator \r\n never arrives then index will eventually grow past the length of the buffer and the function will panic.

We already solved the problem in binary_extract_line. There, the very first thing the function does is check that the starting index is still within the slice, and return a RESPError::OutOfBounds error otherwise. The equivalent move for us is to check, at the top of each loop iteration, that we still have room for one more byte.

Let's first add a dedicated variant of ConnectionError that captures this condition.

src/connection.rs
#[derive(Debug)]
pub enum ConnectionError {
    BufferTooSmall(usize),
    CannotReadFromStream(String),
    CannotWriteToStream(String),
    MalformedRESP(String),
    RequestFailed(String, String),
    ServerError(ServerError),
}

impl fmt::Display for ConnectionError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ConnectionError::BufferTooSmall(size) => {
                write!(f, "Buffer too small, size {}.", size)
            }
            ConnectionError::CannotReadFromStream(string) => {
                write!(f, "Cannot read from stream: {}.", string)
            }
            ...
        }
    }
}

With the error variant in place, we can write a test that exposes the issue.

src/connection.rs
mod tests {

    #[tokio::test]
    // Test that the function stream_read_line returns
    // a BufferTooSmall error when the caller's buffer
    // cannot hold the whole line, instead of panicking
    // or writing past the end of the buffer.
    async fn test_stream_read_line_buffer_too_small() {
        let mut stream: &[u8] = b"HELLO\r\n";
        let mut buffer = [0u8; 3];

        match stream_read_line(&mut stream, &mut buffer).await {
            Err(ConnectionError::BufferTooSmall(size)) => {
                assert_eq!(size, 3);
            }
            _ => panic!(),
        }
    }

After we verified that the test fails, we can add the code that makes it pass. The check is pretty straightforward.

src/connection.rs
pub async fn stream_read_line<R: AsyncRead + Unpin>(

    loop {
        // Make sure we still have room in the buffer
        // before consuming another byte from the stream.
        if index >= buffer.len() {
            return Err(ConnectionError::BufferTooSmall(buffer.len()));
        }

        // Read one byte.
        stream
            .read_exact(&mut byte)
            .await
            .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

        // Store the byte in the buffer.
        buffer[index] = byte[0];

        // Check if we reached the end of the line.
        if byte[0] == b'\n' && previous_byte == b'\r' {
            break;
        }

        previous_byte = byte[0];
        index += 1;
    }

With this step, stream_read_line is now defensible on both of its failure modes: a pathological first byte (fixed in 8.5.2) and a buffer that's too small to hold the incoming line (fixed here). What remains is fleshing out the error-path tests to cover the cases where the stream itself misbehaves — closes unexpectedly, sends only half of a terminator, or sends a line without any terminator at all. That's the subject of the final sub-step.

Step 8.5.4 - Complete the test suite#

The last step is all about tests. Given the similarities between the two functions, we want to increase the coverage of stream_read_line as we did for binary_extract_line.

src/connection.rs
mod tests {

    #[tokio::test]
    // Make sure the function stream_read_line doesn't
    // have any hardcoded values testing a different
    // use case.
    async fn test_stream_read_line_longer_string() {
        let mut stream: &[u8] = b"ECHO\r\n";
        let mut buffer = [0u8; 16];

        let index = stream_read_line(&mut stream, &mut buffer).await.unwrap();

        assert_eq!(&buffer[..=index], b"ECHO\r\n");
        assert_eq!(index, 5);
    }

    #[tokio::test]
    // Test that the function stream_read_line returns
    // a CannotReadFromStream error when the stream is
    // empty and no bytes can be read at all.
    async fn test_stream_read_line_empty_stream() {
        let mut stream: &[u8] = b"";
        let mut buffer = [0u8; 16];

        match stream_read_line(&mut stream, &mut buffer).await {
            Err(ConnectionError::CannotReadFromStream(_)) => {}
            _ => panic!(),
        }
    }

    #[tokio::test]
    // Test that the function stream_read_line returns
    // a CannotReadFromStream error when the stream
    // ends before the terminator \r\n is reached.
    async fn test_stream_read_line_no_separator() {
        let mut stream: &[u8] = b"OK";
        let mut buffer = [0u8; 16];

        match stream_read_line(&mut stream, &mut buffer).await {
            Err(ConnectionError::CannotReadFromStream(_)) => {}
            _ => panic!(),
        }
    }

    #[tokio::test]
    // Test that the function stream_read_line returns
    // a CannotReadFromStream error when the stream
    // sends only \r and then ends, leaving the
    // terminator \r\n incomplete.
    async fn test_stream_read_line_half_separator() {
        let mut stream: &[u8] = b"OK\r";
        let mut buffer = [0u8; 16];

        match stream_read_line(&mut stream, &mut buffer).await {
            Err(ConnectionError::CannotReadFromStream(_)) => {}
            _ => panic!(),
        }
    }

That concludes the refactoring of stream_read_line.

Step 8.5.5 - Other functions#

Since we introduced trait bounds for stream_read_line, we can do the same for the other stream functions: stream_write_resp, stream_read_resp, stream_send_receive_resp, and stream_read_data_length.

Let's start with stream_write_resp. The set of changes to the prototype is similar to what we have done before, but here we need AsyncWrite instead of AsyncRead.

src/connection.rs
use crate::resp::{bytes_to_resp, RESP};
use crate::server::{handshake, ServerInfo};
use crate::server_result::{ServerMessage, ServerValue};
use crate::{request::Request, server_result::ServerError};
use std::fmt;
use tokio::{
    io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
    net::{TcpListener, TcpStream},
    select,
    sync::mpsc,
};

...

// Write RESP data to the stream.
pub async fn stream_write_resp(stream: &mut TcpStream, data: &RESP) -> ConnectionResult<usize> {
pub async fn stream_write_resp<W: AsyncWrite + Unpin>(
    stream: &mut W,
    data: &RESP,
) -> ConnectionResult<usize> {

...

mod tests {

    #[tokio::test]
    // Test that the function stream_write_resp
    // serialises a RESP value and writes the
    // resulting bytes to the stream, returning
    // the number of bytes written.
    async fn test_stream_write_resp() {
        let mut stream: Vec<u8> = Vec::new();
        let data = RESP::SimpleString(String::from("OK"));

        let size = stream_write_resp(&mut stream, &data).await.unwrap();

        assert_eq!(stream, b"+OK\r\n");
        assert_eq!(size, 5);
    }

For stream_read_resp, there are three tests and some code changes, since the current version has some bugs when the RESP message is malformed and when the message is shorter than the buffer. This leads to some changes to the body of the function as well.

src/connection.rs
// Read RESP data from the stream.
async fn stream_read_resp(stream: &mut TcpStream, buffer: &mut [u8]) -> ConnectionResult<RESP> {
async fn stream_read_resp<R: AsyncRead + Unpin>(
    stream: &mut R,
    buffer: &mut [u8],
) -> ConnectionResult<RESP> {
    // Read bytes from the TCP stream.
    match stream.read(buffer).await {
        Ok(size) => Ok(size),
        Err(e) => Err(ConnectionError::CannotReadFromStream(e.to_string())),
    }?;
    // Read bytes from the stream. Remember how many
    // bytes we actually got, so that we only hand the
    // parser the valid portion of the buffer.
    let size = stream
        .read(buffer)
        .await
        .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

    // Set the index to start reading from the first byte.
    let mut index: usize = 0;

    // Convert bytes to RESP.
    bytes_to_resp(&buffer, &mut index).map_err(|e| ConnectionError::MalformedRESP(e.to_string()))
    // Convert bytes to RESP, only looking at the slice
    // that actually holds data read from the stream.
    bytes_to_resp(&buffer[..size], &mut index)
        .map_err(|e| ConnectionError::MalformedRESP(e.to_string()))
}

...

mod tests {

    #[tokio::test]
    // Test that the function stream_read_resp
    // reads bytes from the stream and parses them
    // into a RESP value.
    async fn test_stream_read_resp() {
        let mut stream: &[u8] = b"+OK\r\n";
        let mut buffer = [0u8; 16];

        let resp = stream_read_resp(&mut stream, &mut buffer).await.unwrap();

        assert_eq!(resp, RESP::SimpleString(String::from("OK")));
    }

    #[tokio::test]
    // Test that the function stream_read_resp
    // passes to the parser the bytes it has
    // actually read and not any stale bytes
    // left in the buffer.
    async fn test_stream_read_resp_short_message_in_large_buffer() {
        let mut stream: &[u8] = b"+OK\r\n";
        // Pre-fill the buffer with garbage that would
        // confuse the parser if it were consulted.
        let mut buffer = [b'X'; 64];

        let resp = stream_read_resp(&mut stream, &mut buffer).await.unwrap();

        assert_eq!(resp, RESP::SimpleString(String::from("OK")));
    }

    #[tokio::test]
    // Test that the function stream_read_resp returns
    // a MalformedRESP error when the bytes in the
    // stream are not a valid RESP message.
    async fn test_stream_read_resp_malformed() {
        let mut stream: &[u8] = b"not resp";
        let mut buffer = [0u8; 16];

        match stream_read_resp(&mut stream, &mut buffer).await {
            Err(ConnectionError::MalformedRESP(_)) => {}
            _ => panic!(),
        }
    }

Next, the function stream_send_receive_resp.

src/connection.rs
// Write a RESP request to the stream and read the RESP response.
pub async fn stream_send_receive_resp(
    stream: &mut TcpStream,
pub async fn stream_send_receive_resp<S: AsyncRead + AsyncWrite + Unpin>(
    stream: &mut S,

...

mod tests {

    #[tokio::test]
    // Test that the function stream_send_receive_resp
    // writes the request to the stream and then reads
    // the response from the same stream. We use a
    // tokio duplex pair to simulate a connected peer
    // that reads the request and writes back a canned
    // response.
    async fn test_stream_send_receive_resp() {
        use tokio::io::{AsyncReadExt, AsyncWriteExt};

        let (mut client, mut server) = tokio::io::duplex(64);

        // Spawn a task that acts as the remote peer:
        // read whatever the client sends, then write
        // back a canned RESP response.
        let peer = tokio::spawn(async move {
            let mut request_buffer = [0u8; 64];
            let n = server.read(&mut request_buffer).await.unwrap();

            // Sanity check: the client must have sent
            // the serialised PING command.
            assert_eq!(&request_buffer[..n], b"+PING\r\n");
            server.write_all(b"+PONG\r\n").await.unwrap();
        });

        let request = RESP::SimpleString(String::from("PING"));
        let mut buffer = [0u8; 64];

        let response = stream_send_receive_resp(&mut client, &request, &mut buffer)
            .await
            .unwrap();

        assert_eq!(response, RESP::SimpleString(String::from("PONG")));
        peer.await.unwrap();
    }

The test for this function deserves an explanation. Here, we need to test that the interaction with an active party works. The function sends data but then also reads the response. To test this, we can create a simple Tokio duplex [docs] and spawn a task that reads (server.read(...).await) and then writes (server.write_all(...).await).

Last, we change the function stream_read_data_length. This is actually a good chance to refactor the function, replacing the current implementation.

src/connection.rs
// Read the given amount of bytes from the stream
pub async fn stream_read_data_length(
    stream: &mut TcpStream,
    buffer: &mut [u8],
    length: usize,
) -> ConnectionResult<usize> {
    let mut byte = [0; 1];
    let mut index = 0;

    // Keep reading every byte until we reach the given amount.
    while index < length {
        // Read one byte.
        stream
            .read_exact(&mut byte)
            .await
            .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

        // Store the byte in the buffer.
        buffer[index] = byte[0];

        index += 1;
    }

    Ok(index)
}
pub async fn stream_read_data_length<R: AsyncRead + Unpin>(
    stream: &mut R,
    buffer: &mut [u8],
    length: usize,
) -> ConnectionResult<usize> {
    // Make sure the caller's buffer can hold the
    // requested amount of bytes before we touch the
    // stream. This avoids panicking when slicing and
    // mirrors the check we use in stream_read_line.
    if length > buffer.len() {
        return Err(ConnectionError::BufferTooSmall(buffer.len()));
    }

    // Read exactly `length` bytes into the head of
    // the buffer. read_exact returns an error if the
    // stream closes before enough bytes are available.
    stream
        .read_exact(&mut buffer[..length])
        .await
        .map_err(|e| ConnectionError::CannotReadFromStream(e.to_string()))?;

    Ok(length)
}

...

mod tests {

    #[tokio::test]
    // Test that the function stream_read_data_length
    // reads exactly the requested number of bytes
    // from the stream and stores them in the buffer.
    async fn test_stream_read_data_length() {
        let mut stream: &[u8] = b"HELLO WORLD";
        let mut buffer = [0u8; 16];

        let size = stream_read_data_length(&mut stream, &mut buffer, 5)
            .await
            .unwrap();

        assert_eq!(size, 5);
        assert_eq!(&buffer[..size], b"HELLO");
    }

    #[tokio::test]
    // Test that the function stream_read_data_length
    // returns a BufferTooSmall error when the caller
    // asks for more bytes than the buffer can hold,
    // without consuming any bytes from the stream.
    async fn test_stream_read_data_length_buffer_too_small() {
        let mut stream: &[u8] = b"HELLO";
        let mut buffer = [0u8; 3];

        match stream_read_data_length(&mut stream, &mut buffer, 5).await {
            Err(ConnectionError::BufferTooSmall(size)) => {
                assert_eq!(size, 3);
            }
            _ => panic!(),
        }
    }

    #[tokio::test]
    // Test that the function stream_read_data_length
    // returns a CannotReadFromStream error when the
    // stream ends before the requested amount of
    // bytes have been read.
    async fn test_stream_read_data_length_short_stream() {
        let mut stream: &[u8] = b"HI";
        let mut buffer = [0u8; 16];

        match stream_read_data_length(&mut stream, &mut buffer, 5).await {
            Err(ConnectionError::CannotReadFromStream(_)) => {}
           _ => panic!(),
        }
    }

CodeCrafters

Replication Stage 13: Command Processing

This last set of steps were a pure refactoring, so the code still passes Replication - Stage 13 of the CodeCrafters challenge.

Step 8.6 - ACKs with no commands#

The new stage introduces the command REPLCONF GETACK sent by the server to a replica. The master sends this command periodically to see if replicas are in sync with the current state of the database.

So, the server will send a RESP array with the command REPLCONF GETACK * and the replica is supposed to reply with REPLCONF ACK <offset>, where <offset> is the current offset of the replica. In this stage, we assume that the replica hasn't received any other command, so the offset will be 0.

The current implementation of REPLCONF doesn't consider subcommands like GETACK and doesn't take into consideration the nature of the server (master, replica) as so far the command was supposed to reach the master only.

To see if the current server is a master or a replica, we can look at server.replication.master and see if it's None (master) or Some (replica). However, such a comparison won't be very readable, so it makes sense to create a function that does it for us.

src/replication.rs
impl ReplicationConfig {

    pub fn is_replica(&self) -> bool {
        match self.master {
            Some(_) => true,
            None => false,
        }
    }

Now we can change the function that implements REPLCONF.

src/commands/replconf.rs
use crate::request::Request;
use crate::resp::RESP;
use crate::server::Server;
use crate::server_result::ServerValue;
use crate::server_result::{ServerError, ServerValue};

pub async fn command(_server: &Server, request: &Request, _command: &Vec<String>) {
pub async fn command(server: &Server, request: &Request, command: &Vec<String>) {
    request
        .data(ServerValue::RESP(RESP::SimpleString("OK".to_string())))
        .await;
    // On a master, the REPLCONF is received only
    // from the handshake process. In that case we
    // always reply with OK.
    if !server.replication.is_replica() {
        request
            .data(ServerValue::RESP(RESP::SimpleString("OK".to_string())))
            .await;
        return;
    }

    // On a replica, the master sends `REPLCONF GETACK *`
    // and expects a `REPLCONF ACK <offset>` reply.
    if command.len() != 3 {
        request
            .error(ServerError::CommandSyntaxError(command.join(" ")))
            .await;
        return;
    }

    // The only subcommand that REPLCONF supports
    // on a replica is GETACK. Reject anything else.
    if command[1].to_uppercase() != String::from("GETACK") {
        request
            .error(ServerError::CommandSyntaxError(command.join(" ")))
            .await;
        return;
    }

    // Create the response for the server.
    let resp = RESP::Array(vec![
        RESP::BulkString(String::from("REPLCONF")),
        RESP::BulkString(String::from("ACK")),
        RESP::BulkString(String::from("0")),
    ]);

    request.data(ServerValue::RESP(resp)).await;
}

As we introduced different logic for the master and the replica, it's worth clarifying that the existing tests are for the master.

src/commands/replconf.rs
mod tests {

    #[tokio::test]
    // Test that the function command processes
    // a `REPLCONF listening-port` request
    // a `REPLCONF listening-port` request on a master
    // and that it responds with the correct value.
    async fn test_command_replconf_listening_port() {

...

    #[tokio::test]
    // Test that the function command processes
    // a `REPLCONF capa` request and
    // a `REPLCONF capa` request on a master and
    // that it responds with the correct value.
    async fn test_command_replconf_capa() {

...

We also added several checks in the function, so we can write new tests for it.

src/commands/replconf.rs
mod tests {

    use super::*;
    use crate::replication::ReplicationConfig;
    use crate::server_result::ServerMessage;
    use tokio::sync::mpsc;

    #[tokio::test]
    // Test that the function command processes
    // a `REPLCONF GETACK *` request on a replica
    // and that it responds with `REPLCONF ACK 0`.
    async fn test_command_replconf_getack_on_replica() {
        let cmd = vec![
            String::from("replconf"),
            String::from("GETACK"),
            String::from("*"),
        ];

        let mut server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::RESP(RESP::Array(vec![
                RESP::BulkString(String::from("REPLCONF")),
                RESP::BulkString(String::from("ACK")),
                RESP::BulkString(String::from("0")),
            ])))
        )
    }

    #[tokio::test]
    // Test that the function command processes
    // a lowercase `REPLCONF getack *` request on a
    // replica and that it responds with `REPLCONF ACK 0`.
    async fn test_command_replconf_getack_on_replica_lowercase() {
        let cmd = vec![
            String::from("replconf"),
            String::from("getack"),
            String::from("*"),
        ];

        let mut server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::RESP(RESP::Array(vec![
                RESP::BulkString(String::from("REPLCONF")),
                RESP::BulkString(String::from("ACK")),
                RESP::BulkString(String::from("0")),
            ])))
        )
    }

    #[tokio::test]
    // Test that the function command returns the
    // correct error when a replica receives a
    // `REPLCONF` request with the wrong number of
    // arguments.
    async fn test_command_replconf_on_replica_wrong_length() {
        let cmd = vec![String::from("replconf"), String::from("GETACK")];

        let mut server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Error(ServerError::CommandSyntaxError(String::from(
                "replconf GETACK"
            )))
        )
    }

    #[tokio::test]
    // Test that the function command returns the
    // correct error when a replica receives a
    // `REPLCONF` request with an unknown subcommand.
    async fn test_command_replconf_on_replica_wrong_subcommand() {
        let cmd = vec![
            String::from("replconf"),
            String::from("something"),
            String::from("*"),
        ];

        let mut server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Error(ServerError::CommandSyntaxError(String::from(
                "replconf something *"
            )))
        )
    }

CodeCrafters

Replication Stage 14: ACKs with no commands

The code we wrote in this section passes Replication - Stage 14 of the CodeCrafters challenge.

Step 8.7 - ACKs with commands#

In this stage, we change the system so that replicas actually keep track of the commands the master sends them. There are three new concepts here:

  • Replicas silently process PING and SET that come from the master, without sending any response.
  • Replicas keep track of the commands sent by the master so that they can reply to REPLCONF GETACK * with the right amount of bytes to show that they are in sync.
  • Replicas should reject write commands like SET if they come from clients.

Step 8.7.1 - Make requests aware of master/replica#

The first change to the code is therefore a way to tell if the connection the server is handling comes from the master or from a normal client. We already have two different functions to manage those connections, run_listener and run_master_listener, and both spawn tasks that run handle_connection. It seems reasonable, then, to add a flag to the latter to signal which type of connection we are talking about.

src/connection.rs
// The main entry point for valid TCP connections.
pub async fn handle_connection(
    mut stream: TcpStream,
    server_sender: mpsc::Sender<ConnectionMessage>,
    master_connection: bool,
) {
    // Create a buffer to host incoming data.
    let mut buffer = [0; 512];

...

Let's propagate that change where we spawn the tasks.

src/connection.rs
pub async fn run_listener(host: String, port: u16, server_sender: mpsc::Sender<ConnectionMessage>) {

    loop {
    
            // Process a new connection.
            connection = listener.accept() => {
                match connection {
                    // The connection is valid, handle it.
                    Ok((stream, _)) => {
                        // Spawn a task to take care of this connection.
                        tokio::spawn(handle_connection(stream, server_sender.clone()));
                        // Mark the connection as coming from a client.
                        tokio::spawn(handle_connection(stream, server_sender.clone(), false));
                    }

pub async fn run_master_listener(

    // Run the handshake protocol
    if let Err(e) = handshake(&mut stream, info).await {
        eprintln!("Handshake failed: {}", e.to_string());
        std::process::exit(1);
    }

    // Spawn a task to take care of this connection.
    tokio::spawn(async move { handle_connection(stream, server_sender.clone()).await });
    // Mark the connection as coming from the master.
    tokio::spawn(async move { handle_connection(stream, server_sender.clone(), true).await });

Now, to propagate the knowledge we have when we call handle_connection we can add the flag to Request, which is the message we exchange with the server that runs command functions.

src/request.rs
#[derive(Debug)]
pub struct Request {
    pub value: RESP,
    pub sender: mpsc::Sender<ServerMessage>,
    pub binary: Vec<u8>,
    pub master_connection: bool,
}

And with this we can pass the argument directly when we create the request in handle_connection.

src/connection.rs
pub async fn handle_connection(

    loop {

                            // Process the bytes in the buffer according to
                            // the content and extract the request. Update the index.
                            let resp = match bytes_to_resp(&buffer[..size].to_vec(), &mut index) {
                                Ok(v) => v,
                                Err(e) => {
                                    eprintln!("Error: {}", e);
                                    return;
                                }
                            };

                            // Create the request.
                            let request = Request {
                                value: resp,
                                sender: connection_sender.clone(),
                                binary: buffer[start..index].to_vec(),
                                master_connection,
                            };

This unfortunately prevents several tests from compiling, as they create a request and therefore need to pass the flag. You can use cargo test to find which tests must be fixed.

Fortunately, since all of them have been written from the point of view of the master, we can pass false to each of them. The change is the same for all of them.

src/commands/echo.rs
mod tests {

    async fn test_command() {
    
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

Step 8.7.2 - Silently accept commands#

It's time to make sure the replica silently accepts PING coming from the master and doesn't reply with PONG.

src/commands/ping.rs
pub async fn command(_server: &Server, request: &Request, _command: &Vec<String>) {
    // When the master sends a PING to a replica
    // the replica must stay silent.
    if request.master_connection {
        return;
    }

    request
        .data(ServerValue::RESP(RESP::SimpleString("PONG".to_string())))
        .await;
}

mod tests {

    #[tokio::test]
    // Test that the function command stays silent
    // when the PING comes from the master connection.
    async fn test_command_ping_master_connection() {
        let cmd = vec![String::from("PING")];
        let server = Server::new("localhost".to_string(), 6379);
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: true,
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
        );
    }

And now we can make sure the replica doesn't accept SET from clients.

src/commands/set.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {
    // A replica only accepts SET from the master.
    // Any client that sends SET to a replica gets
    // a CommandNotAvailable error.
    if server.replication.is_replica() && !request.master_connection {
        request
            .error(ServerError::CommandNotAvailable(command[0].clone()))
            .await;
        return;
    }

    // Extract the storage from the server.
    let storage = match server.storage.as_mut() {
        Some(storage) => storage,
        None => {
            request.error(ServerError::StorageNotInitialised).await;
            return;
        }
    };

...

mod tests {

    use super::*;
    use crate::server_result::ServerMessage;
    use crate::storage::Storage;
    use crate::ReplicationConfig;
    use tokio::sync::mpsc;

...

    #[tokio::test]
    // Test that a replica rejects a SET command
    // received from a regular client connection.
    async fn test_command_replica_rejects_client_set() {
        let storage = Storage::new();
        let mut server: Server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));
        server.set_storage(storage);

        let cmd = vec![
            String::from("set"),
            String::from("key"),
            String::from("value"),
        ];

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
            master_connection: false,
        };

        command(&mut server, &request, &cmd).await;

        assert_eq!(
            request_channel_rx.try_recv().unwrap(),
            ServerMessage::Error(ServerError::CommandNotAvailable(String::from("set")))
        );
    }

However, replicas should also silently accept SET from master.

src/commands/set.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // Set the value of the key in the storage.
    if let Err(_) = storage.set(key, value, args) {
        request
            .error(ServerError::CommandInternalError(command.join(" ")))
            .await;
        return;
    };

    // When the master forwards a SET command to
    // a replica the replica must stay silent.
    if request.master_connection {
        return;
    }

    request
        .data(ServerValue::RESP(RESP::SimpleString(String::from("OK"))))
        .await;

...

mod tests {

    #[tokio::test]
    // Test that the function command stays silent
    // when a SET is propagated from the master
    // down the replication link.
    async fn test_command_master_connection_is_silent() {
        let storage = Storage::new();
        let mut server: Server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));
        server.set_storage(storage);

        let cmd = vec![
            String::from("set"),
            String::from("key"),
            String::from("value"),
        ];

        let (request_channel_tx, mut request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
            master_connection: true,
        };

        command(&mut server, &request, &cmd).await;

        assert_eq!(
            request_channel_rx.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
        );
    }

Step 8.7.3 - Increment the offset#

The very last change we need to add to the replica in this stage is offset management. When the replica receives a command from the master it has to increment the internal offset. This offset must then be sent as a response when the master issues a REPLCONF GETACK *, to ensure that the replica is in sync.

Keeping track of how many bytes have been processed by the server is extremely simple. As each request carries its binary representation, we can just add the length of that field to the internal counter repl_offset.

src/server.rs
pub async fn process_request(request: Request, server: &mut Server) {

    // Process the request using the requested command.
    match command_name.as_str() {
        "echo" => {
            echo::command(server, &request, &command).await;
        }

        ...

        _ => {
            request
                .error(ServerError::CommandNotAvailable(command[0].clone()))
                .await;
        }
    }

    // Track how many bytes have been processed
    // by the server. The master will use this
    // value to check if the replicas are in sync.
    server.replication.repl_offset += request.binary.len();
}

The replica must send the value of that counter when it responds to REPLCONF GETACK *.

src/commands/replconf.rs
pub async fn command(server: &Server, request: &Request, command: &Vec<String>) {

    // The only subcommand that REPLCONF supports
    // on a replica is GETACK. Reject anything else.
    if command[1].to_uppercase() != String::from("GETACK") {
        request
            .error(ServerError::CommandSyntaxError(command.join(" ")))
            .await;
        return;
    }

    // Create the response for the server.
    let resp = RESP::Array(vec![
        RESP::BulkString(String::from("REPLCONF")),
        RESP::BulkString(String::from("ACK")),
        RESP::BulkString(String::from("0")),
        RESP::BulkString(server.replication.repl_offset.to_string()),
    ]);

CodeCrafters

Replication Stage 15: ACKs with commands

The code we wrote in this section passes Replication - Stage 15 of the CodeCrafters challenge.

Step 8.8 - WAIT with no replicas#

In the next three stages, we are tasked to add a new command, WAIT. The first two stages are rather simple, as we just need to report back the number of registered replicas. The last stage, where the master has to check which replicas have successfully processed the latest write command, will be a bit more complicated.

For now, we need to make sure the server can receive a WAIT and replies with 0.

Since the number of replicas has to be a RESP integer, the first move is to implement the new data type.

src/resp.rs
#[derive(Debug, PartialEq)]
pub enum RESP {
    Array(Vec<RESP>),
    BulkString(String),
    Integer(i64),
    Null,
    RDBPrefix(usize),
    SimpleString(String),
}

impl fmt::Display for RESP {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let data = match self {
            Self::Array(data) => {
                let mut output = String::from("*");
                output.push_str(format!("{}\r\n", data.len()).as_str());

                for elem in data.iter() {
                    output.push_str(elem.to_string().as_str());
                }

                output
            }
            Self::BulkString(data) => format!("${}\r\n{}\r\n", data.len(), data),
            Self::Integer(data) => format!(":{}\r\n", data),
            Self::Null => String::from("$-1\r\n"),
            Self::RDBPrefix(data) => format!("${}\r\n", data.to_string()),
            Self::SimpleString(data) => format!("+{}\r\n", data),
        };

        write!(f, "{}", data)
    }
}

With that, we can easily write the new command.

src/commands/wait.rs
use crate::request::Request;
use crate::resp::RESP;
use crate::server::Server;
use crate::server_result::ServerValue;

pub async fn command(_server: &Server, request: &Request, _command: &Vec<String>) {
    // At this stage WAIT always reports zero replicas.
    request
        .data(ServerValue::RESP(RESP::Integer(0)))
        .await;
}

That can be paired with a simple test.

src/commands/wait.rs
#[cfg(test)]
mod tests {
    use super::*;
    use crate::server_result::ServerMessage;
    use tokio::sync::mpsc;

    #[tokio::test]
    // Test that the function command replies with
    // the integer 0 regardless of the arguments.
    async fn test_command_wait_no_replicas() {
        let cmd = vec![String::from("wait"), String::from("0"), String::from("100")];
        let server = Server::new("localhost".to_string(), 6379);
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::RESP(RESP::Integer(0)))
        )
    }
}

We need to declare the new module.

src/commands/mod.rs
pub mod echo;
pub mod get;
pub mod info;
pub mod ping;
pub mod psync;
pub mod replconf;
pub mod set;
pub mod wait;

And finally add the command to process_request.

src/server.rs
use crate::commands::{echo, get, info, ping, psync, replconf, set};
use crate::commands::{echo, get, info, ping, psync, replconf, set, wait};
use crate::connection::{
    stream_read_data_length, stream_read_line, stream_send_receive_resp, ConnectionMessage,
};
use crate::replication::ReplicationConfig;
use crate::request::Request;
use crate::resp::{bytes_to_resp, resp_extract_length, resp_remove_type};
use crate::server_result::{ServerError, ServerMessage, ServerResult, ServerValue};
use crate::storage::Storage;
use crate::RESP;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::{io::AsyncWriteExt, net::TcpStream};

pub async fn process_request(request: Request, server: &mut Server) {

    // Process the request using the requested command.
    match command_name.as_str() {
        "echo" => {
            echo::command(server, &request, &command).await;
        }
        "get" => {
            get::command(server, &request, &command).await;
        }
        "info" => {
            info::command(server, &request, &command).await;
        }
        "ping" => {
            ping::command(server, &request, &command).await;
        }
        "psync" => {
            psync::command(server, &request, &command).await;
        }
        "replconf" => {
            replconf::command(server, &request, &command).await;
        }
        "set" => {
            set::command(server, &request, &command).await;

            // Forward the request to all replicas.
            send_request_to_replicas(&request, server).await;
        }
        "wait" => {
            wait::command(server, &request, &command).await;
        }
        _ => {
            request
                .error(ServerError::CommandNotAvailable(command[0].clone()))
                .await;
        }
    }

CodeCrafters

Replication Stage 16: WAIT with no replicas

The code we wrote in this section passes Replication - Stage 16 of the CodeCrafters challenge.

Step 8.9 - WAIT with no commands#

In this stage, the master is connected to some replicas, but no commands have been issued, so there is no real waiting involved behind the scenes. So, for now the command will still ignore the arguments but will return the actual number of replicas that completed the handshake.

The change is almost trivial, as the command has already access to the server and thus to the number of replicas.

src/commands/wait.rs
pub async fn command(_server: &Server, request: &Request, _command: &Vec<String>) {
pub async fn command(server: &Server, request: &Request, _command: &Vec<String>) {
    // At this stage WAIT always reports zero replicas.
    request.data(ServerValue::RESP(RESP::Integer(0))).await;
    // Report how many replicas are currently connected.
    request
        .data(ServerValue::RESP(RESP::Integer(
            server.replica_senders.len() as i64,
        )))
        .await
}

It is worth changing the description of the initial test now that the function doesn't return a hard coded value any longer. Then, we can add a test for the new case where there are connected replicas.

src/commands/wait.rs
mod tests {

    #[tokio::test]
    // Test that the function command replies with
    // the integer 0 regardless of the arguments.
    // 0 when there are no replicas registered on
    // the server.
    async fn test_command_wait_no_replicas() {

...

    #[tokio::test]
    // Test that the function command replies with
    // the number of replicas currently registered
    // on the server.
    async fn test_command_wait_counts_replicas() {
        let cmd = vec![String::from("wait"), String::from("0"), String::from("100")];
        let mut server = Server::new("localhost".to_string(), 6379);

        // Register two fake replica channels. The
        // receivers stay alive for the duration of
        // the test so the senders remain valid.
        let (replica_a_tx, _replica_a_rx) = mpsc::channel::<ServerMessage>(32);
        let (replica_b_tx, _replica_b_rx) = mpsc::channel::<ServerMessage>(32);
        server.replica_senders.push(replica_a_tx);
        server.replica_senders.push(replica_b_tx);

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::RESP(RESP::Integer(2)))
        )
    }

CodeCrafters

Replication Stage 17: WAIT with no commands

The code we wrote in this section passes Replication - Stage 17 of the CodeCrafters challenge.

Step 8.10 - WAIT with multiple commands#

This last stage is more complicated than the previous ones, as it involves new channels and timeouts. Let's first break down how WAIT works from outside, and then we will look at the internal implementation.

When a master receives write commands like SET, it forwards them to replicas, but doesn't wait for an acknowledgement. Rather, the server counts how many bytes it has sent on the replication stream (repl_offset) and only verifies that replicas are in sync on demand, when a client issues WAIT.

The command WAIT is issued by a user to force the master to check if its replicas are in sync. The master sends to each replica a command REPLCONF GETACK * and waits for the replicas to reply. When all replicas have replied, the master replies to WAIT. This is a simplified view, as there is a timeout involved, but it will do for now.

Managing WAIT represents a problem for the master. It has to keep working to receive messages (the ACKs from replicas), but at the same time it has to stop replying to the client until all replicas have replied. This is not impossible to achieve with the current architecture, but it would push some of the logic behind WAIT into the main server loop, which couples the two components and results in a messy system.

A better solution is the following:

  • When the master receives WAIT, it asks all replicas to acknowledge and it creates a task whose job is to manage the acknowledgements. This means that the whole logic behind counting acknowledgements will be isolated in a different, independent space.
  • The master then continues to work as usual, but when it receives an acknowledgement from a replica it forwards it to the wait handler.
  • When the wait handler is happy with the number of received ACKs it can directly reply to the client that issues the WAIT.

This solution allows us to also easily implement the timeout mechanism of WAIT. The wait handler is responsible for receiving the ACKs and also to monitor the timeout. No burden on the server, no need to pollute the server's logic with something that is specific for a command.

Step 8.10.1 - Add the required components#

The first addition to the system is a way for the server to know if there are pending write operations. If there aren't, a WAIT can be replied directly without the need to actually ask the replicas for ACKs.

src/replication.rs
pub struct ReplicationConfig {
    pub master: Option<MasterConfig>,
    pub replid: String,
    pub repl_offset: usize,
    pub pending_write_operations: bool,
}

impl ReplicationConfig {

    pub fn new_master() -> Self {
        ReplicationConfig {
            master: None,
            replid: Alphanumeric.sample_string(&mut rng(), 40),
            repl_offset: 0,
            pending_write_operations: false,
        }
    }

    pub fn new_replica(master_host: String, master_port: u16) -> Self {
        ReplicationConfig {
            master: Some(MasterConfig {
                host: master_host,
                port: master_port,
            }),
            replid: Alphanumeric.sample_string(&mut rng(), 40),
            repl_offset: 0,
            pending_write_operations: false,
        }
    }

Then, the server must be given access to the wait handler, and this happens through a channel, just like we did for replicas. The wait handler, however, exists only while the command WAIT is being processed, so it's worth making it optional.

src/server.rs
pub struct Server {
    pub info: ServerInfo,
    pub storage: Option<Storage>,
    pub replication: ReplicationConfig,
    pub replica_senders: Vec<mpsc::Sender<ServerMessage>>,
    pub wait_handler_sender: Option<mpsc::Sender<ServerMessage>>,
}

impl Server {

    pub fn new(host: String, port: u16) -> Self {
        Self {
            info: ServerInfo {
                host: host,
                port: port,
            },
            storage: None,
            replication: ReplicationConfig::new_master(),
            replica_senders: Vec::new(),
            wait_handler_sender: None,
        }
    }

Last, we need to create the component that will manage the ACKs, the wait_handler. For now, the only thing that this component needs to do is to count responses. Keep in mind that the replica ACKs are still received by the server as usual (command REPLCONF), but will be forwarded to the wait handler.

src/server.rs
// Collect replica ACKs over an interval.
pub async fn wait_handler(mut wait_handler_receiver: mpsc::Receiver<ServerMessage>) {
    let mut acks: i64 = 0;

    loop {
        tokio::select! {
            Some(_) = wait_handler_receiver.recv() => {
                acks += 1;
            }
        }
    }
}

Step 8.10.2 - Record pending writes in SET#

The components we added in the previous step will be inert for a while, until we wire up the whole system. In this step, we need to add another important piece of the puzzle: mark on the server if at least one write operation happened.

At the moment, the only write operation that can be run on the server is SET, but a full-fledged Redis implementation provides a wide range of them. For the time being, we will toggle the flag when we process SET, but in the future it will be useful to implement a better mechanism to reduce code duplication.

src/commands/set.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // Set the value of the key in the storage.
    if let Err(_) = storage.set(key, value, args) {
        request
            .error(ServerError::CommandInternalError(command.join(" ")))
            .await;
        return;
    };

    // If the server is the master, toggle the
    // flag to signal that at least one write
    // operation has happened.
    if !server.replication.is_replica() {
        server.replication.pending_write_operations = true;
    }

    // When the master forwards a SET command to
    // a replica the replica must stay silent.
    if request.master_connection {
        return;
    }

We can add two tests for this logic.

src/commands/set.rs
mod tests {

    #[tokio::test]
    // Test that a successful SET on a master toggles
    // pending_write_operations.
    async fn test_command_master_sets_pending_write_operations() {
        let storage = Storage::new();
        let mut server: Server = Server::new("localhost".to_string(), 6379);
        server.set_storage(storage);

        assert_eq!(server.replication.pending_write_operations, false);

        let cmd = vec![
            String::from("set"),
            String::from("key"),
            String::from("value"),
        ];

        let (request_channel_tx, _request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
            master_connection: false,
        };

        command(&mut server, &request, &cmd).await;

        assert_eq!(server.replication.pending_write_operations, true);
    }

    #[tokio::test]
    // Test that a SET propagated from the master to
    // a replica does NOT toggle the replica's own
    // pending_write_operations flag.
    async fn test_command_replica_does_not_set_pending_write_operations() {
        let storage = Storage::new();
        let mut server: Server = Server::new("localhost".to_string(), 6379);
        server.set_replication(ReplicationConfig::new_replica(
            "otherhost".to_string(),
            1234,
        ));
        server.set_storage(storage);

        let cmd = vec![
            String::from("set"),
            String::from("key"),
            String::from("value"),
        ];

        let (request_channel_tx, _request_channel_rx) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: request_channel_tx.clone(),
            binary: Vec::new(),
            master_connection: true,
        };

        command(&mut server, &request, &cmd).await;

        assert_eq!(server.replication.pending_write_operations, false);
    }

Step 8.10.3 - Forward ACKs to the wait handler#

The next step is to change the behaviour of REPLCONF. Currently, the function always responds with OK, as so far the command was received only during the handshake. Now, however, the command can be sent to the master by a replica in response to a GETACK. In that case, we will not respond but simply "forward the message" to the wait handler. We don't actually send the same message to the handler, as there is no need. The handler just needs to receive any message on the wait_handler_receiver.

src/commands/replconf.rs
use crate::request::Request;
use crate::resp::RESP;
use crate::server::Server;
use crate::server_result::{ServerError, ServerValue};
use crate::server_result::{ServerError, ServerMessage, ServerValue};

pub async fn command(server: &Server, request: &Request, command: &Vec<String>) {

    // On a master, the REPLCONF is received only
    // from the handshake process. In that case we
    // always reply with OK.
    // On a master, the REPLCONF is received either
    // from the handshake process (in which case we
    // reply with OK) or as a `REPLCONF ACK <offset>`
    // sent by a replica in response to a GETACK we
    // previously issued.
    if !server.replication.is_replica() {
        // A replica is reporting its offset in
        // response to a GETACK. Send an empty
        // message to the wait handler and stay
        // silent on the replica's own channel.
        if command.len() >= 2 && command[1].to_uppercase() == String::from("ACK") {
            if let Some(tx) = &server.wait_handler_sender {
                // The wait handler might have already
                // dropped its receiver.
                // Ignore the send error.
                let _ = tx.send(ServerMessage::Data(ServerValue::None)).await;
            }
            return;
        }

        request
            .data(ServerValue::RESP(RESP::SimpleString("OK".to_string())))
            .await;
        return;
    }

We can add a couple of tests for the new code.

src/commands/replconf.rs
mod tests {

    #[tokio::test]
    // Test that a master receiving `REPLCONF ACK`
    // sends an empty message to the wait handler
    // and sends nothing back on the replica's own
    // channel.
    async fn test_command_replconf_ack_on_master_forwards_to_wait_handler() {
        let cmd = vec![
            String::from("replconf"),
            String::from("ACK"),
            String::from("42"),
        ];

        let mut server = Server::new("localhost".to_string(), 6379);

        // Install a wait handler sender so the
        // master has somewhere to forward the ACK.
        let (wait_tx, mut wait_rx) = mpsc::channel::<ServerMessage>(32);
        server.wait_handler_sender = Some(wait_tx);

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            wait_rx.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::None)
        );

        assert_eq!(
            connection_receiver.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
        );
    }

    #[tokio::test]
    // Test that a master receiving `REPLCONF ACK`
    // while no wait handler is listening simply
    // drops the ACK without replying.
    async fn test_command_replconf_ack_on_master_no_wait_handler() {
        let cmd = vec![
            String::from("replconf"),
            String::from("ACK"),
            String::from("42"),
        ];

        let server = Server::new("localhost".to_string(), 6379);
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&server, &request, &cmd).await;

        assert_eq!(
            connection_receiver.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
        );
    }

Step 8.10.4 - Add logic to the WAIT command#

It's time to start adding some logic to the function that implements WAIT.

The core of the current architecture is that the wait handler is responsible for replying to the client, so the server will just forward the ACKs and forget them. The handler, in turn, has two conditions to monitor:

  • The required number of replicas has acknowledged all previous write commands.
  • The timeout expires.

We will add the timeout in the next step, so for the time being the wait handler will collect ACKs and respond to the client when enough replicas are in sync.

Let's have a look at the new version of the command step by step. First of all, we need some imports and a change of the prototype. The command will reset the flag pending_write_operations, so the server needs to be mutable. The first piece of logic we can add to the command is a short-circuit: when there are no pending operations, it is pointless to ask replicas to acknowledge.

src/commands/wait.rs
use crate::request::Request;
use crate::resp::RESP;
use crate::server::Server;
use crate::server::{wait_handler, Server};
use crate::server_result::ServerValue;
use crate::server_result::{ServerError, ServerMessage, ServerValue};
use tokio::sync::mpsc;

pub async fn command(server: &Server, request: &Request, _command: &Vec<String>) {
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {
    // If no write has happened since the last WAIT
    // there is nothing for the replicas to
    // acknowledge. Report the replica count right
    // away and skip the handler machinery.
    if !server.replication.pending_write_operations {
        request
            .data(ServerValue::RESP(RESP::Integer(
                server.replica_senders.len() as i64,
            )))
            .await;
        return;
    }

We need to extract from the command the number of acknowledgements that the user wants to wait for (numreplicas) and then to reset the flag pending_write_operations.

src/commands/wait.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    if !server.replication.pending_write_operations {
        request
            .data(ServerValue::RESP(RESP::Integer(
                server.replica_senders.len() as i64,
            )))
            .await;
        return;
    }

    // The command comes in the form `WAIT numreplicas timeout`.
    let numreplicas: i64 = match command[1].parse() {
        Ok(v) => v,
        Err(_) => {
            request
                .error(ServerError::CommandSyntaxError(command.join(" ")))
                .await;
            return;
        }
    };

    // WAIT will process all pending operations.
    // Reset the flag.
    server.replication.pending_write_operations = false;

We need to ask all connected replicas to send an acknowledgement. This can be done with a simple loop

src/commands/wait.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // WAIT will process all pending operations.
    // Reset the flag.
    server.replication.pending_write_operations = false;
    
    // Prepare the message for all replicas.
    let getack = RESP::Array(vec![
        RESP::BulkString(String::from("REPLCONF")),
        RESP::BulkString(String::from("GETACK")),
        RESP::BulkString(String::from("*")),
    ]);

    // Ask every replica to send back an ACK with
    // its current offset.
    for replica in server.replica_senders.iter() {
        let _ = replica
            .send(ServerMessage::Data(ServerValue::RESP(getack.clone())))
            .await;
    }

This piece of code requires a small change to the type RESP to provide the clone method.

src/resp.rs
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub enum RESP {
    Array(Vec<RESP>),
    BulkString(String),
    Integer(i64),
    Null,
    RDBPrefix(usize),
    SimpleString(String),
}

Last, the core of the function. We create a channel to communicate with the handler and we add the sender to the server, so that it can forward ACKs. We then spawn the wait handler itself giving it the parameters it needs to work: the number of replicas to perform the check, the wait handler receiver to receive the ACKs, and the request sender to be able to respond to the client.

Cloning messages

Adding Clone to RESP is a design choice, not a necessity. Building the message inside the loop would cost exactly the same: in both cases every iteration allocates three String and one Vec, because String::clone and Vec::clone both perform a deep copy. At the scale we care about (a handful of replicas) the cost is negligible either way, and Clone pays for itself by making the code more readable. If we were ever optimising for a large number of replicas, the fix would not be to avoid cloning but to change the representation: serialise the message once into a shared byte buffer (Arc<[u8]> or bytes::Bytes) and hand a cheap reference-counted clone to each replica.

src/commands/wait.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // Ask every replica to send back an ACK with
    // its current offset.
    for replica in server.replica_senders.iter() {
        let _ = replica
            .send(ServerMessage::Data(ServerValue::RESP(getack.clone())))
            .await;
    }

    // Create a channel that will receive replica ACKs.
    // The wait handler will listen there.
    let (wait_handler_sender, wait_handler_receiver) = mpsc::channel::<ServerMessage>(32);

    // Add the wait handler sender to the server.
    server.wait_handler_sender = Some(wait_handler_sender);

    // Spawn the wait handler. It owns the receiver
    // and a clone of the client's reply channel, so
    // it can answer the client directly.
    tokio::spawn(wait_handler(
        numreplicas,
        wait_handler_receiver,
        request.sender.clone(),
    ));
}

We need to change existing tests to match the mutability of the server.

src/commands/wait.rs
mod tests {

    async fn test_command_wait_no_replicas() {

        let server = Server::new("localhost".to_string(), 6379);
        let mut server = Server::new("localhost".to_string(), 6379);
        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&server, &request, &cmd).await;
        command(&mut server, &request, &cmd).await;


     async fn test_command_wait_counts_replicas() {

        command(&server, &request, &cmd).await;
        command(&mut server, &request, &cmd).await;

And we can add new ones to validate the logic.

src/commands/wait.rs
mod tests {

    #[tokio::test]
    // Test that when pending_write_operations is
    // set, WAIT sends a `REPLCONF GETACK *` request
    // to every registered replica, consumes the
    // flag and installs wait_handler_sender. The
    // handler runs in the background but has no
    // ACKs to report yet, so nothing is sent back
    // to the client during the command execution.
    async fn test_command_wait_pending_writes() {
        let cmd = vec![String::from("wait"), String::from("2"), String::from("100")];
        let mut server = Server::new("localhost".to_string(), 6379);

        // Register two fake replica channels.
        let (replica_a_tx, mut replica_a_rx) = mpsc::channel::<ServerMessage>(32);
        let (replica_b_tx, mut replica_b_rx) = mpsc::channel::<ServerMessage>(32);
        server.replica_senders.push(replica_a_tx);
        server.replica_senders.push(replica_b_tx);

        // Simulate a write having happened since
        // the last WAIT.
        server.replication.pending_write_operations = true;

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&mut server, &request, &cmd).await;

        // Each replica has received a GETACK.
        let getack = ServerMessage::Data(ServerValue::RESP(RESP::Array(vec![
            RESP::BulkString(String::from("REPLCONF")),
            RESP::BulkString(String::from("GETACK")),
            RESP::BulkString(String::from("*")),
        ])));
        assert_eq!(replica_a_rx.try_recv().unwrap(), getack);
        assert_eq!(replica_b_rx.try_recv().unwrap(), getack);

        // The flag has been consumed.
        assert_eq!(server.replication.pending_write_operations, false);

        // A wait handler sender has been installed.
        assert!(server.wait_handler_sender.is_some());

        // The handler is still waiting for ACKs:
        // nothing has been forwarded to the client
        // yet.
        assert_eq!(
            connection_receiver.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
        );
    }

    #[tokio::test]
    // Test that the wait handler replies to the
    // client with the number of ACKs received once
    // the requested threshold is reached.
    async fn test_command_wait_handler_replies_on_threshold() {
        let cmd = vec![String::from("wait"), String::from("2"), String::from("100")];
        let mut server = Server::new("localhost".to_string(), 6379);

        // Register two fake replica channels.
        let (replica_a_tx, _replica_a_rx) = mpsc::channel::<ServerMessage>(32);
        let (replica_b_tx, _replica_b_rx) = mpsc::channel::<ServerMessage>(32);
        server.replica_senders.push(replica_a_tx);
        server.replica_senders.push(replica_b_tx);

        server.replication.pending_write_operations = true;

        let (connection_sender, mut connection_receiver) = mpsc::channel::<ServerMessage>(32);

        let request = Request {
            value: RESP::Null,
            sender: connection_sender,
            binary: Vec::new(),
            master_connection: false,
        };

        command(&mut server, &request, &cmd).await;

        // Feed two ACKs through the installed
        // wait_handler_sender, simulating what
        // replconf::command would do when two real
        // replicas reply.
        let wait_sender = server.wait_handler_sender.clone().unwrap();
        wait_sender
            .send(ServerMessage::Data(ServerValue::None))
            .await
            .unwrap();
        wait_sender
            .send(ServerMessage::Data(ServerValue::None))
            .await
            .unwrap();

        // The handler should reply to the client
        // with the collected count.
        let response = connection_receiver.recv().await.unwrap();
        assert_eq!(
            response,
            ServerMessage::Data(ServerValue::RESP(RESP::Integer(2)))
        );
    }

In the code above, we changed the way we call the wait handler function, but we haven't changed the function prototype and its body yet.

src/server.rs
// Collect replica ACKs over an interval.
// Collect replica ACKs and reply to the client
// as soon as the requested number of replicas
// has acknowledged.
pub async fn wait_handler(mut wait_handler_receiver: mpsc::Receiver<ServerMessage>) {
pub async fn wait_handler(
    numreplicas: i64,
    mut wait_handler_receiver: mpsc::Receiver<ServerMessage>,
    client_sender: mpsc::Sender<ServerMessage>,
) {
    let mut acks: i64 = 0;

    loop {
        // When the requested number of replicas
        // is reached, reply to the client on
        // the provided sender.
        if acks >= numreplicas {
            let _ = client_sender
                .send(ServerMessage::Data(ServerValue::RESP(RESP::Integer(acks))))
                .await;
            break;
        }

        tokio::select! {
            Some(_) = wait_handler_receiver.recv() => {
                acks += 1;
            }
        }
    }
}

Step 8.10.5 - Add WAIT timeout#

This is the last step for this stage. We need to add to the wait handler a timeout, so that it replies to the client and terminates even if the requested number of replicas haven't replied yet.

Adding a timeout is not a big deal once a function is running asynchronously. Tokio provides a way to set up a timer and to monitor it with tokio::select!, just like any other async function.

src/server.rs
// Collect replica ACKs and reply to the client
// as soon as the requested number of replicas
// has acknowledged.
// has acknowledged or the timeout has expired.
pub async fn wait_handler(
    timeout: u64,
    numreplicas: i64,
    mut wait_handler_receiver: mpsc::Receiver<ServerMessage>,
    client_sender: mpsc::Sender<ServerMessage>,
) {
    let mut acks: i64 = 0;

    // `tokio::time::interval` returns an `Interval`, and
    // the first call to `.tick().await` returns immediately.
    // So, we need to call it once to actually wait that
    // amount of time.
    let mut interval_timer = tokio::time::interval(Duration::from_millis(timeout));
    interval_timer.tick().await;

    loop {
        // When the requested number of replicas
        // is reached, reply to the client on
        // the provided sender.
        if acks >= numreplicas {
            let _ = client_sender
                .send(ServerMessage::Data(ServerValue::RESP(RESP::Integer(acks))))
                .await;
            break;
        }

        tokio::select! {
            Some(_) = wait_handler_receiver.recv() => {
                acks += 1;
            }

            // If the timeout has expired, reply with
            // whatever count we have collected so far.
            _ = interval_timer.tick() => {
                let _ = client_sender
                    .send(ServerMessage::Data(ServerValue::RESP(RESP::Integer(acks))))
                    .await;
                break;
            }
        }
    }
}

Please note that the Interval needs to be awaited once. The rationale for this is that the common use case for interval is periodic work, where you run something every N milliseconds, starting now. What we want here, instead, is to first sleep for that amount of time, and then to trigger the branch in select!.

The changes to the command WAIT are minimal. We need to extract the timeout from the command and to pass it to wait_handler.

src/commands/wait.rs
pub async fn command(server: &mut Server, request: &Request, command: &Vec<String>) {

    // The command comes in the form `WAIT numreplicas timeout`.
    let numreplicas: i64 = match command[1].parse() {
        Ok(v) => v,
        Err(_) => {
            request
                .error(ServerError::CommandSyntaxError(command.join(" ")))
                .await;
            return;
        }
    };

    // The command comes in the form `WAIT numreplicas timeout`.
    let timeout: u64 = match command[2].parse() {
        Ok(v) => v,
        Err(_) => {
            request
                .error(ServerError::CommandSyntaxError(command.join(" ")))
                .await;
            return;
        }
    };

    ...

    // Spawn the wait handler. It owns the receiver
    // and a clone of the client's reply channel, so
    // it can answer the client directly.
    tokio::spawn(wait_handler(
        timeout,
        numreplicas,
        wait_handler_receiver,
        request.sender.clone(),
    ));

We will not add a test to check that the timer works. Unit tests that depend on real elapsed time are notoriously unreliable. A busy CI machine can delay a spawned task by tens of milliseconds, which is enough to turn a test that expects to see a timeout fire into one that sees the opposite. The safer option for genuinely time-dependent behaviour is to isolate it into integration tests, or to abstract the clock so the test controls virtual time instead of wall-clock time.

In this case we could register one replica, ask WAIT to wait for two ACKs, and feed only one. The handler would have no way to reach the numreplicas threshold, so it could only exit through the timer arm. However, depending on actual time is always dangerous and this would set a precedent. Therefore, we will rely only on the CodeCrafters integration tests to check that the timer logic actually works.

CodeCrafters

Replication Stage 18: WAIT with multiple commands

The code we wrote in this section passes Replication - Stage 18 of the CodeCrafters challenge.

Step 8.11 - Final refactoring#

In this final step, we will go through some refactoring. The changes presented here are mostly cosmetic and help to make the code more readable, and are not done with a specific architectural change in mind as we did when we introduced actors.

Looking at the current state of the REPLCONF code, we can see that the function command contains two completely different behaviours, depending on the nature of the server, master or replica. To simplify it, we can consider splitting the code into two separate functions and use command as a simple router. The following pseudo-code illustrates the idea.

src/commands/replconf.rs
pub async fn command(...) {
    let result = if server.replication.is_replica() {
        command_replica(...)
    } else {
        command_master(...).await
    };

    request.result(result).await;
}

This highlights a missing piece of the puzzle. Currently, we call either request.data or request.error explicitly, depending on the nature of what we are handling (a successful result or an error). The two methods of Request, however, are calling the same function self.sender.send behind the scenes, adding only the explicit transformation of ServerValue into a ServerMessage::Data and ServerError into ServerMessage::Error.

src/request.rs
impl Request {
    pub async fn error(&self, e: ServerError) {
        self.sender.send(ServerMessage::Error(e)).await.unwrap();
    }

    pub async fn data(&self, d: ServerValue) {
        self.sender.send(ServerMessage::Data(d)).await.unwrap();
    }
}

Fortunately, this can be done automatically by Rust through the powerful trait From. We will therefore first implement an automated way to convert ServerMessage into ServerResult, and at that point write the generic sending function Request::result. With that, we will be able to split the logic of replconf::command.

Step 8.11.1 - Convert types automatically#

We already implemented From several times in our project: StorageData can be converted automatically from String, and RESPError can be created automatically from the errors num::ParseIntError and FromUtf8Error. This is the first time we implement it between two types we created, but the idea is the same.

src/server_result.rs
impl From<ServerResult> for ServerMessage {
    fn from(result: ServerResult) -> Self {
        match result {
            Ok(v) => ServerMessage::Data(v),
            Err(e) => ServerMessage::Error(e),
        }
    }
}

At this point we can write Request::result.

src/request.rs
use crate::{
    resp::RESP,
    server_result::{ServerError, ServerMessage, ServerValue},
    server_result::{ServerError, ServerMessage, ServerResult, ServerValue},
};
use tokio::sync::mpsc;

...

impl Request {
    pub async fn error(&self, e: ServerError) {
        self.sender.send(ServerMessage::Error(e)).await.unwrap();
    }

    pub async fn data(&self, d: ServerValue) {
        self.sender.send(ServerMessage::Data(d)).await.unwrap();
    }

    pub async fn result(&self, r: ServerResult) {
        self.sender.send(ServerMessage::from(r)).await.unwrap();
    }
}

Explicit or implicit conversion

The code of the function Request::result calls ServerMessage::from explicitly. The same effect can be obtained delegating the selection of the function to the compiler using into.

src/request.rs
    pub async fn result(&self, r: ServerResult) {
        self.sender.send(r.into()).await.unwrap();
    }

The choice between the two is a matter of preference, as the code produced by the compiler is exactly the same. The version chosen here highlights the conversion and has been chosen for its readability.

Step 8.11.2 - Split the logic#

We can at this point change replconf::command. We are performing two changes here: use Request::result and split the function in two. First, let's replace .data and .error with .result.

src/commands/replconf.rs
pub async fn command(server: &Server, request: &Request, command: &Vec<String>) {

    if !server.replication.is_replica() {

        ...

        request
            .data(ServerValue::RESP(RESP::SimpleString("OK".to_string())))
            .result(Ok(ServerValue::RESP(RESP::SimpleString("OK".to_string()))))
            .await;
        return;
    }

    // On a replica, the master sends `REPLCONF GETACK *`
    // and expects a `REPLCONF ACK <offset>` reply.
    if command.len() != 3 {
        request
            .error(ServerError::CommandSyntaxError(command.join(" ")))
            .result(Err(ServerError::CommandSyntaxError(command.join(" "))))
            .await;
        return;
    }

    // The only subcommand that REPLCONF supports
    // on a replica is GETACK. Reject anything else.
    if command[1].to_uppercase() != String::from("GETACK") {
        request
            .error(ServerError::CommandSyntaxError(command.join(" ")))
            .result(Err(ServerError::CommandSyntaxError(command.join(" "))))
            .await;
        return;
    }

    // Create the response for the server.
    let resp = RESP::Array(vec![
        RESP::BulkString(String::from("REPLCONF")),
        RESP::BulkString(String::from("ACK")),
        RESP::BulkString(server.replication.repl_offset.to_string()),
    ]);

    request.data(ServerValue::RESP(resp)).await;
    request.result(Ok(ServerValue::RESP(resp))).await;
}

As it is, this seems to introduce more complexity rather than make things more readable. Using request.result we are forced to introduce Ok and Err to wrap the actual results. Let's have a look at the next stage of the change to see how this can help us.

src/commands/replconf.rs
pub async fn command(server: &Server, request: &Request, command: &Vec<String>) {
    if !server.replication.is_replica() { 1
        let result = command_master(server.wait_handler_sender.clone(), command).await;
        request.result(result).await; 2
    }

...

async fn command_master(
    wait_handler_sender: Option<mpsc::Sender<ServerMessage>>,
    command: &Vec<String>,
) -> ServerResult {
    // A replica is reporting its offset in
    // response to a GETACK. Send an empty
    // message to the wait handler and stay
    // silent on the replica's own channel.
    if command.len() >= 2 && command[1].to_uppercase() == String::from("ACK") { 3
        if let Some(tx) = wait_handler_sender {
            // The wait handler might have already
            // dropped its receiver.
            // Ignore the send error.
            let _ = tx.send(ServerMessage::Data(ServerValue::None)).await;
        }

        return Ok(ServerValue::None); 4
    }

    Ok(ServerValue::RESP(RESP::SimpleString("OK".to_string()))) 5
}

fn command_replica(server: &Server, command: &Vec<String>) -> ServerResult {...}

As you can see, we can now isolate the logic that returns Ok and Err from the logic that sends a response to the request. The code in the main function is responsible for choosing the path 1 and dispatching results 2, while the code in command_master and command_replica is responsible for the actual logic 3 and for generating results 4 5.

The full set of changes is the following. The full body of the function command has been split into command_master and command_replica, so the diff has been simplified for the sake of readability.

src/commands/replconf.rs
pub async fn command(server: &Server, request: &Request, command: &Vec<String>) {
    // CODE MOVED TO THE FUNCTIONS BELOW
    ...
    let result = if server.replication.is_replica() {
        command_replica(server, command)
    } else {
        command_master(server.wait_handler_sender.clone(), command).await
    };

    request.result(result).await;
}

// On a master, the REPLCONF is received either
// from the handshake process (in which case we
// reply with OK) or as a `REPLCONF ACK <offset>`
// sent by a replica in response to a GETACK we
// previously issued, which shouldn't produce
// any reply.
async fn command_master(
    wait_handler_sender: Option<mpsc::Sender<ServerMessage>>,
    command: &Vec<String>,
) -> ServerResult {
    // A replica is reporting its offset in
    // response to a GETACK. Send an empty
    // message to the wait handler and stay
    // silent on the replica's own channel.
    if command.len() >= 2 && command[1].to_uppercase() == String::from("ACK") {
        if let Some(tx) = wait_handler_sender {
            // The wait handler might have already
            // dropped its receiver.
            // Ignore the send error.
            let _ = tx.send(ServerMessage::Data(ServerValue::None)).await;
        }

        return Ok(ServerValue::None);
    }

    Ok(ServerValue::RESP(RESP::SimpleString("OK".to_string())))
}

// On a replica, the master sends `REPLCONF GETACK *`
// and expects a `REPLCONF ACK <offset>` reply
// reporting how many bytes of the replication stream
// the replica processed.
fn command_replica(server: &Server, command: &Vec<String>) -> ServerResult {
    // On a replica, the master sends `REPLCONF GETACK *`
    // and expects a `REPLCONF ACK <offset>` reply.
    if command.len() != 3 {
        return Err(ServerError::CommandSyntaxError(command.join(" ")));
    }

    // The only subcommand that REPLCONF supports
    // on a replica is GETACK. Reject anything else.
    if command[1].to_uppercase() != String::from("GETACK") {
        return Err(ServerError::CommandSyntaxError(command.join(" ")));
    }

    // Create the response for the server.
    let resp = RESP::Array(vec![
        RESP::BulkString(String::from("REPLCONF")),
        RESP::BulkString(String::from("ACK")),
        RESP::BulkString(server.replication.repl_offset.to_string()),
    ]);

    Ok(ServerValue::RESP(resp))
}

We also need to change two tests, as now we rely on ServerValue::None for empty results.

src/commands/replconf.rs
mod tests {

    async fn test_command_replconf_ack_on_master_forwards_to_wait_handler() {

        assert_eq!(
            wait_rx.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::None)
        );

        // The dispatcher sends ServerValue::None
        // back through request.result(), but the
        // connection handler treats None as a no-op
        // so the client sees nothing.
        assert_eq!(
            connection_receiver.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::None)
        );


    async fn test_command_replconf_ack_on_master_no_wait_handler() {

        command(&server, &request, &cmd).await;

        // command_master returns Ok(ServerValue::None),
        // which the dispatcher sends through
        // request.result(). The connection handler
        // treats None as a no-op.
        assert_eq!(
            connection_receiver.try_recv().unwrap_err(),
            mpsc::error::TryRecvError::Empty
            connection_receiver.try_recv().unwrap(),
            ServerMessage::Data(ServerValue::None)
        );

CodeCrafters

Replication Stage 18: WAIT with multiple commands

The code we wrote in this section was a refactoring, so it still passes Replication - Stage 18 of the CodeCrafters challenge and concludes the replication extension.