Argsort for Rank


November 30, 2021

Background

Back to 2020, I re-implemented SASRec[1][3] and TiSASRec[2][4] for research use in http://claws.cc.gatech.edu/, inside the repo, there is such a piece of code to get the rank of target item, it could be confusing somtimes.

predictions = -model.predict(uid, user_interacted_item_seq, to_predict_item_indices)
rank = predictions.argsort().argsort()[0].item()

Here, I will explain the details on how it works, thus you can focus on more important parts like model structure when reading the code of RecSys experiments.

Explaination

The to_predict_item_indices is a list of item ids with the target item(the user really interacted with later) set as the first element(other items are randomly picked from unexplored item set most of the time).

The predictions is a list of ctr scores corresponding to each item in to_predict_item_indices, suggesting how likely they would be interacted with by the user as predicted by the model.

The negative sign in -model.predict(...) is for turning over all prediction scores so highest scored item would be ranked as 0th element if ascendingly sorted.

argsort got used just for convenience but can be confusing sometimes.

Forget the specification of argsort method, given a score list like [0.233, 0.6, 0.1] we can do argsort by:

  1. pair each score with its index so get the original index recorded, we get [(0.233, 0), (0.6, 1), (0.1, 2)]
  2. sort these pairs by scores, we get [(0.1, 2), (0.233, 0), (0.6, 1)]
  3. return the original index of each score, we get [2, 0, 1], that’s the argsort result

Now we know that the lowest score is of index 2 currently, we know the exact position of kth ranked element in current list. If taking each of them to their rank by the final index list, we would have all scores ranked, and that’s why the dirty trick arr[arr.argsort()] could have array sorted :)

After having predictions.argsort() showing where is the kth ranked item, we hope to know the rank for the target item which is of index 0 in the list currently by:

  1. make pairs again [(2 indexed, 0 ranked), (0 indexed, 1 ranked), (1 indexed, 2 ranked)]
  2. sort pairs again, this time by the index entry [(0 indexed, 1 ranked), (1 indexed, 2 ranked), (2 indexed, 0 ranked)]
  3. return the [1, 2, 0] as the rank for each score in the score list

As target item index is 0, we take its rank by predictions.argsort().argsort()[0].

That’s all! BTW, permutation.argsort().argsort() = the_same_permutation, argsort can be misleading, pls always think it as sorting pairs, that would make things easier to understand, math is powerful, as the same number can represent a lot of stuff, math is sometimes hard, as we can easily forget what that number means currently.


Reference

[1] SASRec: https://github.com/kang205/SASRec

[2] TiSASRec: https://github.com/JiachengLi1995/TiSASRec

[3] SASRec.pytorch: https://github.com/pmixer/SASRec.pytorch

[4] TiSASRec.pytorch: https://github.com/pmixer/TiSASRec.pytorch