Skip to content

Commit 4c155bd

Browse files
dan-zhengcopybara-github
authored andcommitted
Restore reverted changes.
Sync to 84444c9. PiperOrigin-RevId: 610263918
1 parent 6a30858 commit 4c155bd

File tree

7 files changed

+196
-21
lines changed

7 files changed

+196
-21
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: build
22

3-
# Trigger on push or via manual dispatch.
4-
on: [push, workflow_dispatch]
3+
# Trigger on push, pull request, or via manual dispatch.
4+
on: [push, pull_request, workflow_dispatch]
55

66
jobs:
77
build:

CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
2424
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
2525
FetchContent_MakeAvailable(highway)
2626

27-
## Note: absl meeds tp be installed by sentencepiece. This will only happen if
27+
## Note: absl needs to be installed by sentencepiece. This will only happen if
2828
## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module
2929
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
3030
FetchContent_MakeAvailable(sentencepiece)
@@ -49,7 +49,7 @@ endif()
4949

5050
# Allowable types for WEIGHT_TYPE:
5151
# float - slow, not recommended
52-
# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway
52+
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
5353
# SfpStream - 8-bit switched floating point (recommended)
5454
# NuqStream - experimental, work-in-progress
5555
option(WEIGHT_TYPE "Set weight type" "")
@@ -67,6 +67,8 @@ target_link_libraries(gemma hwy hwy_contrib sentencepiece)
6767
target_include_directories(gemma PRIVATE ./)
6868
FetchContent_GetProperties(sentencepiece)
6969
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
70+
target_compile_definitions(gemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
71+
target_compile_options(gemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
7072

7173
## Library Target
7274

@@ -76,3 +78,5 @@ set_target_properties(libgemma PROPERTIES PREFIX "")
7678
target_include_directories(libgemma PUBLIC ./)
7779
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
7880
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
81+
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
82+
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)

CMakePresets.json

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
{
2+
"version": 3,
3+
"cmakeMinimumRequired": {
4+
"major": 3,
5+
"minor": 11,
6+
"patch": 0
7+
},
8+
"configurePresets": [
9+
{
10+
"name": "__defaults__",
11+
"hidden": true,
12+
"binaryDir": "${sourceDir}/build"
13+
},
14+
{
15+
"name": "make",
16+
"inherits": "__defaults__",
17+
"displayName": "Make",
18+
"description": "Unix Makefiles",
19+
"generator": "Unix Makefiles",
20+
"binaryDir": "${sourceDir}/build"
21+
},
22+
{
23+
"name": "windows",
24+
"inherits": "__defaults__",
25+
"displayName": "Windows",
26+
"description": "Visual Studio 2022 with Clang/LLVM frontend",
27+
"generator": "Visual Studio 17 2022",
28+
"toolset": "ClangCL",
29+
"condition": {
30+
"type": "equals",
31+
"lhs": "${hostSystemName}",
32+
"rhs": "Windows"
33+
}
34+
}
35+
],
36+
"buildPresets": [
37+
{
38+
"name": "__defaults__",
39+
"hidden": true,
40+
"targets": [
41+
"gemma",
42+
"libgemma"
43+
]
44+
},
45+
{
46+
"name": "make",
47+
"inherits": "__defaults__",
48+
"displayName": "Unix Makefiles",
49+
"configurePreset": "make"
50+
},
51+
{
52+
"name": "windows",
53+
"inherits": "__defaults__",
54+
"displayName": "Windows",
55+
"configuration": "Release",
56+
"configurePreset": "windows"
57+
}
58+
]
59+
}

README.md

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ Before starting, you should have installed:
5555
least C++17.
5656
- `tar` for extracting archives from Kaggle.
5757

58+
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
59+
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
60+
command line with
61+
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
62+
63+
```sh
64+
winget install --id Kitware.CMake
65+
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
66+
```
67+
5868
### Step 1: Obtain model weights and tokenizer from Kaggle
5969

6070
Visit [the Gemma model page on
@@ -107,6 +117,7 @@ runtime, create a build directory and generate the build files using `cmake`
107117
from the top-level project directory. For the 8-bit switched floating point
108118
weights (sfp), run cmake with no options:
109119

120+
#### Unix-like Platforms
110121
```sh
111122
cmake -B build
112123
```
@@ -126,17 +137,18 @@ your weights, you can enter the `build/` directory and run `make` to build the
126137
`./gemma` executable:
127138

128139
```sh
129-
cd build
130-
make -j [number of parallel threads to use] gemma
140+
# Configure `build` directory
141+
cmake --preset make
142+
143+
# Build project using make
144+
cmake --build --preset make -j [number of parallel threads to use]
131145
```
132146

133147
Replace `[number of parallel threads to use]` with a number - the number of
134-
cores available on your system is a reasonable heuristic.
135-
136-
For example, `make -j4 gemma` will build using 4 threads. If this is successful,
137-
you should now have a `gemma` executable in the `build/` directory. If the
138-
`nproc` command is available, you can use `make -j$(nproc) gemma` as a
139-
reasonable default for the number of threads.
148+
cores available on your system is a reasonable heuristic. For example,
149+
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
150+
available, you can use `make -j$(nproc) gemma` as a reasonable default
151+
for the number of threads.
140152

141153
If you aren't sure of the right value for the `-j` flag, you can simply run
142154
`make gemma` instead and it should still build the `./gemma` executable.
@@ -145,6 +157,20 @@ If you aren't sure of the right value for the `-j` flag, you can simply run
145157
> On Windows Subsystem for Linux (WSL) users should set the number of
146158
> parallel threads to 1. Using a larger number may result in errors.
147159
160+
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
161+
162+
#### Windows
163+
164+
```sh
165+
# Configure `build` directory
166+
cmake --preset windows
167+
168+
# Build project using Visual Studio Build Tools
169+
cmake --build --preset windows -j [number of parallel threads to use]
170+
```
171+
172+
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
173+
148174
### Step 4: Run
149175

150176
You can now run `gemma` from inside the `build/` directory.

compression/blob_store.cc

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
// copybara:import_next_line:gemma_cpp
1717
#include "compression/blob_store.h"
1818

19-
#include <fcntl.h> // open
2019
#include <stdint.h>
2120
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
2221
#include <sys/stat.h> // O_RDONLY
23-
#include <unistd.h> // read, close
22+
#include <fcntl.h> // open
23+
#if HWY_OS_WIN
24+
#include <io.h> // read, write, close
25+
#include <fileapi.h>
26+
#else
27+
#include <unistd.h> // read, write, close
28+
#endif
2429

2530
#include <atomic>
2631
#include <vector>
@@ -30,6 +35,54 @@
3035
#include "hwy/contrib/thread_pool/thread_pool.h"
3136
#include "hwy/detect_compiler_arch.h"
3237

38+
namespace {
39+
#if HWY_OS_WIN
40+
41+
// pread is not supported on Windows
42+
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
43+
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
44+
if (file == INVALID_HANDLE_VALUE) {
45+
return -1;
46+
}
47+
48+
OVERLAPPED overlapped = {0};
49+
overlapped.Offset = offset & 0xFFFFFFFF;
50+
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
51+
52+
DWORD bytes_read;
53+
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
54+
if (GetLastError() != ERROR_HANDLE_EOF) {
55+
return -1;
56+
}
57+
}
58+
59+
return bytes_read;
60+
}
61+
62+
// pwrite is not supported on Windows
63+
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
64+
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
65+
if (file == INVALID_HANDLE_VALUE) {
66+
return -1;
67+
}
68+
69+
OVERLAPPED overlapped = {0};
70+
overlapped.Offset = offset & 0xFFFFFFFF;
71+
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
72+
73+
DWORD bytes_written;
74+
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
75+
if (GetLastError() != ERROR_HANDLE_EOF) {
76+
return -1;
77+
}
78+
}
79+
80+
return bytes_written;
81+
}
82+
83+
#endif
84+
}
85+
3386
namespace gcpp {
3487

3588
hwy::uint128_t MakeKey(const char* string) {
@@ -64,19 +117,30 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
64117
}
65118
}
66119

