Argsort for Rank
Background
Back to 2020, I re-implemented SASRec[1][3] and TiSASRec[2][4] for research use in http://claws.cc.gatech.edu/, inside the repo, there is such a piece of code to get the rank of target item, it could be confusing somtimes.
predictions = -model.predict(uid, user_interacted_item_seq, to_predict_item_indices)
rank = predictions.argsort().argsort()[0].item()
Here, I will explain the details on how it works, thus you can focus on more important parts like model structure when reading the code of RecSys experiments.
Explaination
The to_predict_item_indices
is a list of item ids with the target item(the user really interacted with later) set as the first element(other items are randomly picked from unexplored item set most of the time).
The predictions
is a list of ctr scores corresponding to each item in to_predict_item_indices
, suggesting how likely they would be interacted with by the user as predicted by the model.
The negative sign in -model.predict(...)
is for turning over all prediction scores so highest scored item would be ranked as 0th element if ascendingly sorted.
argsort
got used just for convenience but can be confusing sometimes.
Forget the specification of argsort
method, given a score list like [0.233, 0.6, 0.1] we can do argsort
by:
- pair each score with its index so get the original index recorded, we get [(0.233, 0), (0.6, 1), (0.1, 2)]
- sort these pairs by scores, we get [(0.1, 2), (0.233, 0), (0.6, 1)]
- return the original index of each score, we get [2, 0, 1], that’s the argsort result
Now we know that the lowest score is of index 2 currently, we know the exact position of kth
ranked element in current list. If taking each of them to their rank by the final index list, we would have all scores ranked, and that’s why the dirty trick arr[arr.argsort()]
could have array sorted :)
After having predictions.argsort()
showing where is the kth
ranked item, we hope to know the rank for the target item which is of index 0 in the list currently by:
- make pairs again [(2 indexed, 0 ranked), (0 indexed, 1 ranked), (1 indexed, 2 ranked)]
- sort pairs again, this time by the index entry [(0 indexed, 1 ranked), (1 indexed, 2 ranked), (2 indexed, 0 ranked)]
- return the [1, 2, 0] as the rank for each score in the score list
As target item index is 0, we take its rank by predictions.argsort().argsort()[0]
.
That’s all! BTW, permutation.argsort().argsort() = the_same_permutation
, argsort can be misleading, pls always think it as sorting pairs, that would make things easier to understand, math is powerful, as the same number can represent a lot of stuff, math is sometimes hard, as we can easily forget what that number means currently.
Reference
[1] SASRec: https://github.com/kang205/SASRec
[2] TiSASRec: https://github.com/JiachengLi1995/TiSASRec
[3] SASRec.pytorch: https://github.com/pmixer/SASRec.pytorch
[4] TiSASRec.pytorch: https://github.com/pmixer/TiSASRec.pytorch