Skip to content

Commit

Permalink
Almost working version
Browse files Browse the repository at this point in the history
  • Loading branch information
henneber committed Jan 7, 2024
1 parent c688d62 commit 431513c
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 66 deletions.
152 changes: 99 additions & 53 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,53 +1,99 @@
# Create a new service (generic) template

This repository contains the Python + FastAPI template to create a service
without a model or from an existing model compatible with the Core engine.

Please read the documentation at
<https://docs.swiss-ai-center.ch/how-to-guides/how-to-create-a-new-service> to
understand how to use this template.

## Guidelines

TODO: Add instructions on how to edit this template.

### Publishing and deploying using a CI/CD pipeline

This is the recommended way to publish and deploy your service if you have
access to GitHub Actions or GitLab CI.

TODO

### Publishing and deploying manually

This is the recommended way to publish and deploy your service if you do not
have access to GitHub Actions or GitLab CI or do not want to use these services.

TODO

## Checklist

These checklists allow you to ensure ensure everything is set up correctly.

### Common tasks

- [ ] Rename the project in the [`pyproject.toml`](./pyproject.toml) file
- [x] Add files that must be ignored to the [`.gitignore`](.gitignore) configuration file
- [ ] TODO

### Publishing and deploying using a CI/CD pipeline

> [!NOTE]
> This checklist is specific to the _Publishing and deploying using a CI/CD
> pipeline_ section.
- [x] Add the environment variables
- [ ] TODO

### Publishing and deploying manually

> [!NOTE]
> This checklist is specific to the _Publishing and deploying manually_ section.
- [x] Edit the [`.env`](.env) configuration file
- [ ] TODO
# Core Engine service for Not Safe For Work image detection

This repository contains the Python + FastAPI code to run a Core Engine service for NSFW detection. It was created from the *template to create a service without a model or from an existing model* available in the repository templates. See <https://docs.swiss-ai-center.ch/how-to-guides/how-to-create-a-new-service> and <https://docs.swiss-ai-center.ch/tutorials/implement-service/>

This service takes as input an image and returns a json with information about the possibility that it includes NSFW content.

## NSFW content detection

NSFW stands for *not safe for work*. This Internet slang is a general term associated to un-appropriate content such as nudity, pornography etc. See e.g. https://en.wikipedia.org/wiki/Not_safe_for_work. It is important to exercise caution when viewing or sharing NSFW images, as they may violate workplace policies or community guidelines.

The current service encapsulates a trained AI model to detect NSFW images with a focus on sexual content. Caution: the current version of the service is not able to detect profanity and violence for now.

### Definition of categories

The border between categories is sometimes thin, e.g. what can be
considered as acceptable nudity in some cultural context would be considered as
pornography by others. Therefore we need to disclaim any complaints that would
be done by using the model trained in this project. We can't be taken responsible
of any offense or classifications that would be falsely considered as appropriate
or not. To make the task even more interesting, we went here for two main
categories *nsfw* and *safe* in which we have sub-categories.

- **nsfw**:
- **porn**: male erection, open legs, touching breast or genital parts,
intercourse, blowjob, etc; men or women nude and with open legs fall into
this category; nudity with sperma on body parts is considered porn
- **nudity**: penis visible, female breast visible, vagina visible in
normal position (i.e. standing or sitting but not open leg)
- **suggestive**: images including people or objects making someone think
of sex and sexual relationships; genital parts are not visible otherwise
the image should be in the porn or nudity category; dressed people kissing
and or touching fall into this category; people undressing; licking
fingers; woman with tong with sexy bra
- **cartoon_sex**: cartoon images that are showing or strongly
suggesting sexual situation
- **safe**:
- **neutral**: all kind of images with or without people not falling
into porn, nudity or suggestive category
- **cartoon_neutral**: cartoon images that are not showing or
suggesting sexual situation

Inspecting the output giving probabilities for the categories (safe vs not-safe) and
the sub-categories, the user can decide where to place the threshold on what is
acceptable or not for a given service.


### Data set used to build the model

A dataset was assembled using existing NSFW image sets and was completed with web scraping data.
The dataset is available for research purpose - contact us if you want to have an access. Here
are some statistics about its conent (numbers indicate amount of images). The dataset is balanced among
the categories, which should avoid biased classifications.

| categories | safe | | | nsfw | | | | total | | |
|----------------|---------|--------|---------|------------|--------|------|---------|---------|-------|-------|
| sub-categories | general | person | cartoon | suggestive | nudity | porn | cartoon | safe | nsfw | all |
| v2.2 | 5500 | 5500 | 5500 | 5500 | 5500 | 5500 | 5500 | 16500 | 22000 | 38500 |

