Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SamratThapa120 committed Dec 7, 2024
1 parent 3de50e0 commit 06dbabb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/rocker/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def get_docker_args(self, cliargs):
args = ''
shm_size = cliargs.get('shm_size', None)
if shm_size:
args += f' --shm-size={shm_size} '
args += f' --shm-size {shm_size} '
return args

@staticmethod
Expand Down
70 changes: 70 additions & 0 deletions test/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,73 @@ def test_group_add_extension(self):
args = p.get_docker_args(mock_cliargs)
self.assertIn('--group-add sudo', args)
self.assertIn('--group-add docker', args)

class ShmSizeExtensionTest(unittest.TestCase):

def setUp(self):
# Work around interference between empy Interpreter
# stdout proxy and test runner. empy installs a proxy on stdout
# to be able to capture the information.
# And the test runner creates a new stdout object for each test.
# This breaks empy as it assumes that the proxy has persistent
# between instances of the Interpreter class
# empy will error with the exception
# "em.Error: interpreter stdout proxy lost"
em.Interpreter._wasProxyInstalled = False

@pytest.mark.docker
def test_shm_size_extension(self):
plugins = list_plugins()
shm_size_plugin = plugins['shm_size']
self.assertEqual(shm_size_plugin.get_name(), 'shm_size')

p = shm_size_plugin()
self.assertTrue(plugin_load_parser_correctly(shm_size_plugin))

mock_cliargs = {}
self.assertEqual(p.get_snippet(mock_cliargs), '')
self.assertEqual(p.get_preamble(mock_cliargs), '')
args = p.get_docker_args(mock_cliargs)
self.assertNotIn('--shm-size', args)

mock_cliargs = {'shm_size': '12g'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--shm-size 12g', args)

class GpusExtensionTest(unittest.TestCase):

def setUp(self):
# Work around interference between empy Interpreter
# stdout proxy and test runner. empy installs a proxy on stdout
# to be able to capture the information.
# And the test runner creates a new stdout object for each test.
# This breaks empy as it assumes that the proxy has persistent
# between instances of the Interpreter class
# empy will error with the exception
# "em.Error: interpreter stdout proxy lost"
em.Interpreter._wasProxyInstalled = False

@pytest.mark.docker
def test_gpus_extension(self):
plugins = list_plugins()
gpus_plugin = plugins['gpus']
self.assertEqual(gpus_plugin.get_name(), 'gpus')

p = gpus_plugin()
self.assertTrue(plugin_load_parser_correctly(gpus_plugin))

# Test when no GPUs are specified
mock_cliargs = {}
self.assertEqual(p.get_snippet(mock_cliargs), '')
self.assertEqual(p.get_preamble(mock_cliargs), '')
args = p.get_docker_args(mock_cliargs)
self.assertNotIn('--gpus', args)

# Test when GPUs are specified
mock_cliargs = {'gpus': 'all'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--gpus all', args)

mock_cliargs = {'gpus': '0,1'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--gpus 0,1', args)

0 comments on commit 06dbabb

Please sign in to comment.