diff --git a/.env.example b/.env.example
index 0cde9f3..f259d4a 100644
--- a/.env.example
+++ b/.env.example
@@ -9,9 +9,4 @@ DATABASE_URL=
# so the Replicate API can send it webhooks
#
# e.g. https://8db01fea81ad.ngrok.io
-NGROK_HOST=
-
-
-# Optional: Set a value for this to run the private Replicate deployment of the model
-# instead of the public ControlNet Scribble model.
-USE_REPLICATE_DEPLOYMENT=
\ No newline at end of file
+NGROK_HOST=
\ No newline at end of file
diff --git a/components/welcome.js b/components/welcome.js
new file mode 100644
index 0000000..162238b
--- /dev/null
+++ b/components/welcome.js
@@ -0,0 +1,57 @@
+import Link from "next/link";
+
+export default function Welcome({ handleTokenSubmit }) {
+ return (
+
+
+
+
+
+
+ To get started, grab your{" "}
+
+ Replicate API token
+ {" "}
+ and paste it here:
+
+
+
+
+
+
+ );
+}
diff --git a/pages/api/predictions/[id].js b/pages/api/predictions/[id].js
index 9e3163c..35abb43 100644
--- a/pages/api/predictions/[id].js
+++ b/pages/api/predictions/[id].js
@@ -2,12 +2,14 @@ import { NextResponse } from "next/server";
import Replicate from "replicate";
import packageData from "../../../package.json";
-const replicate = new Replicate({
- auth: process.env.REPLICATE_API_TOKEN,
- userAgent: `${packageData.name}/${packageData.version}`,
-});
-
export default async function handler(req) {
+ const authHeader = req.headers.get("authorization");
+ const replicate_api_token = authHeader.split(" ")[1]; // Assuming a "Bearer" token
+
+ const replicate = new Replicate({
+ auth: replicate_api_token,
+ userAgent: `${packageData.name}/${packageData.version}`,
+ });
const predictionId = req.nextUrl.searchParams.get("id");
const prediction = await replicate.predictions.get(predictionId);
diff --git a/pages/api/predictions/index.js b/pages/api/predictions/index.js
index 010105d..c202dea 100644
--- a/pages/api/predictions/index.js
+++ b/pages/api/predictions/index.js
@@ -2,11 +2,6 @@ import { NextResponse } from "next/server";
import Replicate from "replicate";
import packageData from "../../../package.json";
-const replicate = new Replicate({
- auth: process.env.REPLICATE_API_TOKEN,
- userAgent: `${packageData.name}/${packageData.version}`,
-});
-
async function getObjectFromRequestBodyStream(body) {
const input = await body.getReader().read();
const decoder = new TextDecoder();
@@ -19,40 +14,25 @@ const WEBHOOK_HOST = process.env.VERCEL_URL
: process.env.NGROK_HOST;
export default async function handler(req) {
- if (!process.env.REPLICATE_API_TOKEN) {
- throw new Error(
- "The REPLICATE_API_TOKEN environment variable is not set. See README.md for instructions on how to set it."
- );
- }
-
const input = await getObjectFromRequestBodyStream(req.body);
+ // Destructure to extract replicate_api_token and keep the rest of the properties in input
+ const { replicate_api_token, ...restInput } = input;
+
const replicate = new Replicate({
- auth: process.env.REPLICATE_API_TOKEN,
+ auth: replicate_api_token,
+ userAgent: `${packageData.name}/${packageData.version}`,
});
- let prediction;
- if (process.env.USE_REPLICATE_DEPLOYMENT) {
- prediction = await replicate.deployments.predictions.create(
- "replicate",
- "scribble-diffusion-jagilley-controlnet",
- {
- input,
- webhook: `${WEBHOOK_HOST}/api/replicate-webhook`,
- webhook_events_filter: ["start", "completed"],
- }
- );
- } else {
- // https://replicate.com/jagilley/controlnet-scribble/versions
- prediction = await replicate.predictions.create({
- version:
- "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117",
- input,
- webhook: `${WEBHOOK_HOST}/api/replicate-webhook`,
- webhook_events_filter: ["start", "completed"],
- });
- }
+ // https://replicate.com/jagilley/controlnet-scribble/versions
+ const prediction = await replicate.predictions.create({
+ version:
+ "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117",
+ input,
+ webhook: `${WEBHOOK_HOST}/api/replicate-webhook`,
+ webhook_events_filter: ["start", "completed"],
+ });
if (prediction?.error) {
return NextResponse.json({ detail: prediction.error }, { status: 500 });
diff --git a/pages/index.js b/pages/index.js
index b12d439..abca2e8 100644
--- a/pages/index.js
+++ b/pages/index.js
@@ -1,10 +1,10 @@
import Canvas from "components/canvas";
import PromptForm from "components/prompt-form";
import Head from "next/head";
-import Link from "next/link";
-import { useState } from "react";
+import { useState, useEffect } from "react";
import Predictions from "components/predictions";
import Error from "components/error";
+import Welcome from "components/welcome";
import uploadFile from "lib/upload";
import naughtyWords from "naughty-words";
import Script from "next/script";
@@ -25,6 +25,7 @@ export default function Home() {
const [seed] = useState(seeds[Math.floor(Math.random() * seeds.length)]);
const [initialPrompt] = useState(seed.prompt);
const [scribble, setScribble] = useState(null);
+ const [welcomeOpen, setWelcomeOpen] = useState(false);
const handleSubmit = async (e) => {
e.preventDefault();
@@ -46,6 +47,7 @@ export default function Home() {
prompt,
image: fileUrl,
structure: "scribble",
+ replicate_api_token: localStorage.getItem("replicate_api_token"),
};
const response = await fetch("/api/predictions", {
@@ -72,7 +74,13 @@ export default function Home() {
prediction.status !== "failed"
) {
await sleep(500);
- const response = await fetch("/api/predictions/" + prediction.id);
+ const response = await fetch("/api/predictions/" + prediction.id, {
+ headers: {
+ Authorization: `Bearer ${localStorage.getItem(
+ "replicate_api_token"
+ )}`,
+ },
+ });
prediction = await response.json();
setPredictions((predictions) => ({
...predictions,
@@ -87,6 +95,23 @@ export default function Home() {
setIsProcessing(false);
};
+ const handleTokenSubmit = (e) => {
+ e.preventDefault();
+ console.log(e.target[0].value);
+ localStorage.setItem("replicate_api_token", e.target[0].value);
+ setWelcomeOpen(false);
+ };
+
+ useEffect(() => {
+ const replicateApiToken = localStorage.getItem("replicate_api_token");
+
+ if (replicateApiToken) {
+ setWelcomeOpen(false);
+ } else {
+ setWelcomeOpen(true);
+ }
+ }, []);
+
return (
<>
@@ -102,32 +127,36 @@ export default function Home() {
-
-
-
- {pkg.appName}
-
-
- {pkg.appSubtitle}
-
-
-
-
-
-
-
-
-
+ {welcomeOpen ? (
+
+ ) : (
+
+
+
+ {pkg.appName}
+
+
+ {pkg.appSubtitle}
+
+
+
+
+
+
+
+
+
+ )}