Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Digit Classifier Migration to Superthin Template #2690

Open
wants to merge 8 commits into
base: humble-devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os.path
from typing import Callable

from src.manager.libs.applications.compatibility.exercise_wrapper_ros2 import CompatibilityExerciseWrapperRos2


class Exercise(CompatibilityExerciseWrapperRos2):
def __init__(self, circuit: str, update_callback: Callable):
current_path = os.path.dirname(__file__)

super(Exercise, self).__init__(exercise_command=f"{current_path}/../../python_template/ros2_humble/exercise.py 0.0.0.0",
gui_command=f"{current_path}/../../python_template/ros2_humble/gui.py 0.0.0.0 {circuit}",
update_callback=update_callback)
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import json
import os
import rclpy
import cv2
import sys
import base64
import threading
import time
import numpy as np
from datetime import datetime
import websocket
import subprocess
import logging

from hal_interfaces.general.odometry import OdometryNode
from console_interfaces.general.console import start_console


# Graphical User Interface Class
class GUI:
# Initialization function
# The actual initialization
def __init__(self, host):

self.payload = {'image': '', 'shape': []}

# ROS2 init
if not rclpy.ok():
rclpy.init(args=None)


# Image variables
self.image_to_be_shown = None
self.image_to_be_shown_updated = False
self.image_show_lock = threading.Lock()
self.host = host
self.client = None



self.ack = False
self.ack_lock = threading.Lock()

# Create the lap object
# TODO: maybe move this to HAL and have it be hybrid


self.client_thread = threading.Thread(target=self.run_websocket)
self.client_thread.start()

def run_websocket(self):
while True:
print("GUI WEBSOCKET CONNECTED")
self.client = websocket.WebSocketApp(self.host, on_message=self.on_message)
self.client.run_forever(ping_timeout=None, ping_interval=0)

# Function to prepare image payload
# Encodes the image as a JSON string and sends through the WS
def payloadImage(self):
with self.image_show_lock:
image_to_be_shown_updated = self.image_to_be_shown_updated
image_to_be_shown = self.image_to_be_shown

image = image_to_be_shown
payload = {'image': '', 'shape': ''}

if not image_to_be_shown_updated:
return payload

shape = image.shape
frame = cv2.imencode('.JPEG', image)[1]
encoded_image = base64.b64encode(frame)

payload['image'] = encoded_image.decode('utf-8')
payload['shape'] = shape
with self.image_show_lock:
self.image_to_be_shown_updated = False

return payload

# Function for student to call
def showImage(self, image):
with self.image_show_lock:
self.image_to_be_shown = image
self.image_to_be_shown_updated = True

# Update the gui
def update_gui(self):
# print("GUI update")
# Payload Image Message
payload = self.payloadImage()
self.payload["image"] = json.dumps(payload)


message = json.dumps(self.payload)
if self.client:
try:
self.client.send(message)
# print(message)
except Exception as e:
print(f"Error sending message: {e}")

def on_message(self, ws, message):
"""Handles incoming messages from the websocket client."""
if message.startswith("#ack"):
# print("on message" + str(message))
self.set_acknowledge(True)

def get_acknowledge(self):
"""Gets the acknowledge status."""
with self.ack_lock:
ack = self.ack

return ack

def set_acknowledge(self, value):
"""Sets the acknowledge status."""
with self.ack_lock:
self.ack = value


class ThreadGUI:
"""Class to manage GUI updates and frequency measurements in separate threads."""

def __init__(self, gui):
"""Initializes the ThreadGUI with a reference to the GUI instance."""
self.gui = gui
self.ideal_cycle = 80
self.real_time_factor = 0
self.frequency_message = {'brain': '', 'gui': ''}
self.iteration_counter = 0
self.running = True

def start(self):
"""Starts the GUI, frequency measurement, and real-time factor threads."""
self.frequency_thread = threading.Thread(target=self.measure_and_send_frequency)
self.gui_thread = threading.Thread(target=self.run)
self.frequency_thread.start()
self.gui_thread.start()
print("GUI Thread Started!")

def measure_and_send_frequency(self):
"""Measures and sends the frequency of GUI updates and brain cycles."""
previous_time = datetime.now()
while self.running:
time.sleep(2)

current_time = datetime.now()
dt = current_time - previous_time
ms = (dt.days * 24 * 60 * 60 + dt.seconds) * 1000 + dt.microseconds / 1000.0
previous_time = current_time
measured_cycle = ms / self.iteration_counter if self.iteration_counter > 0 else 0
self.iteration_counter = 0
brain_frequency = round(1000 / measured_cycle, 1) if measured_cycle != 0 else 0
gui_frequency = round(1000 / self.ideal_cycle, 1)
self.frequency_message = {'brain': brain_frequency, 'gui': gui_frequency}
message = json.dumps(self.frequency_message)
if self.gui.client:
try:
self.gui.client.send(message)
except Exception as e:
print(f"Error sending frequency message: {e}")

