Skip to content

Commit

Permalink
Merge pull request #1 from Orca-bit/develop
Browse files Browse the repository at this point in the history
rename file
  • Loading branch information
Orca-bit authored Apr 27, 2024
2 parents 5ebfb9c + 96480d0 commit 6048ea9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/SUMMARY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Summary

- [tf优化器](./chapter_1.md)
- [TF优化器](./optimizer.md)
25 changes: 23 additions & 2 deletions src/chapter_1.md → src/optimizer.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
# tf优化器
# TF优化器

`tensorflow`中优化器类继承自`optimizer.Optimizer`
```python
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
...
```
该方法的默认实现与优化器类型无关,子类可以复写以下方法改变具体行为:
```python
def _create_slots():
...
def _prepare():
...
def _apply_dense():
...
def _apply_sparse():
...
```

## ADAM优化器

Adam - A Method for Stochastic Optimization: [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) ([pdf](https://arxiv.org/pdf/1412.6980.pdf))

初始化

$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
Expand All @@ -17,10 +36,12 @@ $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
$$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
$$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$

`tensorflow`中实现了`Adam`优化器,具体实现逻辑如下:

[tensorflow 源码](https://github.com/tensorflow/tensorflow/blob/80b1605dbc7ac2f475dff03b13d7efcf295d35c4/tensorflow/python/training/adam.py#L247)
```python
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
# 主要计算逻辑
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
Expand All @@ -43,4 +64,4 @@ def _apply_sparse_shared(self, grad, var, indices, scatter_add):
var_update = state_ops.assign_sub(
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
```
```

0 comments on commit 6048ea9

Please sign in to comment.