diff --git a/pkg/models/subscription.go b/pkg/models/subscription.go index 0d77cee7a..5bbe3afba 100644 --- a/pkg/models/subscription.go +++ b/pkg/models/subscription.go @@ -107,38 +107,37 @@ func getSubscription(s *xorm.Session, entityType SubscriptionEntityType, a web.A ) } + if entityType == SubscriptionEntityTask { + cond = builder.Or( + builder.And( + builder.Eq{"entity_id": entityID}, + builder.Eq{"entity_type": SubscriptionEntityTask}, + ), + builder.And( + builder.Eq{"entity_id": builder. + Select("namespace_id"). + From("list"). + Join("INNER", "tasks", "list.id = tasks.list_id"). + Where(builder.Eq{"tasks.id": entityID}), + }, + builder.Eq{"entity_type": SubscriptionEntityNamespace}, + ), + builder.And( + builder.Eq{"entity_id": builder. + Select("list_id"). + From("tasks"). + Where(builder.Eq{"id": entityID}), + }, + builder.Eq{"entity_type": SubscriptionEntityList}, + ), + ) + } + _, err = s. Where("user_id = ?", u.ID). And(cond). Get(subscription) return - // - //subscriptions := []*Subscription{} - //err = query.Find(&subscriptions) - //if err != nil { - // return - //} - // - //subscriptionsByType := make(map[SubscriptionEntityType]*Subscription, 3) - //for _, sb := range subscriptions { - // if sb.EntityType == SubscriptionEntityNamespace && entityType == SubscriptionEntityNamespace { - // return sb, nil - // } - // subscriptionsByType[sb.EntityType] = sb - //} - // - //if entityType == SubscriptionEntityList { - // // If there's a subscription for this list, return it directly - // sb, exists := subscriptionsByType[SubscriptionEntityList] - // if exists { - // return sb, nil - // } - // - // // Otherwise check if there's one for the namespace the list belongs to - // //sb, exists := subscriptionsByType[SubscriptionEntityNamespace] - //} - // - //return } // Create subscribes the current user to an entity diff --git a/pkg/models/subscription_test.go b/pkg/models/subscription_test.go index 2d22f5871..de52ba5db 100644 --- a/pkg/models/subscription_test.go +++ b/pkg/models/subscription_test.go @@ -275,6 +275,7 @@ func TestSubscriptionGet(t *testing.T) { sub, err := getSubscription(s, SubscriptionEntityNamespace, u, 6) assert.NoError(t, err) assert.NotNil(t, sub) + assert.Equal(t, int64(2), sub.ID) }) t.Run("list", func(t *testing.T) { db.LoadAndAssertFixtures(t) @@ -284,6 +285,7 @@ func TestSubscriptionGet(t *testing.T) { sub, err := getSubscription(s, SubscriptionEntityList, u, 12) assert.NoError(t, err) assert.NotNil(t, sub) + assert.Equal(t, int64(3), sub.ID) }) t.Run("task", func(t *testing.T) { db.LoadAndAssertFixtures(t) @@ -293,6 +295,7 @@ func TestSubscriptionGet(t *testing.T) { sub, err := getSubscription(s, SubscriptionEntityTask, u, 22) assert.NoError(t, err) assert.NotNil(t, sub) + assert.Equal(t, int64(4), sub.ID) }) }) t.Run("inherited", func(t *testing.T) { @@ -305,6 +308,7 @@ func TestSubscriptionGet(t *testing.T) { sub, err := getSubscription(s, SubscriptionEntityList, u, 6) assert.NoError(t, err) assert.NotNil(t, sub) + assert.Equal(t, int64(2), sub.ID) }) t.Run("task from namespace", func(t *testing.T) { db.LoadAndAssertFixtures(t) @@ -315,6 +319,7 @@ func TestSubscriptionGet(t *testing.T) { sub, err := getSubscription(s, SubscriptionEntityTask, u, 20) assert.NoError(t, err) assert.NotNil(t, sub) + assert.Equal(t, int64(2), sub.ID) }) t.Run("task from list", func(t *testing.T) { db.LoadAndAssertFixtures(t) @@ -325,6 +330,7 @@ func TestSubscriptionGet(t *testing.T) { sub, err := getSubscription(s, SubscriptionEntityTask, u, 21) assert.NoError(t, err) assert.NotNil(t, sub) + assert.Equal(t, int64(3), sub.ID) }) }) }