Skip to content

Commit

Permalink
Merge pull request #328 from morganchen12/mc/ios
Browse files Browse the repository at this point in the history
Add iOS on-device chat example
  • Loading branch information
schmidt-sebastian authored Mar 8, 2024
2 parents aca7bfe + 061fc7e commit dbf95a6
Show file tree
Hide file tree
Showing 14 changed files with 866 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/llm_inference/ios/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.bin
InferenceExample.xcodeproj/xcuserdata/
.DS_Store
380 changes: 380 additions & 0 deletions examples/llm_inference/ios/InferenceExample.xcodeproj/project.pbxproj

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "512x512"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "512x512"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}
139 changes: 139 additions & 0 deletions examples/llm_inference/ios/InferenceExample/ConversationScreen.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2024 The Mediapipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import SwiftUI

struct ConversationScreen: View {
@EnvironmentObject
var viewModel: ConversationViewModel

@State
private var userPrompt = ""

enum FocusedField: Hashable {
case message
}

@FocusState
var focusedField: FocusedField?

var body: some View {
VStack {
ScrollViewReader { scrollViewProxy in
List {
ForEach(viewModel.messages) { message in
MessageView(message: message)
}
if let error = viewModel.error {
ErrorView(error: error)
.tag("errorView")
}
}
.listStyle(.plain)
.onChange(of: viewModel.messages) { _, newValue in
if viewModel.hasError {
// wait for a short moment to make sure we can actually scroll to the bottom
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
withAnimation {
scrollViewProxy.scrollTo("errorView", anchor: .bottom)
}
focusedField = .message
}
} else {
guard let lastMessage = viewModel.messages.last else { return }

// wait for a short moment to make sure we can actually scroll to the bottom
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
withAnimation {
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom)
}
focusedField = .message
}
}
}
}
TextField("Message...", text: $userPrompt)
.focused($focusedField, equals: .message)
.onSubmit { sendOrStop() }
.submitLabel(.send)
.disabled(viewModel.busy)
.padding()
}
.toolbar {
ToolbarItem(placement: .primaryAction) {
Button(action: newChat) {
Image(systemName: "square.and.pencil")
}
}
}
.navigationTitle("Chat sample")
.onAppear {
focusedField = .message
}
}

private func sendMessage() {
Task {
let prompt = userPrompt
userPrompt = ""
await viewModel.sendMessage(prompt)
}
}

private func sendOrStop() {
if viewModel.busy {
viewModel.stop()
} else {
sendMessage()
}
}

private func newChat() {
viewModel.startNewChat()
}
}

struct MessageView: View {
var message: ChatMessage

var body: some View {
HStack {
if message.participant == .user {
Spacer()
}
Text(message.message)
.padding(10)
.background(message.participant == .system
? Color(white: 0.9231)
: Color(red: 0.8627, green: 0.9725, blue: 0.7764))
.clipShape(RoundedRectangle(cornerRadius: 16))
if message.participant == .system {
Spacer()
}
}
.listRowSeparator(.hidden)
}
}

struct ErrorView: View {
var error: Error

var body: some View {
HStack {
Text("An error occurred: \(error.localizedDescription)")
}
.frame(maxWidth: .infinity, alignment: .center)
.listRowSeparator(.hidden)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

enum Participant {
case system
case user
}

struct ChatMessage: Identifiable, Equatable {
let id = UUID().uuidString
var message: String
let participant: Participant
var pending = false

static func pending(participant: Participant) -> ChatMessage {
self.init(message: "", participant: participant, pending: true)
}
}

@MainActor
class ConversationViewModel: ObservableObject {
/// This array holds both the user's and the system's chat messages
@Published var messages = [ChatMessage]()

/// Indicates we're waiting for the model to finish
@Published var busy = false

@Published var error: Error?
var hasError: Bool {
return error != nil
}

private var model: OnDeviceModel
private var chat: Chat
private var stopGenerating = false

private var chatTask: Task<Void, Never>?

init() {
model = OnDeviceModel()
chat = model.startChat()
}

func sendMessage(_ text: String) async {
error = nil
await internalSendMessage(text)
}

func startNewChat() {
stop()
error = nil
chat = model.startChat()
messages.removeAll()
}

func stop() {
chatTask?.cancel()
error = nil
}

private func internalSendMessage(_ text: String) async {
chatTask?.cancel()

chatTask = Task {
busy = true
defer {
busy = false
}

// first, add the user's message to the chat
let userMessage = ChatMessage(message: text, participant: .user)
messages.append(userMessage)

// add a pending message while we're waiting for a response from the backend
let systemMessage = ChatMessage.pending(participant: .system)
messages.append(systemMessage)

do {
let response = try await chat.sendMessage(text)

// replace pending message with model response
messages[messages.count - 1].message = response
messages[messages.count - 1].pending = false
} catch {
self.error = error
print(error.localizedDescription)
messages.removeLast()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
</dict>
</plist>
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import SwiftUI

@main
struct InferenceExampleApp: App {
var body: some Scene {
WindowGroup {
ConversationScreen()
.environmentObject(ConversationViewModel())
}
}
}
Loading

0 comments on commit dbf95a6

Please sign in to comment.