From 49659021c2b105ab076e213d38b96fa022ad44f2 Mon Sep 17 00:00:00 2001 From: Tanner Kvarfordt Date: Tue, 23 Nov 2021 00:50:39 -0700 Subject: [PATCH] closes #1 --- conversational.go | 37 ++++++++++++++++++++++++++++++------- conversational_test.go | 22 ++++++++++------------ summarization.go | 37 ++++++++++++++++++++++++++++++------- summarization_test.go | 12 ++++++------ 4 files changed, 76 insertions(+), 32 deletions(-) diff --git a/conversational.go b/conversational.go index 9136d97..da478ab 100644 --- a/conversational.go +++ b/conversational.go @@ -32,19 +32,19 @@ type ConverstationalInputs struct { // Used with ConversationalRequest type ConversationalParameters struct { // (Default: None). Integer to define the minimum length in tokens of the output summary. - MinLength int `json:"min_length,omitempty"` + MinLength *int `json:"min_length,omitempty"` // (Default: None). Integer to define the maximum length in tokens of the output summary. - MaxLength int `json:"max_length,omitempty"` + MaxLength *int `json:"max_length,omitempty"` // (Default: None). Integer to define the top tokens considered within the sample operation to create // new text. - TopK int `json:"top_k,omitempty"` + TopK *int `json:"top_k,omitempty"` // (Default: None). Float to define the tokens that are within the sample` operation of text generation. // Add tokens in the sample for more probable to least probable until the sum of the probabilities is // greater than top_p. - TopP float64 `json:"top_p,omitempty"` + TopP *float64 `json:"top_p,omitempty"` // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, // 0 mens top_k=1, 100.0 is getting closer to uniform probability. @@ -52,21 +52,44 @@ type ConversationalParameters struct { // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized // to not be picked in successive generation passes. - RepetitionPenalty float64 `json:"repetitionpenalty,omitempty"` + RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. // Network can cause some overhead so it will be a soft limit. - MaxTime float64 `json:"maxtime,omitempty"` + MaxTime *float64 `json:"maxtime,omitempty"` } func NewConversationalParameters() *ConversationalParameters { return &ConversationalParameters{} } - +func (c *ConversationalParameters) SetMinLength(minLength int) *ConversationalParameters { + c.MinLength = &minLength + return c +} +func (c *ConversationalParameters) SetMaxLength(maxLength int) *ConversationalParameters { + c.MaxLength = &maxLength + return c +} +func (c *ConversationalParameters) SetTopK(topK int) *ConversationalParameters { + c.TopK = &topK + return c +} +func (c *ConversationalParameters) SetTopP(topP float64) *ConversationalParameters { + c.TopP = &topP + return c +} func (c *ConversationalParameters) SetTempurature(temperature float64) *ConversationalParameters { c.Temperature = &temperature return c } +func (c *ConversationalParameters) SetRepetitionPenalty(penalty float64) *ConversationalParameters { + c.RepetitionPenalty = &penalty + return c +} +func (c *ConversationalParameters) SetMaxTime(maxTime float64) *ConversationalParameters { + c.MaxTime = &maxTime + return c +} // Response structure for the conversational endpoint type ConversationalResponse struct { diff --git a/conversational_test.go b/conversational_test.go index 7558834..2f54af5 100644 --- a/conversational_test.go +++ b/conversational_test.go @@ -39,12 +39,11 @@ func TestMarshalUnMarshalConversationalRequest(t *testing.T) { Inputs: hfapigo.ConverstationalInputs{ Text: "Hey my name is Julien! How are you?", }, - Parameters: *(&hfapigo.ConversationalParameters{ - MinLength: 10, - TopK: 0, - TopP: 0.12345, - MaxTime: 0.2, - }).SetTempurature(0.2345), + Parameters: *(&hfapigo.ConversationalParameters{}). + SetTempurature(0.2345). + SetMinLength(10). + SetMaxLength(20). + SetRepetitionPenalty(20), Options: *hfapigo.NewOptions().SetWaitForModel(true), } @@ -72,12 +71,11 @@ func TestConversationalRequest(t *testing.T) { Inputs: hfapigo.ConverstationalInputs{ Text: "Hey my name is Julien! How are you?", }, - Parameters: *(&hfapigo.ConversationalParameters{ - MinLength: 10, - TopK: 0, - TopP: 0.12345, - MaxTime: 0.2, - }).SetTempurature(0.2345), + Parameters: *(&hfapigo.ConversationalParameters{}). + SetTempurature(0.2345). + SetMinLength(10). + SetMaxLength(20). + SetRepetitionPenalty(20), Options: *hfapigo.NewOptions().SetWaitForModel(true), }) if err != nil { diff --git a/summarization.go b/summarization.go index 16ac8eb..6dc152b 100644 --- a/summarization.go +++ b/summarization.go @@ -18,19 +18,19 @@ type SummarizationRequest struct { // Used with SummarizationRequest type SummarizationParameters struct { // (Default: None). Integer to define the minimum length in tokens of the output summary. - MinLength int `json:"min_length,omitempty"` + MinLength *int `json:"min_length,omitempty"` // (Default: None). Integer to define the maximum length in tokens of the output summary. - MaxLength int `json:"max_length,omitempty"` + MaxLength *int `json:"max_length,omitempty"` // (Default: None). Integer to define the top tokens considered within the sample operation to create // new text. - TopK int `json:"top_k,omitempty"` + TopK *int `json:"top_k,omitempty"` // (Default: None). Float to define the tokens that are within the sample` operation of text generation. // Add tokens in the sample for more probable to least probable until the sum of the probabilities is // greater than top_p. - TopP float64 `json:"top_p,omitempty"` + TopP *float64 `json:"top_p,omitempty"` // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, // 0 mens top_k=1, 100.0 is getting closer to uniform probability. @@ -38,21 +38,44 @@ type SummarizationParameters struct { // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized // to not be picked in successive generation passes. - RepetitionPenalty float64 `json:"repetitionpenalty,omitempty"` + RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. // Network can cause some overhead so it will be a soft limit. - MaxTime float64 `json:"maxtime,omitempty"` + MaxTime *float64 `json:"maxtime,omitempty"` } func NewSummarizationParameters() *SummarizationParameters { return &SummarizationParameters{} } - +func (sp *SummarizationParameters) SetMinLength(minLength int) *SummarizationParameters { + sp.MinLength = &minLength + return sp +} +func (sp *SummarizationParameters) SetMaxLength(maxLength int) *SummarizationParameters { + sp.MaxLength = &maxLength + return sp +} +func (sp *SummarizationParameters) SetTopK(topK int) *SummarizationParameters { + sp.TopK = &topK + return sp +} +func (sp *SummarizationParameters) SetTopP(topP float64) *SummarizationParameters { + sp.TopP = &topP + return sp +} func (sp *SummarizationParameters) SetTempurature(temperature float64) *SummarizationParameters { sp.Temperature = &temperature return sp } +func (sp *SummarizationParameters) SetRepetitionPenalty(penalty float64) *SummarizationParameters { + sp.RepetitionPenalty = &penalty + return sp +} +func (sp *SummarizationParameters) SetMaxTime(maxTime float64) *SummarizationParameters { + sp.MaxTime = &maxTime + return sp +} // Response structure for the summarization endpoint type SummarizationResponse struct { diff --git a/summarization_test.go b/summarization_test.go index 612c87b..432fb40 100644 --- a/summarization_test.go +++ b/summarization_test.go @@ -35,12 +35,12 @@ func TestMarshalUnmarshalSummarizationRequest(t *testing.T) { { srExpected := hfapigo.SummarizationRequest{ Inputs: []string{"Foobar", "baz"}, - Parameters: *(&hfapigo.SummarizationParameters{ - MaxLength: 5, - TopK: 20, - TopP: 1.25, - RepetitionPenalty: 0.215, - }).SetTempurature(92.123456789), + Parameters: *(&hfapigo.SummarizationParameters{}). + SetTempurature(92.123456789). + SetMinLength(5). + SetMaxLength(10). + SetTopK(30). + SetTopP(55.505), Options: *hfapigo.NewOptions().SetUseCache(false), }