Unverified Commit c56832ed authored by Francisco Massa's avatar Francisco Massa Committed by GitHub

Add support for Python 2 (#11)

* Add missing __init__.py files

* Add packages

* Rename logging.py to logger.py

Import rules from Python2 makes this a bad idea

* Make import_file py2 compatible

* list does not have .copy() in py2

* math.log2 does not exist in py2

* Miscellaneous fixes for py2

* Address comments
parent 8323c118
...@@ -3,7 +3,7 @@ __pycache__ ...@@ -3,7 +3,7 @@ __pycache__
_ext _ext
*.pyc *.pyc
*.so *.so
maskrcnn-benchmark.egg-info/ maskrcnn_benchmark.egg-info/
build/ build/
dist/ dist/
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect import bisect
import copy
import logging import logging
import torch.utils.data import torch.utils.data
...@@ -63,7 +64,8 @@ def make_data_sampler(dataset, shuffle, distributed): ...@@ -63,7 +64,8 @@ def make_data_sampler(dataset, shuffle, distributed):
def _quantize(x, bins): def _quantize(x, bins):
bins = sorted(bins.copy()) bins = copy.copy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized return quantized
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -57,7 +56,7 @@ class Pooler(nn.Module): ...@@ -57,7 +56,7 @@ class Pooler(nn.Module):
""" """
Arguments: Arguments:
output_size (list[tuple[int]] or list[int]): output size for the pooled region output_size (list[tuple[int]] or list[int]): output size for the pooled region
scales (list[flaot]): scales for each Pooler scales (list[float]): scales for each Pooler
sampling_ratio (int): sampling ratio for ROIAlign sampling_ratio (int): sampling ratio for ROIAlign
""" """
super(Pooler, self).__init__() super(Pooler, self).__init__()
...@@ -72,8 +71,8 @@ class Pooler(nn.Module): ...@@ -72,8 +71,8 @@ class Pooler(nn.Module):
self.output_size = output_size self.output_size = output_size
# get the levels in the feature map by leveraging the fact that the network always # get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level. # downsamples by a factor of 2 at each level.
lvl_min = -math.log2(scales[0]) lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -math.log2(scales[-1]) lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.map_levels = LevelMapper(lvl_min, lvl_max) self.map_levels = LevelMapper(lvl_min, lvl_max)
def convert_to_roi_format(self, boxes): def convert_to_roi_format(self, boxes):
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from __future__ import division
import torch import torch
......
...@@ -119,7 +119,10 @@ def _rename_weights_for_resnet(weights, stage_names): ...@@ -119,7 +119,10 @@ def _rename_weights_for_resnet(weights, stage_names):
def _load_c2_pickled_weights(file_path): def _load_c2_pickled_weights(file_path):
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
data = pickle.load(f, encoding="latin1") if torch._six.PY3:
data = pickle.load(f, encoding="latin1")
else:
data = pickle.load(f)
if "blobs" in data: if "blobs" in data:
weights = data["blobs"] weights = data["blobs"]
else: else:
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import importlib import torch
import importlib.util
import sys if torch._six.PY3:
import importlib
import importlib.util
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa import sys
def import_file(module_name, file_path, make_importable=False):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec) # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
spec.loader.exec_module(module) def import_file(module_name, file_path, make_importable=False):
if make_importable: spec = importlib.util.spec_from_file_location(module_name, file_path)
sys.modules[module_name] = module module = importlib.util.module_from_spec(spec)
return module spec.loader.exec_module(module)
if make_importable:
sys.modules[module_name] = module
return module
else:
import imp
def import_file(module_name, file_path, make_importable=None):
module = imp.load_source(module_name, file_path)
return module
...@@ -62,7 +62,7 @@ setup( ...@@ -62,7 +62,7 @@ setup(
author="fmassa", author="fmassa",
url="https://github.com/facebookresearch/maskrnn-benchmark", url="https://github.com/facebookresearch/maskrnn-benchmark",
description="object detection in pytorch", description="object detection in pytorch",
# packages=find_packages(exclude=("configs", "examples", "test",)), packages=find_packages(exclude=("configs", "tests",)),
# install_requires=requirements, # install_requires=requirements,
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
......
...@@ -14,7 +14,7 @@ from maskrcnn_benchmark.modeling.detector import build_detection_model ...@@ -14,7 +14,7 @@ from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize from maskrcnn_benchmark.utils.comm import synchronize
from maskrcnn_benchmark.utils.logging import setup_logger from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir from maskrcnn_benchmark.utils.miscellaneous import mkdir
......
...@@ -22,7 +22,7 @@ from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer ...@@ -22,7 +22,7 @@ from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize from maskrcnn_benchmark.utils.comm import synchronize
from maskrcnn_benchmark.utils.imports import import_file from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logging import setup_logger from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir from maskrcnn_benchmark.utils.miscellaneous import mkdir
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment