Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nemo_curator import in CPU only environment when GPU packages are installed. #123

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_curator/utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda12x]`
GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo-curator[cuda12x]`
or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source"""


Expand Down
10 changes: 5 additions & 5 deletions nemo_curator/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def safe_import(module, *, msg=None, alt=None):
"""A function used to import modules that may not be available

This function will attempt to import a module with the given name, but it
will not throw an ModuleNotFoundError if the module is not found. Instead, it will
will not throw an ImportError if the module is not found. Instead, it will
return a placeholder object which will raise an exception only if used.

Parameters
Expand All @@ -259,7 +259,7 @@ def safe_import(module, *, msg=None, alt=None):
"""
try:
return importlib.import_module(module)
except ModuleNotFoundError:
except ImportError:
exception_text = traceback.format_exc()
logger.debug(f"Import of {module} failed with: {exception_text}")
except Exception:
Expand Down Expand Up @@ -303,7 +303,7 @@ def safe_import_from(module, symbol, *, msg=None, alt=None):
try:
imported_module = importlib.import_module(module)
return getattr(imported_module, symbol)
except ModuleNotFoundError:
except ImportError:
exception_text = traceback.format_exc()
logger.debug(f"Import of {module} failed with: {exception_text}")
except AttributeError:
Expand Down Expand Up @@ -346,7 +346,7 @@ def gpu_only_import(module, *, alt=None):

return safe_import(
module,
msg=f"{module} is not installed in non GPU-enabled installations. {GPU_INSTALL_STRING}",
msg=f"{module} is not enabled in non GPU-enabled installations or environemnts. {GPU_INSTALL_STRING}",
alt=alt,
)

Expand Down Expand Up @@ -379,6 +379,6 @@ def gpu_only_import_from(module, symbol, *, alt=None):
return safe_import_from(
module,
symbol,
msg=f"{module}.{symbol} is not installed in non GPU-enabled installations. {GPU_INSTALL_STRING}",
msg=f"{module}.{symbol} is not enabled in non GPU-enabled installations or environments. {GPU_INSTALL_STRING}",
alt=alt,
)