### Model training and performance

We used transfer learning on MobileNetV2 which present a good trade-off between performance and runtime efficiency.

| Set | Model | Whole | | Val | | Test | |
|------|---------------------------------------------------------|-------|-------|-------|-------|-------|-------|
| | | sa/ns | sub | sa/ns | sub | sa/ns | sub |
| V2.1 | TL_MNV2_finetune_224_B32_AD1E10-5_NSFW-V2.1_DA2.hdf5 | | | 95.7% | 85.1% | 95.7% | 86.1% |

In this Table, the performance is reported as accuracy on the safe vs not-safe (sa/ns) main categories and
on the sub-categories (sub). The sub performance in indeed lower as we have naturally more confusion between
some categories and as there is simply a larger cardinality in the number of classes.


## How to test locally the service?

1. Create and activate the virtual environment:
```sh
python3.10 -m venv .venv
source .venv/bin/activate
```

2. Then install the dependencies:
```sh
pip install --requirement requirements.txt
pip install --requirement requirements-all.txt
```

3. Run locally an instance of the Core AI Engine. For this follow the installation
instructions available here: https://docs.swiss-ai-center.ch/reference/core-engine/. Here are
the steps:
- Get the core engine code from here: https://github.com/swiss-ai-center/core-engine/tree/main
- Backend: follow instructions in section `Start the service locally with Python`, in a first
terminal start the dependencies with `docker compose up` and in a second terminal in the `src`
sub-directory start the application with `uvicorn --reload --port 8080 main:app`. The backend
api should be visible in the browser.
- This service: in a terminal start the service with `cd src` and
`uvicorn main:app --reload --host localhost --port 9090`. The service should register to the
Core Engine backend and now be visible on the api page.
- Frontend: in a terminal follow the starting instruction (make sure Nodes and npm are
installed).
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
common-code[test] @ git+https://github.com/swiss-ai-center/common-code.git@main
fastapi
imagehash
pip-chill
tensorflow
128 changes: 115 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,52 @@

# Imports required by the service's model
# TODO: 1. ADD REQUIRED IMPORTS (ALSO IN THE REQUIREMENTS.TXT)
import os
import io
from PIL import Image
import numpy as np
import tensorflow as tf


settings = get_settings()

class ImageInfo():
filename = None
img_type = None
img_dim_x = None
img_dim_y = None



class MyService(Service):
# TODO: 2. CHANGE THIS DESCRIPTION
"""
My service model
Not Safe For Work (NSFW) image classification service.
Caution: the current version of the service is able to detect nudity, sexual and hentai content.
It is not able to detect profanity and violence for now.
"""

# Any additional fields must be excluded for Pydantic to work
model: object = Field(exclude=True)
base_model: object = Field(exclude=True)
nsfw_model: object = Field(exclude=True)
logger: object = Field(exclude=True)

# Some class attributes
SUB_CAT_NAMES = ['nsfw_cartoon', 'nsfw_nudity', 'nsfw_porn', 'nsfw_suggestive',
'safe_cartoon', 'safe_general', 'safe_person']
CAT_NAMES = ['nsfw', 'safe']
IMG_SIZE = 224
CHANNELS = 3
N_CLASSES = len(SUB_CAT_NAMES)
WEIGHT_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'model/TL_MNV2_finetune_224_B32_AD1E10-5_NSFW-V2.1_DA2.hdf5')


def __init__(self):
super().__init__(
# TODO: 3. CHANGE THE SERVICE NAME AND SLUG
name="My Service",
slug="my-service",
name="NSFW Image Classification",
slug="nsfw-image-classification",
url=settings.service_url,
summary=api_summary,
description=api_description,
Expand All @@ -52,36 +79,111 @@ def __init__(self):
],
tags=[
ExecutionUnitTag(
name=ExecutionUnitTagName.IMAGE_PROCESSING,
acronym=ExecutionUnitTagAcronym.IMAGE_PROCESSING,
name=ExecutionUnitTagName.IMAGE_RECOGNITION,
acronym=ExecutionUnitTagAcronym.IMAGE_RECOGNITION,
),
],
has_ai=False,
has_ai=True,
)
self.logger = get_logger(settings)
# read the ai model here
self.logger.info("Loading the base model...")
self.base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(
include_top=False,
weights='imagenet',
input_shape=(self.IMG_SIZE, self.IMG_SIZE, self.CHANNELS))
self.logger.info("Base model loaded. Recreating structure of model before loading fine-tuned weights...")
self.nsfw_model = tf.keras.Sequential([
self.base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(16),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(self.N_CLASSES),
tf.keras.layers.Activation('softmax')
], name='MNV2')
self.logger.info('Loading weights from file: {}'.format(self.WEIGHT_FILE))
self.nsfw_model.load_weights(self.WEIGHT_FILE)
self.logger.info('Weights loaded.')


