Home

Awesome

Transformer LicenseGo.Dev referenceTravis CIGo Report Card

Overview

transformer is pure Go package to facilitate applying Natural Language Processing (NLP) models train/test and inference in Go.

This package is in active mode of building and there are many changes ahead. Hence you can use it with your complete own risk. The package will be considered as stable when version 1.0 is released.

transformer is heavily inspired by and based on the popular Python HuggingFace Transformers. It's also influenced by Rust version - rust-bert. In fact, all pre-trained models for Rust are compatible to import to this Go transformer package as both rust-bert's dependency Pytorch Rust binding - tch-rs and Go binding gotch are built with similar principles.

transformer is part of an ambitious goal (together with tokenizer and gotch) to bring more AI/deep-learning tools to Gophers so that they can stick to the language they love and good at and build faster software in production.

Dependencies

2 main dependencies are:

Prerequisites and installation

Basic example

    import (
        "fmt"
        "log"

        "github.com/sugarme/gotch"
        ts "github.com/sugarme/gotch/tensor"
        "github.com/sugarme/tokenizer"

        "github.com/sugarme/transformer/bert"
    )

    func main() {
        var config *bert.BertConfig = new(bert.BertConfig)
        if err := transformer.LoadConfig(config, "bert-base-uncased", nil); err != nil {
            log.Fatal(err)
        }

        var model *bert.BertForMaskedLM = new(bert.BertForMaskedLM)
        if err := transformer.LoadModel(model, "bert-base-uncased", config, nil, gotch.CPU); err != nil {
            log.Fatal(err)
        }

        var tk *bert.Tokenizer = bert.NewTokenizer()
        if err := tk.Load("bert-base-uncased", nil); err != nil{
            log.Fatal(err)
        }

        sentence1 := "Looks like one [MASK] is missing"
        sentence2 := "It was a very nice and [MASK] day"

        var input []tokenizer.EncodeInput
        input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
        input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))

        encodings, err := tk.EncodeBatch(input, true)
        if err != nil {
            log.Fatal(err)
        }

        var maxLen int = 0
        for _, en := range encodings {
            if len(en.Ids) > maxLen {
                maxLen = len(en.Ids)
            }
        }

        var tensors []ts.Tensor
        for _, en := range encodings {
            var tokInput []int64 = make([]int64, maxLen)
            for i := 0; i < len(en.Ids); i++ {
                tokInput[i] = int64(en.Ids[i])
            }

            tensors = append(tensors, ts.TensorFrom(tokInput))
        }

        inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
        var output ts.Tensor
        ts.NoGrad(func() {
            output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
        })
        index1 := output.MustGet(0).MustGet(4).MustArgmax(0, false, false).Int64Values()[0]
        index2 := output.MustGet(1).MustGet(7).MustArgmax(0, false, false).Int64Values()[0]

        got1, ok := tk.IdToToken(int(index1))
        if !ok {
            fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1)
        }
        got2, ok := tk.IdToToken(int(index2))
        if !ok {
            fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2)
        }

        fmt.Println(got1)
        fmt.Println(got2)
        
        // Output:
        // person
        // pleasant
    }

Getting Started

License

transformer is Apache 2.0 licensed.

Acknowledgement