120+
67121
struct IO {
68122
// Returns size in bytes or 0.
69123
static uint64_t FileSize(const char* filename) {
70124
int fd = open(filename, O_RDONLY);
71-
if (fd >= 0) {
72-
const off_t size = lseek(fd, 0, SEEK_END);
73-
HWY_ASSERT(close(fd) != -1);
74-
if (size != static_cast<off_t>(-1)) {
75-
return static_cast<uint64_t>(size);
76-
}
125+
if (fd < 0) {
126+
return 0;
77127
}
78128

79-
return 0;
129+
#if HWY_OS_WIN
130+
const int64_t size = _lseeki64(fd, 0, SEEK_END);
131+
HWY_ASSERT(close(fd) != -1);
132+
if (size < 0) {
133+
return 0;
134+
}
135+
#else
136+
const off_t size = lseek(fd, 0, SEEK_END);
137+
HWY_ASSERT(close(fd) != -1);
138+
if (size == static_cast<off_t>(-1)) {
139+
return 0;
140+
}
141+
#endif
142+
143+
return static_cast<uint64_t>(size);
80144
}
81145

82146
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
@@ -252,7 +316,14 @@ class BlobStore {
252316
#pragma pack(pop)
253317

254318
BlobError BlobReader::Open(const char* filename) {
319+
#if HWY_OS_WIN
320+
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
321+
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, flags, nullptr);
322+
if (file == INVALID_HANDLE_VALUE) return __LINE__;
323+
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
324+
#else
255325
fd_ = open(filename, O_RDONLY);
326+
#endif
256327
if (fd_ < 0) return __LINE__;
257328

258329
#if _POSIX_C_SOURCE >= 200112L
@@ -330,7 +401,14 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
330401
keys_.data(), blobs_.data(), keys_.size());
331402

332403
// Create/replace existing file.
404+
#if HWY_OS_WIN
405+
DWORD flags = FILE_ATTRIBUTE_NORMAL;
406+
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, flags, nullptr);
407+
if (file == INVALID_HANDLE_VALUE) return __LINE__;
408+
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
409+
#else
333410
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
411+
#endif
334412
if (fd < 0) return __LINE__;
335413

336414
std::atomic_flag err = ATOMIC_FLAG_INIT;
@@ -341,6 +419,7 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
341419
err.test_and_set();
342420
}
343421
});
422+
HWY_ASSERT(close(fd) != -1);
344423
if (err.test_and_set()) return __LINE__;
345424
return 0;
346425
}

run.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
144144
return;
145145
}
146146

147+
if (prompt_string == "%c" || prompt_string == "%C") {
148+
abs_pos = 0;
149+
continue;
150+
}
151+
147152
if (model.model_training == ModelTraining::GEMMA_IT) {
148153
// For instruction-tuned models: add control tokens.
149154
prompt_string = "<start_of_turn>user\n" + prompt_string +

util/app.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
1919
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
2020

21+
#if HWY_OS_LINUX
2122
#include <sched.h>
23+
#endif
2224
#include <stddef.h>
2325

2426
#include <algorithm> // std::clamp

0 commit comments

Comments
 (0)