Skip to content

Commit

Permalink
Merge pull request #69 from davidmytton/arcjet
Browse files Browse the repository at this point in the history
Added Arcjet security (Shield, rate limit, bot detection)
  • Loading branch information
davidmytton authored Jun 4, 2024
2 parents 0f659c1 + d202c38 commit 47bdac8
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 8 deletions.
3 changes: 3 additions & 0 deletions .env.local.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ NEXT_PUBLIC_CLERK_SIGN_UP_URL=/sign-up
NEXT_PUBLIC_CLERK_AFTER_SIGN_IN_URL=/
NEXT_PUBLIC_CLERK_AFTER_SIGN_UP_URL=/

# Arcjet related environment variables
ARCJET_KEY=ajkey_****

# OpenAI related environment variables
OPENAI_API_KEY=sk-****

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Text Model: [OpenAI](https://platform.openai.com/docs/models)
- Text streaming: [ai sdk](https://github.com/vercel-labs/ai)
- Deployment: [Fly](https://fly.io/)
- Security: [Arcjet](https://arcjet.com/)

## Overview
- 🚀 [Quickstart](#quickstart)
Expand Down Expand Up @@ -76,6 +77,10 @@ e. **Supabase API key**
- `SUPABASE_PRIVATE_KEY` is the key starts with `ey` under Project API Keys
- Now, you should enable pgvector on Supabase and create a schema. You can do this easily by clicking on "SQL editor" on the left hand side on supabase UI and then clicking on "+New Query". Copy paste [this code snippet](https://github.com/a16z-infra/ai-getting-started/blob/main/pgvector.sql) in the SQL editor and click "Run".

f. **Arcjet key**

Visit https://app.arcjet.com to sign up for free and get your Arcjet key.

### 4. Generate embeddings

There are a few markdown files under `/blogs` directory as examples so you can do Q&A on them. To generate embeddings and store them in the vector database for future queries, you can run the following command:
Expand Down
117 changes: 117 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"generate-embeddings-supabase": "node src/scripts/indexBlogPGVector.mjs"
},
"dependencies": {
"@arcjet/next": "^1.0.0-alpha.13",
"@clerk/nextjs": "^4.21.9-snapshot.56dc3e",
"@headlessui/react": "^1.7.15",
"@pinecone-database/pinecone": "^0.1.6",
Expand Down
71 changes: 70 additions & 1 deletion src/app/api/qa-pg-vector/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,79 @@ import dotenv from "dotenv";
import { VectorDBQAChain } from "langchain/chains";
import { StreamingTextResponse, LangChainStream } from "ai";
import { CallbackManager } from "langchain/callbacks";
import { currentUser } from "@clerk/nextjs";
import arcjet, { shield, fixedWindow, detectBot } from "@arcjet/next";
import { NextResponse } from "next/server";

dotenv.config({ path: `.env.local` });

// The arcjet instance is created outside of the handler
const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
// Arcjet Shield protects against common attacks e.g. SQL injection
shield({
mode: "LIVE",
}),
// Create a fixed window rate limit. Other algorithms are supported.
fixedWindow({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
characteristics: ["userId"], // Rate limit based on the Clerk userId
window: "60s", // 60 second fixed window
max: 10, // allow a maximum of 10 requests
}),
// Blocks all automated clients
detectBot({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
block: ["AUTOMATED"],
}),
],
});

export async function POST(req: Request) {
// Get the current user from Clerk
const user = await currentUser();
if (!user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}

// Use Arcjet to protect the route
const decision = await aj.protect(req, { userId: user.id });

if (decision.isDenied()) {
if (decision.reason.isRateLimit()) {
return NextResponse.json(
{
error: "Too Many Requests",
reason: decision.reason,
},
{
status: 429,
},
);
} else if (decision.reason.isBot()) {
return NextResponse.json(
{
error: "Bots are not allowed",
reason: decision.reason,
},
{
status: 403,
},
);
} else {
return NextResponse.json(
{
error: "Unauthorized",
reason: decision.reason,
},
{
status: 401,
},
);
}
}

const { prompt } = await req.json();

const privateKey = process.env.SUPABASE_PRIVATE_KEY;
Expand All @@ -31,7 +100,7 @@ export async function POST(req: Request) {
client,
tableName: "documents",
queryName: "match_documents",
}
},
);

const { stream, handlers } = LangChainStream();
Expand Down
71 changes: 70 additions & 1 deletion src/app/api/qa-pinecone/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,79 @@ import { OpenAI } from "langchain/llms/openai";
import { PineconeStore } from "langchain/vectorstores/pinecone";
import { StreamingTextResponse, LangChainStream } from "ai";
import { CallbackManager } from "langchain/callbacks";
import { currentUser } from "@clerk/nextjs";
import arcjet, { shield, fixedWindow, detectBot } from "@arcjet/next";
import { NextResponse } from "next/server";

dotenv.config({ path: `.env.local` });

// The arcjet instance is created outside of the handler
const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
// Arcjet Shield protects against common attacks e.g. SQL injection
shield({
mode: "LIVE",
}),
// Create a fixed window rate limit. Other algorithms are supported.
fixedWindow({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
characteristics: ["userId"], // Rate limit based on the Clerk userId
window: "60s", // 60 second fixed window
max: 10, // allow a maximum of 10 requests
}),
// Blocks all automated clients
detectBot({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
block: ["AUTOMATED"],
}),
],
});

export async function POST(request: Request) {
// Get the current user from Clerk
const user = await currentUser();
if (!user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}

// Use Arcjet to protect the route
const decision = await aj.protect(request, { userId: user.id });

if (decision.isDenied()) {
if (decision.reason.isRateLimit()) {
return NextResponse.json(
{
error: "Too Many Requests",
reason: decision.reason,
},
{
status: 429,
},
);
} else if (decision.reason.isBot()) {
return NextResponse.json(
{
error: "Bots are not allowed",
reason: decision.reason,
},
{
status: 403,
},
);
} else {
return NextResponse.json(
{
error: "Unauthorized",
reason: decision.reason,
},
{
status: 401,
},
);
}
}

const { prompt } = await request.json();
const client = new PineconeClient();
await client.init({
Expand All @@ -20,7 +89,7 @@ export async function POST(request: Request) {

const vectorStore = await PineconeStore.fromExistingIndex(
new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }),
{ pineconeIndex }
{ pineconeIndex },
);

const { stream, handlers } = LangChainStream();
Expand Down
Loading

0 comments on commit 47bdac8

Please sign in to comment.