Skip to content

Commit

Permalink
Merge pull request replicate#84 from replicate/bring-your-own-token
Browse files Browse the repository at this point in the history
  • Loading branch information
cbh123 authored Mar 8, 2024
2 parents 3107992 + 2d0a01d commit 7b5cfaa
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 73 deletions.
7 changes: 1 addition & 6 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,4 @@ DATABASE_URL=
# so the Replicate API can send it webhooks
#
# e.g. https://8db01fea81ad.ngrok.io
NGROK_HOST=


# Optional: Set a value for this to run the private Replicate deployment of the model
# instead of the public ControlNet Scribble model.
USE_REPLICATE_DEPLOYMENT=
NGROK_HOST=
57 changes: 57 additions & 0 deletions components/welcome.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import Link from "next/link";

export default function Welcome({ handleTokenSubmit }) {
return (
<div className="landing-page container flex-column mx-auto mt-24">
<div className="hero mx-auto">
<div className="hero-text text-center max-w-3xl mx-auto">
<h1 className="text-4xl font-bold mb-4">
<a href="https://replicate.com?utm_source=project&utm_campaign=scribblediffusion">
Scribble Diffusion turns your sketch into a refined image using
AI.
</a>
</h1>
</div>

<div className="mt-12 max-w-xl mx-auto text-center">
<p className="text-base text-gray-500">
To get started, grab your{" "}
<Link
className="underline"
href="https://replicate.com/account/api-tokens?utm_campaign=scribblediffusion-diy&utm_source=project"
target="_blank"
rel="noopener noreferrer"
>
Replicate API token
</Link>{" "}
and paste it here:
</p>

<form onSubmit={handleTokenSubmit}>
<label htmlFor="api-key" className="sr-only">
API token
</label>
<input
type="text"
name="api-key"
id="api-key"
className="block mt-6 w-full p-3 rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 text-xl"
placeholder="r8_..."
minLength="40"
maxLength="40"
required
/>
<div className="mt-5 sm:mt-6 sm:gap-3">
<button
type="submit"
className="inline-flex w-full justify-center rounded-md bg-black p-3 text-xl text-sm font-semibold text-white shadow-sm hover:bg-gray-700 focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-gray-600 sm:col-start-2"
>
Start scribbling &rarr;
</button>
</div>
</form>
</div>
</div>
</div>
);
}
12 changes: 7 additions & 5 deletions pages/api/predictions/[id].js
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ import { NextResponse } from "next/server";
import Replicate from "replicate";
import packageData from "../../../package.json";

const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN,
userAgent: `${packageData.name}/${packageData.version}`,
});

export default async function handler(req) {
const authHeader = req.headers.get("authorization");
const replicate_api_token = authHeader.split(" ")[1]; // Assuming a "Bearer" token

const replicate = new Replicate({
auth: replicate_api_token,
userAgent: `${packageData.name}/${packageData.version}`,
});
const predictionId = req.nextUrl.searchParams.get("id");
const prediction = await replicate.predictions.get(predictionId);

Expand Down
46 changes: 13 additions & 33 deletions pages/api/predictions/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ import { NextResponse } from "next/server";
import Replicate from "replicate";
import packageData from "../../../package.json";

const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN,
userAgent: `${packageData.name}/${packageData.version}`,
});

async function getObjectFromRequestBodyStream(body) {
const input = await body.getReader().read();
const decoder = new TextDecoder();
Expand All @@ -19,40 +14,25 @@ const WEBHOOK_HOST = process.env.VERCEL_URL
: process.env.NGROK_HOST;

export default async function handler(req) {
if (!process.env.REPLICATE_API_TOKEN) {
throw new Error(
"The REPLICATE_API_TOKEN environment variable is not set. See README.md for instructions on how to set it."
);
}

const input = await getObjectFromRequestBodyStream(req.body);

// Destructure to extract replicate_api_token and keep the rest of the properties in input
const { replicate_api_token, ...restInput } = input;

const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN,
auth: replicate_api_token,
userAgent: `${packageData.name}/${packageData.version}`,
});

let prediction;

