Commit ed78e8b1 authored by Davis King's avatar Davis King

Made this not use stdin's file descriptor for data transfer between the

processes since sometimes stdin was closed in some environments.  Also cleaned
up the code a little bit.
parent d40e34cf
...@@ -401,10 +401,17 @@ namespace dlib ...@@ -401,10 +401,17 @@ namespace dlib
namespace impl namespace impl
{ {
std::ostream& get_data_ostream() int get_data_fd()
{ {
static filestreambuf dbuff(STDIN_FILENO, -1); char* env_fd = getenv("DLIB_SUBPROCESS_DATA_FD");
static ostream out(&dbuff); DLIB_CASSERT(env_fd != 0,"");
return atoi(env_fd);
}
std::iostream& get_data_iostream()
{
static filestreambuf dbuff(get_data_fd(), -1);
static iostream out(&dbuff);
return out; return out;
} }
} }
...@@ -426,10 +433,8 @@ namespace dlib ...@@ -426,10 +433,8 @@ namespace dlib
if (child_pid == 0) if (child_pid == 0)
{ {
// In child process // In child process
dup2(data_pipe.child_fd(), STDIN_FILENO);
dup2(stdout_pipe.child_fd(), STDOUT_FILENO); dup2(stdout_pipe.child_fd(), STDOUT_FILENO);
dup2(stderr_pipe.child_fd(), STDERR_FILENO); dup2(stderr_pipe.child_fd(), STDERR_FILENO);
data_pipe.close();
stdout_pipe.close(); stdout_pipe.close();
stderr_pipe.close(); stderr_pipe.close();
...@@ -437,13 +442,20 @@ namespace dlib ...@@ -437,13 +442,20 @@ namespace dlib
char* cudadevs = getenv("CUDA_VISIBLE_DEVICES"); char* cudadevs = getenv("CUDA_VISIBLE_DEVICES");
if (cudadevs) if (cudadevs)
{ {
std::string extra = std::string("CUDA_VISIBLE_DEVICES=") + cudadevs; std::ostringstream sout;
char* envp[] = {(char*)extra.c_str(), nullptr}; sout << "DLIB_SUBPROCESS_DATA_FD="<<data_pipe.child_fd();
std::string extra = sout.str();
std::string extra2 = std::string("CUDA_VISIBLE_DEVICES=") + cudadevs;
char* envp[] = {(char*)extra.c_str(), (char*)extra2.c_str(), nullptr};
execve(argv[0], argv, envp); execve(argv[0], argv, envp);
} }
else else
{ {
char* envp[] = {nullptr}; std::ostringstream sout;
sout << "DLIB_SUBPROCESS_DATA_FD="<<data_pipe.child_fd();
std::string extra = sout.str();
char* envp[] = {(char*)extra.c_str(), nullptr};
execve(argv[0], argv, envp); execve(argv[0], argv, envp);
} }
......
...@@ -52,9 +52,9 @@ namespace dlib ...@@ -52,9 +52,9 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl{ std::ostream& get_data_ostream(); } namespace impl{ std::iostream& get_data_iostream(); }
inline void send_to_parent_process() {impl::get_data_ostream().flush();} inline void send_to_parent_process() {impl::get_data_iostream().flush();}
template <typename U, typename ...T> template <typename U, typename ...T>
void send_to_parent_process(U&& arg1, T&& ...args) void send_to_parent_process(U&& arg1, T&& ...args)
/*! /*!
...@@ -63,9 +63,9 @@ namespace dlib ...@@ -63,9 +63,9 @@ namespace dlib
serializing them with interprocess_serialize(). serializing them with interprocess_serialize().
!*/ !*/
{ {
interprocess_serialize(arg1, impl::get_data_ostream()); interprocess_serialize(arg1, impl::get_data_iostream());
send_to_parent_process(std::forward<T>(args)...); send_to_parent_process(std::forward<T>(args)...);
if (!impl::get_data_ostream()) if (!impl::get_data_iostream())
throw dlib::error("Error sending object to parent process."); throw dlib::error("Error sending object to parent process.");
} }
...@@ -74,14 +74,13 @@ namespace dlib ...@@ -74,14 +74,13 @@ namespace dlib
void receive_from_parent_process(U&& arg1, T&& ...args) void receive_from_parent_process(U&& arg1, T&& ...args)
/*! /*!
ensures ensures
- receives all the arguments to receive_from_parent_process() from standard - receives all the arguments to receive_from_parent_process() from the parent
input (and hence from the parent process) by deserializing them with process by deserializing them from interprocess_serialize().
interprocess_deserialize().
!*/ !*/
{ {
interprocess_deserialize(arg1, std::cin); interprocess_deserialize(arg1, impl::get_data_iostream());
receive_from_parent_process(std::forward<T>(args)...); receive_from_parent_process(std::forward<T>(args)...);
if (!std::cin) if (!impl::get_data_iostream())
throw dlib::error("Error receiving object from parent process."); throw dlib::error("Error receiving object from parent process.");
} }
...@@ -90,12 +89,12 @@ namespace dlib ...@@ -90,12 +89,12 @@ namespace dlib
class filestreambuf; class filestreambuf;
class subprocess_stream class subprocess_stream : noncopyable
{ {
/*! /*!
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is a tool for spawning a subprocess and communicating with it through This is a tool for spawning a subprocess and communicating with it. Here
its standard input, output, and error. Here is an example: is an example:
subprocess_stream s("/usr/bin/some_program"); subprocess_stream s("/usr/bin/some_program");
s.send(obj1, obj2, obj3); s.send(obj1, obj2, obj3);
...@@ -111,11 +110,9 @@ namespace dlib ...@@ -111,11 +110,9 @@ namespace dlib
Additionally, if the sub process writes to its standard out then that will Additionally, if the sub process writes to its standard out then that will
be echoed to std::cout in the parent process. Also, the communication of be echoed to std::cout in the parent process. Writing to std::cerr or
send()/receive() calls between the parent and child happens all on the returning a non-zero value from main will also be noted by the parent
standard input file descriptor. So you can't really use std::cin for process and an appropriate exception will be thrown.
anything inside the child process as that would interfere with
receive_from_parent_process() and send_to_parent_process().
!*/ !*/
public: public:
...@@ -140,8 +137,8 @@ namespace dlib ...@@ -140,8 +137,8 @@ namespace dlib
); );
/*! /*!
ensures ensures
- closes the standard input of the child process and then waits for the - closes the input stream to the child process and then waits for the child
child to terminate. to terminate.
- If the child returns an error (by returning != 0 from its main) or - If the child returns an error (by returning != 0 from its main) or
outputs to its standard error then wait() throws a dlib::error() with the outputs to its standard error then wait() throws a dlib::error() with the
standard error output in it. standard error output in it.
...@@ -196,7 +193,7 @@ namespace dlib ...@@ -196,7 +193,7 @@ namespace dlib
void send_eof(); void send_eof();
class cpipe class cpipe : noncopyable
{ {
private: private:
int fd[2]; int fd[2];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment