Skip to content

Commit

Permalink
feat: Allow configuring device (gpu/cpu) via env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
aecio committed Jul 16, 2024
1 parent df32f75 commit a2246f7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
12 changes: 12 additions & 0 deletions bdikit/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import torch


BDIKIT_DEVICE: str = os.getenv("BDIKIT_DEVICE", default="cpu")


def get_device() -> str:
if BDIKIT_DEVICE == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
else:
return BDIKIT_DEVICE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from typing import List, Dict, Tuple, Optional

from bdikit.config import get_device
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(

self.unlabeled = PretrainTableDataset()
self.batch_size = batch_size
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = get_device()
self.model = self.load_checkpoint()

def load_checkpoint(self, lm: str = "roberta"):
Expand Down
5 changes: 5 additions & 0 deletions bdikit/mapping_algorithms/value_mapping/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from autofj import AutoFJ
from Levenshtein import ratio
import pandas as pd
import flair
import torch
from bdikit.config import get_device

flair.device = torch.device(get_device())


class ValueMatch(NamedTuple):
Expand Down
5 changes: 1 addition & 4 deletions examples/getting-started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
"outputs": [],
"source": [
"import bdikit as bdi\n",
"import pandas as pd\n",
"\n",
"import flair, torch\n",
"flair.device = torch.device(\"cpu\") "
"import pandas as pd"
]
},
{
Expand Down

0 comments on commit a2246f7

Please sign in to comment.