Skip to content

Commit

Permalink
closes #1
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanner Kvarfordt authored and Tanner Kvarfordt committed Nov 23, 2021
1 parent d3983e0 commit 4965902
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 32 deletions.
37 changes: 30 additions & 7 deletions conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,41 +32,64 @@ type ConverstationalInputs struct {
// Used with ConversationalRequest
type ConversationalParameters struct {
// (Default: None). Integer to define the minimum length in tokens of the output summary.
MinLength int `json:"min_length,omitempty"`
MinLength *int `json:"min_length,omitempty"`

// (Default: None). Integer to define the maximum length in tokens of the output summary.
MaxLength int `json:"max_length,omitempty"`
MaxLength *int `json:"max_length,omitempty"`

// (Default: None). Integer to define the top tokens considered within the sample operation to create
// new text.
TopK int `json:"top_k,omitempty"`
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation.
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is
// greater than top_p.
TopP float64 `json:"top_p,omitempty"`
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 mens top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty float64 `json:"repetitionpenalty,omitempty"`
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit.
MaxTime float64 `json:"maxtime,omitempty"`
MaxTime *float64 `json:"maxtime,omitempty"`
}

func NewConversationalParameters() *ConversationalParameters {
return &ConversationalParameters{}
}

func (c *ConversationalParameters) SetMinLength(minLength int) *ConversationalParameters {
c.MinLength = &minLength
return c
}
func (c *ConversationalParameters) SetMaxLength(maxLength int) *ConversationalParameters {
c.MaxLength = &maxLength
return c
}
func (c *ConversationalParameters) SetTopK(topK int) *ConversationalParameters {
c.TopK = &topK
return c
}
func (c *ConversationalParameters) SetTopP(topP float64) *ConversationalParameters {
c.TopP = &topP
return c
}
func (c *ConversationalParameters) SetTempurature(temperature float64) *ConversationalParameters {
c.Temperature = &temperature
return c
}
func (c *ConversationalParameters) SetRepetitionPenalty(penalty float64) *ConversationalParameters {
c.RepetitionPenalty = &penalty
return c
}
func (c *ConversationalParameters) SetMaxTime(maxTime float64) *ConversationalParameters {
c.MaxTime = &maxTime
return c
}

// Response structure for the conversational endpoint
type ConversationalResponse struct {
Expand Down
22 changes: 10 additions & 12 deletions conversational_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ func TestMarshalUnMarshalConversationalRequest(t *testing.T) {
Inputs: hfapigo.ConverstationalInputs{
Text: "Hey my name is Julien! How are you?",
},
Parameters: *(&hfapigo.ConversationalParameters{
MinLength: 10,
TopK: 0,
TopP: 0.12345,
MaxTime: 0.2,
}).SetTempurature(0.2345),
Parameters: *(&hfapigo.ConversationalParameters{}).
SetTempurature(0.2345).
SetMinLength(10).
SetMaxLength(20).
SetRepetitionPenalty(20),
Options: *hfapigo.NewOptions().SetWaitForModel(true),
}

Expand Down Expand Up @@ -72,12 +71,11 @@ func TestConversationalRequest(t *testing.T) {
Inputs: hfapigo.ConverstationalInputs{
Text: "Hey my name is Julien! How are you?",
},
Parameters: *(&hfapigo.ConversationalParameters{
MinLength: 10,
TopK: 0,
TopP: 0.12345,
MaxTime: 0.2,
}).SetTempurature(0.2345),
Parameters: *(&hfapigo.ConversationalParameters{}).
SetTempurature(0.2345).
SetMinLength(10).
SetMaxLength(20).
SetRepetitionPenalty(20),
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
if err != nil {
Expand Down
37 changes: 30 additions & 7 deletions summarization.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,64 @@ type SummarizationRequest struct {
// Used with SummarizationRequest
type SummarizationParameters struct {
// (Default: None). Integer to define the minimum length in tokens of the output summary.
MinLength int `json:"min_length,omitempty"`
MinLength *int `json:"min_length,omitempty"`

// (Default: None). Integer to define the maximum length in tokens of the output summary.
MaxLength int `json:"max_length,omitempty"`
MaxLength *int `json:"max_length,omitempty"`

// (Default: None). Integer to define the top tokens considered within the sample operation to create
// new text.
TopK int `json:"top_k,omitempty"`
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation.
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is
// greater than top_p.
TopP float64 `json:"top_p,omitempty"`
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 mens top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty float64 `json:"repetitionpenalty,omitempty"`
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit.
MaxTime float64 `json:"maxtime,omitempty"`
MaxTime *float64 `json:"maxtime,omitempty"`
}

func NewSummarizationParameters() *SummarizationParameters {
return &SummarizationParameters{}
}

func (sp *SummarizationParameters) SetMinLength(minLength int) *SummarizationParameters {
sp.MinLength = &minLength
return sp
}
func (sp *SummarizationParameters) SetMaxLength(maxLength int) *SummarizationParameters {
sp.MaxLength = &maxLength
return sp
}
func (sp *SummarizationParameters) SetTopK(topK int) *SummarizationParameters {
sp.TopK = &topK
return sp
}
func (sp *SummarizationParameters) SetTopP(topP float64) *SummarizationParameters {
sp.TopP = &topP
return sp
}
func (sp *SummarizationParameters) SetTempurature(temperature float64) *SummarizationParameters {
sp.Temperature = &temperature
return sp
}
func (sp *SummarizationParameters) SetRepetitionPenalty(penalty float64) *SummarizationParameters {
sp.RepetitionPenalty = &penalty
return sp
}
func (sp *SummarizationParameters) SetMaxTime(maxTime float64) *SummarizationParameters {
sp.MaxTime = &maxTime
return sp
}

// Response structure for the summarization endpoint
type SummarizationResponse struct {
Expand Down
12 changes: 6 additions & 6 deletions summarization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ func TestMarshalUnmarshalSummarizationRequest(t *testing.T) {
{
srExpected := hfapigo.SummarizationRequest{
Inputs: []string{"Foobar", "baz"},
Parameters: *(&hfapigo.SummarizationParameters{
MaxLength: 5,
TopK: 20,
TopP: 1.25,
RepetitionPenalty: 0.215,
}).SetTempurature(92.123456789),
Parameters: *(&hfapigo.SummarizationParameters{}).
SetTempurature(92.123456789).
SetMinLength(5).
SetMaxLength(10).
SetTopK(30).
SetTopP(55.505),
Options: *hfapigo.NewOptions().SetUseCache(false),
}

Expand Down

0 comments on commit 4965902

Please sign in to comment.