Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve variables in prompt templates #116

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class CodeCompletionAgentImpl implements CodeCompletionAgent {
const file = model.uri.toString(false);
const language = model.getLanguageId();

const prompt = this.promptService.getPrompt('code-completion-prompt', { snippet, file, language });
const prompt = await this.promptService.getPrompt('code-completion-prompt', { snippet, file, language });
if (!prompt) {
console.error('No prompt found for code-completion-agent');
return undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
// *****************************************************************************

import 'reflect-metadata';
import { Container } from 'inversify';
import { PromptService, PromptServiceImpl } from '../prompt-service';

import { expect } from 'chai';
import { Container } from 'inversify';
import { PromptService, PromptServiceImpl } from './prompt-service';
import { DefaultAIVariableService, AIVariableService } from './variable-service';

describe('PromptService', () => {
let promptService: PromptService;
Expand All @@ -26,24 +28,34 @@ describe('PromptService', () => {
const container = new Container();
container.bind<PromptService>(PromptService).to(PromptServiceImpl).inSingletonScope();

const variableService = new DefaultAIVariableService({ getContributions: () => [] });
const nameVariable = { id: 'test', name: 'name', description: 'Test name ' };
variableService.registerResolver(nameVariable, {
canResolve: () => 100,
resolve: async () => ({ variable: nameVariable, value: 'Jane' })
});
container.bind<AIVariableService>(AIVariableService).toConstantValue(variableService);

promptService = container.get<PromptService>(PromptService);
promptService.storePrompt('1', 'Hello, ${name}!');
promptService.storePrompt('2', 'Goodbye, ${name}!');
promptService.storePrompt('3', 'Ciao, ${invalid}!');
});

it('should initialize prompts from PromptCollectionService', () => {
const allPrompts = promptService.getAllPrompts();
expect(allPrompts['1'].template).to.equal('Hello, ${name}!');
expect(allPrompts['2'].template).to.equal('Goodbye, ${name}!');
expect(allPrompts['3'].template).to.equal('Ciao, ${invalid}!');
});

it('should retrieve raw prompt by id', () => {
const rawPrompt = promptService.getRawPrompt('1');
expect(rawPrompt?.template).to.equal('Hello, ${name}!');
});

it('should format prompt with provided arguments', () => {
const formattedPrompt = promptService.getPrompt('1', { name: 'John' });
it('should format prompt with provided arguments', async () => {
const formattedPrompt = await promptService.getPrompt('1', { name: 'John' });
expect(formattedPrompt).to.equal('Hello, John!');
});

Expand All @@ -52,4 +64,24 @@ describe('PromptService', () => {
const newPrompt = promptService.getRawPrompt('3');
expect(newPrompt?.template).to.equal('Welcome, ${name}!');
});

it('should replace placeholders with provided arguments', async () => {
const prompt = await promptService.getPrompt('1', { name: 'John' });
expect(prompt).to.equal('Hello, John!');
});

it('should use variable service to resolve placeholders if argument value is not provided', async () => {
const prompt = await promptService.getPrompt('1');
expect(prompt).to.equal('Hello, Jane!');
});

it('should return the prompt even if there are no replacements', async () => {
const prompt = await promptService.getPrompt('3');
expect(prompt).to.equal('Ciao, ${invalid}!');
});

it('should return undefined if the prompt id is not found', async () => {
const prompt = await promptService.getPrompt('4');
expect(prompt).to.be.undefined;
});
});
28 changes: 20 additions & 8 deletions packages/ai-core/src/common/prompt-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import { URI } from '@theia/core';
import { inject, injectable, optional } from '@theia/core/shared/inversify';
import { PromptTemplate } from './types';
import { AIVariableService } from './variable-service';

export interface PromptMap { [id: string]: PromptTemplate }

Expand All @@ -38,7 +39,7 @@ export interface PromptService {
* @param id the id of the prompt
* @param args the object with placeholders, mapping the placeholder key to the value
*/
getPrompt(id: string, args?: { [key: string]: unknown }): string | undefined;
getPrompt(id: string, args?: { [key: string]: unknown }): Promise<string | undefined>;
/**
* Manually add a prompt to the list of prompts.
* @param id the id of the prompt
Expand Down Expand Up @@ -88,12 +89,16 @@ export interface PromptCustomizationService {
getTemplateIDFromURI(uri: URI): string | undefined;
}

const PROMPT_VARIABLE_REGEX = /\$\{([\w_\-]+)\}/g;

@injectable()
export class PromptServiceImpl implements PromptService {

@inject(PromptCustomizationService) @optional()
protected readonly customizationService: PromptCustomizationService | undefined;

@inject(AIVariableService) @optional()
protected readonly variableService: AIVariableService | undefined;

protected _prompts: PromptMap = {};

getRawPrompt(id: string): PromptTemplate | undefined {
Expand All @@ -108,16 +113,23 @@ export class PromptServiceImpl implements PromptService {
getDefaultRawPrompt(id: string): PromptTemplate | undefined {
return this._prompts[id];
}
getPrompt(id: string, args?: { [key: string]: unknown }): string | undefined {
async getPrompt(id: string, args?: { [key: string]: unknown }): Promise<string | undefined> {
const prompt = this.getRawPrompt(id);
if (prompt === undefined) {
return undefined;
}
if (args === undefined) {
return prompt.template;
}
const formattedPrompt = Object.keys(args).reduce((acc, key) => acc.replace(`\${${key}}`, args[key] as string), prompt.template);
return formattedPrompt;

const matches = [...prompt.template.matchAll(PROMPT_VARIABLE_REGEX)];
const replacements = await Promise.all(matches.map(async match => {
const key = match[1];
return {
placeholder: match[0],
value: String(args?.[key] ?? (await this.variableService?.resolveVariable(key, {}))?.value ?? match[0])
};
}));
let result = prompt.template;
replacements.forEach(replacement => result = result.replace(replacement.placeholder, replacement.value));
return result;
}
getAllPrompts(): PromptMap {
if (this.customizationService !== undefined) {
Expand Down
4 changes: 2 additions & 2 deletions packages/ai-terminal/src/browser/ai-terminal-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ recent-terminal-contents:
}
const lm = lms[0];

const systemPrompt = this.promptService.getPrompt('ai-terminal:system-prompt', input);
const userPrompt = this.promptService.getPrompt('ai-terminal:user-prompt', input);
const systemPrompt = await this.promptService.getPrompt('ai-terminal:system-prompt', input);
const userPrompt = await this.promptService.getPrompt('ai-terminal:user-prompt', input);
if (!systemPrompt || !userPrompt) {
this.logger.error('The prompt service didn\'t return prompts for the AI Terminal Agent.');
return [];
Expand Down