Skip to content

Commit

Permalink
homogenize to use get_name() in all extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
tfoote committed Feb 5, 2024
1 parent 9bedb6a commit f63f89c
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 80 deletions.
41 changes: 3 additions & 38 deletions src/rocker/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ class Devices(RockerExtension):
def get_name():
return 'devices'

def __init__(self):
self.name = Devices.get_name()

def get_preamble(self, cliargs):
return ''

Expand Down Expand Up @@ -68,8 +65,6 @@ def get_name():

def __init__(self):
self._env_subs = None
self.name = DevHelpers.get_name()


def get_environment_subs(self):
if not self._env_subs:
Expand All @@ -80,7 +75,7 @@ def get_preamble(self, cliargs):
return ''

def get_snippet(self, cliargs):
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.name).decode('utf-8')
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.get_name()).decode('utf-8')
return em.expand(snippet, self.get_environment_subs())

@staticmethod
Expand All @@ -96,9 +91,6 @@ class Hostname(RockerExtension):
def get_name():
return 'hostname'

def __init__(self):
self.name = Hostname.get_name()

def get_preamble(self, cliargs):
return ''

Expand All @@ -119,9 +111,6 @@ class Name(RockerExtension):
def get_name():
return 'name'

def __init__(self):
self.name = Name.get_name()

def get_preamble(self, cliargs):
return ''

Expand All @@ -143,9 +132,6 @@ class Network(RockerExtension):
def get_name():
return 'network'

def __init__(self):
self.name = Network.get_name()

def get_preamble(self, cliargs):
return ''

Expand All @@ -168,9 +154,6 @@ class Expose(RockerExtension):
def get_name():
return 'expose'

def __init__(self):
self.name = Expose.get_name()

def get_preamble(self, cliargs):
return ''

Expand All @@ -194,9 +177,6 @@ class Port(RockerExtension):
def get_name():
return 'port'

def __init__(self):
self.name = Port.get_name()

def get_preamble(self, cliargs):
return ''

Expand All @@ -222,8 +202,6 @@ def get_name():

def __init__(self):
self._env_subs = None
self.name = PulseAudio.get_name()


def get_environment_subs(self):
if not self._env_subs:
Expand All @@ -237,7 +215,7 @@ def get_preamble(self, cliargs):
return ''

def get_snippet(self, cliargs):
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.name).decode('utf-8')
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.get_name()).decode('utf-8')
return em.expand(snippet, self.get_environment_subs())

def get_docker_args(self, cliargs):
Expand All @@ -258,9 +236,6 @@ class HomeDir(RockerExtension):
def get_name():
return 'home'

def __init__(self):
self.name = HomeDir.get_name()

def get_docker_args(self, cliargs):
return ' -v %s:%s ' % (Path.home(), Path.home())

Expand Down Expand Up @@ -288,10 +263,9 @@ def get_environment_subs(self):

def __init__(self):
self._env_subs = None
self.name = User.get_name()

def get_snippet(self, cliargs):
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.name).decode('utf-8')
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.get_name()).decode('utf-8')
substitutions = self.get_environment_subs()
if 'user_override_name' in cliargs and cliargs['user_override_name']:
substitutions['name'] = cliargs['user_override_name']
Expand Down Expand Up @@ -355,9 +329,6 @@ class Environment(RockerExtension):
def get_name():
return 'env'

def __init__(self):
self.name = Environment.get_name()

def get_snippet(self, cli_args):
return ''

Expand Down Expand Up @@ -403,9 +374,6 @@ class Privileged(RockerExtension):
def get_name():
return 'privileged'

def __init__(self):
self.name = Privileged.get_name()

def get_snippet(self, cli_args):
return ''

Expand All @@ -426,9 +394,6 @@ class GroupAdd(RockerExtension):
def get_name():
return 'group_add'

def __init__(self):
self.name = GroupAdd.get_name()

def get_preamble(self, cliargs):
return ''

Expand Down
8 changes: 3 additions & 5 deletions src/rocker/git_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

class Git(RockerExtension):

name = 'git'

@classmethod
def get_name(cls):
return cls.name
@staticmethod
def get_name():
return 'git'


def get_docker_args(self, cli_args):
Expand Down
11 changes: 4 additions & 7 deletions src/rocker/nvidia_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def get_name():
return 'x11'

def __init__(self):
self.name = X11.get_name()
self._env_subs = None
self._xauth = None

