Commit d38723f4 authored by Davis King's avatar Davis King

Fixed a race condition in the BSP code and also simplified the

logic somewhat.
parent cf643ee9
......@@ -97,33 +97,36 @@ namespace dlib
namespace impl2
{
// These control bytes are sent before each message nodes send to each other.
// These control bytes are sent before each message between nodes. Note that many
// of these are only sent between the control node (node 0) and the other nodes.
// This is because the controller node is responsible for handling the
// synchronization that needs to happen when all nodes block on calls to
// receive_data()
// at the same time.
// denotes a normal content message.
const static char MESSAGE_HEADER = 0;
// sent back to sender, means message was returned by receive().
// sent to the controller node when someone receives a message via receive_data().
const static char GOT_MESSAGE = 1;
// broadcast when a node goes into a state where it has no outstanding sent
// messages (i.e. it received GOT_MESSAGE for all its sent messages) and is waiting
// on receive().
const static char IN_WAITING_STATE = 2;
// sent to the controller node when someone sends a message via send().
const static char SENT_MESSAGE = 2;
// broadcast when no longer in IN_WAITING_STATE state.
const static char NOT_IN_WAITING_STATE = 3;
// sent to the controller node when someone enters a call to receive_data()
const static char IN_WAITING_STATE = 3;
// broadcast when a node terminates itself.
const static char NODE_TERMINATE = 4;
const static char NODE_TERMINATE = 5;
// broadcast when a node finds out that all non-terminated nodes are in the
// IN_WAITING_STATE state. sending this message puts a node into the
// SEE_ALL_IN_WAITING_STATE where it will wait until it gets this message from all
// others and then return from receive() once this happens.
const static char SEE_ALL_IN_WAITING_STATE = 5;
// broadcast by the controller node when it determines that all nodes are blocked
// on calls to receive_data() and there aren't any messages in flight. This is also
// what makes us go to the next epoch.
const static char SEE_ALL_IN_WAITING_STATE = 6;
const static char READ_ERROR = 6;
// This isn't ever transmitted between nodes. It is used internally to indicate
// that an error occurred.
const static char READ_ERROR = 7;
// ------------------------------------------------------------------------------------
......@@ -131,7 +134,7 @@ namespace dlib
impl1::bsp_con* con,
unsigned long node_id,
unsigned long sender_id,
impl1::thread_safe_deque& msg_buffer
impl1::thread_safe_message_queue& msg_buffer
)
{
try
......@@ -145,6 +148,7 @@ namespace dlib
if (msg.msg_type == MESSAGE_HEADER)
{
msg.data.reset(new std::string);
deserialize(msg.epoch, con->stream);
deserialize(*msg.data, con->stream);
}
......@@ -203,12 +207,15 @@ namespace dlib
close_all_connections_gracefully(
)
{
_cons.reset();
while (_cons.move_next())
if (node_id() != 0)
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
}
impl1::msg_data msg;
......@@ -219,20 +226,59 @@ namespace dlib
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
if (msg.msg_type == impl2::NODE_TERMINATE)
{
++num_terminated_nodes;
_cons[msg.sender_id]->terminated = true;
}
else if (msg.msg_type == impl2::READ_ERROR)
{
throw dlib::socket_error(*msg.data);
}
else if (msg.msg_type == impl2::MESSAGE_HEADER)
{
throw dlib::socket_error("A BSP node received a message after it has terminated.");
}
else if (msg.msg_type == impl2::GOT_MESSAGE)
{
--num_waiting_nodes;
--outstanding_messages;
}
else if (msg.msg_type == impl2::SENT_MESSAGE)
{
++outstanding_messages;
}
else if (msg.msg_type == impl2::IN_WAITING_STATE)
{
++num_waiting_nodes;
}
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
++current_epoch;
}
}
if (outstanding_messages != 0)
if (node_id() == 0)
{
std::ostringstream sout;
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
sout << "have a corresponding call to receive().";
throw dlib::socket_error(sout.str());
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
if (outstanding_messages != 0)
{
std::ostringstream sout;
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
sout << "have a corresponding call to receive().";
throw dlib::socket_error(sout.str());
}
}
}
......@@ -263,6 +309,7 @@ namespace dlib
outstanding_messages(0),
num_waiting_nodes(0),
num_terminated_nodes(0),
current_epoch(1),
_cons(cons_),
_node_id(node_id_)
{
......@@ -288,95 +335,73 @@ namespace dlib
unsigned long& sending_node_id
)
{
if (outstanding_messages == 0)
broadcast_byte(impl2::IN_WAITING_STATE);
unsigned long num_in_see_all_in_waiting_state = 0;
bool sent_see_all_in_waiting_state = false;
std::stack<impl1::msg_data> buf;
notify_control_node(impl2::IN_WAITING_STATE);
while (true)
{
// if there aren't any nodes left to give us messages then return right now.
if (num_terminated_nodes == _cons.size())
// If there aren't any nodes left to give us messages then return right now.
// We need to check the msg_buffer size to make sure there aren't any
// unprocessed message there. Recall that this can happen because status
// messages always jump to the front of the message buffer. So we might have
// learned about the node terminations before processing their messages for us.
if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
{
return false;
}
// if all running nodes are currently blocking forever on receive_data()
if (outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
{
num_waiting_nodes = 0;
sent_see_all_in_waiting_state = true;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
// Note that the reason we have this epoch counter is so we can tell if a
// sent message is from before or after one of these "all nodes waiting"
// synchronization events. If we didn't have the epoch count we would have
// a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
// message before others and then sends out a message to another node
// before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
// node would think the normal message came before SEE_ALL_IN_WAITING_STATE
// which would be bad.
++current_epoch;
return false;
}
impl1::msg_data data;
if (!msg_buffer.pop(data))
if (!msg_buffer.pop(data, current_epoch))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
if (sent_see_all_in_waiting_state)
{
// Once we have gotten one SEE_ALL_IN_WAITING_STATE, all we care about is
// getting the rest of them. So the effect of this code is to always move
// any SEE_ALL_IN_WAITING_STATE messages to the front of the message queue.
if (data.msg_type != impl2::SEE_ALL_IN_WAITING_STATE)
{
buf.push(data);
continue;
}
}
switch(data.msg_type)
{
case impl2::MESSAGE_HEADER: {
item = data.data;
sending_node_id = data.sender_id;
// if we would have send the IN_WAITING_STATE message before getting to
// this point then let other nodes know that we aren't waiting anymore.
if (outstanding_messages == 0)
broadcast_byte(impl2::NOT_IN_WAITING_STATE);
send_byte(impl2::GOT_MESSAGE, data.sender_id);
notify_control_node(impl2::GOT_MESSAGE);
return true;
} break;
case impl2::IN_WAITING_STATE: {
++num_waiting_nodes;
} break;
case impl2::NOT_IN_WAITING_STATE: {
case impl2::GOT_MESSAGE: {
--outstanding_messages;
--num_waiting_nodes;
} break;
case impl2::GOT_MESSAGE: {
--outstanding_messages;
if (outstanding_messages == 0)
broadcast_byte(impl2::IN_WAITING_STATE);
case impl2::SENT_MESSAGE: {
++outstanding_messages;
} break;
case impl2::NODE_TERMINATE: {
++num_terminated_nodes;
_cons[data.sender_id]->terminated = true;
if (num_terminated_nodes == _cons.size())
{
return false;
}
} break;
case impl2::SEE_ALL_IN_WAITING_STATE: {
++num_in_see_all_in_waiting_state;
if (num_in_see_all_in_waiting_state + num_terminated_nodes == _cons.size())
{
// put stuff from buf back into msg_buffer
while (buf.size() != 0)
{
msg_buffer.push_front(buf.top());
buf.pop();
}
return false;
}
++current_epoch;
return false;
} break;
case impl2::READ_ERROR: {
......@@ -393,13 +418,36 @@ namespace dlib
// ----------------------------------------------------------------------------------------
void bsp_context::
send_byte (
char val,
unsigned long target_node_id
notify_control_node (
char val
)
{
serialize(val, _cons[target_node_id]->stream);
_cons[target_node_id]->stream.flush();
if (node_id() == 0)
{
using namespace impl2;
switch(val)
{
case SENT_MESSAGE: {
++outstanding_messages;
} break;
case GOT_MESSAGE: {
--outstanding_messages;
} break;
case IN_WAITING_STATE: {
// nothing to do in this case
} break;
default:
DLIB_CASSERT(false,"This should never happen");
}
}
else
{
serialize(val, _cons[0]->stream);
_cons[0]->stream.flush();
}
}
// ----------------------------------------------------------------------------------------
......@@ -415,7 +463,8 @@ namespace dlib
if (i == node_id() || _cons[i]->terminated)
continue;
send_byte(val,i);
serialize(val, _cons[i]->stream);
_cons[i]->stream.flush();
}
}
......@@ -432,10 +481,11 @@ namespace dlib
throw socket_error("Attempt to send a message to a node that has terminated.");
serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
serialize(current_epoch, _cons[target_node_id]->stream);
serialize(item, _cons[target_node_id]->stream);
_cons[target_node_id]->stream.flush();
++outstanding_messages;
notify_control_node(SENT_MESSAGE);
}
// ----------------------------------------------------------------------------------------
......
......@@ -12,7 +12,7 @@
#include "../serialize.h"
#include "../map.h"
#include "../ref.h"
#include <deque>
#include <queue>
#include <vector>
namespace dlib
......@@ -210,15 +210,64 @@ namespace dlib
shared_ptr<std::string> data;
unsigned long sender_id;
char msg_type;
dlib::uint64 epoch;
msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {}
};
// ------------------------------------------------------------------------------------
class thread_safe_deque
class thread_safe_message_queue : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a simple message queue for msg_data objects. Note that it
has the special property that, while messages will generally leave
the queue in the order they are inserted, any message with a smaller
epoch value will always be popped out first. But for all messages
with equal epoch values the queue functions as a normal FIFO queue.
!*/
private:
struct msg_wrap
{
msg_wrap(
const msg_data& data_,
const dlib::uint64& sequence_number_
) : data(data_), sequence_number(sequence_number_) {}
msg_wrap() : sequence_number(0){}
msg_data data;
dlib::uint64 sequence_number;
// Make it so that when msg_wrap objects are in a std::priority_queue,
// messages with a smaller epoch number always come first. Then, within an
// epoch, messages are ordered by their sequence number (so smaller first
// there as well).
bool operator<(const msg_wrap& item) const
{
if (data.epoch < item.data.epoch)
{
return false;
}
else if (data.epoch > item.data.epoch)
{
return true;
}
else
{
if (sequence_number < item.sequence_number)
return false;
else
return true;
}
}
};
public:
thread_safe_deque() : sig(class_mutex),disabled(false) {}
thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {}
~thread_safe_deque()
~thread_safe_message_queue()
{
disable();
}
......@@ -230,19 +279,16 @@ namespace dlib
sig.broadcast();
}
unsigned long size() const { return data.size(); }
void push_front( const msg_data& item)
{
unsigned long size() const
{
auto_mutex lock(class_mutex);
data.push_front(item);
sig.signal();
return data.size();
}
void push_and_consume( msg_data& item)
{
auto_mutex lock(class_mutex);
data.push_back(item);
data.push(msg_wrap(item, next_seq_num++));
// do this here so that we don't have to worry about different threads touching the shared_ptr.
item.data.reset();
sig.signal();
......@@ -266,17 +312,43 @@ namespace dlib
if (disabled)
return false;
item = data.front();
data.pop_front();
item = data.top().data;
data.pop();
return true;
}
bool pop (
msg_data& item,
const dlib::uint64& max_epoch
)
/*!
ensures
- if (this function returns true) then
- #item == the next thing from the queue that has an epoch <= max_epoch
- else
- this object is disabled
!*/
{
auto_mutex lock(class_mutex);
while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled)
sig.wait();
if (disabled)
return false;
item = data.top().data;
data.pop();
return true;
}
private:
std::deque<msg_data> data;
std::priority_queue<msg_wrap> data;
dlib::mutex class_mutex;
dlib::signaler sig;
bool disabled;
dlib::uint64 next_seq_num;
};
......@@ -396,9 +468,8 @@ namespace dlib
);
void send_byte (
char val,
unsigned long target_node_id
void notify_control_node (
char val
);
void broadcast_byte (
......@@ -423,8 +494,9 @@ namespace dlib
unsigned long outstanding_messages;
unsigned long num_waiting_nodes;
unsigned long num_terminated_nodes;
dlib::uint64 current_epoch;
impl1::thread_safe_deque msg_buffer;
impl1::thread_safe_message_queue msg_buffer;
impl1::map_id_to_con& _cons;
const unsigned long _node_id;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment