Skip to content

Commit

Permalink
Merge pull request #1 from avbodas/abdev
Browse files Browse the repository at this point in the history
Code cleanup
  • Loading branch information
ttrigui committed May 29, 2024
2 parents 5e8bfe0 + a245751 commit 0e7b4aa
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 40 deletions.
13 changes: 7 additions & 6 deletions VideoRAGQnA/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ Visual RAG is a framework that retrives video based on provided user prompt. It

## Prerequisites

There are 10 example videos present in ```files/videos``` along with their description generated by open-source vision model.
There are 10 example videos present in ```video_ingest/videos``` along with their description generated by open-source vision model.
If you want these visual RAG to work on your own videos, make sure it matches below format.

## File Structure

```bash
files/
video_ingest/
.
├── scene_description
│ ├── op_10_0320241830.mp4.txt
Expand Down Expand Up @@ -52,7 +52,8 @@ files/
Install pip requirements

```bash
pip3 install -r VideoRAGQnA/requirements.txt
cd VideoRAGQnA
pip3 install -r docs/requirements.txt
```

The current framework supports both Chroma DB and Intel's VDMS, use either of them,
Expand All @@ -72,12 +73,12 @@ docker run -d -p 55555:55555 intellabs/vdms:latest

Update your choice of db and port in ```config.yaml```.

Generating Image embeddigns and store them into selected db, specify config file location and video input location
Generating Image embeddings and store them into selected db, specify config file location and video input location
```bash
python3 VideoRAGQnA/embedding/generate_store_embeddings.py VideoRAGQnA/docs/config.yaml VideoRAGQnA/video_ingest/videos/
python3 embedding/generate_store_embeddings.py docs/config.yaml video_ingest/videos/
```

**Web UI Video RAG**
```bash
streamlit run video-rag-ui.py --server.address 0.0.0.0 --server.port 50055
```
```
14 changes: 7 additions & 7 deletions VideoRAGQnA/docs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
# Path to all videos
videos: VideoRAGQnA/video_ingest/videos/
videos: video_ingest/videos/
# Path to video description generated by open-source vision models (ex. video-llama, video-llava, etc.)
description: VideoRAGQnA/video_ingest/scene_description/
description: video_ingest/scene_description/
# Do you want to extract frames of videos (True if not done already, else False)
generate_frames: True
# Do you wnat to generate image embeddings?
embed_frames: True
# Path to store extracted frames
image_output_dir: VideoRAGQnA/video_ingest/frames/
image_output_dir: video_ingest/frames/
# Path to store metadata files
meta_output_dir: VideoRAGQnA/video_ingest/frame_metadata/
meta_output_dir: video_ingest/frame_metadata/
# Number of frames to extract per second,
# if 24 fps, and this value is 2, then it will extract 12th and 24th frame
number_of_frames_per_second: 2

vector_db:
choice_of_db: 'vdms' #'chroma' # #Supported databases [vdms, chroma]
host: 10.190.167.193
port: 55556 #8000 #
host: 0.0.0.0
port: 55555 #8000 #

# LLM path
model_path: VideoRAGQnA/ckpt/llama-2-7b-chat-hf
model_path: meta-llama/Llama-2-7b-chat-hf
4 changes: 3 additions & 1 deletion VideoRAGQnA/docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ streamlit
metafunctions
sentence-transformers
accelerate
vdms
vdms
tzlocal
dateparser
6 changes: 4 additions & 2 deletions VideoRAGQnA/embedding/generate_store_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Add the parent directory of the current script to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
VECTORDB_SERVICE_HOST_IP = os.getenv("VECTORDB_SERVICE_HOST_IP", "0.0.0.0")


# sys.path.append(os.path.abspath('../utils'))
Expand Down Expand Up @@ -150,16 +151,17 @@ def retrieval_testing():
meta_output_dir = config['meta_output_dir']
N = config['number_of_frames_per_second']

host = config['vector_db']['host']
host = VECTORDB_SERVICE_HOST_IP
port = int(config['vector_db']['port'])
selected_db = config['vector_db']['choice_of_db']

# Creating DB
print ('Creating DB with text and image embedding support, \nIt may take few minutes to download and load all required models if you are running for first time.')
print('Connect to {} at {}:{}'.format(selected_db, host, port))

vs = db.VS(host, port, selected_db)

generate_image_embeddings(selected_db)

retrieval_testing()

4 changes: 2 additions & 2 deletions VideoRAGQnA/utils/prompt_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from jinja2 import Environment, BaseLoader

PROMPT = open("VideoRAGQnA/utils/prompt_template.jinja2").read().strip()
PROMPT = open("utils/prompt_template.jinja2").read().strip()

def get_formatted_prompt(scene, prompt):
env = Environment(loader=BaseLoader())
template = env.from_string(PROMPT)
return template.render(scene=scene, prompt=prompt)
return template.render(scene=scene, prompt=prompt)
34 changes: 12 additions & 22 deletions VideoRAGQnA/video-rag-ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,25 @@
from embedding.vector_stores import db
import time
import torch
import streamlit as st

import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import set_seed

from transformers import TextIteratorStreamer
from typing import Any, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import threading
from transformers import set_seed
from utils import config_reader as reader
from utils import prompt_handler as ph
# from vector_stores import db
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "")

set_seed(22)

if 'config' not in st.session_state.keys():
st.session_state.config = reader.read_config('VideoRAGQnA/docs/config.yaml')
st.session_state.config = reader.read_config('docs/config.yaml')

config = st.session_state.config

Expand Down Expand Up @@ -51,11 +50,12 @@

@st.cache_resource
def load_models():
#print("HF Token: ", HUGGINGFACEHUB_API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float32, device_map='auto', trust_remote_code=True,
model_path, torch_dtype=torch.float32, device_map='auto', trust_remote_code=True, token=HUGGINGFACEHUB_API_TOKEN
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HUGGINGFACEHUB_API_TOKEN)
tokenizer.padding_size = 'right'
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)

Expand Down Expand Up @@ -248,22 +248,12 @@ def display_messages():
'Find similar videos',
'Man wearing glasses',
'People reading item description',
'Man wearing khaki pants',
'Man laughing',
'Black tshirt guy holding red basket',
'Man holding red shopping basket',
'Man wearing blue shirt',
'Man putting object into his pocket',
'Was there any shoplifting reported?',
'Was there any shoplifting reported today?',
'Was there any shoplifting reported in the last 6 hours?',
'Was there any shoplifting reported last Sunday?',
'Was there any shoplifting reported last Monday?',
'Have there been instances of shoplifting?',
'Have there been instances of shoplifting last Friday?',
'Have there been any instances of theft or shoplifting in the last 30 minutes?',
'Have there been any instances of theft or shoplifting in the last 48 hours?',
'Have there been any instances of theft or shoplifting in the last 72 hours?',
'Was there any person wearing a blue shirt seen today?',
'Was there any person wearing a blue shirt seen in the last 6 hours?',
'Was there any person wearing a blue shirt seen last Sunday?',
'Was a person wearing glasses seen in the last 30 minutes?',
'Was a person wearing glasses seen in the last 72 hours?',
),
key='example_video'
)
Expand All @@ -290,4 +280,4 @@ def display_messages():

with col1:
display_messages()
handle_message()
handle_message()

0 comments on commit 0e7b4aa

Please sign in to comment.