Skip to content

Commit

Permalink
Merge pull request #102 from ymd-h/Feature_linter
Browse files Browse the repository at this point in the history
  • Loading branch information
keiohta authored Aug 25, 2020
2 parents 0d226ca + 58fb730 commit 5e188a0
Show file tree
Hide file tree
Showing 28 changed files with 48 additions and 49 deletions.
3 changes: 3 additions & 0 deletions .github/linters/.flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
ignore = E129,E226,E303,E501,W503,W504,E741
exclude = .git,__pycache__,egg-info,.eggs,prepare_output_dir.py
14 changes: 14 additions & 0 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Lint Code

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: docker://github/super-linter:v3
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
VALIDATE_PYTHON_FLAKE8: true
PYTHON_FLAKE8_CONFIG_FILE: ".flake8"
5 changes: 4 additions & 1 deletion examples/run_apex_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

# Prepare env and policy function
class env_fn:
def __init__(self,env_name):
def __init__(self, env_name):
self.env_name = env_name

def __call__(self):
return gym.make(self.env_name)


def policy_fn(env, name, memory_capacity=int(1e6),
gpu=-1, noise_level=0.3):
return DDPG(
Expand All @@ -29,6 +30,7 @@ def policy_fn(env, name, memory_capacity=int(1e6),
critic_units=[400, 300],
memory_capacity=memory_capacity)


def get_weights_fn(policy):
# TODO: Check if following needed
import tensorflow as tf
Expand All @@ -37,6 +39,7 @@ def get_weights_fn(policy):
policy.critic.weights,
policy.critic_target.weights]


def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
Expand Down
17 changes: 9 additions & 8 deletions examples/run_apex_dqn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import argparse
import numpy as np
import gym
import tensorflow as tf

Expand All @@ -11,15 +9,16 @@

# Prepare env and policy function
class env_fn:
def __init__(self,env_name):
def __init__(self, env_name):
self.env_name = env_name

def __call__(self):
return gym.make(self.env_name)


