Skip to content

Commit

Permalink
[RUNTIME] Fix cache dir (#2196)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Keren Zhou <[email protected]>
  • Loading branch information
jon-chuang and Jokeren authored Aug 30, 2023
1 parent 2ff88c1 commit 9af76e7
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/triton/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,20 @@ def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")

def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)

def has_file(self, filename):
def has_file(self, filename) -> bool:
if not self.cache_dir:
return False
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))

def get_file(self, filename) -> Optional[str]:
Expand Down Expand Up @@ -80,16 +82,16 @@ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
return result

# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]):
def put_group(self, filename: str, group: Dict[str, str]) -> str:
if not self.cache_dir:
return
raise RuntimeError("Could not create or locate cache dir")
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)

def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
return
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
Expand Down

0 comments on commit 9af76e7

Please sign in to comment.