Awesome
Pytorch Negative Sampling Loss
Negative Sampling Loss implemented in PyTorch.
Usage
neg_loss = NEG_loss(num_classes, embedding_size)
optimizer = SGD(neg_loss.parameters(), 0.1)
for i in range(num_iterations):
'''
input is [batch_size] shaped tensors of Long type
while target has shape of [batch_size, window_size]
'''
input, target = next_batch(batch_size)
loss = neg_loss(input, target, num_sample)
optimizer.zero_grad()
loss.backward()
optimizer.step()
word_embeddings = neg_loss.input_embeddings()