if (process.env.USE_REPLICATE_DEPLOYMENT) {
prediction = await replicate.deployments.predictions.create(
"replicate",
"scribble-diffusion-jagilley-controlnet",
{
input,
webhook: `${WEBHOOK_HOST}/api/replicate-webhook`,
webhook_events_filter: ["start", "completed"],
}
);
} else {
// https://replicate.com/jagilley/controlnet-scribble/versions
prediction = await replicate.predictions.create({
version:
"435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117",
input,
webhook: `${WEBHOOK_HOST}/api/replicate-webhook`,
webhook_events_filter: ["start", "completed"],
});
}
// https://replicate.com/jagilley/controlnet-scribble/versions
const prediction = await replicate.predictions.create({
version:
"435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117",
input,
webhook: `${WEBHOOK_HOST}/api/replicate-webhook`,
webhook_events_filter: ["start", "completed"],
});

if (prediction?.error) {
return NextResponse.json({ detail: prediction.error }, { status: 500 });
Expand Down
87 changes: 58 additions & 29 deletions pages/index.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import Canvas from "components/canvas";
import PromptForm from "components/prompt-form";
import Head from "next/head";
import Link from "next/link";
import { useState } from "react";
import { useState, useEffect } from "react";
import Predictions from "components/predictions";
import Error from "components/error";
import Welcome from "components/welcome";
import uploadFile from "lib/upload";
import naughtyWords from "naughty-words";
import Script from "next/script";
Expand All @@ -25,6 +25,7 @@ export default function Home() {
const [seed] = useState(seeds[Math.floor(Math.random() * seeds.length)]);
const [initialPrompt] = useState(seed.prompt);
const [scribble, setScribble] = useState(null);
const [welcomeOpen, setWelcomeOpen] = useState(false);

const handleSubmit = async (e) => {
e.preventDefault();
Expand All @@ -46,6 +47,7 @@ export default function Home() {
prompt,
image: fileUrl,
structure: "scribble",
replicate_api_token: localStorage.getItem("replicate_api_token"),
};

const response = await fetch("/api/predictions", {
Expand All @@ -72,7 +74,13 @@ export default function Home() {
prediction.status !== "failed"
) {
await sleep(500);
const response = await fetch("/api/predictions/" + prediction.id);
const response = await fetch("/api/predictions/" + prediction.id, {
headers: {
Authorization: `Bearer ${localStorage.getItem(
"replicate_api_token"
)}`,
},
});
prediction = await response.json();
setPredictions((predictions) => ({
...predictions,
Expand All @@ -87,6 +95,23 @@ export default function Home() {
setIsProcessing(false);
};

const handleTokenSubmit = (e) => {
e.preventDefault();
console.log(e.target[0].value);
localStorage.setItem("replicate_api_token", e.target[0].value);
setWelcomeOpen(false);
};

useEffect(() => {
const replicateApiToken = localStorage.getItem("replicate_api_token");

if (replicateApiToken) {
setWelcomeOpen(false);
} else {
setWelcomeOpen(true);
}
}, []);

return (
<>
<Head>
Expand All @@ -102,32 +127,36 @@ export default function Home() {
<link rel="icon" href="/favicon.svg" type="image/svg+xml" />
</Head>
<main className="container max-w-[1024px] mx-auto p-5 ">
<div className="container max-w-[512px] mx-auto">
<hgroup>
<h1 className="text-center text-5xl font-bold m-4">
{pkg.appName}
</h1>
<p className="text-center text-xl opacity-60 m-4">
{pkg.appSubtitle}
</p>
</hgroup>

<Canvas
startingPaths={seed.paths}
onScribble={setScribble}
scribbleExists={scribbleExists}
setScribbleExists={setScribbleExists}
/>

<PromptForm
initialPrompt={initialPrompt}
onSubmit={handleSubmit}
isProcessing={isProcessing}
scribbleExists={scribbleExists}
/>

<Error error={error} />
</div>
{welcomeOpen ? (
<Welcome handleTokenSubmit={handleTokenSubmit} />
) : (
<div className="container max-w-[512px] mx-auto">
<hgroup>
<h1 className="text-center text-5xl font-bold m-4">
{pkg.appName}
</h1>
<p className="text-center text-xl opacity-60 m-4">
{pkg.appSubtitle}
</p>
</hgroup>

<Canvas
startingPaths={seed.paths}
onScribble={setScribble}
scribbleExists={scribbleExists}
setScribbleExists={setScribbleExists}
/>

<PromptForm
initialPrompt={initialPrompt}
onSubmit={handleSubmit}
isProcessing={isProcessing}
scribbleExists={scribbleExists}
/>

<Error error={error} />
</div>
)}

<Predictions
predictions={predictions}
Expand Down

0 comments on commit 7b5cfaa

Please sign in to comment.