Expand Down Expand Up @@ -84,7 +83,6 @@ def get_name():

def __init__(self):
self._env_subs = None
self.name = Nvidia.get_name()
self.supported_distros = ['Ubuntu', 'Debian GNU/Linux']
self.supported_versions = ['16.04', '18.04', '20.04', '10', '22.04']

Expand Down Expand Up @@ -115,11 +113,11 @@ def get_environment_subs(self, cliargs={}):
return self._env_subs

def get_preamble(self, cliargs):
preamble = pkgutil.get_data('rocker', 'templates/%s_preamble.Dockerfile.em' % self.name).decode('utf-8')
preamble = pkgutil.get_data('rocker', 'templates/%s_preamble.Dockerfile.em' % self.get_name()).decode('utf-8')
return em.expand(preamble, self.get_environment_subs(cliargs))

def get_snippet(self, cliargs):
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.name).decode('utf-8')
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.get_name()).decode('utf-8')
return em.expand(snippet, self.get_environment_subs(cliargs))

def get_docker_args(self, cliargs):
Expand Down Expand Up @@ -148,7 +146,6 @@ def get_name():

def __init__(self):
self._env_subs = None
self.name = Cuda.get_name()
self.supported_distros = ['Ubuntu', 'Debian GNU/Linux']
self.supported_versions = ['20.04', '22.04', '18.04', '11'] # Debian 11

Expand Down Expand Up @@ -183,11 +180,11 @@ def get_environment_subs(self, cliargs={}):

def get_preamble(self, cliargs):
return ''
# preamble = pkgutil.get_data('rocker', 'templates/%s_preamble.Dockerfile.em' % self.name).decode('utf-8')
# preamble = pkgutil.get_data('rocker', 'templates/%s_preamble.Dockerfile.em' % self.get_name()).decode('utf-8')
# return em.expand(preamble, self.get_environment_subs(cliargs))

def get_snippet(self, cliargs):
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.name).decode('utf-8')
snippet = pkgutil.get_data('rocker', 'templates/%s_snippet.Dockerfile.em' % self.get_name()).decode('utf-8')
return em.expand(snippet, self.get_environment_subs(cliargs))

def get_docker_args(self, cliargs):
Expand Down
8 changes: 3 additions & 5 deletions src/rocker/ssh_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@

class Ssh(RockerExtension):

name = 'ssh'

@classmethod
def get_name(cls):
return cls.name
@staticmethod
def get_name():
return 'ssh'

def precondition_environment(self, cli_args):
pass
Expand Down
9 changes: 4 additions & 5 deletions src/rocker/volume_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ class Volume(RockerExtension):

ARG_DOCKER_VOLUME = "-v"
ARG_ROCKER_VOLUME = "--volume"
name = 'volume'

@classmethod
def get_name(cls):
return cls.name
@staticmethod
def get_name():
return 'volume'

def get_docker_args(self, cli_args):
"""
Expand All @@ -40,7 +39,7 @@ def get_docker_args(self, cli_args):
args = ['']

# flatten cli_args['volume']
volumes = [ x for sublist in cli_args[self.name] for x in sublist]
volumes = [ x for sublist in cli_args[self.get_name()] for x in sublist]

for volume in volumes:
elems = volume.split(':')
Expand Down
28 changes: 14 additions & 14 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ def test_extension_manager(self):

def test_strict_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'bar'

def required(self, cli_args):
Expand All @@ -163,13 +163,13 @@ def required(self, cli_args):

def test_implicit_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'bar'

def required(self, cli_args):
Expand All @@ -190,13 +190,13 @@ def required(self, cli_args):

def test_extension_sorting(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'bar'

def invoke_after(self, cli_args):
Expand Down Expand Up @@ -254,8 +254,8 @@ class UserSnippet(RockerExtension):
def __init__(self):
self.name = 'usersnippet'

@classmethod
def get_name(cls):
@staticmethod
def get_name():
return 'usersnippet'

def get_snippet(self, cli_args):
Expand Down
9 changes: 3 additions & 6 deletions test/test_file_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@ def test_name_to_argument(self):
self.assertEqual(name_to_argument('as-df'), '--as-df')

class TestFileInjection(RockerExtension):

name = 'test_file_injection'

@classmethod
def get_name(cls):
return cls.name
@staticmethod
def get_name():
return 'test_file_injection'

def get_files(self, cliargs):
all_files = {}
Expand Down

0 comments on commit f63f89c

Please sign in to comment.