diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index eb03df156..b393fc668 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -42,7 +42,7 @@ def create_llm_chain( async def handle_request(self, request: InlineCompletionRequest) -> None: """Handles an inline completion request without streaming.""" self.get_llm_chain() - model_arguments = self._template_inputs_from_request(request) + model_arguments = await self._template_inputs_from_request(request) suggestion = await self.llm_chain.ainvoke(input=model_arguments) suggestion = self._post_process_suggestion(suggestion, request) self.write_message( @@ -80,7 +80,7 @@ async def handle_stream_request(self, request: InlineCompletionRequest): # then, generate and stream LLM output over this connection. self.get_llm_chain() token = self._token_from_request(request, 0) - model_arguments = self._template_inputs_from_request(request) + model_arguments = await self._template_inputs_from_request(request) suggestion = "" async for fragment in self.llm_chain.astream(input=model_arguments): @@ -115,12 +115,71 @@ def _token_from_request(self, request: InlineCompletionRequest, suggestion: int) using request number and suggestion number""" return f"t{request.number}s{suggestion}" - def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: + async def _get_document(self, request: InlineCompletionRequest): + collaboration = self.settings.get("jupyter_collaboration", None) + file_id_manager = self.settings.get("file_id_manager", None) + + if not collaboration or not file_id_manager or not request.path: + return None + + from jupyter_collaboration.rooms import DocumentRoom + from jupyter_collaboration.utils import encode_file_path + from jupyter_collaboration.websocketserver import RoomNotFound + + file_id = file_id_manager.index(request.path) + is_notebook = request.path.endswith("ipynb") + content_type = "notebook" if is_notebook else "file" + file_format = "json" if is_notebook else "text" + + encoded_path = encode_file_path(file_format, content_type, file_id) + room_id: str = encoded_path.split("/")[-1] + + try: + room = await collaboration.ywebsocket_server.get_room(room_id) + except RoomNotFound: + return None + + if isinstance(room, DocumentRoom): + document = room._document + return document + + async def _template_inputs_from_request( + self, request: InlineCompletionRequest + ) -> Dict: + prefix = request.prefix suffix = request.suffix.strip() filename = request.path.split("/")[-1] if request.path else "untitled" + document = await self._get_document(request) + + if document: + from jupyter_ydoc import YNotebook + + if document and isinstance(document, YNotebook): + cell_type = "markdown" if request.language == "markdown" else "code" + + is_before_request_cell = True + before = [] + after = [suffix] + + for cell in document.ycells: + if is_before_request_cell and cell["id"] == request.cell_id: + is_before_request_cell = False + continue + if cell["cell_type"] != cell_type: + continue + source = cell["source"].to_py() + if is_before_request_cell: + before.append(source) + else: + after.append(source) + + before.append(prefix) + prefix = "\n\n".join(before) + suffix = "\n\n".join(after) + return { - "prefix": request.prefix, + "prefix": prefix, "suffix": suffix, "language": request.language, "filename": filename,