Skip to content

Commit

Permalink
Automated formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Aug 15, 2023
1 parent 023c8dc commit e01f515
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 42 deletions.
8 changes: 4 additions & 4 deletions example/Example-Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,11 @@
plt.show()

# plot RSI with custom arguments
rsi(dis, oversold = 20, overbought = 80)
rsi(dis, oversold=20, overbought=80)
plt.show()

# plot RSI standalone graph
rsi(dis, oversold = 20, overbought = 80, standalone=True)
rsi(dis, oversold=20, overbought=80, standalone=True)
plt.show()

# <codecell>
Expand All @@ -353,9 +353,9 @@
plt.show()

# plot MACD using custom arguments
macd(dis, longer_ema_window = 30, shorter_ema_window = 15, signal_ema_window = 10)
macd(dis, longer_ema_window=30, shorter_ema_window=15, signal_ema_window=10)
plt.show()

# plot MACD standalone graph
macd(standlone = True)
macd(standlone=True)
plt.show()
52 changes: 33 additions & 19 deletions finquant/momentum_indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import matplotlib.pyplot as plt
import pandas as pd

def relative_strength_index(data, window_length: int = 14, oversold: int = 30,
overbought: int = 70, standalone: bool = False) -> None:

""" Computes and visualizes a RSI graph,
def relative_strength_index(
data,
window_length: int = 14,
oversold: int = 30,
overbought: int = 70,
standalone: bool = False,
) -> None:
"""Computes and visualizes a RSI graph,
plotted along with the prices in another sub-graph
for comparison.
Expand Down Expand Up @@ -68,24 +73,33 @@ def relative_strength_index(data, window_length: int = 14, oversold: int = 30,
# Single plot
fig = plt.figure()
ax = fig.add_subplot(111)
ax.axhline(y = oversold, color = 'g', linestyle = '--')
ax.axhline(y = overbought, color = 'r', linestyle ='--')
data['rsi'].plot(ylabel = 'RSI', xlabel = 'Date', ax = ax, grid = True)
ax.axhline(y=oversold, color="g", linestyle="--")
ax.axhline(y=overbought, color="r", linestyle="--")
data["rsi"].plot(ylabel="RSI", xlabel="Date", ax=ax, grid=True)
plt.title("RSI Plot")
plt.legend()
else:
# RSI against price in 2 plots
fig, ax = plt.subplots(2, 1, sharex=True, sharey=False)
ax[0].axhline(y = oversold, color = 'g', linestyle = '--')
ax[0].axhline(y = overbought, color = 'r', linestyle ='--')
ax[0].set_title('RSI + Price Plot')
ax[0].axhline(y=oversold, color="g", linestyle="--")
ax[0].axhline(y=overbought, color="r", linestyle="--")
ax[0].set_title("RSI + Price Plot")
# plot 2 graphs in 2 colors
colors = plt.rcParams["axes.prop_cycle"]()
data['rsi'].plot(ylabel = 'RSI', ax = ax[0], grid = True, color = next(colors)["color"], legend=True)
data[stock].plot(xlabel = 'Date', ylabel = 'Price', ax = ax[1], grid = True,
color = next(colors)["color"], legend = True)
data["rsi"].plot(
ylabel="RSI", ax=ax[0], grid=True, color=next(colors)["color"], legend=True
)
data[stock].plot(
xlabel="Date",
ylabel="Price",
ax=ax[1],
grid=True,
color=next(colors)["color"],
legend=True,
)
plt.legend()



def macd(
data,
longer_ema_window: int = 26,
Expand Down Expand Up @@ -121,8 +135,8 @@ def macd(
if longer_ema_window < shorter_ema_window:
raise ValueError("longer ema window should be > shorter ema window")
if longer_ema_window < signal_ema_window:
raise ValueError("longer ema window should be > signal ema window")
raise ValueError("longer ema window should be > signal ema window")

# converting data to pd.DataFrame if it is a pd.Series (for subsequent function calls):
if isinstance(data, pd.Series):
data = data.to_frame()
Expand Down Expand Up @@ -171,9 +185,9 @@ def macd(

for i, key in enumerate(hist.index):
if hist[key] < 0:
ax.bar(data.index[i], hist[key], color = 'orange')
ax.bar(data.index[i], hist[key], color="orange")
else:
ax.bar(data.index[i], hist[key], color = 'black')
ax.bar(data.index[i], hist[key], color="black")
else:
# RSI against price in 2 plots
fig, ax = plt.subplots(2, 1, sharex=True, sharey=False)
Expand All @@ -197,9 +211,9 @@ def macd(

for i, key in enumerate(hist.index):
if hist[key] < 0:
ax.bar(data.index[i], hist[key], color = 'orange')
ax.bar(data.index[i], hist[key], color="orange")
else:
ax.bar(data.index[i], hist[key], color = 'black')
ax.bar(data.index[i], hist[key], color="black")

data[stock].plot(
xlabel="Date",
Expand Down
40 changes: 21 additions & 19 deletions tests/test_momentum_indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import numpy as np
import pandas as pd

from finquant.momentum_indicators import (
relative_strength_index as rsi,
macd,
)
from finquant.momentum_indicators import macd
from finquant.momentum_indicators import relative_strength_index as rsi

plt.switch_backend("Agg")


def test_rsi():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
Expand All @@ -18,23 +17,24 @@ def test_rsi():
rsi(df)
# get data from axis object
ax = plt.gca()
# ax.lines[0] is the data we passed to plot_bollinger_band
# ax.lines[0] is the data we passed to plot_bollinger_band
line1 = ax.lines[0]
stock_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
# tests
assert (df['Stock'].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].values == stock_plot[:, 1]).all()
assert xlabel_orig == xlabel_plot
assert ylabel_orig == ylabel_plot



def test_rsi_standalone():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "RSI"
labels_orig = ['rsi']
title_orig = 'RSI Plot'
labels_orig = ["rsi"]
title_orig = "RSI Plot"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
rsi(df, standalone=True)
Expand All @@ -45,20 +45,21 @@ def test_rsi_standalone():
rsi_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
print (xlabel_plot, ylabel_plot)
print(xlabel_plot, ylabel_plot)
# tests
assert (df['rsi'].index.values == rsi_plot[:, 0]).all()
assert (df["rsi"].index.values == rsi_plot[:, 0]).all()
# for comparing values, we need to remove nan
a, b = df['rsi'].values, rsi_plot[:, 1]
a, b = df["rsi"].values, rsi_plot[:, 1]
a, b = map(lambda x: x[~np.isnan(x)], (a, b))
assert (a == b).all()
labels_plot = ax.get_legend_handles_labels()[1]
title_plot = ax.get_title()
assert labels_plot == labels_orig
assert xlabel_plot == xlabel_orig
assert ylabel_plot == ylabel_orig
assert ylabel_plot == ylabel_orig
assert title_plot == title_orig


def test_macd():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
Expand All @@ -68,19 +69,20 @@ def test_macd():
macd(df)
# get data from axis object
ax = plt.gca()
# ax.lines[0] is the data we passed to plot_bollinger_band
# ax.lines[0] is the data we passed to plot_bollinger_band
line1 = ax.lines[0]
stock_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
# tests
assert (df['Stock'].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].values == stock_plot[:, 1]).all()
assert xlabel_orig == xlabel_plot
assert ylabel_orig == ylabel_plot



def test_macd_standalone():
labels_orig = ['MACD', 'diff', 'SIGNAL']
labels_orig = ["MACD", "diff", "SIGNAL"]
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "MACD"
Expand All @@ -94,12 +96,12 @@ def test_macd_standalone():
ylabel_plot = ax.get_ylabel()
assert labels_plot == labels_orig
assert xlabel_plot == xlabel_orig
assert ylabel_plot == ylabel_orig
assert ylabel_plot == ylabel_orig
# ax.lines[0] is macd data
# ax.lines[1] is diff data
# ax.lines[2] is macd_s data
# tests
for i, key in ((0, 'macd'), (1, 'diff'), (2, 'macd_s')):
for i, key in ((0, "macd"), (1, "diff"), (2, "macd_s")):
line = ax.lines[i]
data_plot = line.get_xydata()
# tests
Expand Down

0 comments on commit e01f515

Please sign in to comment.