Skip to content

Commit

Permalink
update to more stagehand methods with succes bool handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Filip Michalsky committed Sep 26, 2024
1 parent fe6569f commit 16b8ea9
Showing 1 changed file with 106 additions and 65 deletions.
171 changes: 106 additions & 65 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,30 +120,41 @@ export class Stagehand {
await download.delete();
}

async init({ modelName = "gpt-4o" }: { modelName?: string } = {}) {
const { context } = await getBrowser(this.env, this.headless);
this.context = context;
this.page = context.pages()[0];
this.defaultModelName = modelName;
async init({ modelName = "gpt-4o" }: { modelName?: string } = {}): Promise<{ success: boolean; error?: string }> {
try {
const { context } = await getBrowser(this.env, this.headless);
this.context = context;
this.page = context.pages()[0];
this.defaultModelName = modelName;

// Set the browser to headless mode if specified
if (this.headless) {
await this.page.setViewportSize({ width: 1280, height: 720 });
}

// Set the browser to headless mode if specified
if (this.headless) {
await this.page.setViewportSize({ width: 1280, height: 720 });
}
// This can be greatly improved, but the tldr is we put our built web scripts in dist, which should always
// be one level above our running directly across evals, example, and as a package
await this.page.addInitScript({
path: path.join(__dirname, "..", "dist", "dom", "build", "process.js"),
});

// This can be greatly improved, but the tldr is we put our built web scripts in dist, which should always
// be one level above our running directly across evals, example, and as a package
await this.page.addInitScript({
path: path.join(__dirname, "..", "dist", "dom", "build", "process.js"),
});
await this.page.addInitScript({
path: path.join(__dirname, "..", "dist", "dom", "build", "utils.js"),
});

await this.page.addInitScript({
path: path.join(__dirname, "..", "dist", "dom", "build", "utils.js"),
});
await this.page.addInitScript({
path: path.join(__dirname, "..", "dist", "dom", "build", "debug.js"),
});

await this.page.addInitScript({
path: path.join(__dirname, "..", "dist", "dom", "build", "debug.js"),
});
return { success: true };
} catch (error) {
this.log({
category: "init",
message: `Error during initialization: ${error.message}`,
level: 1,
});
return { success: false, error: error.message };
}
}

async waitForSettledDom() {
Expand Down Expand Up @@ -200,54 +211,63 @@ export class Stagehand {
content?: z.infer<T>;
chunksSeen?: Array<number>;
modelName?: string;
}): Promise<z.infer<T>> {
this.log({
category: "extraction",
message: `starting extraction ${instruction}`,
level: 1
});
}): Promise<{ success: boolean; data?: z.infer<T>; error?: string }> {
try {
this.log({
category: "extraction",
message: `starting extraction ${instruction}`,
level: 1,
});

await this.waitForSettledDom();
await this.startDomDebug();
const { outputString, chunk, chunks } = await this.page.evaluate(() =>
window.processDom([])
);
await this.waitForSettledDom();
await this.startDomDebug();
const { outputString, chunk, chunks } = await this.page.evaluate(() =>
window.processDom([])
);

const extractionResponse = await extract({
instruction,
progress,
domElements: outputString,
llmProvider: this.llmProvider,
schema,
modelName: modelName || this.defaultModelName,
});
const { progress: newProgress, completed, ...output } = extractionResponse;
await this.cleanupDomDebug();
const extractionResponse = await extract({
instruction,
progress,
domElements: outputString,
llmProvider: this.llmProvider,
schema,
modelName: modelName || this.defaultModelName,
});
const { progress: newProgress, completed, ...output } = extractionResponse;
await this.cleanupDomDebug();

chunksSeen.push(chunk);
chunksSeen.push(chunk);

if (completed || chunksSeen.length === chunks.length) {
this.log({
category: "extraction",
message: `response: ${JSON.stringify(extractionResponse)}`,
level: 1
});
if (completed || chunksSeen.length === chunks.length) {
this.log({
category: "extraction",
message: `response: ${JSON.stringify(extractionResponse)}`,
level: 1
});

return merge(content, output);
} else {
return { success: true, data: merge(content, output) };
} else {
this.log({
category: "extraction",
message: `continuing extraction, progress: ${progress + newProgress + ", "}`,
level: 1
});
return this.extract({
instruction,
schema,
progress: progress + newProgress + ", ",
content: merge(content, output),
chunksSeen,
modelName,
});
}
} catch (error) {
this.log({
category: "extraction",
message: `continuing extraction, progress: ${progress + newProgress + ", "}`,
level: 1
});
return this.extract({
instruction,
schema,
progress: progress + newProgress + ", ",
content: merge(content, output),
chunksSeen,
modelName,
message: `Error during extraction: ${error.message}`,
level: 1,
});
return { success: false, error: error.message };
}
}

Expand Down Expand Up @@ -312,12 +332,33 @@ export class Stagehand {

return { success: true, result: observationId };
}
async ask(question: string, modelName?: string): Promise<string | null> {
return ask({
question,
llmProvider: this.llmProvider,
modelName: modelName || this.defaultModelName,
async ask(question: string, modelName?: string): Promise<{ success: boolean; result?: string; error?: string }> {
this.log({
category: "ask",
message: `Asking question: ${question}`,
level: 1,
});

try {
const response = await ask({
question,
llmProvider: this.llmProvider,
modelName: modelName || this.defaultModelName,
});

if (!response) {
throw new Error("No response from LLM");
}

return { success: true, result: response };
} catch (error) {
this.log({
category: "ask",
message: `Error during ask: ${error.message}`,
level: 1,
});
return { success: false, error: error.message };
}
}

async recordObservation(
Expand Down

0 comments on commit 16b8ea9

Please sign in to comment.