diff --git a/core.go b/core.go index 41d7145..9cd9810 100644 --- a/core.go +++ b/core.go @@ -1,7 +1,6 @@ package time_arrow import ( - "errors" "fmt" "github.com/google/uuid" "log" @@ -40,10 +39,6 @@ type TimeArrow struct { type TimeArrows []TimeArrow -var GetData = func(group string) (TimeArrows, error) { - return nil, errors.New("GetData 未实现") -} - func isInDateSlice(t time.Time, ta TimeArrow) bool { for e := range ta.DateSlice { startStr := ta.DateSlice[e].Start @@ -213,8 +208,12 @@ func CreateDateSliceTypePlan(group string, dateSlice []DateSlice, timesOnDay []s } } -func GetHitTimeArrow(t time.Time, group string, expandTags ...string) (*TimeArrow, error) { - ta, err := GetData(group) +type TimeArrowHelper struct { + GetData func(group string) (TimeArrows, error) +} + +func (th *TimeArrowHelper) GetHitTimeArrow(t time.Time, group string, expandTags ...string) (*TimeArrow, error) { + ta, err := th.GetData(group) if err != nil { return nil, err } diff --git a/core_test.go b/core_test.go index 5a27cb2..abcc32b 100644 --- a/core_test.go +++ b/core_test.go @@ -150,6 +150,8 @@ func Test_isInExpandTags(t *testing.T) { } +var th TimeArrowHelper + func TestGetHitTimeArrow(t *testing.T) { tas := TimeArrows{ { @@ -242,7 +244,7 @@ func TestGetHitTimeArrow(t *testing.T) { }, } - GetData = func(group string) (arrows TimeArrows, e error) { + th.GetData = func(group string) (arrows TimeArrows, e error) { return tas, nil } @@ -251,7 +253,7 @@ func TestGetHitTimeArrow(t *testing.T) { panic(err) } - result, err := GetHitTimeArrow(ti, "", "一号场") + result, err := th.GetHitTimeArrow(ti, "", "一号场") if err != nil || result == nil { t.Fatal("error") panic(err) @@ -262,7 +264,7 @@ func TestGetHitTimeArrow(t *testing.T) { } ti = ti.AddDate(0, 0, -1) - result, err = GetHitTimeArrow(ti, "", "一号场") + result, err = th.GetHitTimeArrow(ti, "", "一号场") if err != nil || result == nil { t.Fatal("error") panic(err) @@ -271,7 +273,7 @@ func TestGetHitTimeArrow(t *testing.T) { if result.ExpandValue.(int) != 30 { t.Fatal("error", result.ExpandValue.(int)) } - result, err = GetHitTimeArrow(ti, "", "三号场") + result, err = th.GetHitTimeArrow(ti, "", "三号场") if err != nil || result == nil { t.Fatal("error") panic(err) @@ -286,7 +288,7 @@ func TestGetHitTimeArrow(t *testing.T) { panic(err) } - result, err = GetHitTimeArrow(ti, "", "三号场") + result, err = th.GetHitTimeArrow(ti, "", "三号场") if err != nil || result == nil { t.Fatal("error") panic(err) @@ -297,7 +299,7 @@ func TestGetHitTimeArrow(t *testing.T) { } ti = ti.Add(time.Hour) - result, err = GetHitTimeArrow(ti, "", "三号场") + result, err = th.GetHitTimeArrow(ti, "", "三号场") if err != nil || result == nil { t.Fatal("error") panic(err)