Skip to content

Commit

Permalink
Resolve variables in prompt templates
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-fleck-at committed Jul 30, 2024
1 parent ba3bf65 commit c6848c2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class CodeCompletionAgentImpl implements CodeCompletionAgent {
const file = model.uri.toString(false);
const language = model.getLanguageId();

const prompt = this.promptService.getPrompt('code-completion-prompt', { snippet, file, language });
const prompt = await this.promptService.getPrompt('code-completion-prompt', { snippet, file, language });
if (!prompt) {
console.error('No prompt found for code-completion-agent');
return undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
// *****************************************************************************

import 'reflect-metadata';
import { Container } from 'inversify';
import { PromptService, PromptServiceImpl } from '../prompt-service';
import { expect } from 'chai';
import { Container } from 'inversify';
import 'reflect-metadata';
import { PromptService, PromptServiceImpl } from './prompt-service';
import { DefaultAIVariableService, AIVariableService } from './variable-service';

describe('PromptService', () => {
let promptService: PromptService;
Expand All @@ -26,24 +27,34 @@ describe('PromptService', () => {
const container = new Container();
container.bind<PromptService>(PromptService).to(PromptServiceImpl).inSingletonScope();

const variableService = new DefaultAIVariableService({ getContributions: () => [] });
const nameVariable = { id: 'test', name: 'name', description: 'Test name ' };
variableService.registerResolver(nameVariable, {
canResolve: () => 100,
resolve: async () => ({ variable: nameVariable, value: 'Jane' })
});
container.bind<AIVariableService>(AIVariableService).toConstantValue(variableService);

promptService = container.get<PromptService>(PromptService);
promptService.storePrompt('1', 'Hello, ${name}!');
promptService.storePrompt('2', 'Goodbye, ${name}!');
promptService.storePrompt('3', 'Ciao, ${invalid}!');
});

it('should initialize prompts from PromptCollectionService', () => {
const allPrompts = promptService.getAllPrompts();
expect(allPrompts['1'].template).to.equal('Hello, ${name}!');
expect(allPrompts['2'].template).to.equal('Goodbye, ${name}!');
expect(allPrompts['3'].template).to.equal('Ciao, ${invalid}!');
});

it('should retrieve raw prompt by id', () => {
const rawPrompt = promptService.getRawPrompt('1');
expect(rawPrompt?.template).to.equal('Hello, ${name}!');
});

it('should format prompt with provided arguments', () => {
const formattedPrompt = promptService.getPrompt('1', { name: 'John' });
it('should format prompt with provided arguments', async () => {
const formattedPrompt = await promptService.getPrompt('1', { name: 'John' });
expect(formattedPrompt).to.equal('Hello, John!');
});

Expand All @@ -52,4 +63,24 @@ describe('PromptService', () => {
const newPrompt = promptService.getRawPrompt('3');
expect(newPrompt?.template).to.equal('Welcome, ${name}!');
});

it('should replace placeholders with provided arguments', async () => {
const prompt = await promptService.getPrompt('1', { name: 'John' });
expect(prompt).to.equal('Hello, John!');
});

it('should use variable service to resolve placeholders if argument value is not provided', async () => {
const prompt = await promptService.getPrompt('1');
expect(prompt).to.equal('Hello, Jane!');
});

it('should return the prompt even if there are no replacements', async () => {
const prompt = await promptService.getPrompt('3');
expect(prompt).to.equal('Ciao, ${invalid}!');
});

it('should return undefined if the prompt id is not found', async () => {
const prompt = await promptService.getPrompt('4');
expect(prompt).to.be.undefined;
});
});
28 changes: 20 additions & 8 deletions packages/ai-core/src/common/prompt-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import { URI } from '@theia/core';
import { inject, injectable, optional } from '@theia/core/shared/inversify';
import { PromptTemplate } from './types';
import { AIVariableService } from './variable-service';

export interface PromptMap { [id: string]: PromptTemplate }

Expand All @@ -38,7 +39,7 @@ export interface PromptService {
* @param id the id of the prompt
* @param args the object with placeholders, mapping the placeholder key to the value
*/
getPrompt(id: string, args?: { [key: string]: unknown }): string | undefined;
getPrompt(id: string, args?: { [key: string]: unknown }): Promise<string | undefined>;
/**
* Manually add a prompt to the list of prompts.
* @param id the id of the prompt
Expand Down Expand Up @@ -88,12 +89,16 @@ export interface PromptCustomizationService {
getTemplateIDFromURI(uri: URI): string | undefined;
}

const PROMPT_VARIABLE_REGEX = /\$\{([\w_\-]+)\}/g;

@injectable()
export class PromptServiceImpl implements PromptService {

@inject(PromptCustomizationService) @optional()
protected readonly customizationService: PromptCustomizationService | undefined;

@inject(AIVariableService) @optional()
protected readonly variableService: AIVariableService | undefined;

protected _prompts: PromptMap = {};

getRawPrompt(id: string): PromptTemplate | undefined {
Expand All @@ -108,16 +113,23 @@ export class PromptServiceImpl implements PromptService {
getDefaultRawPrompt(id: string): PromptTemplate | undefined {
return this._prompts[id];
}
getPrompt(id: string, args?: { [key: string]: unknown }): string | undefined {
async getPrompt(id: string, args?: { [key: string]: unknown }): Promise<string | undefined> {
const prompt = this.getRawPrompt(id);
if (prompt === undefined) {
return undefined;
}
if (args === undefined) {
return prompt.template;
}
const formattedPrompt = Object.keys(args).reduce((acc, key) => acc.replace(`\${${key}}`, args[key] as string), prompt.template);
return formattedPrompt;

const matches = [...prompt.template.matchAll(PROMPT_VARIABLE_REGEX)];
const replacements = await Promise.all(matches.map(async match => {
const key = match[1];
return {
placeholder: match[0],
value: String(args?.[key] ?? (await this.variableService?.resolveVariable(key, {}))?.value ?? match[0])
};
}));
let result = prompt.template;
replacements.forEach(replacement => result = result.replace(replacement.placeholder, replacement.value));
return result;
}
getAllPrompts(): PromptMap {
if (this.customizationService !== undefined) {
Expand Down
4 changes: 2 additions & 2 deletions packages/ai-terminal/src/browser/ai-terminal-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ recent-terminal-contents:
}
const lm = lms[0];

const systemPrompt = this.promptService.getPrompt('ai-terminal:system-prompt', input);
const userPrompt = this.promptService.getPrompt('ai-terminal:user-prompt', input);
const systemPrompt = await this.promptService.getPrompt('ai-terminal:system-prompt', input);
const userPrompt = await this.promptService.getPrompt('ai-terminal:user-prompt', input);
if (!systemPrompt || !userPrompt) {
this.logger.error('The prompt service didn\'t return prompts for the AI Terminal Agent.');
return [];
Expand Down

0 comments on commit c6848c2

Please sign in to comment.