-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #328 from morganchen12/mc/ios
Add iOS on-device chat example
- Loading branch information
Showing
14 changed files
with
866 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*.bin | ||
InferenceExample.xcodeproj/xcuserdata/ | ||
.DS_Store |
380 changes: 380 additions & 0 deletions
380
examples/llm_inference/ios/InferenceExample.xcodeproj/project.pbxproj
Large diffs are not rendered by default.
Oops, something went wrong.
11 changes: 11 additions & 0 deletions
11
...les/llm_inference/ios/InferenceExample/Assets.xcassets/AccentColor.colorset/Contents.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"colors" : [ | ||
{ | ||
"idiom" : "universal" | ||
} | ||
], | ||
"info" : { | ||
"author" : "xcode", | ||
"version" : 1 | ||
} | ||
} |
63 changes: 63 additions & 0 deletions
63
examples/llm_inference/ios/InferenceExample/Assets.xcassets/AppIcon.appiconset/Contents.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
{ | ||
"images" : [ | ||
{ | ||
"idiom" : "universal", | ||
"platform" : "ios", | ||
"size" : "1024x1024" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "1x", | ||
"size" : "16x16" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "2x", | ||
"size" : "16x16" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "1x", | ||
"size" : "32x32" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "2x", | ||
"size" : "32x32" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "1x", | ||
"size" : "128x128" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "2x", | ||
"size" : "128x128" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "1x", | ||
"size" : "256x256" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "2x", | ||
"size" : "256x256" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "1x", | ||
"size" : "512x512" | ||
}, | ||
{ | ||
"idiom" : "mac", | ||
"scale" : "2x", | ||
"size" : "512x512" | ||
} | ||
], | ||
"info" : { | ||
"author" : "xcode", | ||
"version" : 1 | ||
} | ||
} |
6 changes: 6 additions & 0 deletions
6
examples/llm_inference/ios/InferenceExample/Assets.xcassets/Contents.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"info" : { | ||
"author" : "xcode", | ||
"version" : 1 | ||
} | ||
} |
139 changes: 139 additions & 0 deletions
139
examples/llm_inference/ios/InferenceExample/ConversationScreen.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
// Copyright 2024 The Mediapipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import SwiftUI | ||
|
||
struct ConversationScreen: View { | ||
@EnvironmentObject | ||
var viewModel: ConversationViewModel | ||
|
||
@State | ||
private var userPrompt = "" | ||
|
||
enum FocusedField: Hashable { | ||
case message | ||
} | ||
|
||
@FocusState | ||
var focusedField: FocusedField? | ||
|
||
var body: some View { | ||
VStack { | ||
ScrollViewReader { scrollViewProxy in | ||
List { | ||
ForEach(viewModel.messages) { message in | ||
MessageView(message: message) | ||
} | ||
if let error = viewModel.error { | ||
ErrorView(error: error) | ||
.tag("errorView") | ||
} | ||
} | ||
.listStyle(.plain) | ||
.onChange(of: viewModel.messages) { _, newValue in | ||
if viewModel.hasError { | ||
// wait for a short moment to make sure we can actually scroll to the bottom | ||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { | ||
withAnimation { | ||
scrollViewProxy.scrollTo("errorView", anchor: .bottom) | ||
} | ||
focusedField = .message | ||
} | ||
} else { | ||
guard let lastMessage = viewModel.messages.last else { return } | ||
|
||
// wait for a short moment to make sure we can actually scroll to the bottom | ||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { | ||
withAnimation { | ||
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom) | ||
} | ||
focusedField = .message | ||
} | ||
} | ||
} | ||
} | ||
TextField("Message...", text: $userPrompt) | ||
.focused($focusedField, equals: .message) | ||
.onSubmit { sendOrStop() } | ||
.submitLabel(.send) | ||
.disabled(viewModel.busy) | ||
.padding() | ||
} | ||
.toolbar { | ||
ToolbarItem(placement: .primaryAction) { | ||
Button(action: newChat) { | ||
Image(systemName: "square.and.pencil") | ||
} | ||
} | ||
} | ||
.navigationTitle("Chat sample") | ||
.onAppear { | ||
focusedField = .message | ||
} | ||
} | ||
|
||
private func sendMessage() { | ||
Task { | ||
let prompt = userPrompt | ||
userPrompt = "" | ||
await viewModel.sendMessage(prompt) | ||
} | ||
} | ||
|
||
private func sendOrStop() { | ||
if viewModel.busy { | ||
viewModel.stop() | ||
} else { | ||
sendMessage() | ||
} | ||
} | ||
|
||
private func newChat() { | ||
viewModel.startNewChat() | ||
} | ||
} | ||
|
||
struct MessageView: View { | ||
var message: ChatMessage | ||
|
||
var body: some View { | ||
HStack { | ||
if message.participant == .user { | ||
Spacer() | ||
} | ||
Text(message.message) | ||
.padding(10) | ||
.background(message.participant == .system | ||
? Color(white: 0.9231) | ||
: Color(red: 0.8627, green: 0.9725, blue: 0.7764)) | ||
.clipShape(RoundedRectangle(cornerRadius: 16)) | ||
if message.participant == .system { | ||
Spacer() | ||
} | ||
} | ||
.listRowSeparator(.hidden) | ||
} | ||
} | ||
|
||
struct ErrorView: View { | ||
var error: Error | ||
|
||
var body: some View { | ||
HStack { | ||
Text("An error occurred: \(error.localizedDescription)") | ||
} | ||
.frame(maxWidth: .infinity, alignment: .center) | ||
.listRowSeparator(.hidden) | ||
} | ||
} |
104 changes: 104 additions & 0 deletions
104
examples/llm_inference/ios/InferenceExample/ConversationViewModel.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Copyright 2024 The MediaPipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
|
||
enum Participant { | ||
case system | ||
case user | ||
} | ||
|
||
struct ChatMessage: Identifiable, Equatable { | ||
let id = UUID().uuidString | ||
var message: String | ||
let participant: Participant | ||
var pending = false | ||
|
||
static func pending(participant: Participant) -> ChatMessage { | ||
self.init(message: "", participant: participant, pending: true) | ||
} | ||
} | ||
|
||
@MainActor | ||
class ConversationViewModel: ObservableObject { | ||
/// This array holds both the user's and the system's chat messages | ||
@Published var messages = [ChatMessage]() | ||
|
||
/// Indicates we're waiting for the model to finish | ||
@Published var busy = false | ||
|
||
@Published var error: Error? | ||
var hasError: Bool { | ||
return error != nil | ||
} | ||
|
||
private var model: OnDeviceModel | ||
private var chat: Chat | ||
private var stopGenerating = false | ||
|
||
private var chatTask: Task<Void, Never>? | ||
|
||
init() { | ||
model = OnDeviceModel() | ||
chat = model.startChat() | ||
} | ||
|
||
func sendMessage(_ text: String) async { | ||
error = nil | ||
await internalSendMessage(text) | ||
} | ||
|
||
func startNewChat() { | ||
stop() | ||
error = nil | ||
chat = model.startChat() | ||
messages.removeAll() | ||
} | ||
|
||
func stop() { | ||
chatTask?.cancel() | ||
error = nil | ||
} | ||
|
||
private func internalSendMessage(_ text: String) async { | ||
chatTask?.cancel() | ||
|
||
chatTask = Task { | ||
busy = true | ||
defer { | ||
busy = false | ||
} | ||
|
||
// first, add the user's message to the chat | ||
let userMessage = ChatMessage(message: text, participant: .user) | ||
messages.append(userMessage) | ||
|
||
// add a pending message while we're waiting for a response from the backend | ||
let systemMessage = ChatMessage.pending(participant: .system) | ||
messages.append(systemMessage) | ||
|
||
do { | ||
let response = try await chat.sendMessage(text) | ||
|
||
// replace pending message with model response | ||
messages[messages.count - 1].message = response | ||
messages[messages.count - 1].pending = false | ||
} catch { | ||
self.error = error | ||
print(error.localizedDescription) | ||
messages.removeLast() | ||
} | ||
} | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
examples/llm_inference/ios/InferenceExample/InferenceExample.entitlements
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> | ||
<plist version="1.0"> | ||
<dict> | ||
<key>com.apple.security.app-sandbox</key> | ||
<true/> | ||
<key>com.apple.security.files.user-selected.read-only</key> | ||
<true/> | ||
</dict> | ||
</plist> |
25 changes: 25 additions & 0 deletions
25
examples/llm_inference/ios/InferenceExample/InferenceExampleApp.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright 2024 The MediaPipe Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import SwiftUI | ||
|
||
@main | ||
struct InferenceExampleApp: App { | ||
var body: some Scene { | ||
WindowGroup { | ||
ConversationScreen() | ||
.environmentObject(ConversationViewModel()) | ||
} | ||
} | ||
} |
Oops, something went wrong.