Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update char_cnn and fasttext #66

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions models/char_cnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ We experiment the model on the following datasets.
- Reuters (ModApte)
- AAPD

**It should be noted** that if version that follows the implementation of the [Char-CNN (2015)](http://papers.nips.cc/paper/5782-character-level-convolutional-networks-for-text-classification.pdf) produces an dev F1 of 0 on Reuters
## Settings

Adam is used for training.
10 changes: 10 additions & 0 deletions models/char_cnn/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ def get_args():
parser.add_argument('--epoch-decay', type=int, default=15)
parser.add_argument('--weight-decay', type=float, default=0)

parser.add_argument('--number_of_characters', type=float, default=68)
parser.add_argument('--first_kernel', type=int, default=7)
parser.add_argument('--second_kernel', type=int, default=3)
parser.add_argument('--pool_size', type=int, default=3)
parser.add_argument('--max_sentence_length', type=int, default=1000)

parser.add_argument('--using_fixed', type=bool, default=False)



parser.add_argument('--word-vectors-dir', default=os.path.join(os.pardir, 'hedwig-data', 'embeddings', 'word2vec'))
parser.add_argument('--word-vectors-file', default='GoogleNews-vectors-negative300.txt')
parser.add_argument('--save-path', type=str, default=os.path.join('model_checkpoints', 'char_cnn'))
Expand Down
43 changes: 30 additions & 13 deletions models/char_cnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,36 @@ def __init__(self, config):
self.is_cuda_enabled = config.cuda

num_conv_filters = config.num_conv_filters
output_channel = config.output_channel
output_channel = config.output_channel #this parameter is not used anymore for conv6
num_affine_neurons = config.num_affine_neurons
target_class = config.target_class
input_channel = 68
# added paremeters in the config
input_channel = config.number_of_characters # number of characters
first_kernel_size = config.first_kernel
second_kernel_size = config.second_kernel
pool_size = config.pool_size
#whether we are using the fix version of the paper
self.using_fixed = config.using_fixed

self.conv1 = nn.Conv1d(input_channel, num_conv_filters, kernel_size=7)
self.conv2 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=7)
self.conv3 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3)
self.conv4 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3)
self.conv5 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3)
self.conv6 = nn.Conv1d(num_conv_filters, output_channel, kernel_size=3)
max_sentence_length = config.max_sentence_length # maximum number of characters per sentence

self.conv1 = nn.Conv1d(input_channel, num_conv_filters, kernel_size=first_kernel_size)
self.conv2 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=first_kernel_size)
self.conv3 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=second_kernel_size)
self.conv4 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=second_kernel_size)
self.conv5 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=second_kernel_size)
if self.using_fixed:
self.conv6 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=second_kernel_size)
# due to reduction based on the convolutional neural network
temp = first_kernel_size - 1 + pool_size * (first_kernel_size - 1) + (
pool_size ** 2 * 4 * (second_kernel_size - 1))
linear_size_temp = int((max_sentence_length - temp) / (pool_size ** 3)) * num_conv_filters

self.fc1 = nn.Linear(linear_size_temp, num_affine_neurons)
else:
self.conv6 = nn.Conv1d(num_conv_filters, output_channel, kernel_size=second_kernel_size)
self.fc1 = nn.Linear(output_channel, num_affine_neurons)
self.dropout = nn.Dropout(config.dropout)
self.fc1 = nn.Linear(output_channel, num_affine_neurons)
self.fc2 = nn.Linear(num_affine_neurons, num_affine_neurons)
self.fc3 = nn.Linear(num_affine_neurons, target_class)

Expand All @@ -33,15 +49,16 @@ def forward(self, x, **kwargs):
x = x.transpose(1, 2).type(torch.cuda.FloatTensor)
else:
x = x.transpose(1, 2).type(torch.FloatTensor)

x = F.max_pool1d(F.relu(self.conv1(x)), 3)
x = F.max_pool1d(F.relu(self.conv2(x)), 3)
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = F.relu(self.conv6(x))

x = F.max_pool1d(x, x.size(2)).squeeze(2)
if self.using_fixed:
x = F.max_pool1d(F.relu(self.conv6(x)), 3)
else:
x = F.relu(self.conv6(x))
x = F.max_pool1d(x, x.size(2)).squeeze(2)
x = F.relu(self.fc1(x.view(x.size(0), -1)))
x = self.dropout(x)
x = F.relu(self.fc2(x))
Expand Down
27 changes: 27 additions & 0 deletions models/fasttext/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## Bag of Tricks for Efficient Text Classification

Implementation of [FastText (2016)](https://arxiv.org/pdf/1607.01759.pdf)

## Quick Start

To run the model on Reuters dataset, just run the following from the Castor working directory:

```
python -m models.fasttext --dataset Reuters --batch-size 128 --lr 0.001 --seed 3435
```

The best model weights will be saved in

```
models/fasttext/saves/Reuters/best_model.pt
```

To test the model, you can use the following command.

```
python -m models.char_cnn --dataset Reuters --batch_size 32 --trained-model modelsfasttext/saves/Reuters/best_model.pt --seed 3435
```

## Settings

Adam is used for training.
6 changes: 1 addition & 5 deletions models/fasttext/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def forward(self, x, **kwargs):
x = self.static_embed(x) # (batch, sent_len, embed_dim)
elif self.mode == 'non-static':
x = self.non_static_embed(x) # (batch, sent_len, embed_dim)

x = F.avg_pool2d(x, (x.shape[1], 1)).squeeze(1) # (batch, embed_dim)

x = F.avg_pool1d(x.transpose(1, 2), x.shape[1]).squeeze(2) # (batch, embed_dim)
logit = self.fc1(x) # (batch, target_size)
return logit