def run(self):
"""Main loop to update the GUI at regular intervals."""
while self.running:
start_time = datetime.now()

self.gui.update_gui()
self.iteration_counter += 1
finish_time = datetime.now()

dt = finish_time - start_time
ms = (dt.days * 24 * 60 * 60 + dt.seconds) * 1000 + dt.microseconds / 1000.0
sleep_time = max(0, (50 - ms) / 1000.0)
time.sleep(sleep_time)


# Create a GUI interface
host = "ws://127.0.0.1:2303"
gui_interface = GUI(host)

start_console()

# Spin a thread to keep the interface updated
thread_gui = ThreadGUI(gui_interface)
thread_gui.start()

def showImage(image):
gui_interface.showImage(image)

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
import threading
import cv2

current_frame = None # Global variable to store the frame

class WebcamSubscriber(Node):
MihirGore23 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super().__init__('webcam_subscriber')
self.subscription = self.create_subscription(
Image,
'/image_raw',
self.listener_callback,
10)
self.subscription # prevent unused variable warning
self.bridge = CvBridge()

def listener_callback(self, msg):
global current_frame
self.get_logger().info('Receiving video frame')
current_frame = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')

def run_webcam_node():

webcam_subscriber = WebcamSubscriber()

rclpy.spin(webcam_subscriber)
webcam_subscriber.destroy_node()


# Start the ROS2 node in a separate thread
thread = threading.Thread(target=run_webcam_node)
thread.start()

def getImage():
global current_frame
return current_frame

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Exercise Documentation Website](https://jderobot.github.io/RoboticsAcademy/exercises/ComputerVision/dl_digit_classifier)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import GUI
import HAL
import base64
from datetime import datetime
import json
import sys
import time
import cv2
import numpy as np
import onnxruntime

roi_scale = 0.75
input_size = (28, 28)

# Receive model
raw_dl_model = '/workspace/code/demo_model/mnist_cnn.onnx'

# Load ONNX model
try:
ort_session = onnxruntime.InferenceSession(raw_dl_model)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))
print("ERROR: Model couldn't be loaded")

previous_pred = 0
previous_established_pred = "-"
count_same_digit = 0

while True:

# Get input webcam image
image = HAL.getImage()
if image is not None:
input_image_gray = np.mean(image, axis=2).astype(np.uint8)

# Get original image and ROI dimensions
h_in, w_in = image.shape[:2]
min_dim_in = min(h_in, w_in)
h_roi, w_roi = (int(min_dim_in * roi_scale), int(min_dim_in * roi_scale))
h_border, w_border = (int((h_in - h_roi) / 2.), int((w_in - w_roi) / 2.))

# Extract ROI and convert to tensor format required by the model
roi = input_image_gray[h_border:h_border + h_roi, w_border:w_border + w_roi]
roi_norm = (roi - np.mean(roi)) / np.std(roi)
roi_resized = cv2.resize(roi_norm, input_size)
input_tensor = roi_resized.reshape((1, 1, input_size[0], input_size[1])).astype(np.float32)

# Inference
ort_inputs = {ort_session.get_inputs()[0].name: input_tensor}
output = ort_session.run(None, ort_inputs)[0]
pred = int(np.argmax(output, axis=1)) # get the index of the max log-probability

# Show region used as ROI
cv2.rectangle(image, pt2=(w_border, h_border), pt1=(w_border + w_roi, h_border + h_roi), color=(255, 0, 0), thickness=3)

# Show FPS count
cv2.putText(image, "Pred: {}".format(int(pred)), (7, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

# Send result
GUI.showImage(image)


Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import * as React from "react";
import {Fragment} from "react";

import "./css/DigitClassifierRR.css";

const DigitClassifierRR = (props) => {
return (
<Fragment>
{props.children}
</Fragment>
);
};

export default DigitClassifierRR;
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import * as React from "react";
import { Box } from "@mui/material";
import "./css/GUICanvas.css";
import { drawImage } from "./helpers/showImages";


const DisplayFeed = (props) => {
const [image, setImage] = React.useState(null)
const canvasRef = React.useRef(null)

React.useEffect(() => {
console.log("TestShowScreen subscribing to ['update'] events");
const callback = (message) => {
if(message.data.update.image){
console.log('image')
const image = JSON.parse(message.data.update.image)
if(image.image){
drawImage(message.data.update)
}
}
};

window.RoboticsExerciseComponents.commsManager.subscribe(
[window.RoboticsExerciseComponents.commsManager.events.UPDATE],
callback
);

return () => {
console.log("TestShowScreen unsubscribing from ['state-changed'] events");
window.RoboticsExerciseComponents.commsManager.unsubscribe(
[window.RoboticsExerciseComponents.commsManager.events.UPDATE],
callback
);
};
}, []);

return (
<Box sx={{ height: "100%" }}>
<canvas
ref={canvasRef}
className={"exercise-canvas"}
id="canvas"
></canvas>
</Box>
);
};

DisplayFeed.defaultProps = {
width: 800,
height: 600,
};

export default DisplayFeed
Loading