Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows CPU-based execution #235

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dm_haiku==0.0.12
jax==0.4.25
numpy==1.26.4
sentencepiece==0.2.0
louiehelm marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 23 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import logging, os
louiehelm marked this conversation as resolved.
Show resolved Hide resolved

from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model

# Fall back to using CPU execution if less than 8 GPUs
# ONLY MEANT FOR DEVELOPERS WITH 384GB RAM
# CURRENTLY TOO SLOW FOR MEANINGFUL INFERENCE WORKLOADS
#
# Set True to run model on CPU only
USE_CPU_ONLY = False

if USE_CPU_ONLY:
# Simulate 8 devices via CPUs
xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_force_host_platform_device_count=8"
os.environ["XLA_FLAGS"] = xla_flags
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Suppress warnings about unused backends
logging.getLogger("jax._src.xla_bridge").addFilter(logging.Filter("Unable to initialize backend"))
# Suppress false warnings about stuck processes
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("This thread has been waiting for"))
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("Thread is unstuck"))
# Suppress warnings about slow compiling
logging.getLogger("slow_operation_alarm").addFilter(logging.Filter("Very slow compile"))


CKPT_PATH = "./checkpoints/"

Expand Down