Skip to content

Commit

Permalink
Merge pull request google-ai-edge#232 from googlesamples/update_selec…
Browse files Browse the repository at this point in the history
…ting_seperate_model

Update selecting model face stylization
  • Loading branch information
PaulTR authored Sep 14, 2023
2 parents 40d3bda + 3093314 commit 301476e
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 16 deletions.
35 changes: 35 additions & 0 deletions examples/face_stylizer/android/app/download_models.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

task downloadColorInkTask(type: Download) {
src ''
dest project.ext.TEST_ASSETS_DIR + '/face_stylizer_color_ink.task'
overwrite false
}

task downloadColorSketchTask(type: Download) {
src ''
dest project.ext.TEST_ASSETS_DIR + '/face_stylizer_color_sketch.task'
overwrite false
}

task downloadOilPainting(type: Download) {
src ''
dest project.ext.TEST_ASSETS_DIR + '/face_stylizer_oil_painting.task'
overwrite false
}

preBuild.dependsOn downloadColorInkTask, downloadColorSketchTask, downloadOilPainting
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.google.mediapipe.tasks.vision.facestylizer.FaceStylizer.FaceStylizerO
import com.google.mediapipe.tasks.vision.facestylizer.FaceStylizerResult

class FaceStylizationHelper(
private val modelPosition: Int,
private val context: Context,
var faceStylizerListener: FaceStylizerListener? = null
) {
Expand All @@ -36,10 +37,17 @@ class FaceStylizationHelper(
setupFaceStylizer()
}

fun setupFaceStylizer() {
private fun setupFaceStylizer() {
val baseOptionsBuilder = BaseOptions.builder()

baseOptionsBuilder.setModelAssetPath(MODEL_PATH)
// Sets the model selection.
baseOptionsBuilder.setModelAssetPath(
when (modelPosition) {
0 -> MODEL_PATH_COLOR_SKETCH
1 -> MODEL_PATH_COLOR_INK
2 -> MODEL_PATH_OIL_PAINTING
else -> throw Throwable("Invalid model type position")
}
)

try {
val baseOptions = baseOptionsBuilder.build()
Expand Down Expand Up @@ -80,6 +88,10 @@ class FaceStylizationHelper(
return ResultBundle(result, timestampMs)
}

fun close() {
faceStylizer?.close()
}

// Wraps results from inference, the time it takes for inference to be
// performed.
data class ResultBundle(
Expand All @@ -88,7 +100,9 @@ class FaceStylizationHelper(
)

companion object {
const val MODEL_PATH = "face_stylizer.tflite"
const val MODEL_PATH_OIL_PAINTING = "face_stylizer_oil_painting.task"
const val MODEL_PATH_COLOR_INK = "face_stylizer_color_ink.task"
const val MODEL_PATH_COLOR_SKETCH = "face_stylizer_color_sketch.task"
const val OTHER_ERROR = 0
const val GPU_ERROR = 1
private const val TAG = "FaceStylizationHelper"
Expand All @@ -97,5 +111,4 @@ class FaceStylizationHelper(
interface FaceStylizerListener {
fun onError(error: String, errorCode: Int = OTHER_ERROR)
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,26 @@ import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.view.View
import android.widget.AdapterView
import android.widget.AdapterView.OnItemSelectedListener
import android.widget.ArrayAdapter
import android.widget.Toast
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import com.google.mediapipe.examples.facestylizer.databinding.ActivityMainBinding
import com.google.mediapipe.framework.image.ByteBufferExtractor

class MainActivity : AppCompatActivity(), FaceStylizationHelper.FaceStylizerListener {

class MainActivity : AppCompatActivity(),
FaceStylizationHelper.FaceStylizerListener {

private lateinit var binding: ActivityMainBinding
private lateinit var faceStylizationHelper: FaceStylizationHelper
private var faceStylizationHelper: FaceStylizationHelper? = null
private var inputImage: Bitmap? = null


private val getContent =
registerForActivityResult(ActivityResultContracts.GetContent()) {
inputImage = it?.getImage(this)
Expand All @@ -47,11 +53,39 @@ class MainActivity : AppCompatActivity(), FaceStylizationHelper.FaceStylizerList
val view = binding.root
setContentView(view)

faceStylizationHelper = FaceStylizationHelper(this, faceStylizerListener = this)
// Init spinner model name
val modelNameArray = resources.getStringArray(R.array.model_name_array)
val adapter = ArrayAdapter(
this,
android.R.layout.simple_spinner_dropdown_item,
modelNameArray
)
binding.bottomSheetLayout.modelSpinner.adapter = adapter

binding.bottomSheetLayout.modelSpinner.onItemSelectedListener =
object : OnItemSelectedListener {
override fun onItemSelected(
parent: AdapterView<*>?,
view: View?,
position: Int,
id: Long
) {
// Reset the helper if the model type is changed.
faceStylizationHelper?.close()
initHelper(position)
}

override fun onNothingSelected(parent: AdapterView<*>?) {
// do nothing
}

}

binding.btnStylize.setOnClickListener {
inputImage?.let { input ->
onResult(faceStylizationHelper.stylize(input))
faceStylizationHelper?.stylize(input)?.let {
onResult(it)
}
}
}

Expand All @@ -60,6 +94,14 @@ class MainActivity : AppCompatActivity(), FaceStylizationHelper.FaceStylizerList
}
}

private fun initHelper(modelPosition: Int) {
faceStylizationHelper = FaceStylizationHelper(
modelPosition,
this,
faceStylizerListener = this
)
}

private fun Uri.getImage(context: Context): Bitmap {
return BitmapFactory.decodeStream(
context.contentResolver.openInputStream(this)
Expand All @@ -70,8 +112,8 @@ class MainActivity : AppCompatActivity(), FaceStylizationHelper.FaceStylizerList
Toast.makeText(this, error, Toast.LENGTH_SHORT).show()
}

fun onResult(result: FaceStylizationHelper.ResultBundle) {
if( result.stylizedFace == null ) {
private fun onResult(result: FaceStylizationHelper.ResultBundle) {
if (result.stylizedFace == null) {
onError("Failed to stylize image")
return
}
Expand All @@ -87,6 +129,7 @@ class MainActivity : AppCompatActivity(), FaceStylizationHelper.FaceStylizerList

binding.tvImageTwoDescription.visibility = View.GONE
binding.outputImage.setImageBitmap(bitmap)
binding.bottomSheetLayout.inferenceTimeVal.text = result.inferenceTime.toString()
binding.bottomSheetLayout.inferenceTimeVal.text =
result.inferenceTime.toString()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,33 @@
android:src="@drawable/icn_chevron_up" />
</RelativeLayout>

<!-- Select model type row -->
<RelativeLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginTop="@dimen/bottom_sheet_default_row_margin"
android:orientation="horizontal">

<TextView
android:id="@+id/model_type"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_centerVertical="true"
android:text="@string/tv_model"
android:textColor="@color/bottom_sheet_text_color"
android:textSize="@dimen/bottom_sheet_text_size" />

<Spinner
android:id="@+id/model_spinner"
android:layout_width="wrap_content"
android:layout_height="@dimen/spinner_model_height"
android:layout_alignParentEnd="true"
android:gravity="end"
android:padding="@dimen/spinner_model_padding"
android:spinnerMode="dialog"
android:textSize="@dimen/bottom_sheet_text_size" />
</RelativeLayout>

<!-- Inference time row -->
<androidx.appcompat.widget.LinearLayoutCompat
android:layout_width="match_parent"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@

<integer name="bottom_sheet_control_text_min_ems">3</integer>
<dimen name="fab_margin">16dp</dimen>
</resources>
<dimen name="spinner_model_padding">10dp</dimen>
<dimen name="spinner_model_height">48dp</dimen>
</resources>
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,11 @@
<string name="tv_second_image_description">This is where the stylized image will appear</string>
<string name="tv_stylize">Select an image to stylize</string>
<string name="btn_stylize">Stylize</string>
<string name="tv_model">Model</string>

<string-array name="model_name_array">
<item>Color sketch</item>
<item>Color ink</item>
<item>Oil painting</item>
</string-array>
</resources>

0 comments on commit 301476e

Please sign in to comment.