Skip to content

Commit

Permalink
0
Browse files Browse the repository at this point in the history
  • Loading branch information
wjy030522 committed Dec 18, 2024
1 parent ccbdba9 commit 827e416
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 19 deletions.
40 changes: 38 additions & 2 deletions exercises/22_class_template/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ struct Tensor4D {
Tensor4D(unsigned int const shape_[4], T const *data_) {
unsigned int size = 1;
// TODO: 填入正确的 shape 并计算 size
std::memcpy(shape, shape_, sizeof(shape)); // 保存 shape
for (int i = 0; i < 4; ++i) {
size *= shape[i]; // 计算 size
}
data = new T[size];
std::memcpy(data, data_, size * sizeof(T));
}
Expand All @@ -27,9 +31,41 @@ struct Tensor4D {
// 例如,`this` 形状为 `[1, 2, 3, 4]`,`others` 形状为 `[1, 2, 1, 4]`,
// 则 `this` 与 `others` 相加时,3 个形状为 `[1, 2, 1, 4]` 的子张量各自与 `others` 对应项相加。
Tensor4D &operator+=(Tensor4D const &others) {
// TODO: 实现单向广播的加法
return *this;
// 检查维度和形状是否符合广播规则
for (int i = 0; i < 4; ++i) {
if (shape[i] != others.shape[i] && others.shape[i] != 1) {
throw std::invalid_argument("Shapes are not broadcastable.");
}
}

// 计算总元素数量
unsigned int size = 1;
for (int i = 0; i < 4; ++i) {
size *= shape[i];
}

// 执行加法运算(支持广播)
for (unsigned int idx = 0; idx < size; ++idx) {
unsigned int this_idx = idx;
unsigned int other_idx = 0;
unsigned int other_stride = 1;

for (int dim = 3; dim >= 0; --dim) {
unsigned int this_dim_index = (this_idx % shape[dim]);
unsigned int other_dim_index = (others.shape[dim] == 1) ? 0 : this_dim_index;
other_idx += other_dim_index * other_stride;
this_idx /= shape[dim];
if (others.shape[dim] != 1) {
other_stride *= others.shape[dim];
}
}

data[idx] += others.data[other_idx];
}

return *this;
}

};

// ---- 不要修改以下代码 ----
Expand Down
12 changes: 9 additions & 3 deletions exercises/23_template_const/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ struct Tensor {
Tensor(unsigned int const shape_[N]) {
unsigned int size = 1;
// TODO: 填入正确的 shape 并计算 size
std::memcpy(shape, shape_, sizeof(shape));
for (unsigned int i = 0; i < N; ++i) {
size *= shape[i]; // 计算总元素数量
}
data = new T[size];
std::memset(data, 0, size * sizeof(T));
}
Expand All @@ -32,9 +36,11 @@ struct Tensor {
private:
unsigned int data_index(unsigned int const indices[N]) const {
unsigned int index = 0;
for (unsigned int i = 0; i < N; ++i) {
ASSERT(indices[i] < shape[i], "Invalid index");
// TODO: 计算 index
unsigned int stride = 1;
for (int i = N - 1; i >= 0; --i) {
ASSERT(indices[i] < shape[i], "Invalid index"); // 检查索引范围
index += indices[i] * stride;
stride *= shape[i]; // 更新 stride
}
return index;
}
Expand Down
11 changes: 8 additions & 3 deletions exercises/27_strides/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ using udim = unsigned int;
/// @return 张量每维度的访问步长
std::vector<udim> strides(std::vector<udim> const &shape) {
std::vector<udim> strides(shape.size());
// TODO: 完成函数体,根据张量形状计算张量连续存储时的步长。
// READ: 逆向迭代器 std::vector::rbegin <https://zh.cppreference.com/w/cpp/container/vector/rbegin>
// 使用逆向迭代器可能可以简化代码
if (shape.empty()) return strides;

// 计算步长,从最后一维开始
strides.back() = 1; // 最后一维的步长为 1
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1]; // 递推计算步长
}
return strides;
}


// ---- 不要修改以下代码 ----
int main(int argc, char **argv) {
ASSERT((strides({2, 3, 4}) == std::vector<udim>{12, 4, 1}), "Make this assertion pass.");
Expand Down
4 changes: 2 additions & 2 deletions exercises/30_std_unique_ptr/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ int main(int argc, char **argv) {
std::vector<const char *> answers[]{
{"fd"},
// TODO: 分析 problems[1] 中资源的生命周期,将记录填入 `std::vector`
{"rffd"},
{"rd", "rd"},
{"ffr","d"},
{"r","d","d"},
};

// ---- 不要修改以下代码 ----
Expand Down
18 changes: 9 additions & 9 deletions exercises/31_std_shared_ptr/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,36 @@ int main(int argc, char **argv) {
std::shared_ptr<int> ptrs[]{shared, shared, shared};

std::weak_ptr<int> observer = shared;
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 4, "");

ptrs[0].reset();
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 3, "");

ptrs[1] = nullptr;
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 2, "");

ptrs[2] = std::make_shared<int>(*shared);
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 1, "");

ptrs[0] = shared;
ptrs[1] = shared;
ptrs[2] = std::move(shared);
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 3, "");

std::ignore = std::move(ptrs[0]);
ptrs[1] = std::move(ptrs[1]);
ptrs[1] = std::move(ptrs[2]);
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 2, "");

shared = observer.lock();
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 3, "");

shared = nullptr;
for (auto &ptr : ptrs) ptr = nullptr;
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 0, "");

shared = observer.lock();
ASSERT(observer.use_count() == ?, "");
ASSERT(observer.use_count() == 0, "");

return 0;
}
2 changes: 2 additions & 0 deletions exercises/32_std_transform/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ int main(int argc, char **argv) {
std::vector<int> val{8, 13, 21, 34, 55};
// TODO: 调用 `std::transform`,将 `v` 中的每个元素乘以 2,并转换为字符串,存入 `ans`
// std::vector<std::string> ans
std::vector<std::string> ans(val.size());
std::transform(val.begin(), val.end(), ans.begin(), [](int x) { return std::to_string(x * 2); });
ASSERT(ans.size() == val.size(), "ans size should be equal to val size");
ASSERT(ans[0] == "16", "ans[0] should be 16");
ASSERT(ans[1] == "26", "ans[1] should be 26");
Expand Down
4 changes: 4 additions & 0 deletions exercises/33_std_accumulate/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ int main(int argc, char **argv) {
// - 连续存储;
// 的张量占用的字节数
// int size =
int size = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int>());

// 计算字节数
size *= sizeof(DataType);
ASSERT(size == 602112, "4x1x3x224x224 = 602112");
return 0;
}

0 comments on commit 827e416

Please sign in to comment.