import distutils
import imp
import os
import shutil
import subprocess
import sys
from distutils.core import Extension
from distutils.dist import Distribution
from distutils.sysconfig import customize_compiler
from os.path import abspath, dirname, exists, join, getmtime
import numpy as np
from Cython.Build import cythonize
from Cython.Distutils.old_build_ext import old_build_ext as build_ext
from mujoco_py.utils import discover_mujoco
def load_cython_ext(mjpro_path):
"""
Loads the cymj Cython extension. This is safe to be called from
multiple processes running on the same machine.
Cython only gives us back the raw path, regardless of whether
it found a cached version or actually compiled. Since we do
non-idempotent postprocessing of the DLL, be extra careful
to only do that once and then atomically move to the final
location.
"""
if ('glfw' in sys.modules and
'mujoco' in abspath(sys.modules["glfw"].__file__)):
print('''
WARNING: Existing glfw python module detected!
MuJoCo comes with its own version of GLFW, so it's preferable to use that one.
The easy solution is to `import mujoco_py` _before_ `import glfw`.
''')
if sys.platform == 'darwin':
Builder = MacExtensionBuilder
elif sys.platform == 'linux':
if exists('/usr/local/nvidia/lib64'):
Builder = LinuxGPUExtensionBuilder
else:
Builder = LinuxCPUExtensionBuilder
elif sys.platform.startswith("win"):
Builder = WindowsExtensionBuilder
else:
raise RuntimeError("Unsupported platform %s" % sys.platform)
builder = Builder(mjpro_path)
cext_so_path = builder.build()
mod = imp.load_dynamic("cymj", cext_so_path)
return mod
class custom_build_ext(build_ext):
"""
Custom build_ext to suppress the "-Wstrict-prototypes" warning.
It arises from the fact that we're using C++. This seems to be
the cleanest way to get rid of the extra flag.
See http://stackoverflow.com/a/36293331/248400
"""
def build_extensions(self):
customize_compiler(self.compiler)
try:
self.compiler.compiler_so.remove("-Wstrict-prototypes")
except (AttributeError, ValueError):
pass
build_ext.build_extensions(self)
def fix_shared_library(so_file, name, library_path):
ldd_output = subprocess.check_output(
['ldd', so_file]).decode('utf-8')
if name in ldd_output:
subprocess.check_call(['patchelf',
'--remove-needed',
name,
so_file])
subprocess.check_call(
['patchelf', '--add-needed',
library_path,
so_file])
class MujocoExtensionBuilder():
CYMJ_DIR_PATH = abspath(dirname(__file__))
def __init__(self, mjpro_path):
self.mjpro_path = mjpro_path
self.extension = Extension(
'mujoco_py.cymj',
sources=[join(self.CYMJ_DIR_PATH, "cymj.pyx")],
include_dirs=[
self.CYMJ_DIR_PATH,
join(mjpro_path, 'include'),
np.get_include(),
],
libraries=['mujoco150'],
library_dirs=[join(mjpro_path, 'bin')],
extra_compile_args=[
'-fopenmp', # needed for OpenMP
'-w', # suppress numpy compilation warnings
],
extra_link_args=['-fopenmp'],
language='c')
def build(self):
dist = Distribution({
"script_name": None,
"script_args": ["build_ext"]
})
dist.ext_modules = cythonize([self.extension])
dist.include_dirs = []
dist.cmdclass = {'build_ext': custom_build_ext}
build = dist.get_command_obj('build')
# following the convention of cython's pyxbuild and naming
# base directory "_pyxbld"
build.build_base = join(self.CYMJ_DIR_PATH, 'generated',
'_pyxbld_%s' % self.__class__.__name__)
dist.parse_command_line()
obj_build_ext = dist.get_command_obj("build_ext")
dist.run_commands()
so_file_path, = obj_build_ext.get_outputs()
return so_file_path
class WindowsExtensionBuilder(MujocoExtensionBuilder):
def __init__(self, mjpro_path):
super().__init__(mjpro_path)
os.environ["PATH"] += ";" + join(mjpro_path, "bin")
self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/dummyshim.c")
class LinuxCPUExtensionBuilder(MujocoExtensionBuilder):
def __init__(self, mjpro_path):
super().__init__(mjpro_path)
self.extension.sources.append(
join(self.CYMJ_DIR_PATH, "gl", "osmesashim.c"))
self.extension.libraries.extend(['glewosmesa', 'OSMesa'])
self.extension.runtime_library_dirs = [join(mjpro_path, 'bin')]
class LinuxGPUExtensionBuilder(MujocoExtensionBuilder):
def __init__(self, mjpro_path):
super().__init__(mjpro_path)
self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/eglshim.c")
self.extension.include_dirs.append(self.CYMJ_DIR_PATH + '/vendor/egl')
self.extension.libraries.extend(['glewegl'])
self.extension.runtime_library_dirs = [join(mjpro_path, 'bin')]
def build(self):
so_file_path = super().build()
nvidia_lib_dir = '/usr/local/nvidia/lib64/'
fix_shared_library(so_file_path, 'libOpenGL.so',
join(nvidia_lib_dir, 'libOpenGL.so.0'))
fix_shared_library(so_file_path, 'libEGL.so',
join(nvidia_lib_dir, 'libEGL.so.1'))
return so_file_path
class MacExtensionBuilder(MujocoExtensionBuilder):
def __init__(self, mjpro_path):
super().__init__(mjpro_path)
self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/dummyshim.c")
self.extension.libraries.extend(['glfw.3'])
self.extension.define_macros = [('ONMAC', None)]
self.extension.runtime_library_dirs = [join(mjpro_path, 'bin')]
def build(self):
# Prefer GCC 6 for now since GCC 7 may behave differently.
c_compilers = ['/usr/local/bin/gcc-6', '/usr/local/bin/gcc-7']
available_c_compiler = None
for c_compiler in c_compilers:
if distutils.spawn.find_executable(c_compiler) is not None:
available_c_compiler = c_compiler
break
if available_c_compiler is None:
raise RuntimeError(
'Could not find GCC 6 or GCC 7 executable.\n\n'
'HINT: On OS X, install GCC 6 with '
'`brew install gcc --without-multilib`.')
os.environ['CC'] = available_c_compiler
so_file_path = super().build()
del os.environ['CC']
return self.manually_link_libraries(so_file_path)
def manually_link_libraries(self, raw_cext_dll_path):
root, ext = os.path.splitext(raw_cext_dll_path)
final_cext_dll_path = root + '_final' + ext
# If someone else already built the final DLL, don't bother
# recreating it here, even though this should still be idempotent.
if (exists(final_cext_dll_path) and
getmtime(final_cext_dll_path) >= getmtime(raw_cext_dll_path)):
return final_cext_dll_path
tmp_final_cext_dll_path = final_cext_dll_path + '~'
shutil.copyfile(raw_cext_dll_path, tmp_final_cext_dll_path)
mj_bin_path = join(self.mjpro_path, 'bin')
# Fix the rpath of the generated library -- i lost the Stackoverflow
# reference here
from_mujoco_path = '@executable_path/libmujoco150.dylib'
to_mujoco_path = '%s/libmujoco150.dylib' % mj_bin_path
subprocess.check_call(['install_name_tool',
'-change',
from_mujoco_path,
to_mujoco_path,
tmp_final_cext_dll_path])
from_glfw_path = 'libglfw.3.dylib'
to_glfw_path = os.path.join(mj_bin_path, 'libglfw.3.dylib')
subprocess.check_call(['install_name_tool',
'-change',
from_glfw_path,
to_glfw_path,
tmp_final_cext_dll_path])
os.rename(tmp_final_cext_dll_path, final_cext_dll_path)
return final_cext_dll_path
class MujocoException(Exception):
pass
def user_warning_raise_exception(warn_bytes):
'''
User-defined warning callback, which is called by mujoco on warnings.
Here we have two primary jobs:
- Detect known warnings and suggest fixes (with code)
- Decide whether to raise an Exception and raise if needed
More cases should be added as we find new failures.
'''
# TODO: look through test output to see MuJoCo warnings to catch
# and recommend. Also fix those tests
warn = warn_bytes.decode() # Convert bytes to string
if 'Pre-allocated constraint buffer is full' in warn:
raise MujocoException(warn + 'Increase njmax in mujoco XML')
if 'Pre-allocated contact buffer is full' in warn:
raise MujocoException(warn + 'Increase njconmax in mujoco XML')
raise MujocoException('Got MuJoCo Warning: {}'.format(warn))
def user_warning_ignore_exception(warn_bytes):
pass
[docs]class ignore_mujoco_warnings:
"""
Class to turn off mujoco warning exceptions within a scope. Useful for
large, vectorized rollouts.
"""
def __enter__(self):
self.prev_user_warning = cymj.get_warning_callback()
cymj.set_warning_callback(user_warning_ignore_exception)
return self
def __exit__(self, type, value, traceback):
cymj.set_warning_callback(self.prev_user_warning)
mjpro_path, key_path = discover_mujoco()
cymj = load_cython_ext(mjpro_path)
# Trick to expose all mj* functions from mujoco in mujoco_py.*
class dict2(object):
pass
functions = dict2()
for func_name in dir(cymj):
if func_name.startswith("_mj"):
setattr(functions, func_name[1:], getattr(cymj, func_name))
functions.mj_activate(key_path)
# Set user-defined callbacks that raise assertion with message
cymj.set_warning_callback(user_warning_raise_exception)