def build_score_list(self, scores, class_names):
"""
Build a list of Score objects (see definition above) from a numpy array of
float (scores) and a list of class names.
:param scores: the numpy array of scores
:param class_names: the list of class names to be associated to the scores
:return: a list of Score objects
"""
score_list = []
for i, score in enumerate(scores):
s = (class_names[i], score) # each score is a tuple (category_name, score)
score_list.append(s)
return score_list


def predict_from_image(self, image_tensor):
"""
Compute the predicted classes from an image tensor by calling the model.predict()
on that tensor. The method decides on the winning
category by summing the scores on the range of sub-category scores. Then it takes
the arg max to elect the winner of the categories and sub-categories.
:param image: the image from which to predict
:return: a tuple with the winner category, the winner sub-category, the list of
category scores and the list of sub-category scores
"""
image_tensor = np.array([image_tensor])
self.logger.info("Image tensor shape: {}".format(image_tensor.shape))
pred_sub_cat = self.nsfw_model.predict(image_tensor, verbose=0)
self.logger.info("Prediction shape: {}".format(pred_sub_cat.shape))
self.logger.info("Prediction: {}".format(pred_sub_cat))
pred_cat = np.zeros((1, 2))
pred_cat[:, 0] = np.sum(pred_sub_cat[:, :4], axis=1) # do the sum of nsfw sub-categories to compute nsfw pred
pred_cat[:, 1] = np.sum(pred_sub_cat[:, 4:], axis=1) # same thing for safe
# in the end, the pred_cat is a similar output tensor as pred_sub_cat but on 2 main categories nsfw and safe
# let's use the first prediction for now (disregarding the fliped image)
scores_sub_cat = self.build_score_list(pred_sub_cat[0], self.SUB_CAT_NAMES)
self.logger.info("Scores sub-cat: {}".format(scores_sub_cat))
scores_cat = self.build_score_list(pred_cat[0], self.CAT_NAMES)
self.logger.info("Scores cat: {}".format(scores_cat))
winner_sub_cat = pred_sub_cat.argmax(axis=1)[0]
winner_cat = pred_cat.argmax(axis=1)[0]
# get the prediction as category and subcategory
prediction_subcategory = self.SUB_CAT_NAMES[winner_sub_cat]
prediction_category = self.CAT_NAMES[winner_cat]
return prediction_category, prediction_subcategory, scores_cat, scores_sub_cat


# TODO: 5. CHANGE THE PROCESS METHOD (CORE OF THE SERVICE)
def process(self, data):
# NOTE that the data is a dictionary with the keys being the field names set in the data_in_fields
raw = data["image"].data
input_type = data["image"].type
# ... do something with the raw data
buff = io.BytesIO(raw)
image = Image.open(buff)
image = image.resize((self.IMG_SIZE, self.IMG_SIZE), Image.LANCZOS)
image_tensor = np.array(image)
self.logger.info("Image shape: {}".format(image_tensor.shape))
image_tensor = tf.keras.applications.mobilenet.preprocess_input(image_tensor)
self.logger.info("Image shape after preprocessing: {}".format(image_tensor.shape))
prediction_category, prediction_subcategory, scores_cat, scores_sub_cat = self.predict_from_image(image_tensor)

# NOTE that the result must be a dictionary with the keys being the field names set in the data_out_fields
return {
"result": TaskData(
data=...,
data={'prediction_category': prediction_category, 'prediction_subcategory': prediction_subcategory},
type=FieldDescriptionType.APPLICATION_JSON
)
}


# TODO: 6. CHANGE THE API DESCRIPTION AND SUMMARY
api_description = """My service
bla bla bla...
api_description = """
This service detects nudity, sexual and hentai content in images, or if the image is 'safe for work'.
"""
api_summary = """My service
bla bla bla...
api_summary = """
Detects between two main categories : 'nsfw' and 'safe', and detects the following sub-categories:
'nsfw_cartoon', 'nsfw_nudity', 'nsfw_porn', 'nsfw_suggestive', 'safe_cartoon', 'safe_general', 'safe_person'
"""

# Define the FastAPI application with information
Expand Down

0 comments on commit 431513c

Please sign in to comment.