Commit d38723f4 authored by Davis King's avatar Davis King

Fixed a race condition in the BSP code and also simplified the

logic somewhat.
parent cf643ee9
...@@ -97,33 +97,36 @@ namespace dlib ...@@ -97,33 +97,36 @@ namespace dlib
namespace impl2 namespace impl2
{ {
// These control bytes are sent before each message nodes send to each other. // These control bytes are sent before each message between nodes. Note that many
// of these are only sent between the control node (node 0) and the other nodes.
// This is because the controller node is responsible for handling the
// synchronization that needs to happen when all nodes block on calls to
// receive_data()
// at the same time.
// denotes a normal content message. // denotes a normal content message.
const static char MESSAGE_HEADER = 0; const static char MESSAGE_HEADER = 0;
// sent back to sender, means message was returned by receive(). // sent to the controller node when someone receives a message via receive_data().
const static char GOT_MESSAGE = 1; const static char GOT_MESSAGE = 1;
// broadcast when a node goes into a state where it has no outstanding sent // sent to the controller node when someone sends a message via send().
// messages (i.e. it received GOT_MESSAGE for all its sent messages) and is waiting const static char SENT_MESSAGE = 2;
// on receive().
const static char IN_WAITING_STATE = 2;
// broadcast when no longer in IN_WAITING_STATE state. // sent to the controller node when someone enters a call to receive_data()
const static char NOT_IN_WAITING_STATE = 3; const static char IN_WAITING_STATE = 3;
// broadcast when a node terminates itself. // broadcast when a node terminates itself.
const static char NODE_TERMINATE = 4; const static char NODE_TERMINATE = 5;
// broadcast when a node finds out that all non-terminated nodes are in the // broadcast by the controller node when it determines that all nodes are blocked
// IN_WAITING_STATE state. sending this message puts a node into the // on calls to receive_data() and there aren't any messages in flight. This is also
// SEE_ALL_IN_WAITING_STATE where it will wait until it gets this message from all // what makes us go to the next epoch.
// others and then return from receive() once this happens. const static char SEE_ALL_IN_WAITING_STATE = 6;
const static char SEE_ALL_IN_WAITING_STATE = 5;
// This isn't ever transmitted between nodes. It is used internally to indicate
const static char READ_ERROR = 6; // that an error occurred.
const static char READ_ERROR = 7;
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -131,7 +134,7 @@ namespace dlib ...@@ -131,7 +134,7 @@ namespace dlib
impl1::bsp_con* con, impl1::bsp_con* con,
unsigned long node_id, unsigned long node_id,
unsigned long sender_id, unsigned long sender_id,
impl1::thread_safe_deque& msg_buffer impl1::thread_safe_message_queue& msg_buffer
) )
{ {
try try
...@@ -145,6 +148,7 @@ namespace dlib ...@@ -145,6 +148,7 @@ namespace dlib
if (msg.msg_type == MESSAGE_HEADER) if (msg.msg_type == MESSAGE_HEADER)
{ {
msg.data.reset(new std::string); msg.data.reset(new std::string);
deserialize(msg.epoch, con->stream);
deserialize(*msg.data, con->stream); deserialize(*msg.data, con->stream);
} }
...@@ -202,6 +206,8 @@ namespace dlib ...@@ -202,6 +206,8 @@ namespace dlib
void bsp_context:: void bsp_context::
close_all_connections_gracefully( close_all_connections_gracefully(
) )
{
if (node_id() != 0)
{ {
_cons.reset(); _cons.reset();
while (_cons.move_next()) while (_cons.move_next())
...@@ -210,6 +216,7 @@ namespace dlib ...@@ -210,6 +216,7 @@ namespace dlib
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush(); _cons.element().value()->stream.flush();
} }
}
impl1::msg_data msg; impl1::msg_data msg;
// now wait for all the other nodes to terminate // now wait for all the other nodes to terminate
...@@ -219,12 +226,50 @@ namespace dlib ...@@ -219,12 +226,50 @@ namespace dlib
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
if (msg.msg_type == impl2::NODE_TERMINATE) if (msg.msg_type == impl2::NODE_TERMINATE)
{
++num_terminated_nodes; ++num_terminated_nodes;
_cons[msg.sender_id]->terminated = true;
}
else if (msg.msg_type == impl2::READ_ERROR) else if (msg.msg_type == impl2::READ_ERROR)
{
throw dlib::socket_error(*msg.data); throw dlib::socket_error(*msg.data);
}
else if (msg.msg_type == impl2::MESSAGE_HEADER)
{
throw dlib::socket_error("A BSP node received a message after it has terminated.");
}
else if (msg.msg_type == impl2::GOT_MESSAGE) else if (msg.msg_type == impl2::GOT_MESSAGE)
{
--num_waiting_nodes;
--outstanding_messages; --outstanding_messages;
} }
else if (msg.msg_type == impl2::SENT_MESSAGE)
{
++outstanding_messages;
}
else if (msg.msg_type == impl2::IN_WAITING_STATE)
{
++num_waiting_nodes;
}
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
++current_epoch;
}
}
if (node_id() == 0)
{
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
if (outstanding_messages != 0) if (outstanding_messages != 0)
{ {
...@@ -235,6 +280,7 @@ namespace dlib ...@@ -235,6 +280,7 @@ namespace dlib
throw dlib::socket_error(sout.str()); throw dlib::socket_error(sout.str());
} }
} }
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -263,6 +309,7 @@ namespace dlib ...@@ -263,6 +309,7 @@ namespace dlib
outstanding_messages(0), outstanding_messages(0),
num_waiting_nodes(0), num_waiting_nodes(0),
num_terminated_nodes(0), num_terminated_nodes(0),
current_epoch(1),
_cons(cons_), _cons(cons_),
_node_id(node_id_) _node_id(node_id_)
{ {
...@@ -288,95 +335,73 @@ namespace dlib ...@@ -288,95 +335,73 @@ namespace dlib
unsigned long& sending_node_id unsigned long& sending_node_id
) )
{ {
if (outstanding_messages == 0) notify_control_node(impl2::IN_WAITING_STATE);
broadcast_byte(impl2::IN_WAITING_STATE);
unsigned long num_in_see_all_in_waiting_state = 0;
bool sent_see_all_in_waiting_state = false;
std::stack<impl1::msg_data> buf;
while (true) while (true)
{ {
// if there aren't any nodes left to give us messages then return right now. // If there aren't any nodes left to give us messages then return right now.
if (num_terminated_nodes == _cons.size()) // We need to check the msg_buffer size to make sure there aren't any
// unprocessed message there. Recall that this can happen because status
// messages always jump to the front of the message buffer. So we might have
// learned about the node terminations before processing their messages for us.
if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
{
return false; return false;
}
// if all running nodes are currently blocking forever on receive_data() // if all running nodes are currently blocking forever on receive_data()
if (outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
{ {
num_waiting_nodes = 0; num_waiting_nodes = 0;
sent_see_all_in_waiting_state = true;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
// Note that the reason we have this epoch counter is so we can tell if a
// sent message is from before or after one of these "all nodes waiting"
// synchronization events. If we didn't have the epoch count we would have
// a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
// message before others and then sends out a message to another node
// before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
// node would think the normal message came before SEE_ALL_IN_WAITING_STATE
// which would be bad.
++current_epoch;
return false;
} }
impl1::msg_data data; impl1::msg_data data;
if (!msg_buffer.pop(data)) if (!msg_buffer.pop(data, current_epoch))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
if (sent_see_all_in_waiting_state)
{
// Once we have gotten one SEE_ALL_IN_WAITING_STATE, all we care about is
// getting the rest of them. So the effect of this code is to always move
// any SEE_ALL_IN_WAITING_STATE messages to the front of the message queue.
if (data.msg_type != impl2::SEE_ALL_IN_WAITING_STATE)
{
buf.push(data);
continue;
}
}
switch(data.msg_type) switch(data.msg_type)
{ {
case impl2::MESSAGE_HEADER: { case impl2::MESSAGE_HEADER: {
item = data.data; item = data.data;
sending_node_id = data.sender_id; sending_node_id = data.sender_id;
notify_control_node(impl2::GOT_MESSAGE);
// if we would have send the IN_WAITING_STATE message before getting to
// this point then let other nodes know that we aren't waiting anymore.
if (outstanding_messages == 0)
broadcast_byte(impl2::NOT_IN_WAITING_STATE);
send_byte(impl2::GOT_MESSAGE, data.sender_id);
return true; return true;
} break; } break;
case impl2::IN_WAITING_STATE: { case impl2::IN_WAITING_STATE: {
++num_waiting_nodes; ++num_waiting_nodes;
} break; } break;
case impl2::NOT_IN_WAITING_STATE: { case impl2::GOT_MESSAGE: {
--outstanding_messages;
--num_waiting_nodes; --num_waiting_nodes;
} break; } break;
case impl2::GOT_MESSAGE: { case impl2::SENT_MESSAGE: {
--outstanding_messages; ++outstanding_messages;
if (outstanding_messages == 0)
broadcast_byte(impl2::IN_WAITING_STATE);
} break; } break;
case impl2::NODE_TERMINATE: { case impl2::NODE_TERMINATE: {
++num_terminated_nodes; ++num_terminated_nodes;
_cons[data.sender_id]->terminated = true; _cons[data.sender_id]->terminated = true;
if (num_terminated_nodes == _cons.size())
{
return false;
}
} break; } break;
case impl2::SEE_ALL_IN_WAITING_STATE: { case impl2::SEE_ALL_IN_WAITING_STATE: {
++num_in_see_all_in_waiting_state; ++current_epoch;
if (num_in_see_all_in_waiting_state + num_terminated_nodes == _cons.size())
{
// put stuff from buf back into msg_buffer
while (buf.size() != 0)
{
msg_buffer.push_front(buf.top());
buf.pop();
}
return false; return false;
}
} break; } break;
case impl2::READ_ERROR: { case impl2::READ_ERROR: {
...@@ -393,13 +418,36 @@ namespace dlib ...@@ -393,13 +418,36 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void bsp_context:: void bsp_context::
send_byte ( notify_control_node (
char val, char val
unsigned long target_node_id
) )
{ {
serialize(val, _cons[target_node_id]->stream); if (node_id() == 0)
_cons[target_node_id]->stream.flush(); {
using namespace impl2;
switch(val)
{
case SENT_MESSAGE: {
++outstanding_messages;
} break;
case GOT_MESSAGE: {
--outstanding_messages;
} break;
case IN_WAITING_STATE: {
// nothing to do in this case
} break;
default:
DLIB_CASSERT(false,"This should never happen");
}
}
else
{
serialize(val, _cons[0]->stream);
_cons[0]->stream.flush();
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -415,7 +463,8 @@ namespace dlib ...@@ -415,7 +463,8 @@ namespace dlib
if (i == node_id() || _cons[i]->terminated) if (i == node_id() || _cons[i]->terminated)
continue; continue;
send_byte(val,i); serialize(val, _cons[i]->stream);
_cons[i]->stream.flush();
} }
} }
...@@ -432,10 +481,11 @@ namespace dlib ...@@ -432,10 +481,11 @@ namespace dlib
throw socket_error("Attempt to send a message to a node that has terminated."); throw socket_error("Attempt to send a message to a node that has terminated.");
serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
serialize(current_epoch, _cons[target_node_id]->stream);
serialize(item, _cons[target_node_id]->stream); serialize(item, _cons[target_node_id]->stream);
_cons[target_node_id]->stream.flush(); _cons[target_node_id]->stream.flush();
++outstanding_messages; notify_control_node(SENT_MESSAGE);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "../serialize.h" #include "../serialize.h"
#include "../map.h" #include "../map.h"
#include "../ref.h" #include "../ref.h"
#include <deque> #include <queue>
#include <vector> #include <vector>
namespace dlib namespace dlib
...@@ -210,15 +210,64 @@ namespace dlib ...@@ -210,15 +210,64 @@ namespace dlib
shared_ptr<std::string> data; shared_ptr<std::string> data;
unsigned long sender_id; unsigned long sender_id;
char msg_type; char msg_type;
dlib::uint64 epoch;
msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {}
}; };
// ------------------------------------------------------------------------------------
class thread_safe_message_queue : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a simple message queue for msg_data objects. Note that it
has the special property that, while messages will generally leave
the queue in the order they are inserted, any message with a smaller
epoch value will always be popped out first. But for all messages
with equal epoch values the queue functions as a normal FIFO queue.
!*/
private:
struct msg_wrap
{
msg_wrap(
const msg_data& data_,
const dlib::uint64& sequence_number_
) : data(data_), sequence_number(sequence_number_) {}
msg_wrap() : sequence_number(0){}
class thread_safe_deque msg_data data;
dlib::uint64 sequence_number;
// Make it so that when msg_wrap objects are in a std::priority_queue,
// messages with a smaller epoch number always come first. Then, within an
// epoch, messages are ordered by their sequence number (so smaller first
// there as well).
bool operator<(const msg_wrap& item) const
{
if (data.epoch < item.data.epoch)
{
return false;
}
else if (data.epoch > item.data.epoch)
{
return true;
}
else
{ {
if (sequence_number < item.sequence_number)
return false;
else
return true;
}
}
};
public: public:
thread_safe_deque() : sig(class_mutex),disabled(false) {} thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {}
~thread_safe_deque() ~thread_safe_message_queue()
{ {
disable(); disable();
} }
...@@ -230,19 +279,16 @@ namespace dlib ...@@ -230,19 +279,16 @@ namespace dlib
sig.broadcast(); sig.broadcast();
} }
unsigned long size() const { return data.size(); } unsigned long size() const
void push_front( const msg_data& item)
{ {
auto_mutex lock(class_mutex); auto_mutex lock(class_mutex);
data.push_front(item); return data.size();
sig.signal();
} }
void push_and_consume( msg_data& item) void push_and_consume( msg_data& item)
{ {
auto_mutex lock(class_mutex); auto_mutex lock(class_mutex);
data.push_back(item); data.push(msg_wrap(item, next_seq_num++));
// do this here so that we don't have to worry about different threads touching the shared_ptr. // do this here so that we don't have to worry about different threads touching the shared_ptr.
item.data.reset(); item.data.reset();
sig.signal(); sig.signal();
...@@ -266,17 +312,43 @@ namespace dlib ...@@ -266,17 +312,43 @@ namespace dlib
if (disabled) if (disabled)
return false; return false;
item = data.front(); item = data.top().data;
data.pop_front(); data.pop();
return true;
}
bool pop (
msg_data& item,
const dlib::uint64& max_epoch
)
/*!
ensures
- if (this function returns true) then
- #item == the next thing from the queue that has an epoch <= max_epoch
- else
- this object is disabled
!*/
{
auto_mutex lock(class_mutex);
while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled)
sig.wait();
if (disabled)
return false;
item = data.top().data;
data.pop();
return true; return true;
} }
private: private:
std::deque<msg_data> data; std::priority_queue<msg_wrap> data;
dlib::mutex class_mutex; dlib::mutex class_mutex;
dlib::signaler sig; dlib::signaler sig;
bool disabled; bool disabled;
dlib::uint64 next_seq_num;
}; };
...@@ -396,9 +468,8 @@ namespace dlib ...@@ -396,9 +468,8 @@ namespace dlib
); );
void send_byte ( void notify_control_node (
char val, char val
unsigned long target_node_id
); );
void broadcast_byte ( void broadcast_byte (
...@@ -423,8 +494,9 @@ namespace dlib ...@@ -423,8 +494,9 @@ namespace dlib
unsigned long outstanding_messages; unsigned long outstanding_messages;
unsigned long num_waiting_nodes; unsigned long num_waiting_nodes;
unsigned long num_terminated_nodes; unsigned long num_terminated_nodes;
dlib::uint64 current_epoch;
impl1::thread_safe_deque msg_buffer; impl1::thread_safe_message_queue msg_buffer;
impl1::map_id_to_con& _cons; impl1::map_id_to_con& _cons;
const unsigned long _node_id; const unsigned long _node_id;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment