Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
M
maskrcnn
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
人工智能
maskrcnn
Commits
5f2a8263
Commit
5f2a8263
authored
Jan 25, 2019
by
wat3rBro
Committed by
Francisco Massa
Jan 25, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use all_gather to gather results from all gpus (#383)
parent
9b53d15c
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
106 deletions
+79
-106
inference.py
maskrcnn_benchmark/engine/inference.py
+2
-2
comm.py
maskrcnn_benchmark/utils/comm.py
+77
-104
No files found.
maskrcnn_benchmark/engine/inference.py
View file @
5f2a8263
...
@@ -9,7 +9,7 @@ from tqdm import tqdm
...
@@ -9,7 +9,7 @@ from tqdm import tqdm
from
maskrcnn_benchmark.data.datasets.evaluation
import
evaluate
from
maskrcnn_benchmark.data.datasets.evaluation
import
evaluate
from
..utils.comm
import
is_main_process
from
..utils.comm
import
is_main_process
from
..utils.comm
import
scatter
_gather
from
..utils.comm
import
all
_gather
from
..utils.comm
import
synchronize
from
..utils.comm
import
synchronize
...
@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device):
...
@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device):
def
_accumulate_predictions_from_multiple_gpus
(
predictions_per_gpu
):
def
_accumulate_predictions_from_multiple_gpus
(
predictions_per_gpu
):
all_predictions
=
scatter
_gather
(
predictions_per_gpu
)
all_predictions
=
all
_gather
(
predictions_per_gpu
)
if
not
is_main_process
():
if
not
is_main_process
():
return
return
# merge the list of dicts
# merge the list of dicts
...
...
maskrcnn_benchmark/utils/comm.py
View file @
5f2a8263
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
"""
This file contains primitives for multi-gpu communication.
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
This is useful when doing distributed training.
"""
"""
import
os
import
pickle
import
pickle
import
tempfile
import
time
import
time
import
torch
import
torch
import
torch.distributed
as
dist
def
get_world_size
():
def
get_world_size
():
if
not
torch
.
distributed
.
is_available
():
if
not
dist
.
is_available
():
return
1
return
1
if
not
torch
.
distributed
.
is_initialized
():
if
not
dist
.
is_initialized
():
return
1
return
1
return
torch
.
distributed
.
get_world_size
()
return
dist
.
get_world_size
()
def
get_rank
():
def
get_rank
():
if
not
torch
.
distributed
.
is_available
():
if
not
dist
.
is_available
():
return
0
return
0
if
not
torch
.
distributed
.
is_initialized
():
if
not
dist
.
is_initialized
():
return
0
return
0
return
torch
.
distributed
.
get_rank
()
return
dist
.
get_rank
()
def
is_main_process
():
def
is_main_process
():
if
not
torch
.
distributed
.
is_available
():
return
get_rank
()
==
0
return
True
if
not
torch
.
distributed
.
is_initialized
():
return
True
return
torch
.
distributed
.
get_rank
()
==
0
def
synchronize
():
def
synchronize
():
"""
"""
Helper function to synchronize
between multiple
processes when
Helper function to synchronize
(barrier) among all
processes when
using distributed training
using distributed training
"""
"""
if
not
torch
.
distributed
.
is_available
():
if
not
dist
.
is_available
():
return
return
if
not
torch
.
distributed
.
is_initialized
():
if
not
dist
.
is_initialized
():
return
return
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
rank
=
dist
.
get_rank
()
if
world_size
==
1
:
if
world_size
==
1
:
return
return
...
@@ -55,7 +49,7 @@ def synchronize():
...
@@ -55,7 +49,7 @@ def synchronize():
tensor
=
torch
.
tensor
(
0
,
device
=
"cuda"
)
tensor
=
torch
.
tensor
(
0
,
device
=
"cuda"
)
else
:
else
:
tensor
=
torch
.
tensor
(
1
,
device
=
"cuda"
)
tensor
=
torch
.
tensor
(
1
,
device
=
"cuda"
)
torch
.
distributed
.
broadcast
(
tensor
,
r
)
dist
.
broadcast
(
tensor
,
r
)
while
tensor
.
item
()
==
1
:
while
tensor
.
item
()
==
1
:
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -64,94 +58,73 @@ def synchronize():
...
@@ -64,94 +58,73 @@ def synchronize():
_send_and_wait
(
1
)
_send_and_wait
(
1
)
def
_encode
(
encoded_data
,
data
):
def
all_gather
(
data
):
# gets a byte representation for the data
"""
encoded_bytes
=
pickle
.
dumps
(
data
)
Run all_gather on arbitrary picklable data (not necessarily tensors)
# convert this byte string into a byte tensor
Args:
storage
=
torch
.
ByteStorage
.
from_buffer
(
encoded_bytes
)
data: any picklable object
tensor
=
torch
.
ByteTensor
(
storage
)
.
to
(
"cuda"
)
Returns:
# encoding: first byte is the size and then rest is the data
list[data]: list of data gathered from each rank
s
=
tensor
.
numel
()
"""
assert
s
<=
255
,
"Can't encode data greater than 255 bytes"
world_size
=
get_world_size
()
# put the encoded data in encoded_data
if
world_size
==
1
:
encoded_data
[
0
]
=
s
return
[
data
]
encoded_data
[
1
:
(
s
+
1
)]
=
tensor
def
_decode
(
encoded_data
):
size
=
encoded_data
[
0
]
encoded_tensor
=
encoded_data
[
1
:
(
size
+
1
)]
.
to
(
"cpu"
)
return
pickle
.
loads
(
bytearray
(
encoded_tensor
.
tolist
()))
# serialized to a Tensor
buffer
=
pickle
.
dumps
(
data
)
storage
=
torch
.
ByteStorage
.
from_buffer
(
buffer
)
tensor
=
torch
.
ByteTensor
(
storage
)
.
to
(
"cuda"
)
# TODO try to use tensor in shared-memory instead of serializing to disk
# obtain Tensor size of each rank
# this involves getting the all_gather to work
local_size
=
torch
.
IntTensor
([
tensor
.
numel
()])
.
to
(
"cuda"
)
def
scatter_gather
(
data
):
size_list
=
[
torch
.
IntTensor
([
0
])
.
to
(
"cuda"
)
for
_
in
range
(
world_size
)]
"""
dist
.
all_gather
(
size_list
,
local_size
)
This function gathers data from multiple processes, and returns them
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
in a list, as they were obtained from each process.
max_size
=
max
(
size_list
)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list
=
[]
for
_
in
size_list
:
tensor_list
.
append
(
torch
.
ByteTensor
(
size
=
(
max_size
,))
.
to
(
"cuda"
))
if
local_size
!=
max_size
:
padding
=
torch
.
ByteTensor
(
size
=
(
max_size
-
local_size
,))
.
to
(
"cuda"
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
dist
.
all_gather
(
tensor_list
,
tensor
)
This function is useful for retrieving data from multiple processes,
data_list
=
[]
when launching the code with torch.distributed.launch
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
buffer
=
tensor
.
cpu
()
.
numpy
()
.
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
Note: this function is slow and should not be used in tight loops, i.e.,
return
data_list
do not use it in the training loop.
Arguments:
data: the object to be gathered from multiple processes.
It must be serializable
Returns:
def
reduce_dict
(
input_dict
,
average
=
True
):
result (list): a list with as many elements as there are processes,
where each element i in the list corresponds to the data that was
gathered from the process of rank i.
"""
"""
# strategy: the main process creates a temporary directory, and communicates
Args:
# the location of the temporary directory to all other processes.
input_dict (dict): all the values will be reduced
# each process will then serialize the data to the folder defined by
average (bool): whether to do average or sum
# the main process, and then the main process reads all of the serialized
Reduce the values in the dictionary from all processes so that process with rank
# files and returns them in a list
0 has the averaged results. Returns a dict with the same fields as
if
not
torch
.
distributed
.
is_available
():
input_dict, after reduction.
return
[
data
]
"""
if
not
torch
.
distributed
.
is_initialized
():
world_size
=
get_world_size
()
return
[
data
]
if
world_size
<
2
:
synchronize
()
return
input_dict
# get rank of the current process
with
torch
.
no_grad
():
rank
=
torch
.
distributed
.
get_rank
()
names
=
[]
values
=
[]
# the data to communicate should be small
# sort the keys so that they are consistent across processes
data_to_communicate
=
torch
.
empty
(
256
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
for
k
in
sorted
(
input_dict
.
keys
()):
if
rank
==
0
:
names
.
append
(
k
)
# manually creates a temporary directory, that needs to be cleaned
values
.
append
(
input_dict
[
k
])
# afterwards
values
=
torch
.
stack
(
values
,
dim
=
0
)
tmp_dir
=
tempfile
.
mkdtemp
()
dist
.
reduce
(
values
,
dst
=
0
)
_encode
(
data_to_communicate
,
tmp_dir
)
if
dist
.
get_rank
()
==
0
and
average
:
# only main process gets accumulated, so only divide by
synchronize
()
# world_size in this case
# the main process (rank=0) communicates the data to all processes
values
/=
world_size
torch
.
distributed
.
broadcast
(
data_to_communicate
,
0
)
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
# get the data that was communicated
tmp_dir
=
_decode
(
data_to_communicate
)
# each process serializes to a different file
file_template
=
"file{}.pth"
tmp_file
=
os
.
path
.
join
(
tmp_dir
,
file_template
.
format
(
rank
))
torch
.
save
(
data
,
tmp_file
)
# synchronize before loading the data
synchronize
()
# only the master process returns the data
if
rank
==
0
:
data_list
=
[]
world_size
=
torch
.
distributed
.
get_world_size
()
for
r
in
range
(
world_size
):
file_path
=
os
.
path
.
join
(
tmp_dir
,
file_template
.
format
(
r
))
d
=
torch
.
load
(
file_path
)
data_list
.
append
(
d
)
# cleanup
os
.
remove
(
file_path
)
# cleanup
os
.
rmdir
(
tmp_dir
)
return
data_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment