Skip to content

Commit

Permalink
Can pass parameters to lua class constructors now
Browse files Browse the repository at this point in the history
  • Loading branch information
hughperkins committed Mar 4, 2016
1 parent 96b1bae commit 7522c32
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_file_datetime(filepath):

setup(
name='PyTorch',
version='2.6.0',
version='2.7.0',
author='Hugh Perkins',
author_email='[email protected]',
description=(
Expand Down
8 changes: 7 additions & 1 deletion simpleexample/luabit.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ require 'nn'

local Luabit = torch.class('Luabit')

function Luabit:__init()
function Luabit:__init(someName)
print('Luabit:__init(', someName, ')')
self.someName = someName
end

function Luabit:getName()
return self.someName
end

function Luabit:getOut(inTensor, outSize, kernelSize)
Expand Down
3 changes: 2 additions & 1 deletion simpleexample/pybit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
outSize = 3
kernelSize = 3

luabit = Luabit()
luabit = Luabit('green')
print(luabit.getName())

inTensor = np.random.randn(batchSize, numFrames, inSize).astype('float32')
luain = PyTorch.asFloatTensor(inTensor)
Expand Down
32 changes: 17 additions & 15 deletions src/PyTorchAug.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def torchType(lua, pos):
return popString(lua)

class LuaClass(object):
def __init__(self, nameList, *args):
def __init__(self, *args, nameList):
# print('LuaClass.__init__()')
lua = PyTorch.getGlobalState().getLua()
# self.luaclass = luaclass
Expand All @@ -79,6 +79,8 @@ def __init__(self, nameList, *args):
for arg in args:
if isinstance(arg, int):
lua.pushNumber(arg)
elif isinstance(arg, str):
lua.pushString(arg)
else:
raise Exception('arg type ' + str(type(arg)) + ' not implemented')
lua.call(len(args), 1)
Expand Down Expand Up @@ -236,7 +238,7 @@ def __init__(self, _fromLua=False):
# print('Table.__init__')
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()
self.luaclass = 'table'
Expand All @@ -247,7 +249,7 @@ def __init__(self, numIn=1, numOut=1, _fromLua=False):
self.luaclass = 'nn.Linear'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name], numIn, numOut)
super(self.__class__, self).__init__(numIn, numOut, nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -256,7 +258,7 @@ def __init__(self, _fromLua=False):
self.luaclass = 'nn.ClassNLLCriterion'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -265,7 +267,7 @@ def __init__(self, _fromLua=False):
self.luaclass = 'nn.MSECriterion'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -274,7 +276,7 @@ def __init__(self, _fromLua=False):
self.luaclass = 'nn.Sequential'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -283,7 +285,7 @@ def __init__(self, _fromLua=False):
self.luaclass = 'nn.LogSoftMax'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -293,13 +295,13 @@ def __init__(self, s1, s2=None, s3=None, s4=None, _fromLua=False):
if not _fromLua:
name = self.__class__.__name__
if s4 is not None: # this is a bit hacky, but gets it working for now...
super(self.__class__, self).__init__(['nn', name], s1, s2, s3, s4)
super(self.__class__, self).__init__(s1, s2, s3, s4, nameList=['nn', name])
elif s3 is not None:
super(self.__class__, self).__init__(['nn', name], s1, s2, s3)
super(self.__class__, self).__init__(s1, s2, s3, nameList=['nn', name])
elif s2 is not None:
super(self.__class__, self).__init__(['nn', name], s1, s2)
super(self.__class__, self).__init__(s1, s2, nameList=['nn', name])
else:
super(self.__class__, self).__init__(['nn', name], s1)
super(self.__class__, self).__init__(s1, nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -308,7 +310,7 @@ def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW=1, dH=1, padW=0, padH=0
self.luaclass = 'nn.SpatialConvolutionMM'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name], nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
super(self.__class__, self).__init__(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -317,7 +319,7 @@ def __init__(self, kW, kH, dW, dH, padW=0, padH=0, _fromLua=False):
self.luaclass = 'nn.SpatialMaxPooling'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name], kW, kH, dW, dH, padW, padH)
super(self.__class__, self).__init__(kW, kH, dW, dH, padW, padH, nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -326,7 +328,7 @@ def __init__(self, _fromLua=False):
self.luaclass = 'nn.ReLU'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand All @@ -335,7 +337,7 @@ def __init__(self, _fromLua=False):
self.luaclass = 'nn.Tanh'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__(['nn', name])
super(self.__class__, self).__init__(nameList=['nn', name])
else:
self.__dict__['__objectId'] = getNextObjectId()

Expand Down
7 changes: 5 additions & 2 deletions src/PyTorchHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ def load_lua_class(lua_filename, lua_classname):
module = lua_filename.replace('.lua', '')
PyTorch.require(module)
class LuaWrapper(PyTorchAug.LuaClass):
def __init__(self, _fromLua=False):
def __init__(self, *args, _fromLua=False):
#print('calling super constructor with', args)
#super(LuaWrapper, self).__init__(*args)
self.luaclass = lua_classname
if not _fromLua:
name = lua_classname
super(self.__class__, self).__init__([name])
super(self.__class__, self).__init__(*args, nameList=[name])
else:
self.__dict__['__objectId'] = getNextObjectId()
# self.__getattr__('__init')(*args)
return LuaWrapper

0 comments on commit 7522c32

Please sign in to comment.