diff --git a/.env.example b/.env.example index 0cde9f3..f259d4a 100644 --- a/.env.example +++ b/.env.example @@ -9,9 +9,4 @@ DATABASE_URL= # so the Replicate API can send it webhooks # # e.g. https://8db01fea81ad.ngrok.io -NGROK_HOST= - - -# Optional: Set a value for this to run the private Replicate deployment of the model -# instead of the public ControlNet Scribble model. -USE_REPLICATE_DEPLOYMENT= \ No newline at end of file +NGROK_HOST= \ No newline at end of file diff --git a/components/welcome.js b/components/welcome.js new file mode 100644 index 0000000..162238b --- /dev/null +++ b/components/welcome.js @@ -0,0 +1,57 @@ +import Link from "next/link"; + +export default function Welcome({ handleTokenSubmit }) { + return ( +
+
+
+

+ + Scribble Diffusion turns your sketch into a refined image using + AI. + +

+
+ +
+

+ To get started, grab your{" "} + + Replicate API token + {" "} + and paste it here: +

+ +
+ + +
+ +
+
+
+
+
+ ); +} diff --git a/pages/api/predictions/[id].js b/pages/api/predictions/[id].js index 9e3163c..35abb43 100644 --- a/pages/api/predictions/[id].js +++ b/pages/api/predictions/[id].js @@ -2,12 +2,14 @@ import { NextResponse } from "next/server"; import Replicate from "replicate"; import packageData from "../../../package.json"; -const replicate = new Replicate({ - auth: process.env.REPLICATE_API_TOKEN, - userAgent: `${packageData.name}/${packageData.version}`, -}); - export default async function handler(req) { + const authHeader = req.headers.get("authorization"); + const replicate_api_token = authHeader.split(" ")[1]; // Assuming a "Bearer" token + + const replicate = new Replicate({ + auth: replicate_api_token, + userAgent: `${packageData.name}/${packageData.version}`, + }); const predictionId = req.nextUrl.searchParams.get("id"); const prediction = await replicate.predictions.get(predictionId); diff --git a/pages/api/predictions/index.js b/pages/api/predictions/index.js index 010105d..c202dea 100644 --- a/pages/api/predictions/index.js +++ b/pages/api/predictions/index.js @@ -2,11 +2,6 @@ import { NextResponse } from "next/server"; import Replicate from "replicate"; import packageData from "../../../package.json"; -const replicate = new Replicate({ - auth: process.env.REPLICATE_API_TOKEN, - userAgent: `${packageData.name}/${packageData.version}`, -}); - async function getObjectFromRequestBodyStream(body) { const input = await body.getReader().read(); const decoder = new TextDecoder(); @@ -19,40 +14,25 @@ const WEBHOOK_HOST = process.env.VERCEL_URL : process.env.NGROK_HOST; export default async function handler(req) { - if (!process.env.REPLICATE_API_TOKEN) { - throw new Error( - "The REPLICATE_API_TOKEN environment variable is not set. See README.md for instructions on how to set it." - ); - } - const input = await getObjectFromRequestBodyStream(req.body); + // Destructure to extract replicate_api_token and keep the rest of the properties in input + const { replicate_api_token, ...restInput } = input; + const replicate = new Replicate({ - auth: process.env.REPLICATE_API_TOKEN, + auth: replicate_api_token, + userAgent: `${packageData.name}/${packageData.version}`, }); - let prediction; - if (process.env.USE_REPLICATE_DEPLOYMENT) { - prediction = await replicate.deployments.predictions.create( - "replicate", - "scribble-diffusion-jagilley-controlnet", - { - input, - webhook: `${WEBHOOK_HOST}/api/replicate-webhook`, - webhook_events_filter: ["start", "completed"], - } - ); - } else { - // https://replicate.com/jagilley/controlnet-scribble/versions - prediction = await replicate.predictions.create({ - version: - "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117", - input, - webhook: `${WEBHOOK_HOST}/api/replicate-webhook`, - webhook_events_filter: ["start", "completed"], - }); - } + // https://replicate.com/jagilley/controlnet-scribble/versions + const prediction = await replicate.predictions.create({ + version: + "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117", + input, + webhook: `${WEBHOOK_HOST}/api/replicate-webhook`, + webhook_events_filter: ["start", "completed"], + }); if (prediction?.error) { return NextResponse.json({ detail: prediction.error }, { status: 500 }); diff --git a/pages/index.js b/pages/index.js index b12d439..abca2e8 100644 --- a/pages/index.js +++ b/pages/index.js @@ -1,10 +1,10 @@ import Canvas from "components/canvas"; import PromptForm from "components/prompt-form"; import Head from "next/head"; -import Link from "next/link"; -import { useState } from "react"; +import { useState, useEffect } from "react"; import Predictions from "components/predictions"; import Error from "components/error"; +import Welcome from "components/welcome"; import uploadFile from "lib/upload"; import naughtyWords from "naughty-words"; import Script from "next/script"; @@ -25,6 +25,7 @@ export default function Home() { const [seed] = useState(seeds[Math.floor(Math.random() * seeds.length)]); const [initialPrompt] = useState(seed.prompt); const [scribble, setScribble] = useState(null); + const [welcomeOpen, setWelcomeOpen] = useState(false); const handleSubmit = async (e) => { e.preventDefault(); @@ -46,6 +47,7 @@ export default function Home() { prompt, image: fileUrl, structure: "scribble", + replicate_api_token: localStorage.getItem("replicate_api_token"), }; const response = await fetch("/api/predictions", { @@ -72,7 +74,13 @@ export default function Home() { prediction.status !== "failed" ) { await sleep(500); - const response = await fetch("/api/predictions/" + prediction.id); + const response = await fetch("/api/predictions/" + prediction.id, { + headers: { + Authorization: `Bearer ${localStorage.getItem( + "replicate_api_token" + )}`, + }, + }); prediction = await response.json(); setPredictions((predictions) => ({ ...predictions, @@ -87,6 +95,23 @@ export default function Home() { setIsProcessing(false); }; + const handleTokenSubmit = (e) => { + e.preventDefault(); + console.log(e.target[0].value); + localStorage.setItem("replicate_api_token", e.target[0].value); + setWelcomeOpen(false); + }; + + useEffect(() => { + const replicateApiToken = localStorage.getItem("replicate_api_token"); + + if (replicateApiToken) { + setWelcomeOpen(false); + } else { + setWelcomeOpen(true); + } + }, []); + return ( <> @@ -102,32 +127,36 @@ export default function Home() {
-
-
-

- {pkg.appName} -

-

- {pkg.appSubtitle} -

-
- - - - - - -
+ {welcomeOpen ? ( + + ) : ( +
+
+

+ {pkg.appName} +

+

+ {pkg.appSubtitle} +

+
+ + + + + + +
+ )}