class policy_fn:
def __init__(self,args,n_warmup,target_replace_interval,batch_size,
optimizer,epsilon_decay_rate,QFunc):
def __init__(self, args, n_warmup, target_replace_interval, batch_size,
optimizer, epsilon_decay_rate, QFunc):
self.args = args
self.n_warmup = n_warmup
self.target_replace_interval = target_replace_interval
Expand All @@ -28,7 +27,7 @@ def __init__(self,args,n_warmup,target_replace_interval,batch_size,
self.epsilon_decay_rate = epsilon_decay_rate
self.QFunc = QFunc

def __call__(self,env, name, memory_capacity=int(1e6),
def __call__(self, env, name, memory_capacity=int(1e6),
gpu=-1, noise_level=0.3):
return DQN(
name=name,
Expand All @@ -51,10 +50,12 @@ def __call__(self,env, name, memory_capacity=int(1e6),
q_func=self.QFunc,
gpu=gpu)


def get_weights_fn(policy):
return [policy.q_func.weights,
policy.q_func_target.weights]


def set_weights_fn(policy, weights):
q_func_weights, qfunc_target_weights = weights
update_target_variables(
Expand Down Expand Up @@ -90,6 +91,6 @@ def set_weights_fn(policy, weights):
QFunc = None

run(args, env_fn(env_name),
policy_fn(args,n_warmup,target_replace_interval,batch_size,optimizer,
epsilon_decay_rate,QFunc),
policy_fn(args, n_warmup, target_replace_interval, batch_size, optimizer,
epsilon_decay_rate, QFunc),
get_weights_fn, set_weights_fn)
2 changes: 0 additions & 2 deletions examples/run_ppo_atari.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import gym

import numpy as np
import tensorflow as tf

from tf2rl.algos.ppo import PPO
from tf2rl.envs.atari_wrapper import wrap_dqn
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, Extension, find_packages
from setuptools import setup, find_packages

install_requires = [
"cpprb>=8.1.1",
Expand Down
1 change: 0 additions & 1 deletion tests/algos/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .common import CommonAlgos
6 changes: 4 additions & 2 deletions tests/algos/test_apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import unittest
import gym
import numpy as np
import tensorflow as tf

from tf2rl.algos.apex import apex_argument, run
from tf2rl.misc.target_update_ops import update_target_variables
Expand Down Expand Up @@ -42,6 +40,7 @@ def test_run_continuous(self):
def env_fn_discrete():
return gym.make("CartPole-v0")


def policy_fn_discrete(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
from tf2rl.algos.dqn import DQN
return DQN(
Expand All @@ -55,10 +54,12 @@ def policy_fn_discrete(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwa
discount=0.99,
gpu=-1)


def get_weights_fn_discrete(policy):
return [policy.q_func.weights,
policy.q_func_target.weights]


def set_weights_fn_discrete(policy, weights):
q_func_weights, qfunc_target_weights = weights
update_target_variables(
Expand Down Expand Up @@ -95,5 +96,6 @@ def set_weights_fn_continuous(policy, weights):
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)


if __name__ == '__main__':
unittest.main()
2 changes: 0 additions & 2 deletions tests/algos/test_bi_res_ddpg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import unittest
import numpy as np
import tensorflow as tf

from tf2rl.algos.bi_res_ddpg import BiResDDPG
from tests.algos.common import CommonOffPolContinuousAlgos
Expand Down
3 changes: 0 additions & 3 deletions tests/algos/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import unittest
import gym
import numpy as np
import tensorflow as tf

from tf2rl.algos.ppo import PPO
from tests.algos.common import CommonOnPolActorCriticContinuousAlgos, CommonOnPolActorCriticDiscreteAlgos
Expand Down
2 changes: 0 additions & 2 deletions tests/algos/test_td3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import unittest
import numpy as np
import tensorflow as tf

from tf2rl.algos.td3 import TD3
from tests.algos.common import CommonOffPolContinuousAlgos
Expand Down
3 changes: 0 additions & 3 deletions tests/algos/test_vpg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import unittest
import gym
import numpy as np
import tensorflow as tf

from tf2rl.algos.vpg import VPG
from tests.algos.common import CommonOnPolActorCriticContinuousAlgos, CommonOnPolActorCriticDiscreteAlgos
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_diagonal_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestDiagonalGaussian(CommonDist):
def setUpClass(cls):
super().setUpClass()
cls.dist = DiagonalGaussian(dim=cls.dim)
cls.param = param = {
cls.param = {
"mean": np.zeros(shape=(1, cls.dim), dtype=np.float32),
"log_std": np.ones(shape=(1, cls.dim), dtype=np.float32)*np.log(1.)}
cls.params = {
Expand Down
1 change: 1 addition & 0 deletions tests/envs/test_atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tf2rl.envs.atari_wrapper import wrap_dqn


@unittest.skipIf((platform.system() == 'Windows') and (sys.version_info.minor >= 8),
"atari-py doesn't work at Windows with Python3.8 and later")
class TestAtariWrapper(unittest.TestCase):
Expand Down
10 changes: 7 additions & 3 deletions tests/envs/test_multi_thread_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
import gym
import numpy as np
import tensorflow as tf

from tf2rl.envs.multi_thread_env import MultiThreadEnv
Expand All @@ -12,14 +11,19 @@ def setUpClass(cls):
cls.batch_size = 64
cls.thread_pool = 4
cls.max_episode_steps = 1000
def env_fn(): return gym.make("Pendulum-v0")

def env_fn():
return gym.make("Pendulum-v0")

cls.continuous_sample_env = env_fn()
cls.continuous_envs = MultiThreadEnv(
env_fn=env_fn,
batch_size=cls.batch_size,
max_episode_steps=cls.max_episode_steps)

def env_fn(): return gym.make("CartPole-v0")
def env_fn():
return gym.make("CartPole-v0")

cls.discrete_sample_env = env_fn()
cls.discrete_envs = MultiThreadEnv(
env_fn=env_fn,
Expand Down
2 changes: 0 additions & 2 deletions tests/misc/test_get_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import unittest

import os
import numpy as np
import gym

from cpprb import ReplayBuffer
Expand Down
1 change: 0 additions & 1 deletion tests/policies/test_categorical_actor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
import numpy as np
import tensorflow as tf

from tf2rl.policies.categorical_actor import CategoricalActor
from tests.policies.common import CommonModel
Expand Down
1 change: 0 additions & 1 deletion tests/policies/test_gaussian_actor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
import numpy as np
import tensorflow as tf

from tf2rl.policies.gaussian_actor import GaussianActor
from tests.policies.common import CommonModel
Expand Down
2 changes: 1 addition & 1 deletion tf2rl/algos/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import logging
import multiprocessing
from multiprocessing import Process, Queue, Value, Event, Lock
from multiprocessing import Process, Value, Event
from multiprocessing.managers import SyncManager

from cpprb import ReplayBuffer, PrioritizedReplayBuffer
Expand Down
1 change: 0 additions & 1 deletion tf2rl/algos/policy_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import tensorflow as tf


Expand Down
2 changes: 1 addition & 1 deletion tf2rl/algos/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tensorflow as tf
from tensorflow.keras.layers import Dense

from tf2rl.algos.ddpg import DDPG, Actor
from tf2rl.algos.ddpg import DDPG
from tf2rl.misc.target_update_ops import update_target_variables
from tf2rl.misc.huber_loss import huber_loss

Expand Down
1 change: 0 additions & 1 deletion tf2rl/distributions/categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import tensorflow as tf

from tf2rl.distributions.base import Distribution
Expand Down
2 changes: 0 additions & 2 deletions tf2rl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ def get_act_dim(action_space):


def is_mujoco_env(env):
from gym.envs import mujoco
if not hasattr(env, "env"):
return False
return gym.envs.mujoco.mujoco_env.MujocoEnv in env.env.__class__.__bases__


def is_atari_env(env):
from gym.envs import atari
if not hasattr(env, "env"):
return False
return gym.envs.atari.atari_env.AtariEnv == env.env.__class__
2 changes: 0 additions & 2 deletions tf2rl/experiments/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import random

import numpy as np
import joblib
import matplotlib.pyplot as plt
from matplotlib import animation
import tensorflow as tf


def save_path(samples, filename):
Expand Down
1 change: 0 additions & 1 deletion tf2rl/misc/get_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.dict import Dict

from cpprb import ReplayBuffer, PrioritizedReplayBuffer

Expand Down
1 change: 0 additions & 1 deletion tf2rl/misc/huber_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import tensorflow as tf


Expand Down
4 changes: 1 addition & 3 deletions tf2rl/networks/spectral_norm_dense.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense
from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
Expand Down Expand Up @@ -52,7 +50,7 @@ def call(self, inputs):
rank = common_shapes.rank(inputs)
if rank > 2:
# Broadcasting is required for the inputs.
outputs = standard_ops.tensordot(inputs, w, [[rank - 1], [0]])
outputs = tf.tensordot(inputs, w, [[rank - 1], [0]])
# Reshape the output back to the original ndim of the input.
if not context.executing_eagerly():
shape = inputs.get_shape().as_list()
Expand Down
4 changes: 1 addition & 3 deletions tf2rl/tools/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def log_normal_pdf(sample, mean, logvar, raxis=1):
import PIL
import imageio

from IPython import display

(train_images, _), (test_images, _) = tf.keras.datasets.fashion_mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')
Expand Down Expand Up @@ -134,7 +132,7 @@ def log_normal_pdf(sample, mean, logvar, raxis=1):
def generate_and_save_images(model, epoch, test_input):
predictions = model.sample(test_input)
plt.close()
fig = plt.figure(figsize=(4, 4))
plt.figure(figsize=(4, 4))

for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
Expand Down

0 comments on commit 5e188a0

Please sign in to comment.