New upstream version 0.33.8
Erik Johnston
5 years ago
22 | 22 | - run: docker push matrixdotorg/synapse:latest |
23 | 23 | - run: docker push matrixdotorg/synapse:latest-py3 |
24 | 24 | sytestpy2: |
25 | machine: true | |
25 | docker: | |
26 | - image: matrixdotorg/sytest-synapsepy2 | |
27 | working_directory: /src | |
26 | 28 | steps: |
27 | 29 | - checkout |
28 | - run: docker pull matrixdotorg/sytest-synapsepy2 | |
29 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2 | |
30 | - run: /synapse_sytest.sh | |
30 | 31 | - store_artifacts: |
31 | path: ~/project/logs | |
32 | path: /logs | |
32 | 33 | destination: logs |
33 | 34 | - store_test_results: |
34 | path: logs | |
35 | path: /logs | |
35 | 36 | sytestpy2postgres: |
36 | machine: true | |
37 | docker: | |
38 | - image: matrixdotorg/sytest-synapsepy2 | |
39 | working_directory: /src | |
37 | 40 | steps: |
38 | 41 | - checkout |
39 | - run: docker pull matrixdotorg/sytest-synapsepy2 | |
40 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2 | |
42 | - run: POSTGRES=1 /synapse_sytest.sh | |
41 | 43 | - store_artifacts: |
42 | path: ~/project/logs | |
44 | path: /logs | |
43 | 45 | destination: logs |
44 | 46 | - store_test_results: |
45 | path: logs | |
47 | path: /logs | |
46 | 48 | sytestpy2merged: |
47 | machine: true | |
49 | docker: | |
50 | - image: matrixdotorg/sytest-synapsepy2 | |
51 | working_directory: /src | |
48 | 52 | steps: |
49 | 53 | - checkout |
50 | 54 | - run: bash .circleci/merge_base_branch.sh |
51 | - run: docker pull matrixdotorg/sytest-synapsepy2 | |
52 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2 | |
55 | - run: /synapse_sytest.sh | |
53 | 56 | - store_artifacts: |
54 | path: ~/project/logs | |
57 | path: /logs | |
55 | 58 | destination: logs |
56 | 59 | - store_test_results: |
57 | path: logs | |
58 | ||
60 | path: /logs | |
59 | 61 | sytestpy2postgresmerged: |
60 | machine: true | |
62 | docker: | |
63 | - image: matrixdotorg/sytest-synapsepy2 | |
64 | working_directory: /src | |
61 | 65 | steps: |
62 | 66 | - checkout |
63 | 67 | - run: bash .circleci/merge_base_branch.sh |
64 | - run: docker pull matrixdotorg/sytest-synapsepy2 | |
65 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2 | |
68 | - run: POSTGRES=1 /synapse_sytest.sh | |
66 | 69 | - store_artifacts: |
67 | path: ~/project/logs | |
70 | path: /logs | |
68 | 71 | destination: logs |
69 | 72 | - store_test_results: |
70 | path: logs | |
73 | path: /logs | |
71 | 74 | |
72 | 75 | sytestpy3: |
73 | machine: true | |
76 | docker: | |
77 | - image: matrixdotorg/sytest-synapsepy3 | |
78 | working_directory: /src | |
74 | 79 | steps: |
75 | 80 | - checkout |
76 | - run: docker pull matrixdotorg/sytest-synapsepy3 | |
77 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy3 | |
81 | - run: /synapse_sytest.sh | |
78 | 82 | - store_artifacts: |
79 | path: ~/project/logs | |
83 | path: /logs | |
80 | 84 | destination: logs |
81 | 85 | - store_test_results: |
82 | path: logs | |
86 | path: /logs | |
83 | 87 | sytestpy3postgres: |
84 | machine: true | |
88 | docker: | |
89 | - image: matrixdotorg/sytest-synapsepy3 | |
90 | working_directory: /src | |
85 | 91 | steps: |
86 | 92 | - checkout |
87 | - run: docker pull matrixdotorg/sytest-synapsepy3 | |
88 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3 | |
93 | - run: POSTGRES=1 /synapse_sytest.sh | |
89 | 94 | - store_artifacts: |
90 | path: ~/project/logs | |
95 | path: /logs | |
91 | 96 | destination: logs |
92 | 97 | - store_test_results: |
93 | path: logs | |
98 | path: /logs | |
94 | 99 | sytestpy3merged: |
95 | machine: true | |
100 | docker: | |
101 | - image: matrixdotorg/sytest-synapsepy3 | |
102 | working_directory: /src | |
96 | 103 | steps: |
97 | 104 | - checkout |
98 | 105 | - run: bash .circleci/merge_base_branch.sh |
99 | - run: docker pull matrixdotorg/sytest-synapsepy3 | |
100 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy3 | |
106 | - run: /synapse_sytest.sh | |
101 | 107 | - store_artifacts: |
102 | path: ~/project/logs | |
108 | path: /logs | |
103 | 109 | destination: logs |
104 | 110 | - store_test_results: |
105 | path: logs | |
111 | path: /logs | |
106 | 112 | sytestpy3postgresmerged: |
107 | machine: true | |
113 | docker: | |
114 | - image: matrixdotorg/sytest-synapsepy3 | |
115 | working_directory: /src | |
108 | 116 | steps: |
109 | 117 | - checkout |
110 | 118 | - run: bash .circleci/merge_base_branch.sh |
111 | - run: docker pull matrixdotorg/sytest-synapsepy3 | |
112 | - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3 | |
119 | - run: POSTGRES=1 /synapse_sytest.sh | |
113 | 120 | - store_artifacts: |
114 | path: ~/project/logs | |
121 | path: /logs | |
115 | 122 | destination: logs |
116 | 123 | - store_test_results: |
117 | path: logs | |
124 | path: /logs | |
118 | 125 | |
119 | 126 | workflows: |
120 | 127 | version: 2 |
15 | 15 | GITBASE="develop" |
16 | 16 | else |
17 | 17 | # Get the reference, using the GitHub API |
18 | GITBASE=`curl -q https://api.github.com/repos/matrix-org/synapse/pulls/${CIRCLE_PR_NUMBER} | jq -r '.base.ref'` | |
18 | GITBASE=`wget -O- https://api.github.com/repos/matrix-org/synapse/pulls/${CIRCLE_PR_NUMBER} | jq -r '.base.ref'` | |
19 | 19 | fi |
20 | 20 | |
21 | 21 | # Show what we are before |
30 | 30 | git merge --no-edit origin/$GITBASE |
31 | 31 | |
32 | 32 | # Show what we are after. |
33 | git show -s⏎ | |
33 | git show -s |
0 | 0 | sudo: false |
1 | 1 | language: python |
2 | 2 | |
3 | # tell travis to cache ~/.cache/pip | |
4 | cache: pip | |
3 | cache: | |
4 | directories: | |
5 | # we only bother to cache the wheels; parts of the http cache get | |
6 | # invalidated every build (because they get served with a max-age of 600 | |
7 | # seconds), which means that we end up re-uploading the whole cache for | |
8 | # every build, which is time-consuming In any case, it's not obvious that | |
9 | # downloading the cache from S3 would be much faster than downloading the | |
10 | # originals from pypi. | |
11 | # | |
12 | - $HOME/.cache/pip/wheels | |
5 | 13 | |
6 | before_script: | |
7 | - git remote set-branches --add origin develop | |
8 | - git fetch origin develop | |
14 | # don't clone the whole repo history, one commit will do | |
15 | git: | |
16 | depth: 1 | |
17 | ||
18 | # only build branches we care about (PRs are built seperately) | |
19 | branches: | |
20 | only: | |
21 | - master | |
22 | - develop | |
23 | - /^release-v/ | |
9 | 24 | |
10 | 25 | matrix: |
11 | 26 | fast_finish: true |
13 | 28 | - python: 2.7 |
14 | 29 | env: TOX_ENV=packaging |
15 | 30 | |
16 | - python: 2.7 | |
17 | env: TOX_ENV=pep8 | |
31 | - python: 3.6 | |
32 | env: TOX_ENV="pep8,check_isort" | |
18 | 33 | |
19 | 34 | - python: 2.7 |
20 | 35 | env: TOX_ENV=py27 |
38 | 53 | services: |
39 | 54 | - postgresql |
40 | 55 | |
41 | - python: 3.6 | |
42 | env: TOX_ENV=check_isort | |
43 | ||
44 | - python: 3.6 | |
56 | - # we only need to check for the newsfragment if it's a PR build | |
57 | if: type = pull_request | |
58 | python: 3.6 | |
45 | 59 | env: TOX_ENV=check-newsfragment |
60 | script: | |
61 | - git remote set-branches --add origin develop | |
62 | - git fetch origin develop | |
63 | - tox -e $TOX_ENV | |
46 | 64 | |
47 | 65 | install: |
48 | 66 | - pip install tox |
0 | Synapse 0.33.8 (2018-11-01) | |
1 | =========================== | |
2 | ||
3 | No significant changes. | |
4 | ||
5 | ||
6 | Synapse 0.33.8rc2 (2018-10-31) | |
7 | ============================== | |
8 | ||
9 | Bugfixes | |
10 | -------- | |
11 | ||
12 | - Searches that request profile info now no longer fail with a 500. Fixes | |
13 | a regression in 0.33.8rc1. ([\#4122](https://github.com/matrix-org/synapse/issues/4122)) | |
14 | ||
15 | ||
16 | Synapse 0.33.8rc1 (2018-10-29) | |
17 | ============================== | |
18 | ||
19 | Features | |
20 | -------- | |
21 | ||
22 | - Servers with auto-join rooms will now automatically create those rooms when the first user registers ([\#3975](https://github.com/matrix-org/synapse/issues/3975)) | |
23 | - Add config option to control alias creation ([\#4051](https://github.com/matrix-org/synapse/issues/4051)) | |
24 | - The register_new_matrix_user script is now ported to Python 3. ([\#4085](https://github.com/matrix-org/synapse/issues/4085)) | |
25 | - Configure Docker image to listen on both ipv4 and ipv6. ([\#4089](https://github.com/matrix-org/synapse/issues/4089)) | |
26 | ||
27 | ||
28 | Bugfixes | |
29 | -------- | |
30 | ||
31 | - Fix HTTP error response codes for federated group requests. ([\#3969](https://github.com/matrix-org/synapse/issues/3969)) | |
32 | - Fix issue where Python 3 users couldn't paginate /publicRooms ([\#4046](https://github.com/matrix-org/synapse/issues/4046)) | |
33 | - Fix URL previewing to work in Python 3.7 ([\#4050](https://github.com/matrix-org/synapse/issues/4050)) | |
34 | - synctl will use the right python executable to run worker processes ([\#4057](https://github.com/matrix-org/synapse/issues/4057)) | |
35 | - Manhole now works again on Python 3, instead of failing with a "couldn't match all kex parts" when connecting. ([\#4060](https://github.com/matrix-org/synapse/issues/4060), [\#4067](https://github.com/matrix-org/synapse/issues/4067)) | |
36 | - Fix some metrics being racy and causing exceptions when polled by Prometheus. ([\#4061](https://github.com/matrix-org/synapse/issues/4061)) | |
37 | - Fix bug which prevented email notifications from being sent unless an absolute path was given for `email_templates`. ([\#4068](https://github.com/matrix-org/synapse/issues/4068)) | |
38 | - Correctly account for cpu usage by background threads ([\#4074](https://github.com/matrix-org/synapse/issues/4074)) | |
39 | - Fix race condition where config defined reserved users were not being added to | |
40 | the monthly active user list prior to the homeserver reactor firing up ([\#4081](https://github.com/matrix-org/synapse/issues/4081)) | |
41 | - Fix bug which prevented backslashes being used in event field filters ([\#4083](https://github.com/matrix-org/synapse/issues/4083)) | |
42 | ||
43 | ||
44 | Internal Changes | |
45 | ---------------- | |
46 | ||
47 | - Add information about the [matrix-docker-ansible-deploy](https://github.com/spantaleev/matrix-docker-ansible-deploy) playbook ([\#3698](https://github.com/matrix-org/synapse/issues/3698)) | |
48 | - Add initial implementation of new state resolution algorithm ([\#3786](https://github.com/matrix-org/synapse/issues/3786)) | |
49 | - Reduce database load when fetching state groups ([\#4011](https://github.com/matrix-org/synapse/issues/4011)) | |
50 | - Various cleanups in the federation client code ([\#4031](https://github.com/matrix-org/synapse/issues/4031)) | |
51 | - Run the CircleCI builds in docker containers ([\#4041](https://github.com/matrix-org/synapse/issues/4041)) | |
52 | - Only colourise synctl output when attached to tty ([\#4049](https://github.com/matrix-org/synapse/issues/4049)) | |
53 | - Refactor room alias creation code ([\#4063](https://github.com/matrix-org/synapse/issues/4063)) | |
54 | - Make the Python scripts in the top-level scripts folders meet pep8 and pass flake8. ([\#4068](https://github.com/matrix-org/synapse/issues/4068)) | |
55 | - The README now contains example for the Caddy web server. Contributed by steamp0rt. ([\#4072](https://github.com/matrix-org/synapse/issues/4072)) | |
56 | - Add psutil as an explicit dependency ([\#4073](https://github.com/matrix-org/synapse/issues/4073)) | |
57 | - Clean up threading and logcontexts in pushers ([\#4075](https://github.com/matrix-org/synapse/issues/4075)) | |
58 | - Correctly manage logcontexts during startup to fix some "Unexpected logging context" warnings ([\#4076](https://github.com/matrix-org/synapse/issues/4076)) | |
59 | - Give some more things logcontexts ([\#4077](https://github.com/matrix-org/synapse/issues/4077)) | |
60 | - Clean up some bits of code which were flagged by the linter ([\#4082](https://github.com/matrix-org/synapse/issues/4082)) | |
61 | ||
62 | ||
0 | 63 | Synapse 0.33.7 (2018-10-18) |
1 | 64 | =========================== |
2 | 65 | |
9 | 72 | `email.template_dir` is either configured to point at a directory where you |
10 | 73 | have installed customised templates, or leave it unset to use the default |
11 | 74 | templates. |
12 | ||
13 | The configuration parser will try to detect the situation where | |
14 | `email.template_dir` is incorrectly set to `res/templates` and do the right | |
15 | thing, but will warn about this. | |
16 | 75 | |
17 | 76 | Synapse 0.33.7rc2 (2018-10-17) |
18 | 77 | ============================== |
172 | 172 | Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a |
173 | 173 | Dockerfile to automate a synapse server in a single Docker image, at |
174 | 174 | https://hub.docker.com/r/avhost/docker-matrix/tags/ |
175 | ||
176 | Slavi Pantaleev has created an Ansible playbook, | |
177 | which installs the offical Docker image of Matrix Synapse | |
178 | along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.). | |
179 | For more details, see | |
180 | https://github.com/spantaleev/matrix-docker-ansible-deploy | |
175 | 181 | |
176 | 182 | Configuring Synapse |
177 | 183 | ------------------- |
650 | 656 | |
651 | 657 | It is recommended to put a reverse proxy such as |
652 | 658 | `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, |
653 | `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or | |
659 | `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_, | |
660 | `Caddy <https://caddyserver.com/docs/proxy>`_ or | |
654 | 661 | `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of |
655 | 662 | doing so is that it means that you can expose the default https port (443) to |
656 | 663 | Matrix clients without needing to run Synapse with root privileges. |
681 | 688 | } |
682 | 689 | } |
683 | 690 | |
684 | and an example apache configuration may look like:: | |
691 | an example Caddy configuration might look like:: | |
692 | ||
693 | matrix.example.com { | |
694 | proxy /_matrix http://localhost:8008 { | |
695 | transparent | |
696 | } | |
697 | } | |
698 | ||
699 | and an example Apache configuration might look like:: | |
685 | 700 | |
686 | 701 | <VirtualHost *:443> |
687 | 702 | SSLEngine on |
46 | 46 | # You may need to specify a port (eg, :8448) if your server is not |
47 | 47 | # configured on port 443. |
48 | 48 | curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:" |
49 | ||
50 | Upgrading to v0.33.7 | |
51 | ==================== | |
52 | ||
53 | This release removes the example email notification templates from | |
54 | ``res/templates`` (they are now internal to the python package). This should | |
55 | only affect you if you (a) deploy your Synapse instance from a git checkout or | |
56 | a github snapshot URL, and (b) have email notifications enabled. | |
57 | ||
58 | If you have email notifications enabled, you should ensure that | |
59 | ``email.template_dir`` is either configured to point at a directory where you | |
60 | have installed customised templates, or leave it unset to use the default | |
61 | templates. | |
49 | 62 | |
50 | 63 | Upgrading to v0.27.3 |
51 | 64 | ==================== |
20 | 20 | {% if not SYNAPSE_NO_TLS %} |
21 | 21 | - |
22 | 22 | port: 8448 |
23 | bind_addresses: ['0.0.0.0'] | |
23 | bind_addresses: ['::'] | |
24 | 24 | type: http |
25 | 25 | tls: true |
26 | 26 | x_forwarded: false |
33 | 33 | |
34 | 34 | - port: 8008 |
35 | 35 | tls: false |
36 | bind_addresses: ['0.0.0.0'] | |
36 | bind_addresses: ['::'] | |
37 | 37 | type: http |
38 | 38 | x_forwarded: false |
39 | 39 |
0 | 0 | #!/usr/bin/env python |
1 | 1 | |
2 | 2 | import argparse |
3 | ||
3 | import getpass | |
4 | 4 | import sys |
5 | 5 | |
6 | 6 | import bcrypt |
7 | import getpass | |
8 | ||
9 | 7 | import yaml |
10 | 8 | |
11 | 9 | bcrypt_rounds=12 |
51 | 49 | password = prompt_for_pass() |
52 | 50 | |
53 | 51 | print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) |
54 |
35 | 35 | |
36 | 36 | import argparse |
37 | 37 | import logging |
38 | ||
38 | import os | |
39 | import shutil | |
39 | 40 | import sys |
40 | ||
41 | import os | |
42 | ||
43 | import shutil | |
44 | 41 | |
45 | 42 | from synapse.rest.media.v1.filepath import MediaFilePaths |
46 | 43 | |
76 | 73 | if not os.path.exists(original_file): |
77 | 74 | logger.warn( |
78 | 75 | "Original for %s/%s (%s) does not exist", |
79 | origin_server, file_id, original_file, | |
76 | origin_server, | |
77 | file_id, | |
78 | original_file, | |
80 | 79 | ) |
81 | 80 | else: |
82 | 81 | mkdir_and_move( |
83 | original_file, | |
84 | dest_paths.remote_media_filepath(origin_server, file_id), | |
82 | original_file, dest_paths.remote_media_filepath(origin_server, file_id) | |
85 | 83 | ) |
86 | 84 | |
87 | 85 | # now look for thumbnails |
88 | original_thumb_dir = src_paths.remote_media_thumbnail_dir( | |
89 | origin_server, file_id, | |
90 | ) | |
86 | original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id) | |
91 | 87 | if not os.path.exists(original_thumb_dir): |
92 | 88 | return |
93 | 89 | |
94 | 90 | mkdir_and_move( |
95 | 91 | original_thumb_dir, |
96 | dest_paths.remote_media_thumbnail_dir(origin_server, file_id) | |
92 | dest_paths.remote_media_thumbnail_dir(origin_server, file_id), | |
97 | 93 | ) |
98 | 94 | |
99 | 95 | |
108 | 104 | |
109 | 105 | if __name__ == "__main__": |
110 | 106 | parser = argparse.ArgumentParser( |
111 | description=__doc__, | |
112 | formatter_class = argparse.RawDescriptionHelpFormatter, | |
107 | description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter | |
113 | 108 | ) |
114 | parser.add_argument( | |
115 | "-v", action='store_true', help='enable debug logging') | |
116 | parser.add_argument( | |
117 | "src_repo", | |
118 | help="Path to source content repo", | |
119 | ) | |
120 | parser.add_argument( | |
121 | "dest_repo", | |
122 | help="Path to source content repo", | |
123 | ) | |
109 | parser.add_argument("-v", action='store_true', help='enable debug logging') | |
110 | parser.add_argument("src_repo", help="Path to source content repo") | |
111 | parser.add_argument("dest_repo", help="Path to source content repo") | |
124 | 112 | args = parser.parse_args() |
125 | 113 | |
126 | 114 | logging_config = { |
127 | 115 | "level": logging.DEBUG if args.v else logging.INFO, |
128 | "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" | |
116 | "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", | |
129 | 117 | } |
130 | 118 | logging.basicConfig(**logging_config) |
131 | 119 |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | from __future__ import print_function | |
16 | 17 | |
17 | import argparse | |
18 | import getpass | |
19 | import hashlib | |
20 | import hmac | |
21 | import json | |
22 | import sys | |
23 | import urllib2 | |
24 | import yaml | |
25 | ||
26 | ||
27 | def request_registration(user, password, server_location, shared_secret, admin=False): | |
28 | req = urllib2.Request( | |
29 | "%s/_matrix/client/r0/admin/register" % (server_location,), | |
30 | headers={'Content-Type': 'application/json'} | |
31 | ) | |
32 | ||
33 | try: | |
34 | if sys.version_info[:3] >= (2, 7, 9): | |
35 | # As of version 2.7.9, urllib2 now checks SSL certs | |
36 | import ssl | |
37 | f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) | |
38 | else: | |
39 | f = urllib2.urlopen(req) | |
40 | body = f.read() | |
41 | f.close() | |
42 | nonce = json.loads(body)["nonce"] | |
43 | except urllib2.HTTPError as e: | |
44 | print "ERROR! Received %d %s" % (e.code, e.reason,) | |
45 | if 400 <= e.code < 500: | |
46 | if e.info().type == "application/json": | |
47 | resp = json.load(e) | |
48 | if "error" in resp: | |
49 | print resp["error"] | |
50 | sys.exit(1) | |
51 | ||
52 | mac = hmac.new( | |
53 | key=shared_secret, | |
54 | digestmod=hashlib.sha1, | |
55 | ) | |
56 | ||
57 | mac.update(nonce) | |
58 | mac.update("\x00") | |
59 | mac.update(user) | |
60 | mac.update("\x00") | |
61 | mac.update(password) | |
62 | mac.update("\x00") | |
63 | mac.update("admin" if admin else "notadmin") | |
64 | ||
65 | mac = mac.hexdigest() | |
66 | ||
67 | data = { | |
68 | "nonce": nonce, | |
69 | "username": user, | |
70 | "password": password, | |
71 | "mac": mac, | |
72 | "admin": admin, | |
73 | } | |
74 | ||
75 | server_location = server_location.rstrip("/") | |
76 | ||
77 | print "Sending registration request..." | |
78 | ||
79 | req = urllib2.Request( | |
80 | "%s/_matrix/client/r0/admin/register" % (server_location,), | |
81 | data=json.dumps(data), | |
82 | headers={'Content-Type': 'application/json'} | |
83 | ) | |
84 | try: | |
85 | if sys.version_info[:3] >= (2, 7, 9): | |
86 | # As of version 2.7.9, urllib2 now checks SSL certs | |
87 | import ssl | |
88 | f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) | |
89 | else: | |
90 | f = urllib2.urlopen(req) | |
91 | f.read() | |
92 | f.close() | |
93 | print "Success." | |
94 | except urllib2.HTTPError as e: | |
95 | print "ERROR! Received %d %s" % (e.code, e.reason,) | |
96 | if 400 <= e.code < 500: | |
97 | if e.info().type == "application/json": | |
98 | resp = json.load(e) | |
99 | if "error" in resp: | |
100 | print resp["error"] | |
101 | sys.exit(1) | |
102 | ||
103 | ||
104 | def register_new_user(user, password, server_location, shared_secret, admin): | |
105 | if not user: | |
106 | try: | |
107 | default_user = getpass.getuser() | |
108 | except: | |
109 | default_user = None | |
110 | ||
111 | if default_user: | |
112 | user = raw_input("New user localpart [%s]: " % (default_user,)) | |
113 | if not user: | |
114 | user = default_user | |
115 | else: | |
116 | user = raw_input("New user localpart: ") | |
117 | ||
118 | if not user: | |
119 | print "Invalid user name" | |
120 | sys.exit(1) | |
121 | ||
122 | if not password: | |
123 | password = getpass.getpass("Password: ") | |
124 | ||
125 | if not password: | |
126 | print "Password cannot be blank." | |
127 | sys.exit(1) | |
128 | ||
129 | confirm_password = getpass.getpass("Confirm password: ") | |
130 | ||
131 | if password != confirm_password: | |
132 | print "Passwords do not match" | |
133 | sys.exit(1) | |
134 | ||
135 | if admin is None: | |
136 | admin = raw_input("Make admin [no]: ") | |
137 | if admin in ("y", "yes", "true"): | |
138 | admin = True | |
139 | else: | |
140 | admin = False | |
141 | ||
142 | request_registration(user, password, server_location, shared_secret, bool(admin)) | |
143 | ||
18 | from synapse._scripts.register_new_matrix_user import main | |
144 | 19 | |
145 | 20 | if __name__ == "__main__": |
146 | parser = argparse.ArgumentParser( | |
147 | description="Used to register new users with a given home server when" | |
148 | " registration has been disabled. The home server must be" | |
149 | " configured with the 'registration_shared_secret' option" | |
150 | " set.", | |
151 | ) | |
152 | parser.add_argument( | |
153 | "-u", "--user", | |
154 | default=None, | |
155 | help="Local part of the new user. Will prompt if omitted.", | |
156 | ) | |
157 | parser.add_argument( | |
158 | "-p", "--password", | |
159 | default=None, | |
160 | help="New password for user. Will prompt if omitted.", | |
161 | ) | |
162 | admin_group = parser.add_mutually_exclusive_group() | |
163 | admin_group.add_argument( | |
164 | "-a", "--admin", | |
165 | action="store_true", | |
166 | help="Register new user as an admin. Will prompt if --no-admin is not set either.", | |
167 | ) | |
168 | admin_group.add_argument( | |
169 | "--no-admin", | |
170 | action="store_true", | |
171 | help="Register new user as a regular user. Will prompt if --admin is not set either.", | |
172 | ) | |
173 | ||
174 | group = parser.add_mutually_exclusive_group(required=True) | |
175 | group.add_argument( | |
176 | "-c", "--config", | |
177 | type=argparse.FileType('r'), | |
178 | help="Path to server config file. Used to read in shared secret.", | |
179 | ) | |
180 | ||
181 | group.add_argument( | |
182 | "-k", "--shared-secret", | |
183 | help="Shared secret as defined in server config file.", | |
184 | ) | |
185 | ||
186 | parser.add_argument( | |
187 | "server_url", | |
188 | default="https://localhost:8448", | |
189 | nargs='?', | |
190 | help="URL to use to talk to the home server. Defaults to " | |
191 | " 'https://localhost:8448'.", | |
192 | ) | |
193 | ||
194 | args = parser.parse_args() | |
195 | ||
196 | if "config" in args and args.config: | |
197 | config = yaml.safe_load(args.config) | |
198 | secret = config.get("registration_shared_secret", None) | |
199 | if not secret: | |
200 | print "No 'registration_shared_secret' defined in config." | |
201 | sys.exit(1) | |
202 | else: | |
203 | secret = args.shared_secret | |
204 | ||
205 | admin = None | |
206 | if args.admin or args.no_admin: | |
207 | admin = args.admin | |
208 | ||
209 | register_new_user(args.user, args.password, args.server_url, secret, admin) | |
21 | main() |
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | |
17 | from twisted.internet import defer, reactor | |
18 | from twisted.enterprise import adbapi | |
19 | ||
20 | from synapse.storage._base import LoggingTransaction, SQLBaseStore | |
21 | from synapse.storage.engines import create_engine | |
22 | from synapse.storage.prepare_database import prepare_database | |
23 | ||
24 | 17 | import argparse |
25 | 18 | import curses |
26 | 19 | import logging |
27 | 20 | import sys |
28 | 21 | import time |
29 | 22 | import traceback |
23 | ||
24 | from six import string_types | |
25 | ||
30 | 26 | import yaml |
31 | 27 | |
32 | from six import string_types | |
33 | ||
28 | from twisted.enterprise import adbapi | |
29 | from twisted.internet import defer, reactor | |
30 | ||
31 | from synapse.storage._base import LoggingTransaction, SQLBaseStore | |
32 | from synapse.storage.engines import create_engine | |
33 | from synapse.storage.prepare_database import prepare_database | |
34 | 34 | |
35 | 35 | logger = logging.getLogger("synapse_port_db") |
36 | 36 | |
104 | 104 | |
105 | 105 | *All* database interactions should go through this object. |
106 | 106 | """ |
107 | ||
107 | 108 | def __init__(self, db_pool, engine): |
108 | 109 | self.db_pool = db_pool |
109 | 110 | self.database_engine = engine |
134 | 135 | txn = conn.cursor() |
135 | 136 | return func( |
136 | 137 | LoggingTransaction(txn, desc, self.database_engine, [], []), |
137 | *args, **kwargs | |
138 | *args, | |
139 | **kwargs | |
138 | 140 | ) |
139 | 141 | except self.database_engine.module.DatabaseError as e: |
140 | 142 | if self.database_engine.is_deadlock(e): |
157 | 159 | def r(txn): |
158 | 160 | txn.execute(sql, args) |
159 | 161 | return txn.fetchall() |
162 | ||
160 | 163 | return self.runInteraction("execute_sql", r) |
161 | 164 | |
162 | 165 | def insert_many_txn(self, txn, table, headers, rows): |
163 | 166 | sql = "INSERT INTO %s (%s) VALUES (%s)" % ( |
164 | 167 | table, |
165 | 168 | ", ".join(k for k in headers), |
166 | ", ".join("%s" for _ in headers) | |
169 | ", ".join("%s" for _ in headers), | |
167 | 170 | ) |
168 | 171 | |
169 | 172 | try: |
170 | 173 | txn.executemany(sql, rows) |
171 | except: | |
172 | logger.exception( | |
173 | "Failed to insert: %s", | |
174 | table, | |
175 | ) | |
174 | except Exception: | |
175 | logger.exception("Failed to insert: %s", table) | |
176 | 176 | raise |
177 | 177 | |
178 | 178 | |
205 | 205 | "table_name": table, |
206 | 206 | "forward_rowid": 1, |
207 | 207 | "backward_rowid": 0, |
208 | } | |
208 | }, | |
209 | 209 | ) |
210 | 210 | |
211 | 211 | forward_chunk = 1 |
220 | 220 | table, forward_chunk, backward_chunk |
221 | 221 | ) |
222 | 222 | else: |
223 | ||
223 | 224 | def delete_all(txn): |
224 | 225 | txn.execute( |
225 | "DELETE FROM port_from_sqlite3 WHERE table_name = %s", | |
226 | (table,) | |
226 | "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) | |
227 | 227 | ) |
228 | 228 | txn.execute("TRUNCATE %s CASCADE" % (table,)) |
229 | 229 | |
231 | 231 | |
232 | 232 | yield self.postgres_store._simple_insert( |
233 | 233 | table="port_from_sqlite3", |
234 | values={ | |
235 | "table_name": table, | |
236 | "forward_rowid": 1, | |
237 | "backward_rowid": 0, | |
238 | } | |
234 | values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, | |
239 | 235 | ) |
240 | 236 | |
241 | 237 | forward_chunk = 1 |
250 | 246 | ) |
251 | 247 | |
252 | 248 | @defer.inlineCallbacks |
253 | def handle_table(self, table, postgres_size, table_size, forward_chunk, | |
254 | backward_chunk): | |
249 | def handle_table( | |
250 | self, table, postgres_size, table_size, forward_chunk, backward_chunk | |
251 | ): | |
255 | 252 | logger.info( |
256 | 253 | "Table %s: %i/%i (rows %i-%i) already ported", |
257 | table, postgres_size, table_size, | |
258 | backward_chunk+1, forward_chunk-1, | |
254 | table, | |
255 | postgres_size, | |
256 | table_size, | |
257 | backward_chunk + 1, | |
258 | forward_chunk - 1, | |
259 | 259 | ) |
260 | 260 | |
261 | 261 | if not table_size: |
270 | 270 | return |
271 | 271 | |
272 | 272 | if table in ( |
273 | "user_directory", "user_directory_search", "users_who_share_rooms", | |
273 | "user_directory", | |
274 | "user_directory_search", | |
275 | "users_who_share_rooms", | |
274 | 276 | "users_in_pubic_room", |
275 | 277 | ): |
276 | 278 | # We don't port these tables, as they're a faff and we can regenreate |
282 | 284 | # We need to make sure there is a single row, `(X, null), as that is |
283 | 285 | # what synapse expects to be there. |
284 | 286 | yield self.postgres_store._simple_insert( |
285 | table=table, | |
286 | values={"stream_id": None}, | |
287 | table=table, values={"stream_id": None} | |
287 | 288 | ) |
288 | 289 | self.progress.update(table, table_size) # Mark table as done |
289 | 290 | return |
290 | 291 | |
291 | 292 | forward_select = ( |
292 | "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" | |
293 | % (table,) | |
293 | "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,) | |
294 | 294 | ) |
295 | 295 | |
296 | 296 | backward_select = ( |
297 | "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" | |
298 | % (table,) | |
297 | "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,) | |
299 | 298 | ) |
300 | 299 | |
301 | 300 | do_forward = [True] |
302 | 301 | do_backward = [True] |
303 | 302 | |
304 | 303 | while True: |
304 | ||
305 | 305 | def r(txn): |
306 | 306 | forward_rows = [] |
307 | 307 | backward_rows = [] |
308 | 308 | if do_forward[0]: |
309 | txn.execute(forward_select, (forward_chunk, self.batch_size,)) | |
309 | txn.execute(forward_select, (forward_chunk, self.batch_size)) | |
310 | 310 | forward_rows = txn.fetchall() |
311 | 311 | if not forward_rows: |
312 | 312 | do_forward[0] = False |
313 | 313 | |
314 | 314 | if do_backward[0]: |
315 | txn.execute(backward_select, (backward_chunk, self.batch_size,)) | |
315 | txn.execute(backward_select, (backward_chunk, self.batch_size)) | |
316 | 316 | backward_rows = txn.fetchall() |
317 | 317 | if not backward_rows: |
318 | 318 | do_backward[0] = False |
324 | 324 | |
325 | 325 | return headers, forward_rows, backward_rows |
326 | 326 | |
327 | headers, frows, brows = yield self.sqlite_store.runInteraction( | |
328 | "select", r | |
329 | ) | |
327 | headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) | |
330 | 328 | |
331 | 329 | if frows or brows: |
332 | 330 | if frows: |
338 | 336 | rows = self._convert_rows(table, headers, rows) |
339 | 337 | |
340 | 338 | def insert(txn): |
341 | self.postgres_store.insert_many_txn( | |
342 | txn, table, headers[1:], rows | |
343 | ) | |
339 | self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) | |
344 | 340 | |
345 | 341 | self.postgres_store._simple_update_one_txn( |
346 | 342 | txn, |
361 | 357 | return |
362 | 358 | |
363 | 359 | @defer.inlineCallbacks |
364 | def handle_search_table(self, postgres_size, table_size, forward_chunk, | |
365 | backward_chunk): | |
360 | def handle_search_table( | |
361 | self, postgres_size, table_size, forward_chunk, backward_chunk | |
362 | ): | |
366 | 363 | select = ( |
367 | 364 | "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" |
368 | 365 | " FROM event_search as es" |
372 | 369 | ) |
373 | 370 | |
374 | 371 | while True: |
372 | ||
375 | 373 | def r(txn): |
376 | txn.execute(select, (forward_chunk, self.batch_size,)) | |
374 | txn.execute(select, (forward_chunk, self.batch_size)) | |
377 | 375 | rows = txn.fetchall() |
378 | 376 | headers = [column[0] for column in txn.description] |
379 | 377 | |
401 | 399 | else: |
402 | 400 | rows_dict.append(d) |
403 | 401 | |
404 | txn.executemany(sql, [ | |
405 | ( | |
406 | row["event_id"], | |
407 | row["room_id"], | |
408 | row["key"], | |
409 | row["sender"], | |
410 | row["value"], | |
411 | row["origin_server_ts"], | |
412 | row["stream_ordering"], | |
413 | ) | |
414 | for row in rows_dict | |
415 | ]) | |
402 | txn.executemany( | |
403 | sql, | |
404 | [ | |
405 | ( | |
406 | row["event_id"], | |
407 | row["room_id"], | |
408 | row["key"], | |
409 | row["sender"], | |
410 | row["value"], | |
411 | row["origin_server_ts"], | |
412 | row["stream_ordering"], | |
413 | ) | |
414 | for row in rows_dict | |
415 | ], | |
416 | ) | |
416 | 417 | |
417 | 418 | self.postgres_store._simple_update_one_txn( |
418 | 419 | txn, |
436 | 437 | def setup_db(self, db_config, database_engine): |
437 | 438 | db_conn = database_engine.module.connect( |
438 | 439 | **{ |
439 | k: v for k, v in db_config.get("args", {}).items() | |
440 | k: v | |
441 | for k, v in db_config.get("args", {}).items() | |
440 | 442 | if not k.startswith("cp_") |
441 | 443 | } |
442 | 444 | ) |
449 | 451 | def run(self): |
450 | 452 | try: |
451 | 453 | sqlite_db_pool = adbapi.ConnectionPool( |
452 | self.sqlite_config["name"], | |
453 | **self.sqlite_config["args"] | |
454 | self.sqlite_config["name"], **self.sqlite_config["args"] | |
454 | 455 | ) |
455 | 456 | |
456 | 457 | postgres_db_pool = adbapi.ConnectionPool( |
457 | self.postgres_config["name"], | |
458 | **self.postgres_config["args"] | |
458 | self.postgres_config["name"], **self.postgres_config["args"] | |
459 | 459 | ) |
460 | 460 | |
461 | 461 | sqlite_engine = create_engine(sqlite_config) |
464 | 464 | self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) |
465 | 465 | self.postgres_store = Store(postgres_db_pool, postgres_engine) |
466 | 466 | |
467 | yield self.postgres_store.execute( | |
468 | postgres_engine.check_database | |
469 | ) | |
467 | yield self.postgres_store.execute(postgres_engine.check_database) | |
470 | 468 | |
471 | 469 | # Step 1. Set up databases. |
472 | 470 | self.progress.set_state("Preparing SQLite3") |
476 | 474 | self.setup_db(postgres_config, postgres_engine) |
477 | 475 | |
478 | 476 | self.progress.set_state("Creating port tables") |
477 | ||
479 | 478 | def create_port_table(txn): |
480 | 479 | txn.execute( |
481 | 480 | "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" |
500 | 499 | ) |
501 | 500 | |
502 | 501 | try: |
503 | yield self.postgres_store.runInteraction( | |
504 | "alter_table", alter_table | |
505 | ) | |
506 | except Exception as e: | |
502 | yield self.postgres_store.runInteraction("alter_table", alter_table) | |
503 | except Exception: | |
504 | # On Error Resume Next | |
507 | 505 | pass |
508 | 506 | |
509 | 507 | yield self.postgres_store.runInteraction( |
513 | 511 | # Step 2. Get tables. |
514 | 512 | self.progress.set_state("Fetching tables") |
515 | 513 | sqlite_tables = yield self.sqlite_store._simple_select_onecol( |
516 | table="sqlite_master", | |
517 | keyvalues={ | |
518 | "type": "table", | |
519 | }, | |
520 | retcol="name", | |
514 | table="sqlite_master", keyvalues={"type": "table"}, retcol="name" | |
521 | 515 | ) |
522 | 516 | |
523 | 517 | postgres_tables = yield self.postgres_store._simple_select_onecol( |
544 | 538 | # Step 4. Do the copying. |
545 | 539 | self.progress.set_state("Copying to postgres") |
546 | 540 | yield defer.gatherResults( |
547 | [ | |
548 | self.handle_table(*res) | |
549 | for res in setup_res | |
550 | ], | |
551 | consumeErrors=True, | |
541 | [self.handle_table(*res) for res in setup_res], consumeErrors=True | |
552 | 542 | ) |
553 | 543 | |
554 | 544 | # Step 5. Do final post-processing |
555 | 545 | yield self._setup_state_group_id_seq() |
556 | 546 | |
557 | 547 | self.progress.done() |
558 | except: | |
548 | except Exception: | |
559 | 549 | global end_error_exec_info |
560 | 550 | end_error_exec_info = sys.exc_info() |
561 | 551 | logger.exception("") |
565 | 555 | def _convert_rows(self, table, headers, rows): |
566 | 556 | bool_col_names = BOOLEAN_COLUMNS.get(table, []) |
567 | 557 | |
568 | bool_cols = [ | |
569 | i for i, h in enumerate(headers) if h in bool_col_names | |
570 | ] | |
558 | bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] | |
571 | 559 | |
572 | 560 | class BadValueException(Exception): |
573 | 561 | pass |
576 | 564 | if j in bool_cols: |
577 | 565 | return bool(col) |
578 | 566 | elif isinstance(col, string_types) and "\0" in col: |
579 | logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) | |
580 | raise BadValueException(); | |
567 | logger.warn( | |
568 | "DROPPING ROW: NUL value in table %s col %s: %r", | |
569 | table, | |
570 | headers[j], | |
571 | col, | |
572 | ) | |
573 | raise BadValueException() | |
581 | 574 | return col |
582 | 575 | |
583 | 576 | outrows = [] |
584 | 577 | for i, row in enumerate(rows): |
585 | 578 | try: |
586 | outrows.append(tuple( | |
587 | conv(j, col) | |
588 | for j, col in enumerate(row) | |
589 | if j > 0 | |
590 | )) | |
579 | outrows.append( | |
580 | tuple(conv(j, col) for j, col in enumerate(row) if j > 0) | |
581 | ) | |
591 | 582 | except BadValueException: |
592 | 583 | pass |
593 | 584 | |
615 | 606 | |
616 | 607 | return headers, [r for r in rows if r[ts_ind] < yesterday] |
617 | 608 | |
618 | headers, rows = yield self.sqlite_store.runInteraction( | |
619 | "select", r, | |
620 | ) | |
609 | headers, rows = yield self.sqlite_store.runInteraction("select", r) | |
621 | 610 | |
622 | 611 | rows = self._convert_rows("sent_transactions", headers, rows) |
623 | 612 | |
638 | 627 | txn.execute( |
639 | 628 | "SELECT rowid FROM sent_transactions WHERE ts >= ?" |
640 | 629 | " ORDER BY rowid ASC LIMIT 1", |
641 | (yesterday,) | |
630 | (yesterday,), | |
642 | 631 | ) |
643 | 632 | |
644 | 633 | rows = txn.fetchall() |
656 | 645 | "table_name": "sent_transactions", |
657 | 646 | "forward_rowid": next_chunk, |
658 | 647 | "backward_rowid": 0, |
659 | } | |
648 | }, | |
660 | 649 | ) |
661 | 650 | |
662 | 651 | def get_sent_table_size(txn): |
663 | 652 | txn.execute( |
664 | "SELECT count(*) FROM sent_transactions" | |
665 | " WHERE ts >= ?", | |
666 | (yesterday,) | |
653 | "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) | |
667 | 654 | ) |
668 | 655 | size, = txn.fetchone() |
669 | 656 | return int(size) |
670 | 657 | |
671 | remaining_count = yield self.sqlite_store.execute( | |
672 | get_sent_table_size | |
673 | ) | |
658 | remaining_count = yield self.sqlite_store.execute(get_sent_table_size) | |
674 | 659 | |
675 | 660 | total_count = remaining_count + inserted_rows |
676 | 661 | |
679 | 664 | @defer.inlineCallbacks |
680 | 665 | def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): |
681 | 666 | frows = yield self.sqlite_store.execute_sql( |
682 | "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), | |
683 | forward_chunk, | |
667 | "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk | |
684 | 668 | ) |
685 | 669 | |
686 | 670 | brows = yield self.sqlite_store.execute_sql( |
687 | "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), | |
688 | backward_chunk, | |
671 | "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk | |
689 | 672 | ) |
690 | 673 | |
691 | 674 | defer.returnValue(frows[0][0] + brows[0][0]) |
693 | 676 | @defer.inlineCallbacks |
694 | 677 | def _get_already_ported_count(self, table): |
695 | 678 | rows = yield self.postgres_store.execute_sql( |
696 | "SELECT count(*) FROM %s" % (table,), | |
679 | "SELECT count(*) FROM %s" % (table,) | |
697 | 680 | ) |
698 | 681 | |
699 | 682 | defer.returnValue(rows[0][0]) |
716 | 699 | def _setup_state_group_id_seq(self): |
717 | 700 | def r(txn): |
718 | 701 | txn.execute("SELECT MAX(id) FROM state_groups") |
719 | next_id = txn.fetchone()[0]+1 | |
720 | txn.execute( | |
721 | "ALTER SEQUENCE state_group_id_seq RESTART WITH %s", | |
722 | (next_id,), | |
723 | ) | |
702 | next_id = txn.fetchone()[0] + 1 | |
703 | txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) | |
704 | ||
724 | 705 | return self.postgres_store.runInteraction("setup_state_group_id_seq", r) |
725 | 706 | |
726 | 707 | |
727 | 708 | ############################################## |
728 | ###### The following is simply UI stuff ###### | |
709 | # The following is simply UI stuff | |
729 | 710 | ############################################## |
730 | 711 | |
731 | 712 | |
732 | 713 | class Progress(object): |
733 | 714 | """Used to report progress of the port |
734 | 715 | """ |
716 | ||
735 | 717 | def __init__(self): |
736 | 718 | self.tables = {} |
737 | 719 | |
757 | 739 | class CursesProgress(Progress): |
758 | 740 | """Reports progress to a curses window |
759 | 741 | """ |
742 | ||
760 | 743 | def __init__(self, stdscr): |
761 | 744 | self.stdscr = stdscr |
762 | 745 | |
800 | 783 | duration = int(now) - int(self.start_time) |
801 | 784 | |
802 | 785 | minutes, seconds = divmod(duration, 60) |
803 | duration_str = '%02dm %02ds' % (minutes, seconds,) | |
786 | duration_str = '%02dm %02ds' % (minutes, seconds) | |
804 | 787 | |
805 | 788 | if self.finished: |
806 | 789 | status = "Time spent: %s (Done!)" % (duration_str,) |
813 | 796 | est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) |
814 | 797 | else: |
815 | 798 | est_remaining_str = "Unknown" |
816 | status = ( | |
817 | "Time spent: %s (est. remaining: %s)" | |
818 | % (duration_str, est_remaining_str,) | |
819 | ) | |
820 | ||
821 | self.stdscr.addstr( | |
822 | 0, 0, | |
823 | status, | |
824 | curses.A_BOLD, | |
825 | ) | |
799 | status = "Time spent: %s (est. remaining: %s)" % ( | |
800 | duration_str, | |
801 | est_remaining_str, | |
802 | ) | |
803 | ||
804 | self.stdscr.addstr(0, 0, status, curses.A_BOLD) | |
826 | 805 | |
827 | 806 | max_len = max([len(t) for t in self.tables.keys()]) |
828 | 807 | |
830 | 809 | middle_space = 1 |
831 | 810 | |
832 | 811 | items = self.tables.items() |
833 | items.sort( | |
834 | key=lambda i: (i[1]["perc"], i[0]), | |
835 | ) | |
812 | items.sort(key=lambda i: (i[1]["perc"], i[0])) | |
836 | 813 | |
837 | 814 | for i, (table, data) in enumerate(items): |
838 | 815 | if i + 2 >= rows: |
843 | 820 | color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) |
844 | 821 | |
845 | 822 | self.stdscr.addstr( |
846 | i + 2, left_margin + max_len - len(table), | |
847 | table, | |
848 | curses.A_BOLD | color, | |
823 | i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color | |
849 | 824 | ) |
850 | 825 | |
851 | 826 | size = 20 |
856 | 831 | ) |
857 | 832 | |
858 | 833 | self.stdscr.addstr( |
859 | i + 2, left_margin + max_len + middle_space, | |
834 | i + 2, | |
835 | left_margin + max_len + middle_space, | |
860 | 836 | "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), |
861 | 837 | ) |
862 | 838 | |
863 | 839 | if self.finished: |
864 | self.stdscr.addstr( | |
865 | rows - 1, 0, | |
866 | "Press any key to exit...", | |
867 | ) | |
840 | self.stdscr.addstr(rows - 1, 0, "Press any key to exit...") | |
868 | 841 | |
869 | 842 | self.stdscr.refresh() |
870 | 843 | self.last_update = time.time() |
876 | 849 | |
877 | 850 | def set_state(self, state): |
878 | 851 | self.stdscr.clear() |
879 | self.stdscr.addstr( | |
880 | 0, 0, | |
881 | state + "...", | |
882 | curses.A_BOLD, | |
883 | ) | |
852 | self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) | |
884 | 853 | self.stdscr.refresh() |
885 | 854 | |
886 | 855 | |
887 | 856 | class TerminalProgress(Progress): |
888 | 857 | """Just prints progress to the terminal |
889 | 858 | """ |
859 | ||
890 | 860 | def update(self, table, num_done): |
891 | 861 | super(TerminalProgress, self).update(table, num_done) |
892 | 862 | |
893 | 863 | data = self.tables[table] |
894 | 864 | |
895 | print "%s: %d%% (%d/%d)" % ( | |
896 | table, data["perc"], | |
897 | data["num_done"], data["total"], | |
865 | print( | |
866 | "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) | |
898 | 867 | ) |
899 | 868 | |
900 | 869 | def set_state(self, state): |
901 | print state + "..." | |
870 | print(state + "...") | |
902 | 871 | |
903 | 872 | |
904 | 873 | ############################################## |
908 | 877 | if __name__ == "__main__": |
909 | 878 | parser = argparse.ArgumentParser( |
910 | 879 | description="A script to port an existing synapse SQLite database to" |
911 | " a new PostgreSQL database." | |
880 | " a new PostgreSQL database." | |
912 | 881 | ) |
913 | 882 | parser.add_argument("-v", action='store_true') |
914 | 883 | parser.add_argument( |
915 | "--sqlite-database", required=True, | |
884 | "--sqlite-database", | |
885 | required=True, | |
916 | 886 | help="The snapshot of the SQLite database file. This must not be" |
917 | " currently used by a running synapse server" | |
887 | " currently used by a running synapse server", | |
918 | 888 | ) |
919 | 889 | parser.add_argument( |
920 | "--postgres-config", type=argparse.FileType('r'), required=True, | |
921 | help="The database config file for the PostgreSQL database" | |
890 | "--postgres-config", | |
891 | type=argparse.FileType('r'), | |
892 | required=True, | |
893 | help="The database config file for the PostgreSQL database", | |
922 | 894 | ) |
923 | 895 | parser.add_argument( |
924 | "--curses", action='store_true', | |
925 | help="display a curses based progress UI" | |
896 | "--curses", action='store_true', help="display a curses based progress UI" | |
926 | 897 | ) |
927 | 898 | |
928 | 899 | parser.add_argument( |
929 | "--batch-size", type=int, default=1000, | |
900 | "--batch-size", | |
901 | type=int, | |
902 | default=1000, | |
930 | 903 | help="The number of rows to select from the SQLite table each" |
931 | " iteration [default=1000]", | |
904 | " iteration [default=1000]", | |
932 | 905 | ) |
933 | 906 | |
934 | 907 | args = parser.parse_args() |
935 | 908 | |
936 | 909 | logging_config = { |
937 | 910 | "level": logging.DEBUG if args.v else logging.INFO, |
938 | "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" | |
911 | "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", | |
939 | 912 | } |
940 | 913 | |
941 | 914 | if args.curses: |
0 | from synapse.events import FrozenEvent | |
1 | from synapse.api.auth import Auth | |
2 | ||
3 | from mock import Mock | |
0 | from __future__ import print_function | |
4 | 1 | |
5 | 2 | import argparse |
6 | 3 | import itertools |
7 | 4 | import json |
8 | 5 | import sys |
9 | 6 | |
7 | from mock import Mock | |
8 | ||
9 | from synapse.api.auth import Auth | |
10 | from synapse.events import FrozenEvent | |
11 | ||
10 | 12 | |
11 | 13 | def check_auth(auth, auth_chain, events): |
12 | 14 | auth_chain.sort(key=lambda e: e.depth) |
13 | 15 | |
14 | auth_map = { | |
15 | e.event_id: e | |
16 | for e in auth_chain | |
17 | } | |
16 | auth_map = {e.event_id: e for e in auth_chain} | |
18 | 17 | |
19 | 18 | create_events = {} |
20 | 19 | for e in auth_chain: |
24 | 23 | for e in itertools.chain(auth_chain, events): |
25 | 24 | auth_events_list = [auth_map[i] for i, _ in e.auth_events] |
26 | 25 | |
27 | auth_events = { | |
28 | (e.type, e.state_key): e | |
29 | for e in auth_events_list | |
30 | } | |
26 | auth_events = {(e.type, e.state_key): e for e in auth_events_list} | |
31 | 27 | |
32 | 28 | auth_events[("m.room.create", "")] = create_events[e.room_id] |
33 | 29 | |
34 | 30 | try: |
35 | 31 | auth.check(e, auth_events=auth_events) |
36 | 32 | except Exception as ex: |
37 | print "Failed:", e.event_id, e.type, e.state_key | |
38 | print "Auth_events:", auth_events | |
39 | print ex | |
40 | print json.dumps(e.get_dict(), sort_keys=True, indent=4) | |
33 | print("Failed:", e.event_id, e.type, e.state_key) | |
34 | print("Auth_events:", auth_events) | |
35 | print(ex) | |
36 | print(json.dumps(e.get_dict(), sort_keys=True, indent=4)) | |
41 | 37 | # raise |
42 | print "Success:", e.event_id, e.type, e.state_key | |
38 | print("Success:", e.event_id, e.type, e.state_key) | |
39 | ||
43 | 40 | |
44 | 41 | if __name__ == '__main__': |
45 | 42 | parser = argparse.ArgumentParser() |
46 | 43 | |
47 | 44 | parser.add_argument( |
48 | 'json', | |
49 | nargs='?', | |
50 | type=argparse.FileType('r'), | |
51 | default=sys.stdin, | |
45 | 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin | |
52 | 46 | ) |
53 | 47 | |
54 | 48 | args = parser.parse_args() |
0 | from synapse.crypto.event_signing import * | |
0 | import argparse | |
1 | import hashlib | |
2 | import json | |
3 | import logging | |
4 | import sys | |
5 | ||
1 | 6 | from unpaddedbase64 import encode_base64 |
2 | 7 | |
3 | import argparse | |
4 | import hashlib | |
5 | import sys | |
6 | import json | |
8 | from synapse.crypto.event_signing import ( | |
9 | check_event_content_hash, | |
10 | compute_event_reference_hash, | |
11 | ) | |
7 | 12 | |
8 | 13 | |
9 | 14 | class dictobj(dict): |
23 | 28 | |
24 | 29 | def main(): |
25 | 30 | parser = argparse.ArgumentParser() |
26 | parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), | |
27 | default=sys.stdin) | |
31 | parser.add_argument( | |
32 | "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin | |
33 | ) | |
28 | 34 | args = parser.parse_args() |
29 | 35 | logging.basicConfig() |
30 | 36 | |
31 | 37 | event_json = dictobj(json.load(args.input_json)) |
32 | 38 | |
33 | algorithms = { | |
34 | "sha256": hashlib.sha256, | |
35 | } | |
39 | algorithms = {"sha256": hashlib.sha256} | |
36 | 40 | |
37 | 41 | for alg_name in event_json.hashes: |
38 | 42 | if check_event_content_hash(event_json, algorithms[alg_name]): |
39 | print "PASS content hash %s" % (alg_name,) | |
43 | print("PASS content hash %s" % (alg_name,)) | |
40 | 44 | else: |
41 | print "FAIL content hash %s" % (alg_name,) | |
45 | print("FAIL content hash %s" % (alg_name,)) | |
42 | 46 | |
43 | 47 | for algorithm in algorithms.values(): |
44 | 48 | name, h_bytes = compute_event_reference_hash(event_json, algorithm) |
45 | print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) | |
49 | print("Reference hash %s: %s" % (name, encode_base64(h_bytes))) | |
46 | 50 | |
47 | if __name__=="__main__": | |
51 | ||
52 | if __name__ == "__main__": | |
48 | 53 | main() |
49 |
0 | 0 | |
1 | import argparse | |
2 | import json | |
3 | import logging | |
4 | import sys | |
5 | import urllib2 | |
6 | ||
7 | import dns.resolver | |
8 | from signedjson.key import decode_verify_key_bytes, write_signing_keys | |
1 | 9 | from signedjson.sign import verify_signed_json |
2 | from signedjson.key import decode_verify_key_bytes, write_signing_keys | |
3 | 10 | from unpaddedbase64 import decode_base64 |
4 | 11 | |
5 | import urllib2 | |
6 | import json | |
7 | import sys | |
8 | import dns.resolver | |
9 | import pprint | |
10 | import argparse | |
11 | import logging | |
12 | 12 | |
13 | 13 | def get_targets(server_name): |
14 | 14 | if ":" in server_name: |
22 | 22 | except dns.resolver.NXDOMAIN: |
23 | 23 | yield (server_name, 8448) |
24 | 24 | |
25 | ||
25 | 26 | def get_server_keys(server_name, target, port): |
26 | 27 | url = "https://%s:%i/_matrix/key/v1" % (target, port) |
27 | 28 | keys = json.load(urllib2.urlopen(url)) |
32 | 33 | verify_keys[key_id] = verify_key |
33 | 34 | return verify_keys |
34 | 35 | |
36 | ||
35 | 37 | def main(): |
36 | 38 | |
37 | 39 | parser = argparse.ArgumentParser() |
38 | 40 | parser.add_argument("signature_name") |
39 | parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), | |
40 | default=sys.stdin) | |
41 | parser.add_argument( | |
42 | "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin | |
43 | ) | |
41 | 44 | |
42 | 45 | args = parser.parse_args() |
43 | 46 | logging.basicConfig() |
47 | 50 | for target, port in get_targets(server_name): |
48 | 51 | try: |
49 | 52 | keys = get_server_keys(server_name, target, port) |
50 | print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) | |
53 | print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port)) | |
51 | 54 | write_signing_keys(sys.stdout, keys.values()) |
52 | 55 | break |
53 | except: | |
56 | except Exception: | |
54 | 57 | logging.exception("Error talking to %s:%s", target, port) |
55 | 58 | |
56 | 59 | json_to_check = json.load(args.input_json) |
57 | print "Checking JSON:" | |
60 | print("Checking JSON:") | |
58 | 61 | for key_id in json_to_check["signatures"][args.signature_name]: |
59 | 62 | try: |
60 | 63 | key = keys[key_id] |
61 | 64 | verify_signed_json(json_to_check, args.signature_name, key) |
62 | print "PASS %s" % (key_id,) | |
63 | except: | |
65 | print("PASS %s" % (key_id,)) | |
66 | except Exception: | |
64 | 67 | logging.exception("Check for key %s failed" % (key_id,)) |
65 | print "FAIL %s" % (key_id,) | |
68 | print("FAIL %s" % (key_id,)) | |
66 | 69 | |
67 | 70 | |
68 | 71 | if __name__ == '__main__': |
69 | 72 | main() |
70 |
0 | import hashlib | |
1 | import json | |
2 | import sys | |
3 | import time | |
4 | ||
5 | import six | |
6 | ||
0 | 7 | import psycopg2 |
1 | 8 | import yaml |
2 | import sys | |
3 | import json | |
4 | import time | |
5 | import hashlib | |
6 | from unpaddedbase64 import encode_base64 | |
9 | from canonicaljson import encode_canonical_json | |
7 | 10 | from signedjson.key import read_signing_keys |
8 | 11 | from signedjson.sign import sign_json |
9 | from canonicaljson import encode_canonical_json | |
12 | from unpaddedbase64 import encode_base64 | |
13 | ||
14 | if six.PY2: | |
15 | db_type = six.moves.builtins.buffer | |
16 | else: | |
17 | db_type = memoryview | |
10 | 18 | |
11 | 19 | |
12 | 20 | def select_v1_keys(connection): |
38 | 46 | cursor.close() |
39 | 47 | results = {} |
40 | 48 | for server_name, key_id, key_json in rows: |
41 | results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8")) | |
49 | results.setdefault(server_name, {})[key_id] = json.loads( | |
50 | str(key_json).decode("utf-8") | |
51 | ) | |
42 | 52 | return results |
43 | 53 | |
44 | 54 | |
46 | 56 | return { |
47 | 57 | "old_verify_keys": {}, |
48 | 58 | "server_name": server_name, |
49 | "verify_keys": { | |
50 | key_id: {"key": key} | |
51 | for key_id, key in keys.items() | |
52 | }, | |
59 | "verify_keys": {key_id: {"key": key} for key_id, key in keys.items()}, | |
53 | 60 | "valid_until_ts": valid_until, |
54 | 61 | "tls_fingerprints": [fingerprint(certificate)], |
55 | 62 | } |
64 | 71 | valid_until = json["valid_until_ts"] |
65 | 72 | key_json = encode_canonical_json(json) |
66 | 73 | for key_id in json["verify_keys"]: |
67 | yield (server, key_id, "-", valid_until, valid_until, buffer(key_json)) | |
74 | yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) | |
68 | 75 | |
69 | 76 | |
70 | 77 | def main(): |
86 | 93 | |
87 | 94 | result = {} |
88 | 95 | for server in keys: |
89 | if not server in json: | |
96 | if server not in json: | |
90 | 97 | v2_json = convert_v1_to_v2( |
91 | 98 | server, valid_until, keys[server], certificates[server] |
92 | 99 | ) |
95 | 102 | |
96 | 103 | yaml.safe_dump(result, sys.stdout, default_flow_style=False) |
97 | 104 | |
98 | rows = list( | |
99 | row for server, json in result.items() | |
100 | for row in rows_v2(server, json) | |
101 | ) | |
105 | rows = list(row for server, json in result.items() for row in rows_v2(server, json)) | |
102 | 106 | |
103 | 107 | cursor = connection.cursor() |
104 | 108 | cursor.executemany( |
106 | 110 | " server_name, key_id, from_server," |
107 | 111 | " ts_added_ms, ts_valid_until_ms, key_json" |
108 | 112 | ") VALUES (%s, %s, %s, %s, %s, %s)", |
109 | rows | |
113 | rows, | |
110 | 114 | ) |
111 | 115 | connection.commit() |
112 | 116 |
0 | 0 | #! /usr/bin/python |
1 | 1 | |
2 | from __future__ import print_function | |
3 | ||
4 | import argparse | |
2 | 5 | import ast |
6 | import os | |
7 | import re | |
8 | import sys | |
9 | ||
3 | 10 | import yaml |
11 | ||
4 | 12 | |
5 | 13 | class DefinitionVisitor(ast.NodeVisitor): |
6 | 14 | def __init__(self): |
41 | 49 | functions = {name: non_empty(f) for name, f in defs['def'].items()} |
42 | 50 | classes = {name: non_empty(f) for name, f in defs['class'].items()} |
43 | 51 | result = {} |
44 | if functions: result['def'] = functions | |
45 | if classes: result['class'] = classes | |
52 | if functions: | |
53 | result['def'] = functions | |
54 | if classes: | |
55 | result['class'] = classes | |
46 | 56 | names = defs['names'] |
47 | 57 | uses = [] |
48 | 58 | for name in names.get('Load', ()): |
49 | 59 | if name not in names.get('Param', ()) and name not in names.get('Store', ()): |
50 | 60 | uses.append(name) |
51 | 61 | uses.extend(defs['attrs']) |
52 | if uses: result['uses'] = uses | |
62 | if uses: | |
63 | result['uses'] = uses | |
53 | 64 | result['names'] = names |
54 | 65 | result['attrs'] = defs['attrs'] |
55 | 66 | return result |
94 | 105 | |
95 | 106 | |
96 | 107 | if __name__ == '__main__': |
97 | import sys, os, argparse, re | |
98 | 108 | |
99 | 109 | parser = argparse.ArgumentParser(description='Find definitions.') |
100 | 110 | parser.add_argument( |
104 | 114 | "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" |
105 | 115 | ) |
106 | 116 | parser.add_argument( |
107 | "--pattern", action="append", metavar="REGEXP", | |
108 | help="Search for a pattern" | |
109 | ) | |
110 | parser.add_argument( | |
111 | "directories", nargs='+', metavar="DIR", | |
112 | help="Directories to search for definitions" | |
113 | ) | |
114 | parser.add_argument( | |
115 | "--referrers", default=0, type=int, | |
116 | help="Include referrers up to the given depth" | |
117 | ) | |
118 | parser.add_argument( | |
119 | "--referred", default=0, type=int, | |
120 | help="Include referred down to the given depth" | |
121 | ) | |
122 | parser.add_argument( | |
123 | "--format", default="yaml", | |
124 | help="Output format, one of 'yaml' or 'dot'" | |
117 | "--pattern", action="append", metavar="REGEXP", help="Search for a pattern" | |
118 | ) | |
119 | parser.add_argument( | |
120 | "directories", | |
121 | nargs='+', | |
122 | metavar="DIR", | |
123 | help="Directories to search for definitions", | |
124 | ) | |
125 | parser.add_argument( | |
126 | "--referrers", | |
127 | default=0, | |
128 | type=int, | |
129 | help="Include referrers up to the given depth", | |
130 | ) | |
131 | parser.add_argument( | |
132 | "--referred", | |
133 | default=0, | |
134 | type=int, | |
135 | help="Include referred down to the given depth", | |
136 | ) | |
137 | parser.add_argument( | |
138 | "--format", default="yaml", help="Output format, one of 'yaml' or 'dot'" | |
125 | 139 | ) |
126 | 140 | args = parser.parse_args() |
127 | 141 | |
161 | 175 | for used_by in entry.get("used", ()): |
162 | 176 | referrers.add(used_by) |
163 | 177 | for name, definition in names.items(): |
164 | if not name in referrers: | |
178 | if name not in referrers: | |
165 | 179 | continue |
166 | 180 | if ignore and any(pattern.match(name) for pattern in ignore): |
167 | 181 | continue |
175 | 189 | for uses in entry.get("uses", ()): |
176 | 190 | referred.add(uses) |
177 | 191 | for name, definition in names.items(): |
178 | if not name in referred: | |
192 | if name not in referred: | |
179 | 193 | continue |
180 | 194 | if ignore and any(pattern.match(name) for pattern in ignore): |
181 | 195 | continue |
184 | 198 | if args.format == 'yaml': |
185 | 199 | yaml.dump(result, sys.stdout, default_flow_style=False) |
186 | 200 | elif args.format == 'dot': |
187 | print "digraph {" | |
201 | print("digraph {") | |
188 | 202 | for name, entry in result.items(): |
189 | print name | |
203 | print(name) | |
190 | 204 | for used_by in entry.get("used", ()): |
191 | 205 | if used_by in result: |
192 | print used_by, "->", name | |
193 | print "}" | |
206 | print(used_by, "->", name) | |
207 | print("}") | |
194 | 208 | else: |
195 | 209 | raise ValueError("Unknown format %r" % (args.format)) |
0 | 0 | #!/usr/bin/env python2 |
1 | 1 | |
2 | from __future__ import print_function | |
3 | ||
4 | import sys | |
5 | ||
2 | 6 | import pymacaroons |
3 | import sys | |
4 | 7 | |
5 | 8 | if len(sys.argv) == 1: |
6 | 9 | sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],)) |
10 | 13 | key = sys.argv[2] if len(sys.argv) > 2 else None |
11 | 14 | |
12 | 15 | macaroon = pymacaroons.Macaroon.deserialize(macaroon_string) |
13 | print macaroon.inspect() | |
16 | print(macaroon.inspect()) | |
14 | 17 | |
15 | print "" | |
18 | print("") | |
16 | 19 | |
17 | 20 | verifier = pymacaroons.Verifier() |
18 | 21 | verifier.satisfy_general(lambda c: True) |
19 | 22 | try: |
20 | 23 | verifier.verify(macaroon, key) |
21 | print "Signature is correct" | |
24 | print("Signature is correct") | |
22 | 25 | except Exception as e: |
23 | print str(e) | |
26 | print(str(e)) |
17 | 17 | from __future__ import print_function |
18 | 18 | |
19 | 19 | import argparse |
20 | import base64 | |
21 | import json | |
22 | import sys | |
20 | 23 | from urlparse import urlparse, urlunparse |
21 | 24 | |
22 | 25 | import nacl.signing |
23 | import json | |
24 | import base64 | |
25 | 26 | import requests |
26 | import sys | |
27 | ||
28 | from requests.adapters import HTTPAdapter | |
29 | 27 | import srvlookup |
30 | 28 | import yaml |
29 | from requests.adapters import HTTPAdapter | |
31 | 30 | |
32 | 31 | # uncomment the following to enable debug logging of http requests |
33 | #from httplib import HTTPConnection | |
34 | #HTTPConnection.debuglevel = 1 | |
32 | # from httplib import HTTPConnection | |
33 | # HTTPConnection.debuglevel = 1 | |
34 | ||
35 | 35 | |
36 | 36 | def encode_base64(input_bytes): |
37 | 37 | """Encode bytes as a base64 string without any padding.""" |
57 | 57 | |
58 | 58 | def encode_canonical_json(value): |
59 | 59 | return json.dumps( |
60 | value, | |
61 | # Encode code-points outside of ASCII as UTF-8 rather than \u escapes | |
62 | ensure_ascii=False, | |
63 | # Remove unecessary white space. | |
64 | separators=(',',':'), | |
65 | # Sort the keys of dictionaries. | |
66 | sort_keys=True, | |
67 | # Encode the resulting unicode as UTF-8 bytes. | |
68 | ).encode("UTF-8") | |
60 | value, | |
61 | # Encode code-points outside of ASCII as UTF-8 rather than \u escapes | |
62 | ensure_ascii=False, | |
63 | # Remove unecessary white space. | |
64 | separators=(',', ':'), | |
65 | # Sort the keys of dictionaries. | |
66 | sort_keys=True, | |
67 | # Encode the resulting unicode as UTF-8 bytes. | |
68 | ).encode("UTF-8") | |
69 | 69 | |
70 | 70 | |
71 | 71 | def sign_json(json_object, signing_key, signing_name): |
86 | 86 | |
87 | 87 | |
88 | 88 | NACL_ED25519 = "ed25519" |
89 | ||
89 | 90 | |
90 | 91 | def decode_signing_key_base64(algorithm, version, key_base64): |
91 | 92 | """Decode a base64 encoded signing key |
142 | 143 | authorization_headers = [] |
143 | 144 | |
144 | 145 | for key, sig in signed_json["signatures"][origin_name].items(): |
145 | header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( | |
146 | origin_name, key, sig, | |
147 | ) | |
146 | header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig) | |
148 | 147 | authorization_headers.append(bytes(header)) |
149 | print ("Authorization: %s" % header, file=sys.stderr) | |
148 | print("Authorization: %s" % header, file=sys.stderr) | |
150 | 149 | |
151 | 150 | dest = "matrix://%s%s" % (destination, path) |
152 | print ("Requesting %s" % dest, file=sys.stderr) | |
151 | print("Requesting %s" % dest, file=sys.stderr) | |
153 | 152 | |
154 | 153 | s = requests.Session() |
155 | 154 | s.mount("matrix://", MatrixConnectionAdapter()) |
157 | 156 | result = s.request( |
158 | 157 | method=method, |
159 | 158 | url=dest, |
160 | headers={ | |
161 | "Host": destination, | |
162 | "Authorization": authorization_headers[0] | |
163 | }, | |
159 | headers={"Host": destination, "Authorization": authorization_headers[0]}, | |
164 | 160 | verify=False, |
165 | 161 | data=content, |
166 | 162 | ) |
170 | 166 | |
171 | 167 | def main(): |
172 | 168 | parser = argparse.ArgumentParser( |
173 | description= | |
174 | "Signs and sends a federation request to a matrix homeserver", | |
175 | ) | |
176 | ||
177 | parser.add_argument( | |
178 | "-N", "--server-name", | |
169 | description="Signs and sends a federation request to a matrix homeserver" | |
170 | ) | |
171 | ||
172 | parser.add_argument( | |
173 | "-N", | |
174 | "--server-name", | |
179 | 175 | help="Name to give as the local homeserver. If unspecified, will be " |
180 | "read from the config file.", | |
181 | ) | |
182 | ||
183 | parser.add_argument( | |
184 | "-k", "--signing-key-path", | |
176 | "read from the config file.", | |
177 | ) | |
178 | ||
179 | parser.add_argument( | |
180 | "-k", | |
181 | "--signing-key-path", | |
185 | 182 | help="Path to the file containing the private ed25519 key to sign the " |
186 | "request with.", | |
187 | ) | |
188 | ||
189 | parser.add_argument( | |
190 | "-c", "--config", | |
183 | "request with.", | |
184 | ) | |
185 | ||
186 | parser.add_argument( | |
187 | "-c", | |
188 | "--config", | |
191 | 189 | default="homeserver.yaml", |
192 | 190 | help="Path to server config file. Ignored if --server-name and " |
193 | "--signing-key-path are both given.", | |
194 | ) | |
195 | ||
196 | parser.add_argument( | |
197 | "-d", "--destination", | |
191 | "--signing-key-path are both given.", | |
192 | ) | |
193 | ||
194 | parser.add_argument( | |
195 | "-d", | |
196 | "--destination", | |
198 | 197 | default="matrix.org", |
199 | 198 | help="name of the remote homeserver. We will do SRV lookups and " |
200 | "connect appropriately.", | |
201 | ) | |
202 | ||
203 | parser.add_argument( | |
204 | "-X", "--method", | |
199 | "connect appropriately.", | |
200 | ) | |
201 | ||
202 | parser.add_argument( | |
203 | "-X", | |
204 | "--method", | |
205 | 205 | help="HTTP method to use for the request. Defaults to GET if --data is" |
206 | "unspecified, POST if it is." | |
207 | ) | |
208 | ||
209 | parser.add_argument( | |
210 | "--body", | |
211 | help="Data to send as the body of the HTTP request" | |
212 | ) | |
213 | ||
214 | parser.add_argument( | |
215 | "path", | |
216 | help="request path. We will add '/_matrix/federation/v1/' to this." | |
206 | "unspecified, POST if it is.", | |
207 | ) | |
208 | ||
209 | parser.add_argument("--body", help="Data to send as the body of the HTTP request") | |
210 | ||
211 | parser.add_argument( | |
212 | "path", help="request path. We will add '/_matrix/federation/v1/' to this." | |
217 | 213 | ) |
218 | 214 | |
219 | 215 | args = parser.parse_args() |
226 | 222 | |
227 | 223 | result = request_json( |
228 | 224 | args.method, |
229 | args.server_name, key, args.destination, | |
225 | args.server_name, | |
226 | key, | |
227 | args.destination, | |
230 | 228 | "/_matrix/federation/v1/" + args.path, |
231 | 229 | content=args.body, |
232 | 230 | ) |
233 | 231 | |
234 | 232 | json.dump(result, sys.stdout) |
235 | print ("") | |
233 | print("") | |
236 | 234 | |
237 | 235 | |
238 | 236 | def read_args_from_config(args): |
252 | 250 | return s, 8448 |
253 | 251 | |
254 | 252 | if ":" in s: |
255 | out = s.rsplit(":",1) | |
253 | out = s.rsplit(":", 1) | |
256 | 254 | try: |
257 | 255 | port = int(out[1]) |
258 | 256 | except ValueError: |
262 | 260 | try: |
263 | 261 | srv = srvlookup.lookup("matrix", "tcp", s)[0] |
264 | 262 | return srv.host, srv.port |
265 | except: | |
263 | except Exception: | |
266 | 264 | return s, 8448 |
267 | 265 | |
268 | 266 | def get_connection(self, url, proxies=None): |
271 | 269 | (host, port) = self.lookup(parsed.netloc) |
272 | 270 | netloc = "%s:%d" % (host, port) |
273 | 271 | print("Connecting to %s" % (netloc,), file=sys.stderr) |
274 | url = urlunparse(( | |
275 | "https", netloc, parsed.path, parsed.params, parsed.query, | |
276 | parsed.fragment, | |
277 | )) | |
272 | url = urlunparse( | |
273 | ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) | |
274 | ) | |
278 | 275 | return super(MatrixConnectionAdapter, self).get_connection(url, proxies) |
279 | 276 | |
280 | 277 |
0 | from __future__ import print_function | |
1 | ||
2 | import sqlite3 | |
3 | import sys | |
4 | ||
5 | from unpaddedbase64 import decode_base64, encode_base64 | |
6 | ||
7 | from synapse.crypto.event_signing import ( | |
8 | add_event_pdu_content_hash, | |
9 | compute_pdu_event_reference_hash, | |
10 | ) | |
11 | from synapse.federation.units import Pdu | |
12 | from synapse.storage._base import SQLBaseStore | |
0 | 13 | from synapse.storage.pdu import PduStore |
1 | 14 | from synapse.storage.signatures import SignatureStore |
2 | from synapse.storage._base import SQLBaseStore | |
3 | from synapse.federation.units import Pdu | |
4 | from synapse.crypto.event_signing import ( | |
5 | add_event_pdu_content_hash, compute_pdu_event_reference_hash | |
6 | ) | |
7 | from synapse.api.events.utils import prune_pdu | |
8 | from unpaddedbase64 import encode_base64, decode_base64 | |
9 | from canonicaljson import encode_canonical_json | |
10 | import sqlite3 | |
11 | import sys | |
15 | ||
12 | 16 | |
13 | 17 | class Store(object): |
14 | 18 | _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] |
15 | 19 | _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] |
16 | 20 | _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] |
17 | _get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] | |
21 | _get_pdu_origin_signatures_txn = SignatureStore.__dict__[ | |
22 | "_get_pdu_origin_signatures_txn" | |
23 | ] | |
18 | 24 | _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] |
19 | _store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] | |
25 | _store_pdu_reference_hash_txn = SignatureStore.__dict__[ | |
26 | "_store_pdu_reference_hash_txn" | |
27 | ] | |
20 | 28 | _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] |
21 | 29 | _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] |
22 | 30 | |
25 | 33 | |
26 | 34 | |
27 | 35 | def select_pdus(cursor): |
28 | cursor.execute( | |
29 | "SELECT pdu_id, origin FROM pdus ORDER BY depth ASC" | |
30 | ) | |
36 | cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC") | |
31 | 37 | |
32 | 38 | ids = cursor.fetchall() |
33 | 39 | |
40 | 46 | for pdu in pdus: |
41 | 47 | try: |
42 | 48 | if pdu.prev_pdus: |
43 | print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus | |
49 | print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) | |
44 | 50 | for pdu_id, origin, hashes in pdu.prev_pdus: |
45 | 51 | ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] |
46 | 52 | hashes[ref_alg] = encode_base64(ref_hsh) |
47 | store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) | |
48 | print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus | |
53 | store._store_prev_pdu_hash_txn( | |
54 | cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh | |
55 | ) | |
56 | print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) | |
49 | 57 | pdu = add_event_pdu_content_hash(pdu) |
50 | 58 | ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) |
51 | 59 | reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) |
52 | store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) | |
60 | store._store_pdu_reference_hash_txn( | |
61 | cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh | |
62 | ) | |
53 | 63 | |
54 | 64 | for alg, hsh_base64 in pdu.hashes.items(): |
55 | print alg, hsh_base64 | |
56 | store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) | |
65 | print(alg, hsh_base64) | |
66 | store._store_pdu_content_hash_txn( | |
67 | cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64) | |
68 | ) | |
57 | 69 | |
58 | except: | |
59 | print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus | |
70 | except Exception: | |
71 | print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus) | |
72 | ||
60 | 73 | |
61 | 74 | def main(): |
62 | 75 | conn = sqlite3.connect(sys.argv[1]) |
64 | 77 | select_pdus(cursor) |
65 | 78 | conn.commit() |
66 | 79 | |
67 | if __name__=='__main__': | |
80 | ||
81 | if __name__ == '__main__': | |
68 | 82 | main() |
0 | 0 | #! /usr/bin/python |
1 | 1 | |
2 | import argparse | |
2 | 3 | import ast |
3 | import argparse | |
4 | 4 | import os |
5 | 5 | import sys |
6 | ||
6 | 7 | import yaml |
7 | 8 | |
8 | 9 | PATTERNS_V1 = [] |
9 | 10 | PATTERNS_V2 = [] |
10 | 11 | |
11 | RESULT = { | |
12 | "v1": PATTERNS_V1, | |
13 | "v2": PATTERNS_V2, | |
14 | } | |
12 | RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2} | |
13 | ||
15 | 14 | |
16 | 15 | class CallVisitor(ast.NodeVisitor): |
17 | 16 | def visit_Call(self, node): |
19 | 18 | name = node.func.id |
20 | 19 | else: |
21 | 20 | return |
22 | ||
23 | 21 | |
24 | 22 | if name == "client_path_patterns": |
25 | 23 | PATTERNS_V1.append(node.args[0].s) |
41 | 39 | parser = argparse.ArgumentParser(description='Find url patterns.') |
42 | 40 | |
43 | 41 | parser.add_argument( |
44 | "directories", nargs='+', metavar="DIR", | |
45 | help="Directories to search for definitions" | |
42 | "directories", | |
43 | nargs='+', | |
44 | metavar="DIR", | |
45 | help="Directories to search for definitions", | |
46 | 46 | ) |
47 | 47 | |
48 | 48 | args = parser.parse_args() |
0 | import requests | |
1 | 0 | import collections |
1 | import json | |
2 | 2 | import sys |
3 | 3 | import time |
4 | import json | |
4 | ||
5 | import requests | |
5 | 6 | |
6 | 7 | Entry = collections.namedtuple("Entry", "name position rows") |
7 | 8 | |
29 | 30 | |
30 | 31 | |
31 | 32 | def replicate(server, streams): |
32 | return parse_response(requests.get( | |
33 | server + "/_synapse/replication", | |
34 | verify=False, | |
35 | params=streams | |
36 | ).content) | |
33 | return parse_response( | |
34 | requests.get( | |
35 | server + "/_synapse/replication", verify=False, params=streams | |
36 | ).content | |
37 | ) | |
37 | 38 | |
38 | 39 | |
39 | 40 | def main(): |
44 | 45 | try: |
45 | 46 | streams = { |
46 | 47 | row.name: row.position |
47 | for row in replicate(server, {"streams":"-1"})["streams"].rows | |
48 | for row in replicate(server, {"streams": "-1"})["streams"].rows | |
48 | 49 | } |
49 | except requests.exceptions.ConnectionError as e: | |
50 | except requests.exceptions.ConnectionError: | |
50 | 51 | time.sleep(0.1) |
51 | 52 | |
52 | 53 | while True: |
53 | 54 | try: |
54 | 55 | results = replicate(server, streams) |
55 | except: | |
56 | sys.stdout.write("connection_lost("+ repr(streams) + ")\n") | |
56 | except Exception: | |
57 | sys.stdout.write("connection_lost(" + repr(streams) + ")\n") | |
57 | 58 | break |
58 | 59 | for update in results.values(): |
59 | 60 | for row in update.rows: |
61 | 62 | streams[update.name] = update.position |
62 | 63 | |
63 | 64 | |
64 | ||
65 | if __name__=='__main__': | |
65 | if __name__ == '__main__': | |
66 | 66 | main() |
13 | 13 | pylint.cfg |
14 | 14 | tox.ini |
15 | 15 | |
16 | [pep8] | |
16 | [flake8] | |
17 | 17 | max-line-length = 90 |
18 | # W503 requires that binary operators be at the end, not start, of lines. Erik | |
19 | # doesn't like it. E203 is contrary to PEP8. E731 is silly. | |
20 | ignore = W503,E203,E731 | |
21 | 18 | |
22 | [flake8] | |
23 | # note that flake8 inherits the "ignore" settings from "pep8" (because it uses | |
24 | # pep8 to do those checks), but not the "max-line-length" setting | |
25 | max-line-length = 90 | |
26 | ignore=W503,E203,E731 | |
19 | # see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes | |
20 | # for error codes. The ones we ignore are: | |
21 | # W503: line break before binary operator | |
22 | # W504: line break after binary operator | |
23 | # E203: whitespace before ':' (which is contrary to pep8?) | |
24 | # E731: do not assign a lambda expression, use a def | |
25 | ignore=W503,W504,E203,E731 | |
27 | 26 | |
28 | 27 | [isort] |
29 | 28 | line_length = 89 |
0 | 0 | #!/usr/bin/env python |
1 | 1 | |
2 | # Copyright 2014-2016 OpenMarket Ltd | |
2 | # Copyright 2014-2017 OpenMarket Ltd | |
3 | # Copyright 2017 Vector Creations Ltd | |
4 | # Copyright 2017-2018 New Vector Ltd | |
3 | 5 | # |
4 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 7 | # you may not use this file except in compliance with the License. |
85 | 87 | name="matrix-synapse", |
86 | 88 | version=version, |
87 | 89 | packages=find_packages(exclude=["tests", "tests.*"]), |
88 | description="Reference Synapse Home Server", | |
90 | description="Reference homeserver for the Matrix decentralised comms protocol", | |
89 | 91 | install_requires=dependencies['requirements'](include_conditional=True).keys(), |
90 | 92 | dependency_links=dependencies["DEPENDENCY_LINKS"].values(), |
91 | 93 | include_package_data=True, |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # Copyright 2018 New Vector | |
3 | # | |
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
5 | # you may not use this file except in compliance with the License. | |
6 | # You may obtain a copy of the License at | |
7 | # | |
8 | # http://www.apache.org/licenses/LICENSE-2.0 | |
9 | # | |
10 | # Unless required by applicable law or agreed to in writing, software | |
11 | # distributed under the License is distributed on an "AS IS" BASIS, | |
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 | # See the License for the specific language governing permissions and | |
14 | # limitations under the License. | |
15 | ||
16 | from __future__ import print_function | |
17 | ||
18 | import argparse | |
19 | import getpass | |
20 | import hashlib | |
21 | import hmac | |
22 | import logging | |
23 | import sys | |
24 | ||
25 | from six.moves import input | |
26 | ||
27 | import requests as _requests | |
28 | import yaml | |
29 | ||
30 | ||
31 | def request_registration( | |
32 | user, | |
33 | password, | |
34 | server_location, | |
35 | shared_secret, | |
36 | admin=False, | |
37 | requests=_requests, | |
38 | _print=print, | |
39 | exit=sys.exit, | |
40 | ): | |
41 | ||
42 | url = "%s/_matrix/client/r0/admin/register" % (server_location,) | |
43 | ||
44 | # Get the nonce | |
45 | r = requests.get(url, verify=False) | |
46 | ||
47 | if r.status_code is not 200: | |
48 | _print("ERROR! Received %d %s" % (r.status_code, r.reason)) | |
49 | if 400 <= r.status_code < 500: | |
50 | try: | |
51 | _print(r.json()["error"]) | |
52 | except Exception: | |
53 | pass | |
54 | return exit(1) | |
55 | ||
56 | nonce = r.json()["nonce"] | |
57 | ||
58 | mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1) | |
59 | ||
60 | mac.update(nonce.encode('utf8')) | |
61 | mac.update(b"\x00") | |
62 | mac.update(user.encode('utf8')) | |
63 | mac.update(b"\x00") | |
64 | mac.update(password.encode('utf8')) | |
65 | mac.update(b"\x00") | |
66 | mac.update(b"admin" if admin else b"notadmin") | |
67 | ||
68 | mac = mac.hexdigest() | |
69 | ||
70 | data = { | |
71 | "nonce": nonce, | |
72 | "username": user, | |
73 | "password": password, | |
74 | "mac": mac, | |
75 | "admin": admin, | |
76 | } | |
77 | ||
78 | _print("Sending registration request...") | |
79 | r = requests.post(url, json=data, verify=False) | |
80 | ||
81 | if r.status_code is not 200: | |
82 | _print("ERROR! Received %d %s" % (r.status_code, r.reason)) | |
83 | if 400 <= r.status_code < 500: | |
84 | try: | |
85 | _print(r.json()["error"]) | |
86 | except Exception: | |
87 | pass | |
88 | return exit(1) | |
89 | ||
90 | _print("Success!") | |
91 | ||
92 | ||
93 | def register_new_user(user, password, server_location, shared_secret, admin): | |
94 | if not user: | |
95 | try: | |
96 | default_user = getpass.getuser() | |
97 | except Exception: | |
98 | default_user = None | |
99 | ||
100 | if default_user: | |
101 | user = input("New user localpart [%s]: " % (default_user,)) | |
102 | if not user: | |
103 | user = default_user | |
104 | else: | |
105 | user = input("New user localpart: ") | |
106 | ||
107 | if not user: | |
108 | print("Invalid user name") | |
109 | sys.exit(1) | |
110 | ||
111 | if not password: | |
112 | password = getpass.getpass("Password: ") | |
113 | ||
114 | if not password: | |
115 | print("Password cannot be blank.") | |
116 | sys.exit(1) | |
117 | ||
118 | confirm_password = getpass.getpass("Confirm password: ") | |
119 | ||
120 | if password != confirm_password: | |
121 | print("Passwords do not match") | |
122 | sys.exit(1) | |
123 | ||
124 | if admin is None: | |
125 | admin = input("Make admin [no]: ") | |
126 | if admin in ("y", "yes", "true"): | |
127 | admin = True | |
128 | else: | |
129 | admin = False | |
130 | ||
131 | request_registration(user, password, server_location, shared_secret, bool(admin)) | |
132 | ||
133 | ||
134 | def main(): | |
135 | ||
136 | logging.captureWarnings(True) | |
137 | ||
138 | parser = argparse.ArgumentParser( | |
139 | description="Used to register new users with a given home server when" | |
140 | " registration has been disabled. The home server must be" | |
141 | " configured with the 'registration_shared_secret' option" | |
142 | " set." | |
143 | ) | |
144 | parser.add_argument( | |
145 | "-u", | |
146 | "--user", | |
147 | default=None, | |
148 | help="Local part of the new user. Will prompt if omitted.", | |
149 | ) | |
150 | parser.add_argument( | |
151 | "-p", | |
152 | "--password", | |
153 | default=None, | |
154 | help="New password for user. Will prompt if omitted.", | |
155 | ) | |
156 | admin_group = parser.add_mutually_exclusive_group() | |
157 | admin_group.add_argument( | |
158 | "-a", | |
159 | "--admin", | |
160 | action="store_true", | |
161 | help=( | |
162 | "Register new user as an admin. " | |
163 | "Will prompt if --no-admin is not set either." | |
164 | ), | |
165 | ) | |
166 | admin_group.add_argument( | |
167 | "--no-admin", | |
168 | action="store_true", | |
169 | help=( | |
170 | "Register new user as a regular user. " | |
171 | "Will prompt if --admin is not set either." | |
172 | ), | |
173 | ) | |
174 | ||
175 | group = parser.add_mutually_exclusive_group(required=True) | |
176 | group.add_argument( | |
177 | "-c", | |
178 | "--config", | |
179 | type=argparse.FileType('r'), | |
180 | help="Path to server config file. Used to read in shared secret.", | |
181 | ) | |
182 | ||
183 | group.add_argument( | |
184 | "-k", "--shared-secret", help="Shared secret as defined in server config file." | |
185 | ) | |
186 | ||
187 | parser.add_argument( | |
188 | "server_url", | |
189 | default="https://localhost:8448", | |
190 | nargs='?', | |
191 | help="URL to use to talk to the home server. Defaults to " | |
192 | " 'https://localhost:8448'.", | |
193 | ) | |
194 | ||
195 | args = parser.parse_args() | |
196 | ||
197 | if "config" in args and args.config: | |
198 | config = yaml.safe_load(args.config) | |
199 | secret = config.get("registration_shared_secret", None) | |
200 | if not secret: | |
201 | print("No 'registration_shared_secret' defined in config.") | |
202 | sys.exit(1) | |
203 | else: | |
204 | secret = args.shared_secret | |
205 | ||
206 | admin = None | |
207 | if args.admin or args.no_admin: | |
208 | admin = args.admin | |
209 | ||
210 | register_new_user(args.user, args.password, args.server_url, secret, admin) | |
211 | ||
212 | ||
213 | if __name__ == "__main__": | |
214 | main() |
171 | 171 | # events a lot easier as we can then use a negative lookbehind |
172 | 172 | # assertion to split '\.' If we allowed \\ then it would |
173 | 173 | # incorrectly split '\\.' See synapse.events.utils.serialize_event |
174 | "pattern": "^((?!\\\).)*$" | |
174 | # | |
175 | # Note that because this is a regular expression, we have to escape | |
176 | # each backslash in the pattern. | |
177 | "pattern": r"^((?!\\\\).)*$" | |
175 | 178 | } |
176 | 179 | } |
177 | 180 | }, |
19 | 19 | |
20 | 20 | from six import iteritems |
21 | 21 | |
22 | import psutil | |
22 | 23 | from prometheus_client import Gauge |
23 | 24 | |
24 | 25 | from twisted.application import service |
501 | 502 | |
502 | 503 | def performance_stats_init(): |
503 | 504 | try: |
504 | import psutil | |
505 | 505 | process = psutil.Process() |
506 | 506 | # Ensure we can fetch both, and make the initial request for cpu_percent |
507 | 507 | # so the next request will use this as the initial point. |
509 | 509 | process.cpu_percent(interval=None) |
510 | 510 | logger.info("report_stats can use psutil") |
511 | 511 | stats_process.append(process) |
512 | except (ImportError, AttributeError): | |
513 | logger.warn( | |
514 | "report_stats enabled but psutil is not installed or incorrect version." | |
515 | " Disabling reporting of memory/cpu stats." | |
516 | " Ensuring psutil is available will help matrix.org track performance" | |
517 | " changes across releases." | |
512 | except (AttributeError): | |
513 | logger.warning( | |
514 | "Unable to read memory/cpu stats. Disabling reporting." | |
518 | 515 | ) |
519 | 516 | |
520 | 517 | def generate_user_daily_visit_stats(): |
529 | 526 | clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) |
530 | 527 | |
531 | 528 | # monthly active user limiting functionality |
532 | clock.looping_call( | |
533 | hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 | |
534 | ) | |
535 | hs.get_datastore().reap_monthly_active_users() | |
529 | def reap_monthly_active_users(): | |
530 | return run_as_background_process( | |
531 | "reap_monthly_active_users", | |
532 | hs.get_datastore().reap_monthly_active_users, | |
533 | ) | |
534 | clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) | |
535 | reap_monthly_active_users() | |
536 | 536 | |
537 | 537 | @defer.inlineCallbacks |
538 | 538 | def generate_monthly_active_users(): |
546 | 546 | registered_reserved_users_mau_gauge.set(float(reserved_count)) |
547 | 547 | max_mau_gauge.set(float(hs.config.max_mau_value)) |
548 | 548 | |
549 | hs.get_datastore().initialise_reserved_users( | |
550 | hs.config.mau_limits_reserved_threepids | |
551 | ) | |
552 | generate_monthly_active_users() | |
549 | def start_generate_monthly_active_users(): | |
550 | return run_as_background_process( | |
551 | "generate_monthly_active_users", | |
552 | generate_monthly_active_users, | |
553 | ) | |
554 | ||
555 | start_generate_monthly_active_users() | |
553 | 556 | if hs.config.limit_usage_by_mau: |
554 | clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) | |
557 | clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) | |
555 | 558 | # End of monthly active user settings |
556 | 559 | |
557 | 560 | if hs.config.report_stats: |
567 | 570 | clock.call_later(5 * 60, start_phone_stats_home) |
568 | 571 | |
569 | 572 | if hs.config.daemonize and hs.config.print_pidfile: |
570 | print (hs.config.pid_file) | |
573 | print(hs.config.pid_file) | |
571 | 574 | |
572 | 575 | _base.start_reactor( |
573 | 576 | "synapse-homeserver", |
160 | 160 | else: |
161 | 161 | yield self.start_pusher(row.user_id, row.app_id, row.pushkey) |
162 | 162 | elif stream_name == "events": |
163 | self.pusher_pool.on_new_notifications( | |
163 | yield self.pusher_pool.on_new_notifications( | |
164 | 164 | token, token, |
165 | 165 | ) |
166 | 166 | elif stream_name == "receipts": |
167 | self.pusher_pool.on_new_receipts( | |
167 | yield self.pusher_pool.on_new_receipts( | |
168 | 168 | token, token, set(row.room_id for row in rows) |
169 | 169 | ) |
170 | 170 | except Exception: |
182 | 182 | def start_pusher(self, user_id, app_id, pushkey): |
183 | 183 | key = "%s:%s" % (app_id, pushkey) |
184 | 184 | logger.info("Starting pusher %r / %r", user_id, key) |
185 | return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) | |
185 | return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) | |
186 | 186 | |
187 | 187 | |
188 | 188 | def start(config_options): |
27 | 27 | sys.stderr.write("\n" + str(e) + "\n") |
28 | 28 | sys.exit(1) |
29 | 29 | |
30 | print (getattr(config, key)) | |
30 | print(getattr(config, key)) | |
31 | 31 | sys.exit(0) |
32 | 32 | else: |
33 | 33 | sys.stderr.write("Unknown command %r\n" % (action,)) |
105 | 105 | @classmethod |
106 | 106 | def check_file(cls, file_path, config_name): |
107 | 107 | if file_path is None: |
108 | raise ConfigError( | |
109 | "Missing config for %s." | |
110 | % (config_name,) | |
111 | ) | |
108 | raise ConfigError("Missing config for %s." % (config_name,)) | |
112 | 109 | try: |
113 | 110 | os.stat(file_path) |
114 | 111 | except OSError as e: |
127 | 124 | if e.errno != errno.EEXIST: |
128 | 125 | raise |
129 | 126 | if not os.path.isdir(dir_path): |
130 | raise ConfigError( | |
131 | "%s is not a directory" % (dir_path,) | |
132 | ) | |
127 | raise ConfigError("%s is not a directory" % (dir_path,)) | |
133 | 128 | return dir_path |
134 | 129 | |
135 | 130 | @classmethod |
155 | 150 | return results |
156 | 151 | |
157 | 152 | def generate_config( |
158 | self, | |
159 | config_dir_path, | |
160 | server_name, | |
161 | is_generating_file, | |
162 | report_stats=None, | |
153 | self, config_dir_path, server_name, is_generating_file, report_stats=None | |
163 | 154 | ): |
164 | 155 | default_config = "# vim:ft=yaml\n" |
165 | 156 | |
166 | default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( | |
167 | "default_config", | |
168 | config_dir_path=config_dir_path, | |
169 | server_name=server_name, | |
170 | is_generating_file=is_generating_file, | |
171 | report_stats=report_stats, | |
172 | )) | |
157 | default_config += "\n\n".join( | |
158 | dedent(conf) | |
159 | for conf in self.invoke_all( | |
160 | "default_config", | |
161 | config_dir_path=config_dir_path, | |
162 | server_name=server_name, | |
163 | is_generating_file=is_generating_file, | |
164 | report_stats=report_stats, | |
165 | ) | |
166 | ) | |
173 | 167 | |
174 | 168 | config = yaml.load(default_config) |
175 | 169 | |
177 | 171 | |
178 | 172 | @classmethod |
179 | 173 | def load_config(cls, description, argv): |
180 | config_parser = argparse.ArgumentParser( | |
181 | description=description, | |
182 | ) | |
183 | config_parser.add_argument( | |
184 | "-c", "--config-path", | |
174 | config_parser = argparse.ArgumentParser(description=description) | |
175 | config_parser.add_argument( | |
176 | "-c", | |
177 | "--config-path", | |
185 | 178 | action="append", |
186 | 179 | metavar="CONFIG_FILE", |
187 | 180 | help="Specify config file. Can be given multiple times and" |
188 | " may specify directories containing *.yaml files." | |
181 | " may specify directories containing *.yaml files.", | |
189 | 182 | ) |
190 | 183 | |
191 | 184 | config_parser.add_argument( |
192 | 185 | "--keys-directory", |
193 | 186 | metavar="DIRECTORY", |
194 | 187 | help="Where files such as certs and signing keys are stored when" |
195 | " their location is given explicitly in the config." | |
196 | " Defaults to the directory containing the last config file", | |
188 | " their location is given explicitly in the config." | |
189 | " Defaults to the directory containing the last config file", | |
197 | 190 | ) |
198 | 191 | |
199 | 192 | config_args = config_parser.parse_args(argv) |
202 | 195 | |
203 | 196 | obj = cls() |
204 | 197 | obj.read_config_files( |
205 | config_files, | |
206 | keys_directory=config_args.keys_directory, | |
207 | generate_keys=False, | |
198 | config_files, keys_directory=config_args.keys_directory, generate_keys=False | |
208 | 199 | ) |
209 | 200 | return obj |
210 | 201 | |
212 | 203 | def load_or_generate_config(cls, description, argv): |
213 | 204 | config_parser = argparse.ArgumentParser(add_help=False) |
214 | 205 | config_parser.add_argument( |
215 | "-c", "--config-path", | |
206 | "-c", | |
207 | "--config-path", | |
216 | 208 | action="append", |
217 | 209 | metavar="CONFIG_FILE", |
218 | 210 | help="Specify config file. Can be given multiple times and" |
219 | " may specify directories containing *.yaml files." | |
211 | " may specify directories containing *.yaml files.", | |
220 | 212 | ) |
221 | 213 | config_parser.add_argument( |
222 | 214 | "--generate-config", |
223 | 215 | action="store_true", |
224 | help="Generate a config file for the server name" | |
216 | help="Generate a config file for the server name", | |
225 | 217 | ) |
226 | 218 | config_parser.add_argument( |
227 | 219 | "--report-stats", |
228 | 220 | action="store", |
229 | 221 | help="Whether the generated config reports anonymized usage statistics", |
230 | choices=["yes", "no"] | |
222 | choices=["yes", "no"], | |
231 | 223 | ) |
232 | 224 | config_parser.add_argument( |
233 | 225 | "--generate-keys", |
234 | 226 | action="store_true", |
235 | help="Generate any missing key files then exit" | |
227 | help="Generate any missing key files then exit", | |
236 | 228 | ) |
237 | 229 | config_parser.add_argument( |
238 | 230 | "--keys-directory", |
239 | 231 | metavar="DIRECTORY", |
240 | 232 | help="Used with 'generate-*' options to specify where files such as" |
241 | " certs and signing keys should be stored in, unless explicitly" | |
242 | " specified in the config." | |
243 | ) | |
244 | config_parser.add_argument( | |
245 | "-H", "--server-name", | |
246 | help="The server name to generate a config file for" | |
233 | " certs and signing keys should be stored in, unless explicitly" | |
234 | " specified in the config.", | |
235 | ) | |
236 | config_parser.add_argument( | |
237 | "-H", "--server-name", help="The server name to generate a config file for" | |
247 | 238 | ) |
248 | 239 | config_args, remaining_args = config_parser.parse_known_args(argv) |
249 | 240 | |
256 | 247 | if config_args.generate_config: |
257 | 248 | if config_args.report_stats is None: |
258 | 249 | config_parser.error( |
259 | "Please specify either --report-stats=yes or --report-stats=no\n\n" + | |
260 | MISSING_REPORT_STATS_SPIEL | |
250 | "Please specify either --report-stats=yes or --report-stats=no\n\n" | |
251 | + MISSING_REPORT_STATS_SPIEL | |
261 | 252 | ) |
262 | 253 | if not config_files: |
263 | 254 | config_parser.error( |
286 | 277 | config_dir_path=config_dir_path, |
287 | 278 | server_name=server_name, |
288 | 279 | report_stats=(config_args.report_stats == "yes"), |
289 | is_generating_file=True | |
280 | is_generating_file=True, | |
290 | 281 | ) |
291 | 282 | obj.invoke_all("generate_files", config) |
292 | 283 | config_file.write(config_str) |
293 | print(( | |
294 | "A config file has been generated in %r for server name" | |
295 | " %r with corresponding SSL keys and self-signed" | |
296 | " certificates. Please review this file and customise it" | |
297 | " to your needs." | |
298 | ) % (config_path, server_name)) | |
284 | print( | |
285 | ( | |
286 | "A config file has been generated in %r for server name" | |
287 | " %r with corresponding SSL keys and self-signed" | |
288 | " certificates. Please review this file and customise it" | |
289 | " to your needs." | |
290 | ) | |
291 | % (config_path, server_name) | |
292 | ) | |
299 | 293 | print( |
300 | 294 | "If this server name is incorrect, you will need to" |
301 | 295 | " regenerate the SSL certificates" |
302 | 296 | ) |
303 | 297 | return |
304 | 298 | else: |
305 | print(( | |
306 | "Config file %r already exists. Generating any missing key" | |
307 | " files." | |
308 | ) % (config_path,)) | |
299 | print( | |
300 | ( | |
301 | "Config file %r already exists. Generating any missing key" | |
302 | " files." | |
303 | ) | |
304 | % (config_path,) | |
305 | ) | |
309 | 306 | generate_keys = True |
310 | 307 | |
311 | 308 | parser = argparse.ArgumentParser( |
337 | 334 | |
338 | 335 | return obj |
339 | 336 | |
340 | def read_config_files(self, config_files, keys_directory=None, | |
341 | generate_keys=False): | |
337 | def read_config_files(self, config_files, keys_directory=None, generate_keys=False): | |
342 | 338 | if not keys_directory: |
343 | 339 | keys_directory = os.path.dirname(config_files[-1]) |
344 | 340 | |
363 | 359 | |
364 | 360 | if "report_stats" not in config: |
365 | 361 | raise ConfigError( |
366 | MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + | |
367 | MISSING_REPORT_STATS_SPIEL | |
362 | MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS | |
363 | + "\n" | |
364 | + MISSING_REPORT_STATS_SPIEL | |
368 | 365 | ) |
369 | 366 | |
370 | 367 | if generate_keys: |
398 | 395 | for entry in os.listdir(config_path): |
399 | 396 | entry_path = os.path.join(config_path, entry) |
400 | 397 | if not os.path.isfile(entry_path): |
401 | print ( | |
402 | "Found subdirectory in config directory: %r. IGNORING." | |
403 | ) % (entry_path, ) | |
398 | err = "Found subdirectory in config directory: %r. IGNORING." | |
399 | print(err % (entry_path,)) | |
404 | 400 | continue |
405 | 401 | |
406 | 402 | if not entry.endswith(".yaml"): |
407 | print ( | |
408 | "Found file in config directory that does not" | |
409 | " end in '.yaml': %r. IGNORING." | |
410 | ) % (entry_path, ) | |
403 | err = ( | |
404 | "Found file in config directory that does not end in " | |
405 | "'.yaml': %r. IGNORING." | |
406 | ) | |
407 | print(err % (entry_path,)) | |
411 | 408 | continue |
412 | 409 | |
413 | 410 | files.append(entry_path) |
18 | 18 | import email.utils |
19 | 19 | import logging |
20 | 20 | import os |
21 | import sys | |
22 | import textwrap | |
23 | 21 | |
24 | from ._base import Config | |
22 | import pkg_resources | |
23 | ||
24 | from ._base import Config, ConfigError | |
25 | 25 | |
26 | 26 | logger = logging.getLogger(__name__) |
27 | ||
28 | TEMPLATE_DIR_WARNING = """\ | |
29 | WARNING: The email notifier is configured to look for templates in '%(template_dir)s', | |
30 | but no templates could be found there. We will fall back to using the example templates; | |
31 | to get rid of this warning, leave 'email.template_dir' unset. | |
32 | """ | |
33 | 27 | |
34 | 28 | |
35 | 29 | class EmailConfig(Config): |
77 | 71 | self.email_notif_template_html = email_config["notif_template_html"] |
78 | 72 | self.email_notif_template_text = email_config["notif_template_text"] |
79 | 73 | |
80 | self.email_template_dir = email_config.get("template_dir") | |
74 | template_dir = email_config.get("template_dir") | |
75 | # we need an absolute path, because we change directory after starting (and | |
76 | # we don't yet know what auxilliary templates like mail.css we will need). | |
77 | # (Note that loading as package_resources with jinja.PackageLoader doesn't | |
78 | # work for the same reason.) | |
79 | if not template_dir: | |
80 | template_dir = pkg_resources.resource_filename( | |
81 | 'synapse', 'res/templates' | |
82 | ) | |
83 | template_dir = os.path.abspath(template_dir) | |
81 | 84 | |
82 | # backwards-compatibility hack | |
83 | if ( | |
84 | self.email_template_dir == "res/templates" | |
85 | and not os.path.isfile( | |
86 | os.path.join(self.email_template_dir, self.email_notif_template_text) | |
87 | ) | |
88 | ): | |
89 | t = TEMPLATE_DIR_WARNING % { | |
90 | "template_dir": self.email_template_dir, | |
91 | } | |
92 | print(textwrap.fill(t, width=80) + "\n", file=sys.stderr) | |
93 | self.email_template_dir = None | |
85 | for f in self.email_notif_template_text, self.email_notif_template_html: | |
86 | p = os.path.join(template_dir, f) | |
87 | if not os.path.isfile(p): | |
88 | raise ConfigError("Unable to find email template file %s" % (p, )) | |
89 | self.email_template_dir = template_dir | |
94 | 90 | |
95 | 91 | self.email_notif_for_new_users = email_config.get( |
96 | 92 | "notif_for_new_users", True |
30 | 30 | from .ratelimiting import RatelimitConfig |
31 | 31 | from .registration import RegistrationConfig |
32 | 32 | from .repository import ContentRepositoryConfig |
33 | from .room_directory import RoomDirectoryConfig | |
33 | 34 | from .saml2 import SAML2Config |
34 | 35 | from .server import ServerConfig |
35 | 36 | from .server_notices_config import ServerNoticesConfig |
48 | 49 | WorkerConfig, PasswordAuthProviderConfig, PushConfig, |
49 | 50 | SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, |
50 | 51 | ConsentConfig, |
51 | ServerNoticesConfig, | |
52 | ServerNoticesConfig, RoomDirectoryConfig, | |
52 | 53 | ): |
53 | 54 | pass |
54 | 55 |
14 | 14 | |
15 | 15 | from distutils.util import strtobool |
16 | 16 | |
17 | from synapse.config._base import Config, ConfigError | |
18 | from synapse.types import RoomAlias | |
17 | 19 | from synapse.util.stringutils import random_string_with_symbols |
18 | ||
19 | from ._base import Config | |
20 | 20 | |
21 | 21 | |
22 | 22 | class RegistrationConfig(Config): |
43 | 43 | ) |
44 | 44 | |
45 | 45 | self.auto_join_rooms = config.get("auto_join_rooms", []) |
46 | for room_alias in self.auto_join_rooms: | |
47 | if not RoomAlias.is_valid(room_alias): | |
48 | raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) | |
49 | self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) | |
46 | 50 | |
47 | 51 | def default_config(self, **kwargs): |
48 | 52 | registration_shared_secret = random_string_with_symbols(50) |
97 | 101 | # to these rooms |
98 | 102 | #auto_join_rooms: |
99 | 103 | # - "#example:example.com" |
104 | ||
105 | # Where auto_join_rooms are specified, setting this flag ensures that the | |
106 | # the rooms exist by creating them when the first user on the | |
107 | # homeserver registers. | |
108 | # Setting to false means that if the rooms are not manually created, | |
109 | # users cannot be auto-joined since they do not exist. | |
110 | autocreate_auto_join_rooms: true | |
100 | 111 | """ % locals() |
101 | 112 | |
102 | 113 | def add_arguments(self, parser): |
177 | 177 | def default_config(self, **kwargs): |
178 | 178 | media_store = self.default_path("media_store") |
179 | 179 | uploads_path = self.default_path("uploads") |
180 | return """ | |
180 | return r""" | |
181 | 181 | # Directory where uploaded images and attachments are stored. |
182 | 182 | media_store_path: "%(media_store)s" |
183 | 183 |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2018 New Vector Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from synapse.util import glob_to_regex | |
16 | ||
17 | from ._base import Config, ConfigError | |
18 | ||
19 | ||
20 | class RoomDirectoryConfig(Config): | |
21 | def read_config(self, config): | |
22 | alias_creation_rules = config["alias_creation_rules"] | |
23 | ||
24 | self._alias_creation_rules = [ | |
25 | _AliasRule(rule) | |
26 | for rule in alias_creation_rules | |
27 | ] | |
28 | ||
29 | def default_config(self, config_dir_path, server_name, **kwargs): | |
30 | return """ | |
31 | # The `alias_creation` option controls who's allowed to create aliases | |
32 | # on this server. | |
33 | # | |
34 | # The format of this option is a list of rules that contain globs that | |
35 | # match against user_id and the new alias (fully qualified with server | |
36 | # name). The action in the first rule that matches is taken, which can | |
37 | # currently either be "allow" or "deny". | |
38 | # | |
39 | # If no rules match the request is denied. | |
40 | alias_creation_rules: | |
41 | - user_id: "*" | |
42 | alias: "*" | |
43 | action: allow | |
44 | """ | |
45 | ||
46 | def is_alias_creation_allowed(self, user_id, alias): | |
47 | """Checks if the given user is allowed to create the given alias | |
48 | ||
49 | Args: | |
50 | user_id (str) | |
51 | alias (str) | |
52 | ||
53 | Returns: | |
54 | boolean: True if user is allowed to crate the alias | |
55 | """ | |
56 | for rule in self._alias_creation_rules: | |
57 | if rule.matches(user_id, alias): | |
58 | return rule.action == "allow" | |
59 | ||
60 | return False | |
61 | ||
62 | ||
63 | class _AliasRule(object): | |
64 | def __init__(self, rule): | |
65 | action = rule["action"] | |
66 | user_id = rule["user_id"] | |
67 | alias = rule["alias"] | |
68 | ||
69 | if action in ("allow", "deny"): | |
70 | self.action = action | |
71 | else: | |
72 | raise ConfigError( | |
73 | "alias_creation_rules rules can only have action of 'allow'" | |
74 | " or 'deny'" | |
75 | ) | |
76 | ||
77 | try: | |
78 | self._user_id_regex = glob_to_regex(user_id) | |
79 | self._alias_regex = glob_to_regex(alias) | |
80 | except Exception as e: | |
81 | raise ConfigError("Failed to parse glob into regex: %s", e) | |
82 | ||
83 | def matches(self, user_id, alias): | |
84 | """Tests if this rule matches the given user_id and alias. | |
85 | ||
86 | Args: | |
87 | user_id (str) | |
88 | alias (str) | |
89 | ||
90 | Returns: | |
91 | boolean | |
92 | """ | |
93 | ||
94 | # Note: The regexes are anchored at both ends | |
95 | if not self._user_id_regex.match(user_id): | |
96 | return False | |
97 | ||
98 | if not self._alias_regex.match(alias): | |
99 | return False | |
100 | ||
101 | return True |
54 | 54 | raise IOError("Cannot get key for %r" % server_name) |
55 | 55 | except (ConnectError, DomainError) as e: |
56 | 56 | logger.warn("Error getting key for %r: %s", server_name, e) |
57 | except Exception as e: | |
57 | except Exception: | |
58 | 58 | logger.exception("Error getting key for %r", server_name) |
59 | 59 | raise IOError("Cannot get key for %r" % server_name) |
60 | 60 |
689 | 689 | auth_types = [] |
690 | 690 | |
691 | 691 | auth_types.append((EventTypes.PowerLevels, "", )) |
692 | auth_types.append((EventTypes.Member, event.user_id, )) | |
692 | auth_types.append((EventTypes.Member, event.sender, )) | |
693 | 693 | auth_types.append((EventTypes.Create, "", )) |
694 | 694 | |
695 | 695 | if event.type == EventTypes.Member: |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | import logging |
16 | import re | |
17 | 16 | |
18 | 17 | import six |
19 | 18 | from six import iteritems |
43 | 42 | ReplicationGetQueryRestServlet, |
44 | 43 | ) |
45 | 44 | from synapse.types import get_domain_from_id |
45 | from synapse.util import glob_to_regex | |
46 | 46 | from synapse.util.async_helpers import Linearizer, concurrently_execute |
47 | 47 | from synapse.util.caches.response_cache import ResponseCache |
48 | 48 | from synapse.util.logcontext import nested_logging_context |
728 | 728 | if not isinstance(acl_entry, six.string_types): |
729 | 729 | logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) |
730 | 730 | return False |
731 | regex = _glob_to_regex(acl_entry) | |
731 | regex = glob_to_regex(acl_entry) | |
732 | 732 | return regex.match(server_name) |
733 | ||
734 | ||
735 | def _glob_to_regex(glob): | |
736 | res = '' | |
737 | for c in glob: | |
738 | if c == '*': | |
739 | res = res + '.*' | |
740 | elif c == '?': | |
741 | res = res + '.' | |
742 | else: | |
743 | res = res + re.escape(c) | |
744 | return re.compile(res + "\\Z", re.IGNORECASE) | |
745 | 733 | |
746 | 734 | |
747 | 735 | class FederationHandlerRegistry(object): |
799 | 787 | yield handler(origin, content) |
800 | 788 | except SynapseError as e: |
801 | 789 | logger.info("Failed to handle edu %r: %r", edu_type, e) |
802 | except Exception as e: | |
790 | except Exception: | |
803 | 791 | logger.exception("Failed to handle edu %r", edu_type) |
804 | 792 | |
805 | 793 | def on_query(self, query_type, args): |
632 | 632 | transaction, json_data_cb |
633 | 633 | ) |
634 | 634 | code = 200 |
635 | ||
636 | if response: | |
637 | for e_id, r in response.get("pdus", {}).items(): | |
638 | if "error" in r: | |
639 | logger.warn( | |
640 | "Transaction returned error for %s: %s", | |
641 | e_id, r, | |
642 | ) | |
643 | 635 | except HttpResponseException as e: |
644 | 636 | code = e.code |
645 | 637 | response = e.response |
656 | 648 | destination, txn_id, code |
657 | 649 | ) |
658 | 650 | |
659 | logger.debug("TX [%s] Sent transaction", destination) | |
660 | logger.debug("TX [%s] Marking as delivered...", destination) | |
661 | ||
662 | 651 | yield self.transaction_actions.delivered( |
663 | 652 | transaction, code, response |
664 | 653 | ) |
665 | 654 | |
666 | logger.debug("TX [%s] Marked as delivered", destination) | |
667 | ||
668 | if code != 200: | |
655 | logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) | |
656 | ||
657 | if code == 200: | |
658 | for e_id, r in response.get("pdus", {}).items(): | |
659 | if "error" in r: | |
660 | logger.warn( | |
661 | "TX [%s] {%s} Remote returned error for %s: %s", | |
662 | destination, txn_id, e_id, r, | |
663 | ) | |
664 | else: | |
669 | 665 | for p in pdus: |
670 | logger.info( | |
671 | "Failed to send event %s to %s", p.event_id, destination | |
666 | logger.warn( | |
667 | "TX [%s] {%s} Failed to send event %s", | |
668 | destination, txn_id, p.event_id, | |
672 | 669 | ) |
673 | 670 | success = False |
674 | 671 |
142 | 142 | transaction (Transaction) |
143 | 143 | |
144 | 144 | Returns: |
145 | Deferred: Results of the deferred is a tuple in the form of | |
146 | (response_code, response_body) where the response_body is a | |
147 | python dict decoded from json | |
145 | Deferred: Succeeds when we get a 2xx HTTP response. The result | |
146 | will be the decoded JSON body. | |
147 | ||
148 | Fails with ``HTTPRequestException`` if we get an HTTP response | |
149 | code >= 300. | |
150 | ||
151 | Fails with ``NotRetryingDestination`` if we are not yet ready | |
152 | to retry this server. | |
153 | ||
154 | Fails with ``FederationDeniedError`` if this destination | |
155 | is not on our federation whitelist | |
148 | 156 | """ |
149 | 157 | logger.debug( |
150 | 158 | "send_data dest=%s, txid=%s", |
167 | 175 | json_data_callback=json_data_callback, |
168 | 176 | long_retries=True, |
169 | 177 | backoff_on_404=True, # If we get a 404 the other side has gone |
170 | ) | |
171 | ||
172 | logger.debug( | |
173 | "send_data dest=%s, txid=%s, got response: 200", | |
174 | transaction.destination, transaction.transaction_id, | |
175 | 178 | ) |
176 | 179 | |
177 | 180 | defer.returnValue(response) |
21 | 21 | import pymacaroons |
22 | 22 | from canonicaljson import json |
23 | 23 | |
24 | from twisted.internet import defer, threads | |
24 | from twisted.internet import defer | |
25 | 25 | from twisted.web.client import PartialDownloadError |
26 | 26 | |
27 | 27 | import synapse.util.stringutils as stringutils |
36 | 36 | ) |
37 | 37 | from synapse.module_api import ModuleApi |
38 | 38 | from synapse.types import UserID |
39 | from synapse.util import logcontext | |
39 | 40 | from synapse.util.caches.expiringcache import ExpiringCache |
40 | from synapse.util.logcontext import make_deferred_yieldable | |
41 | 41 | |
42 | 42 | from ._base import BaseHandler |
43 | 43 | |
883 | 883 | bcrypt.gensalt(self.bcrypt_rounds), |
884 | 884 | ).decode('ascii') |
885 | 885 | |
886 | return make_deferred_yieldable( | |
887 | threads.deferToThreadPool( | |
888 | self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash | |
889 | ), | |
890 | ) | |
886 | return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) | |
891 | 887 | |
892 | 888 | def validate_hash(self, password, stored_hash): |
893 | 889 | """Validates that self.hash(password) == stored_hash. |
912 | 908 | if not isinstance(stored_hash, bytes): |
913 | 909 | stored_hash = stored_hash.encode('ascii') |
914 | 910 | |
915 | return make_deferred_yieldable( | |
916 | threads.deferToThreadPool( | |
917 | self.hs.get_reactor(), | |
918 | self.hs.get_reactor().getThreadPool(), | |
919 | _do_validate_hash, | |
920 | ), | |
921 | ) | |
911 | return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) | |
922 | 912 | else: |
923 | 913 | return defer.succeed(False) |
924 | 914 |
16 | 16 | from twisted.internet import defer |
17 | 17 | |
18 | 18 | from synapse.api.errors import SynapseError |
19 | from synapse.metrics.background_process_metrics import run_as_background_process | |
19 | 20 | from synapse.types import UserID, create_requester |
20 | from synapse.util.logcontext import run_in_background | |
21 | 21 | |
22 | 22 | from ._base import BaseHandler |
23 | 23 | |
120 | 120 | None |
121 | 121 | """ |
122 | 122 | if not self._user_parter_running: |
123 | run_in_background(self._user_parter_loop) | |
123 | run_as_background_process("user_parter_loop", self._user_parter_loop) | |
124 | 124 | |
125 | 125 | @defer.inlineCallbacks |
126 | 126 | def _user_parter_loop(self): |
42 | 42 | self.state = hs.get_state_handler() |
43 | 43 | self.appservice_handler = hs.get_application_service_handler() |
44 | 44 | self.event_creation_handler = hs.get_event_creation_handler() |
45 | self.config = hs.config | |
45 | 46 | |
46 | 47 | self.federation = hs.get_federation_client() |
47 | 48 | hs.get_federation_registry().register_query_handler( |
79 | 80 | ) |
80 | 81 | |
81 | 82 | @defer.inlineCallbacks |
82 | def create_association(self, user_id, room_alias, room_id, servers=None): | |
83 | # association creation for human users | |
84 | # TODO(erikj): Do user auth. | |
85 | ||
86 | if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): | |
87 | raise SynapseError( | |
88 | 403, "This user is not permitted to create this alias", | |
89 | ) | |
90 | ||
91 | can_create = yield self.can_modify_alias( | |
92 | room_alias, | |
93 | user_id=user_id | |
94 | ) | |
95 | if not can_create: | |
96 | raise SynapseError( | |
97 | 400, "This alias is reserved by an application service.", | |
98 | errcode=Codes.EXCLUSIVE | |
99 | ) | |
83 | def create_association(self, requester, room_alias, room_id, servers=None, | |
84 | send_event=True): | |
85 | """Attempt to create a new alias | |
86 | ||
87 | Args: | |
88 | requester (Requester) | |
89 | room_alias (RoomAlias) | |
90 | room_id (str) | |
91 | servers (list[str]|None): List of servers that others servers | |
92 | should try and join via | |
93 | send_event (bool): Whether to send an updated m.room.aliases event | |
94 | ||
95 | Returns: | |
96 | Deferred | |
97 | """ | |
98 | ||
99 | user_id = requester.user.to_string() | |
100 | ||
101 | service = requester.app_service | |
102 | if service: | |
103 | if not service.is_interested_in_alias(room_alias.to_string()): | |
104 | raise SynapseError( | |
105 | 400, "This application service has not reserved" | |
106 | " this kind of alias.", errcode=Codes.EXCLUSIVE | |
107 | ) | |
108 | else: | |
109 | if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): | |
110 | raise AuthError( | |
111 | 403, "This user is not permitted to create this alias", | |
112 | ) | |
113 | ||
114 | if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()): | |
115 | # Lets just return a generic message, as there may be all sorts of | |
116 | # reasons why we said no. TODO: Allow configurable error messages | |
117 | # per alias creation rule? | |
118 | raise SynapseError( | |
119 | 403, "Not allowed to create alias", | |
120 | ) | |
121 | ||
122 | can_create = yield self.can_modify_alias( | |
123 | room_alias, | |
124 | user_id=user_id | |
125 | ) | |
126 | if not can_create: | |
127 | raise AuthError( | |
128 | 400, "This alias is reserved by an application service.", | |
129 | errcode=Codes.EXCLUSIVE | |
130 | ) | |
131 | ||
100 | 132 | yield self._create_association(room_alias, room_id, servers, creator=user_id) |
101 | ||
102 | @defer.inlineCallbacks | |
103 | def create_appservice_association(self, service, room_alias, room_id, | |
104 | servers=None): | |
105 | if not service.is_interested_in_alias(room_alias.to_string()): | |
106 | raise SynapseError( | |
107 | 400, "This application service has not reserved" | |
108 | " this kind of alias.", errcode=Codes.EXCLUSIVE | |
109 | ) | |
110 | ||
111 | # association creation for app services | |
112 | yield self._create_association(room_alias, room_id, servers) | |
113 | ||
114 | @defer.inlineCallbacks | |
115 | def delete_association(self, requester, user_id, room_alias): | |
133 | if send_event: | |
134 | yield self.send_room_alias_update_event( | |
135 | requester, | |
136 | room_id | |
137 | ) | |
138 | ||
139 | @defer.inlineCallbacks | |
140 | def delete_association(self, requester, room_alias): | |
116 | 141 | # association deletion for human users |
142 | ||
143 | user_id = requester.user.to_string() | |
117 | 144 | |
118 | 145 | try: |
119 | 146 | can_delete = yield self._user_can_delete_alias(room_alias, user_id) |
142 | 169 | try: |
143 | 170 | yield self.send_room_alias_update_event( |
144 | 171 | requester, |
145 | requester.user.to_string(), | |
146 | 172 | room_id |
147 | 173 | ) |
148 | 174 | |
260 | 286 | ) |
261 | 287 | |
262 | 288 | @defer.inlineCallbacks |
263 | def send_room_alias_update_event(self, requester, user_id, room_id): | |
289 | def send_room_alias_update_event(self, requester, room_id): | |
264 | 290 | aliases = yield self.store.get_aliases_for_room(room_id) |
265 | 291 | |
266 | 292 | yield self.event_creation_handler.create_and_send_nonmember_event( |
269 | 295 | "type": EventTypes.Aliases, |
270 | 296 | "state_key": self.hs.hostname, |
271 | 297 | "room_id": room_id, |
272 | "sender": user_id, | |
298 | "sender": requester.user.to_string(), | |
273 | 299 | "content": {"aliases": aliases}, |
274 | 300 | }, |
275 | 301 | ratelimit=False |
52 | 52 | ReplicationFederationSendEventsRestServlet, |
53 | 53 | ) |
54 | 54 | from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet |
55 | from synapse.state import resolve_events_with_factory | |
55 | from synapse.state import StateResolutionStore, resolve_events_with_store | |
56 | 56 | from synapse.types import UserID, get_domain_from_id |
57 | 57 | from synapse.util import logcontext, unwrapFirstError |
58 | 58 | from synapse.util.async_helpers import Linearizer |
383 | 383 | for x in remote_state: |
384 | 384 | event_map[x.event_id] = x |
385 | 385 | |
386 | # Resolve any conflicting state | |
387 | @defer.inlineCallbacks | |
388 | def fetch(ev_ids): | |
389 | fetched = yield self.store.get_events( | |
390 | ev_ids, get_prev_content=False, check_redacted=False, | |
391 | ) | |
392 | # add any events we fetch here to the `event_map` so that we | |
393 | # can use them to build the state event list below. | |
394 | event_map.update(fetched) | |
395 | defer.returnValue(fetched) | |
396 | ||
397 | 386 | room_version = yield self.store.get_room_version(room_id) |
398 | state_map = yield resolve_events_with_factory( | |
399 | room_version, state_maps, event_map, fetch, | |
387 | state_map = yield resolve_events_with_store( | |
388 | room_version, state_maps, event_map, | |
389 | state_res_store=StateResolutionStore(self.store), | |
400 | 390 | ) |
401 | 391 | |
402 | # we need to give _process_received_pdu the actual state events | |
392 | # We need to give _process_received_pdu the actual state events | |
403 | 393 | # rather than event ids, so generate that now. |
394 | ||
395 | # First though we need to fetch all the events that are in | |
396 | # state_map, so we can build up the state below. | |
397 | evs = yield self.store.get_events( | |
398 | list(state_map.values()), | |
399 | get_prev_content=False, | |
400 | check_redacted=False, | |
401 | ) | |
402 | event_map.update(evs) | |
403 | ||
404 | 404 | state = [ |
405 | 405 | event_map[e] for e in six.itervalues(state_map) |
406 | 406 | ] |
2519 | 2519 | |
2520 | 2520 | if not backfilled: # Never notify for backfilled events |
2521 | 2521 | for event, _ in event_and_contexts: |
2522 | self._notify_persisted_event(event, max_stream_id) | |
2522 | yield self._notify_persisted_event(event, max_stream_id) | |
2523 | 2523 | |
2524 | 2524 | def _notify_persisted_event(self, event, max_stream_id): |
2525 | 2525 | """Checks to see if notifier/pushers should be notified about the |
2552 | 2552 | extra_users=extra_users |
2553 | 2553 | ) |
2554 | 2554 | |
2555 | self.pusher_pool.on_new_notifications( | |
2555 | return self.pusher_pool.on_new_notifications( | |
2556 | 2556 | event_stream_id, max_stream_id, |
2557 | 2557 | ) |
2558 | 2558 |
19 | 19 | |
20 | 20 | from twisted.internet import defer |
21 | 21 | |
22 | from synapse.api.errors import SynapseError | |
22 | from synapse.api.errors import HttpResponseException, SynapseError | |
23 | 23 | from synapse.types import get_domain_from_id |
24 | 24 | |
25 | 25 | logger = logging.getLogger(__name__) |
36 | 36 | ) |
37 | 37 | else: |
38 | 38 | destination = get_domain_from_id(group_id) |
39 | return getattr(self.transport_client, func_name)( | |
39 | d = getattr(self.transport_client, func_name)( | |
40 | 40 | destination, group_id, *args, **kwargs |
41 | 41 | ) |
42 | ||
43 | # Capture errors returned by the remote homeserver and | |
44 | # re-throw specific errors as SynapseErrors. This is so | |
45 | # when the remote end responds with things like 403 Not | |
46 | # In Group, we can communicate that to the client instead | |
47 | # of a 500. | |
48 | def h(failure): | |
49 | failure.trap(HttpResponseException) | |
50 | e = failure.value | |
51 | if e.code == 403: | |
52 | raise e.to_synapse_error() | |
53 | return failure | |
54 | d.addErrback(h) | |
55 | return d | |
42 | 56 | return f |
43 | 57 | |
44 | 58 |
155 | 155 | room_end_token = "s%d" % (event.stream_ordering,) |
156 | 156 | deferred_room_state = run_in_background( |
157 | 157 | self.store.get_state_for_events, |
158 | [event.event_id], None, | |
158 | [event.event_id], | |
159 | 159 | ) |
160 | 160 | deferred_room_state.addCallback( |
161 | 161 | lambda states: states[event.event_id] |
300 | 300 | def _room_initial_sync_parted(self, user_id, room_id, pagin_config, |
301 | 301 | membership, member_event_id, is_peeking): |
302 | 302 | room_state = yield self.store.get_state_for_events( |
303 | [member_event_id], None | |
303 | [member_event_id], | |
304 | 304 | ) |
305 | 305 | |
306 | 306 | room_state = room_state[member_event_id] |
34 | 34 | from synapse.events.utils import serialize_event |
35 | 35 | from synapse.events.validator import EventValidator |
36 | 36 | from synapse.replication.http.send_event import ReplicationSendEventRestServlet |
37 | from synapse.storage.state import StateFilter | |
37 | 38 | from synapse.types import RoomAlias, UserID |
38 | 39 | from synapse.util.async_helpers import Linearizer |
39 | 40 | from synapse.util.frozenutils import frozendict_json_encoder |
79 | 80 | elif membership == Membership.LEAVE: |
80 | 81 | key = (event_type, state_key) |
81 | 82 | room_state = yield self.store.get_state_for_events( |
82 | [membership_event_id], [key] | |
83 | [membership_event_id], StateFilter.from_types([key]) | |
83 | 84 | ) |
84 | 85 | data = room_state[membership_event_id].get(key) |
85 | 86 | |
87 | 88 | |
88 | 89 | @defer.inlineCallbacks |
89 | 90 | def get_state_events( |
90 | self, user_id, room_id, types=None, filtered_types=None, | |
91 | self, user_id, room_id, state_filter=StateFilter.all(), | |
91 | 92 | at_token=None, is_guest=False, |
92 | 93 | ): |
93 | 94 | """Retrieve all state events for a given room. If the user is |
99 | 100 | Args: |
100 | 101 | user_id(str): The user requesting state events. |
101 | 102 | room_id(str): The room ID to get all state events from. |
102 | types(list[(str, str|None)]|None): List of (type, state_key) tuples | |
103 | which are used to filter the state fetched. If `state_key` is None, | |
104 | all events are returned of the given type. | |
105 | May be None, which matches any key. | |
106 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
107 | list of event types. Other types of events are returned unfiltered. | |
108 | If None, `types` filtering is applied to all events. | |
103 | state_filter (StateFilter): The state filter used to fetch state | |
104 | from the database. | |
109 | 105 | at_token(StreamToken|None): the stream token of the at which we are requesting |
110 | 106 | the stats. If the user is not allowed to view the state as of that |
111 | 107 | stream token, we raise a 403 SynapseError. If None, returns the current |
138 | 134 | event = last_events[0] |
139 | 135 | if visible_events: |
140 | 136 | room_state = yield self.store.get_state_for_events( |
141 | [event.event_id], types, filtered_types=filtered_types, | |
137 | [event.event_id], state_filter=state_filter, | |
142 | 138 | ) |
143 | 139 | room_state = room_state[event.event_id] |
144 | 140 | else: |
157 | 153 | |
158 | 154 | if membership == Membership.JOIN: |
159 | 155 | state_ids = yield self.store.get_filtered_current_state_ids( |
160 | room_id, types, filtered_types=filtered_types, | |
156 | room_id, state_filter=state_filter, | |
161 | 157 | ) |
162 | 158 | room_state = yield self.store.get_events(state_ids.values()) |
163 | 159 | elif membership == Membership.LEAVE: |
164 | 160 | room_state = yield self.store.get_state_for_events( |
165 | [membership_event_id], types, filtered_types=filtered_types, | |
161 | [membership_event_id], state_filter=state_filter, | |
166 | 162 | ) |
167 | 163 | room_state = room_state[membership_event_id] |
168 | 164 | |
778 | 774 | event, context=context |
779 | 775 | ) |
780 | 776 | |
781 | self.pusher_pool.on_new_notifications( | |
777 | yield self.pusher_pool.on_new_notifications( | |
782 | 778 | event_stream_id, max_stream_id, |
783 | 779 | ) |
784 | 780 |
20 | 20 | from synapse.api.constants import EventTypes, Membership |
21 | 21 | from synapse.api.errors import SynapseError |
22 | 22 | from synapse.events.utils import serialize_event |
23 | from synapse.storage.state import StateFilter | |
23 | 24 | from synapse.types import RoomStreamToken |
24 | 25 | from synapse.util.async_helpers import ReadWriteLock |
25 | 26 | from synapse.util.logcontext import run_in_background |
254 | 255 | if event_filter and event_filter.lazy_load_members(): |
255 | 256 | # TODO: remove redundant members |
256 | 257 | |
257 | types = [ | |
258 | (EventTypes.Member, state_key) | |
259 | for state_key in set( | |
260 | event.sender # FIXME: we also care about invite targets etc. | |
261 | for event in events | |
262 | ) | |
263 | ] | |
258 | # FIXME: we also care about invite targets etc. | |
259 | state_filter = StateFilter.from_types( | |
260 | (EventTypes.Member, event.sender) | |
261 | for event in events | |
262 | ) | |
264 | 263 | |
265 | 264 | state_ids = yield self.store.get_state_ids_for_event( |
266 | events[0].event_id, types=types, | |
265 | events[0].event_id, state_filter=state_filter, | |
267 | 266 | ) |
268 | 267 | |
269 | 268 | if state_ids: |
118 | 118 | "receipt_key", max_batch_id, rooms=affected_room_ids |
119 | 119 | ) |
120 | 120 | # Note that the min here shouldn't be relied upon to be accurate. |
121 | self.hs.get_pusherpool().on_new_receipts( | |
121 | yield self.hs.get_pusherpool().on_new_receipts( | |
122 | 122 | min_batch_id, max_batch_id, affected_room_ids, |
123 | 123 | ) |
124 | 124 |
49 | 49 | self._auth_handler = hs.get_auth_handler() |
50 | 50 | self.profile_handler = hs.get_profile_handler() |
51 | 51 | self.user_directory_handler = hs.get_user_directory_handler() |
52 | self.room_creation_handler = self.hs.get_room_creation_handler() | |
52 | 53 | self.captcha_client = CaptchaServerHttpClient(hs) |
53 | 54 | |
54 | 55 | self._next_generated_user_id = None |
219 | 220 | |
220 | 221 | # auto-join the user to any rooms we're supposed to dump them into |
221 | 222 | fake_requester = create_requester(user_id) |
223 | ||
224 | # try to create the room if we're the first user on the server | |
225 | should_auto_create_rooms = False | |
226 | if self.hs.config.autocreate_auto_join_rooms: | |
227 | count = yield self.store.count_all_users() | |
228 | should_auto_create_rooms = count == 1 | |
229 | ||
222 | 230 | for r in self.hs.config.auto_join_rooms: |
223 | 231 | try: |
224 | yield self._join_user_to_room(fake_requester, r) | |
232 | if should_auto_create_rooms: | |
233 | room_alias = RoomAlias.from_string(r) | |
234 | if self.hs.hostname != room_alias.domain: | |
235 | logger.warning( | |
236 | 'Cannot create room alias %s, ' | |
237 | 'it does not match server domain', | |
238 | r, | |
239 | ) | |
240 | else: | |
241 | # create room expects the localpart of the room alias | |
242 | room_alias_localpart = room_alias.localpart | |
243 | yield self.room_creation_handler.create_room( | |
244 | fake_requester, | |
245 | config={ | |
246 | "preset": "public_chat", | |
247 | "room_alias_name": room_alias_localpart | |
248 | }, | |
249 | ratelimit=False, | |
250 | ) | |
251 | else: | |
252 | yield self._join_user_to_room(fake_requester, r) | |
225 | 253 | except Exception as e: |
226 | 254 | logger.error("Failed to join new user to %r: %r", r, e) |
227 | 255 |
32 | 32 | RoomCreationPreset, |
33 | 33 | ) |
34 | 34 | from synapse.api.errors import AuthError, Codes, StoreError, SynapseError |
35 | from synapse.storage.state import StateFilter | |
35 | 36 | from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID |
36 | 37 | from synapse.util import stringutils |
37 | 38 | from synapse.visibility import filter_events_for_client |
189 | 190 | if room_alias: |
190 | 191 | directory_handler = self.hs.get_handlers().directory_handler |
191 | 192 | yield directory_handler.create_association( |
192 | user_id=user_id, | |
193 | requester=requester, | |
193 | 194 | room_id=room_id, |
194 | 195 | room_alias=room_alias, |
195 | 196 | servers=[self.hs.hostname], |
197 | send_event=False, | |
196 | 198 | ) |
197 | 199 | |
198 | 200 | preset_config = config.get( |
288 | 290 | if room_alias: |
289 | 291 | result["room_alias"] = room_alias.to_string() |
290 | 292 | yield directory_handler.send_room_alias_update_event( |
291 | requester, user_id, room_id | |
293 | requester, room_id | |
292 | 294 | ) |
293 | 295 | |
294 | 296 | defer.returnValue(result) |
487 | 489 | else: |
488 | 490 | last_event_id = event_id |
489 | 491 | |
490 | types = None | |
491 | filtered_types = None | |
492 | 492 | if event_filter and event_filter.lazy_load_members(): |
493 | members = set(ev.sender for ev in itertools.chain( | |
494 | results["events_before"], | |
495 | (results["event"],), | |
496 | results["events_after"], | |
497 | )) | |
498 | filtered_types = [EventTypes.Member] | |
499 | types = [(EventTypes.Member, member) for member in members] | |
493 | state_filter = StateFilter.from_lazy_load_member_list( | |
494 | ev.sender | |
495 | for ev in itertools.chain( | |
496 | results["events_before"], | |
497 | (results["event"],), | |
498 | results["events_after"], | |
499 | ) | |
500 | ) | |
501 | else: | |
502 | state_filter = StateFilter.all() | |
500 | 503 | |
501 | 504 | # XXX: why do we return the state as of the last event rather than the |
502 | 505 | # first? Shouldn't we be consistent with /sync? |
503 | 506 | # https://github.com/matrix-org/matrix-doc/issues/687 |
504 | 507 | |
505 | 508 | state = yield self.store.get_state_for_events( |
506 | [last_event_id], types, filtered_types=filtered_types, | |
509 | [last_event_id], state_filter=state_filter, | |
507 | 510 | ) |
508 | 511 | results["state"] = list(state[last_event_id].values()) |
509 | 512 |
15 | 15 | import logging |
16 | 16 | from collections import namedtuple |
17 | 17 | |
18 | from six import iteritems | |
18 | from six import PY3, iteritems | |
19 | 19 | from six.moves import range |
20 | 20 | |
21 | 21 | import msgpack |
443 | 443 | |
444 | 444 | @classmethod |
445 | 445 | def from_token(cls, token): |
446 | if PY3: | |
447 | # The argument raw=False is only available on new versions of | |
448 | # msgpack, and only really needed on Python 3. Gate it behind | |
449 | # a PY3 check to avoid causing issues on Debian-packaged versions. | |
450 | decoded = msgpack.loads(decode_base64(token), raw=False) | |
451 | else: | |
452 | decoded = msgpack.loads(decode_base64(token)) | |
446 | 453 | return RoomListNextBatch(**{ |
447 | 454 | cls.REVERSE_KEY_DICT[key]: val |
448 | for key, val in msgpack.loads(decode_base64(token)).items() | |
455 | for key, val in decoded.items() | |
449 | 456 | }) |
450 | 457 | |
451 | 458 | def to_token(self): |
23 | 23 | from synapse.api.errors import SynapseError |
24 | 24 | from synapse.api.filtering import Filter |
25 | 25 | from synapse.events.utils import serialize_event |
26 | from synapse.storage.state import StateFilter | |
26 | 27 | from synapse.visibility import filter_events_for_client |
27 | 28 | |
28 | 29 | from ._base import BaseHandler |
323 | 324 | else: |
324 | 325 | last_event_id = event.event_id |
325 | 326 | |
327 | state_filter = StateFilter.from_types( | |
328 | [(EventTypes.Member, sender) for sender in senders] | |
329 | ) | |
330 | ||
326 | 331 | state = yield self.store.get_state_for_event( |
327 | last_event_id, | |
328 | types=[(EventTypes.Member, sender) for sender in senders] | |
332 | last_event_id, state_filter | |
329 | 333 | ) |
330 | 334 | |
331 | 335 | res["profile_info"] = { |
26 | 26 | from synapse.api.constants import EventTypes, Membership |
27 | 27 | from synapse.push.clientformat import format_push_rules_for_user |
28 | 28 | from synapse.storage.roommember import MemberSummary |
29 | from synapse.storage.state import StateFilter | |
29 | 30 | from synapse.types import RoomStreamToken |
30 | 31 | from synapse.util.async_helpers import concurrently_execute |
31 | 32 | from synapse.util.caches.expiringcache import ExpiringCache |
468 | 469 | )) |
469 | 470 | |
470 | 471 | @defer.inlineCallbacks |
471 | def get_state_after_event(self, event, types=None, filtered_types=None): | |
472 | def get_state_after_event(self, event, state_filter=StateFilter.all()): | |
472 | 473 | """ |
473 | 474 | Get the room state after the given event |
474 | 475 | |
475 | 476 | Args: |
476 | 477 | event(synapse.events.EventBase): event of interest |
477 | types(list[(str, str|None)]|None): List of (type, state_key) tuples | |
478 | which are used to filter the state fetched. If `state_key` is None, | |
479 | all events are returned of the given type. | |
480 | May be None, which matches any key. | |
481 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
482 | list of event types. Other types of events are returned unfiltered. | |
483 | If None, `types` filtering is applied to all events. | |
478 | state_filter (StateFilter): The state filter used to fetch state | |
479 | from the database. | |
484 | 480 | |
485 | 481 | Returns: |
486 | 482 | A Deferred map from ((type, state_key)->Event) |
487 | 483 | """ |
488 | 484 | state_ids = yield self.store.get_state_ids_for_event( |
489 | event.event_id, types, filtered_types=filtered_types, | |
485 | event.event_id, state_filter=state_filter, | |
490 | 486 | ) |
491 | 487 | if event.is_state(): |
492 | 488 | state_ids = state_ids.copy() |
494 | 490 | defer.returnValue(state_ids) |
495 | 491 | |
496 | 492 | @defer.inlineCallbacks |
497 | def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): | |
493 | def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): | |
498 | 494 | """ Get the room state at a particular stream position |
499 | 495 | |
500 | 496 | Args: |
501 | 497 | room_id(str): room for which to get state |
502 | 498 | stream_position(StreamToken): point at which to get state |
503 | types(list[(str, str|None)]|None): List of (type, state_key) tuples | |
504 | which are used to filter the state fetched. If `state_key` is None, | |
505 | all events are returned of the given type. | |
506 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
507 | list of event types. Other types of events are returned unfiltered. | |
508 | If None, `types` filtering is applied to all events. | |
499 | state_filter (StateFilter): The state filter used to fetch state | |
500 | from the database. | |
509 | 501 | |
510 | 502 | Returns: |
511 | 503 | A Deferred map from ((type, state_key)->Event) |
521 | 513 | if last_events: |
522 | 514 | last_event = last_events[-1] |
523 | 515 | state = yield self.get_state_after_event( |
524 | last_event, types, filtered_types=filtered_types, | |
516 | last_event, state_filter=state_filter, | |
525 | 517 | ) |
526 | 518 | |
527 | 519 | else: |
562 | 554 | |
563 | 555 | last_event = last_events[-1] |
564 | 556 | state_ids = yield self.store.get_state_ids_for_event( |
565 | last_event.event_id, [ | |
557 | last_event.event_id, | |
558 | state_filter=StateFilter.from_types([ | |
566 | 559 | (EventTypes.Name, ''), |
567 | 560 | (EventTypes.CanonicalAlias, ''), |
568 | ] | |
561 | ]), | |
569 | 562 | ) |
570 | 563 | |
571 | 564 | # this is heavily cached, thus: fast. |
716 | 709 | |
717 | 710 | with Measure(self.clock, "compute_state_delta"): |
718 | 711 | |
719 | types = None | |
720 | filtered_types = None | |
712 | members_to_fetch = None | |
721 | 713 | |
722 | 714 | lazy_load_members = sync_config.filter_collection.lazy_load_members() |
723 | 715 | include_redundant_members = ( |
728 | 720 | # We only request state for the members needed to display the |
729 | 721 | # timeline: |
730 | 722 | |
731 | types = [ | |
732 | (EventTypes.Member, state_key) | |
733 | for state_key in set( | |
734 | event.sender # FIXME: we also care about invite targets etc. | |
735 | for event in batch.events | |
736 | ) | |
737 | ] | |
738 | ||
739 | # only apply the filtering to room members | |
740 | filtered_types = [EventTypes.Member] | |
723 | members_to_fetch = set( | |
724 | event.sender # FIXME: we also care about invite targets etc. | |
725 | for event in batch.events | |
726 | ) | |
727 | ||
728 | if full_state: | |
729 | # always make sure we LL ourselves so we know we're in the room | |
730 | # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 | |
731 | # We only need apply this on full state syncs given we disabled | |
732 | # LL for incr syncs in #3840. | |
733 | members_to_fetch.add(sync_config.user.to_string()) | |
734 | ||
735 | state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) | |
736 | else: | |
737 | state_filter = StateFilter.all() | |
741 | 738 | |
742 | 739 | timeline_state = { |
743 | 740 | (event.type, event.state_key): event.event_id |
745 | 742 | } |
746 | 743 | |
747 | 744 | if full_state: |
748 | if lazy_load_members: | |
749 | # always make sure we LL ourselves so we know we're in the room | |
750 | # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 | |
751 | # We only need apply this on full state syncs given we disabled | |
752 | # LL for incr syncs in #3840. | |
753 | types.append((EventTypes.Member, sync_config.user.to_string())) | |
754 | ||
755 | 745 | if batch: |
756 | 746 | current_state_ids = yield self.store.get_state_ids_for_event( |
757 | batch.events[-1].event_id, types=types, | |
758 | filtered_types=filtered_types, | |
747 | batch.events[-1].event_id, state_filter=state_filter, | |
759 | 748 | ) |
760 | 749 | |
761 | 750 | state_ids = yield self.store.get_state_ids_for_event( |
762 | batch.events[0].event_id, types=types, | |
763 | filtered_types=filtered_types, | |
751 | batch.events[0].event_id, state_filter=state_filter, | |
764 | 752 | ) |
765 | 753 | |
766 | 754 | else: |
767 | 755 | current_state_ids = yield self.get_state_at( |
768 | room_id, stream_position=now_token, types=types, | |
769 | filtered_types=filtered_types, | |
756 | room_id, stream_position=now_token, | |
757 | state_filter=state_filter, | |
770 | 758 | ) |
771 | 759 | |
772 | 760 | state_ids = current_state_ids |
780 | 768 | ) |
781 | 769 | elif batch.limited: |
782 | 770 | state_at_timeline_start = yield self.store.get_state_ids_for_event( |
783 | batch.events[0].event_id, types=types, | |
784 | filtered_types=filtered_types, | |
771 | batch.events[0].event_id, state_filter=state_filter, | |
785 | 772 | ) |
786 | 773 | |
787 | 774 | # for now, we disable LL for gappy syncs - see |
796 | 783 | # members to just be ones which were timeline senders, which then ensures |
797 | 784 | # all of the rest get included in the state block (if we need to know |
798 | 785 | # about them). |
799 | types = None | |
800 | filtered_types = None | |
786 | state_filter = StateFilter.all() | |
801 | 787 | |
802 | 788 | state_at_previous_sync = yield self.get_state_at( |
803 | room_id, stream_position=since_token, types=types, | |
804 | filtered_types=filtered_types, | |
789 | room_id, stream_position=since_token, | |
790 | state_filter=state_filter, | |
805 | 791 | ) |
806 | 792 | |
807 | 793 | current_state_ids = yield self.store.get_state_ids_for_event( |
808 | batch.events[-1].event_id, types=types, | |
809 | filtered_types=filtered_types, | |
794 | batch.events[-1].event_id, state_filter=state_filter, | |
810 | 795 | ) |
811 | 796 | |
812 | 797 | state_ids = _calculate_state( |
820 | 805 | else: |
821 | 806 | state_ids = {} |
822 | 807 | if lazy_load_members: |
823 | if types and batch.events: | |
808 | if members_to_fetch and batch.events: | |
824 | 809 | # We're returning an incremental sync, with no |
825 | 810 | # "gap" since the previous sync, so normally there would be |
826 | 811 | # no state to return. |
830 | 815 | # timeline here, and then dedupe any redundant ones below. |
831 | 816 | |
832 | 817 | state_ids = yield self.store.get_state_ids_for_event( |
833 | batch.events[0].event_id, types=types, | |
834 | filtered_types=None, # we only want members! | |
818 | batch.events[0].event_id, | |
819 | # we only want members! | |
820 | state_filter=StateFilter.from_types( | |
821 | (EventTypes.Member, member) | |
822 | for member in members_to_fetch | |
823 | ), | |
835 | 824 | ) |
836 | 825 | |
837 | 826 | if lazy_load_members and not include_redundant_members: |
19 | 19 | from twisted.internet import defer |
20 | 20 | |
21 | 21 | from synapse.api.constants import EventTypes, JoinRules, Membership |
22 | from synapse.metrics.background_process_metrics import run_as_background_process | |
22 | 23 | from synapse.storage.roommember import ProfileInfo |
23 | 24 | from synapse.types import get_localpart_from_id |
24 | 25 | from synapse.util.metrics import Measure |
97 | 98 | """ |
98 | 99 | return self.store.search_user_dir(user_id, search_term, limit) |
99 | 100 | |
100 | @defer.inlineCallbacks | |
101 | 101 | def notify_new_event(self): |
102 | 102 | """Called when there may be more deltas to process |
103 | 103 | """ |
107 | 107 | if self._is_processing: |
108 | 108 | return |
109 | 109 | |
110 | @defer.inlineCallbacks | |
111 | def process(): | |
112 | try: | |
113 | yield self._unsafe_process() | |
114 | finally: | |
115 | self._is_processing = False | |
116 | ||
110 | 117 | self._is_processing = True |
111 | try: | |
112 | yield self._unsafe_process() | |
113 | finally: | |
114 | self._is_processing = False | |
118 | run_as_background_process("user_directory.notify_new_event", process) | |
115 | 119 | |
116 | 120 | @defer.inlineCallbacks |
117 | 121 | def handle_local_profile_change(self, user_id, profile): |
194 | 194 | ) |
195 | 195 | self.clock = hs.get_clock() |
196 | 196 | self._store = hs.get_datastore() |
197 | self.version_string = hs.version_string.encode('ascii') | |
197 | self.version_string_bytes = hs.version_string.encode('ascii') | |
198 | 198 | self.default_timeout = 60 |
199 | 199 | |
200 | 200 | def schedule(x): |
229 | 229 | Returns: |
230 | 230 | Deferred: resolves with the http response object on success. |
231 | 231 | |
232 | Fails with ``HTTPRequestException``: if we get an HTTP response | |
232 | Fails with ``HttpResponseException``: if we get an HTTP response | |
233 | 233 | code >= 300. |
234 | 234 | |
235 | 235 | Fails with ``NotRetryingDestination`` if we are not yet ready |
260 | 260 | ignore_backoff=ignore_backoff, |
261 | 261 | ) |
262 | 262 | |
263 | method = request.method | |
264 | destination = request.destination | |
263 | method_bytes = request.method.encode("ascii") | |
264 | destination_bytes = request.destination.encode("ascii") | |
265 | 265 | path_bytes = request.path.encode("ascii") |
266 | 266 | if request.query: |
267 | 267 | query_bytes = encode_query_args(request.query) |
269 | 269 | query_bytes = b"" |
270 | 270 | |
271 | 271 | headers_dict = { |
272 | "User-Agent": [self.version_string], | |
273 | "Host": [request.destination], | |
272 | b"User-Agent": [self.version_string_bytes], | |
273 | b"Host": [destination_bytes], | |
274 | 274 | } |
275 | 275 | |
276 | 276 | with limiter: |
281 | 281 | else: |
282 | 282 | retries_left = MAX_SHORT_RETRIES |
283 | 283 | |
284 | url = urllib.parse.urlunparse(( | |
285 | b"matrix", destination.encode("ascii"), | |
284 | url_bytes = urllib.parse.urlunparse(( | |
285 | b"matrix", destination_bytes, | |
286 | 286 | path_bytes, None, query_bytes, b"", |
287 | )).decode('ascii') | |
288 | ||
289 | http_url = urllib.parse.urlunparse(( | |
287 | )) | |
288 | url_str = url_bytes.decode('ascii') | |
289 | ||
290 | url_to_sign_bytes = urllib.parse.urlunparse(( | |
290 | 291 | b"", b"", |
291 | 292 | path_bytes, None, query_bytes, b"", |
292 | )).decode('ascii') | |
293 | )) | |
293 | 294 | |
294 | 295 | while True: |
295 | 296 | try: |
296 | 297 | json = request.get_json() |
297 | 298 | if json: |
299 | headers_dict[b"Content-Type"] = [b"application/json"] | |
300 | self.sign_request( | |
301 | destination_bytes, method_bytes, url_to_sign_bytes, | |
302 | headers_dict, json, | |
303 | ) | |
298 | 304 | data = encode_canonical_json(json) |
299 | headers_dict["Content-Type"] = ["application/json"] | |
300 | self.sign_request( | |
301 | destination, method, http_url, headers_dict, json | |
302 | ) | |
303 | else: | |
304 | data = None | |
305 | self.sign_request(destination, method, http_url, headers_dict) | |
306 | ||
307 | logger.info( | |
308 | "{%s} [%s] Sending request: %s %s", | |
309 | request.txn_id, destination, method, url | |
310 | ) | |
311 | ||
312 | if data: | |
313 | 305 | producer = FileBodyProducer( |
314 | 306 | BytesIO(data), |
315 | cooperator=self._cooperator | |
307 | cooperator=self._cooperator, | |
316 | 308 | ) |
317 | 309 | else: |
318 | 310 | producer = None |
319 | ||
320 | request_deferred = treq.request( | |
321 | method, | |
322 | url, | |
311 | self.sign_request( | |
312 | destination_bytes, method_bytes, url_to_sign_bytes, | |
313 | headers_dict, | |
314 | ) | |
315 | ||
316 | logger.info( | |
317 | "{%s} [%s] Sending request: %s %s", | |
318 | request.txn_id, request.destination, request.method, | |
319 | url_str, | |
320 | ) | |
321 | ||
322 | # we don't want all the fancy cookie and redirect handling that | |
323 | # treq.request gives: just use the raw Agent. | |
324 | request_deferred = self.agent.request( | |
325 | method_bytes, | |
326 | url_bytes, | |
323 | 327 | headers=Headers(headers_dict), |
324 | data=producer, | |
325 | agent=self.agent, | |
326 | reactor=self.hs.get_reactor(), | |
327 | unbuffered=True | |
328 | bodyProducer=producer, | |
328 | 329 | ) |
329 | 330 | |
330 | 331 | request_deferred = timeout_deferred( |
343 | 344 | logger.warn( |
344 | 345 | "{%s} [%s] Request failed: %s %s: %s", |
345 | 346 | request.txn_id, |
346 | destination, | |
347 | method, | |
348 | url, | |
347 | request.destination, | |
348 | request.method, | |
349 | url_str, | |
349 | 350 | _flatten_response_never_received(e), |
350 | 351 | ) |
351 | 352 | |
365 | 366 | logger.debug( |
366 | 367 | "{%s} [%s] Waiting %ss before re-sending...", |
367 | 368 | request.txn_id, |
368 | destination, | |
369 | request.destination, | |
369 | 370 | delay, |
370 | 371 | ) |
371 | 372 | |
377 | 378 | logger.info( |
378 | 379 | "{%s} [%s] Got response headers: %d %s", |
379 | 380 | request.txn_id, |
380 | destination, | |
381 | request.destination, | |
381 | 382 | response.code, |
382 | 383 | response.phrase.decode('ascii', errors='replace'), |
383 | 384 | ) |
410 | 411 | destination_is must be non-None. |
411 | 412 | method (bytes): The HTTP method of the request |
412 | 413 | url_bytes (bytes): The URI path of the request |
413 | headers_dict (dict): Dictionary of request headers to append to | |
414 | content (bytes): The body of the request | |
414 | headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to | |
415 | append to | |
416 | content (object): The body of the request | |
415 | 417 | destination_is (bytes): As 'destination', but if the destination is an |
416 | 418 | identity server |
417 | 419 | |
477 | 479 | Deferred: Succeeds when we get a 2xx HTTP response. The result |
478 | 480 | will be the decoded JSON body. |
479 | 481 | |
480 | Fails with ``HTTPRequestException`` if we get an HTTP response | |
482 | Fails with ``HttpResponseException`` if we get an HTTP response | |
481 | 483 | code >= 300. |
482 | 484 | |
483 | 485 | Fails with ``NotRetryingDestination`` if we are not yet ready |
531 | 533 | Deferred: Succeeds when we get a 2xx HTTP response. The result |
532 | 534 | will be the decoded JSON body. |
533 | 535 | |
534 | Fails with ``HTTPRequestException`` if we get an HTTP response | |
536 | Fails with ``HttpResponseException`` if we get an HTTP response | |
535 | 537 | code >= 300. |
536 | 538 | |
537 | 539 | Fails with ``NotRetryingDestination`` if we are not yet ready |
586 | 588 | Deferred: Succeeds when we get a 2xx HTTP response. The result |
587 | 589 | will be the decoded JSON body. |
588 | 590 | |
589 | Fails with ``HTTPRequestException`` if we get an HTTP response | |
591 | Fails with ``HttpResponseException`` if we get an HTTP response | |
590 | 592 | code >= 300. |
591 | 593 | |
592 | 594 | Fails with ``NotRetryingDestination`` if we are not yet ready |
637 | 639 | Deferred: Succeeds when we get a 2xx HTTP response. The result |
638 | 640 | will be the decoded JSON body. |
639 | 641 | |
640 | Fails with ``HTTPRequestException`` if we get an HTTP response | |
642 | Fails with ``HttpResponseException`` if we get an HTTP response | |
641 | 643 | code >= 300. |
642 | 644 | |
643 | 645 | Fails with ``NotRetryingDestination`` if we are not yet ready |
681 | 683 | Deferred: resolves with an (int,dict) tuple of the file length and |
682 | 684 | a dict of the response headers. |
683 | 685 | |
684 | Fails with ``HTTPRequestException`` if we get an HTTP response code | |
686 | Fails with ``HttpResponseException`` if we get an HTTP response code | |
685 | 687 | >= 300 |
686 | 688 | |
687 | 689 | Fails with ``NotRetryingDestination`` if we are not yet ready |
38 | 38 | ) |
39 | 39 | |
40 | 40 | response_timer = Histogram( |
41 | "synapse_http_server_response_time_seconds", "sec", | |
41 | "synapse_http_server_response_time_seconds", | |
42 | "sec", | |
42 | 43 | ["method", "servlet", "tag", "code"], |
43 | 44 | ) |
44 | 45 | |
78 | 79 | # than when the response was written. |
79 | 80 | |
80 | 81 | in_flight_requests_ru_utime = Counter( |
81 | "synapse_http_server_in_flight_requests_ru_utime_seconds", | |
82 | "", | |
83 | ["method", "servlet"], | |
82 | "synapse_http_server_in_flight_requests_ru_utime_seconds", "", ["method", "servlet"] | |
84 | 83 | ) |
85 | 84 | |
86 | 85 | in_flight_requests_ru_stime = Counter( |
87 | "synapse_http_server_in_flight_requests_ru_stime_seconds", | |
88 | "", | |
89 | ["method", "servlet"], | |
86 | "synapse_http_server_in_flight_requests_ru_stime_seconds", "", ["method", "servlet"] | |
90 | 87 | ) |
91 | 88 | |
92 | 89 | in_flight_requests_db_txn_count = Counter( |
133 | 130 | # type |
134 | 131 | counts = {} |
135 | 132 | for rm in reqs: |
136 | key = (rm.method, rm.name,) | |
133 | key = (rm.method, rm.name) | |
137 | 134 | counts[key] = counts.get(key, 0) + 1 |
138 | 135 | |
139 | 136 | return counts |
174 | 171 | if context != self.start_context: |
175 | 172 | logger.warn( |
176 | 173 | "Context have unexpectedly changed %r, %r", |
177 | context, self.start_context | |
174 | context, | |
175 | self.start_context, | |
178 | 176 | ) |
179 | 177 | return |
180 | 178 | |
191 | 189 | resource_usage = context.get_resource_usage() |
192 | 190 | |
193 | 191 | response_ru_utime.labels(self.method, self.name, tag).inc( |
194 | resource_usage.ru_utime, | |
192 | resource_usage.ru_utime | |
195 | 193 | ) |
196 | 194 | response_ru_stime.labels(self.method, self.name, tag).inc( |
197 | resource_usage.ru_stime, | |
195 | resource_usage.ru_stime | |
198 | 196 | ) |
199 | 197 | response_db_txn_count.labels(self.method, self.name, tag).inc( |
200 | 198 | resource_usage.db_txn_count |
221 | 219 | diff = new_stats - self._request_stats |
222 | 220 | self._request_stats = new_stats |
223 | 221 | |
224 | in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime) | |
225 | in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime) | |
222 | # max() is used since rapid use of ru_stime/ru_utime can end up with the | |
223 | # count going backwards due to NTP, time smearing, fine-grained | |
224 | # correction, or floating points. Who knows, really? | |
225 | in_flight_requests_ru_utime.labels(self.method, self.name).inc( | |
226 | max(diff.ru_utime, 0) | |
227 | ) | |
228 | in_flight_requests_ru_stime.labels(self.method, self.name).inc( | |
229 | max(diff.ru_stime, 0) | |
230 | ) | |
226 | 231 | |
227 | 232 | in_flight_requests_db_txn_count.labels(self.method, self.name).inc( |
228 | 233 | diff.db_txn_count |
185 | 185 | def count_listeners(): |
186 | 186 | all_user_streams = set() |
187 | 187 | |
188 | for x in self.room_to_user_streams.values(): | |
188 | for x in list(self.room_to_user_streams.values()): | |
189 | 189 | all_user_streams |= x |
190 | for x in self.user_to_user_stream.values(): | |
190 | for x in list(self.user_to_user_stream.values()): | |
191 | 191 | all_user_streams.add(x) |
192 | 192 | |
193 | 193 | return sum(stream.count_listeners() for stream in all_user_streams) |
195 | 195 | |
196 | 196 | LaterGauge( |
197 | 197 | "synapse_notifier_rooms", "", [], |
198 | lambda: count(bool, self.room_to_user_streams.values()), | |
198 | lambda: count(bool, list(self.room_to_user_streams.values())), | |
199 | 199 | ) |
200 | 200 | LaterGauge( |
201 | 201 | "synapse_notifier_users", "", [], |
17 | 17 | from twisted.internet import defer |
18 | 18 | from twisted.internet.error import AlreadyCalled, AlreadyCancelled |
19 | 19 | |
20 | from synapse.util.logcontext import LoggingContext | |
21 | from synapse.util.metrics import Measure | |
20 | from synapse.metrics.background_process_metrics import run_as_background_process | |
22 | 21 | |
23 | 22 | logger = logging.getLogger(__name__) |
24 | 23 | |
70 | 69 | # See httppusher |
71 | 70 | self.max_stream_ordering = None |
72 | 71 | |
73 | self.processing = False | |
74 | ||
75 | @defer.inlineCallbacks | |
72 | self._is_processing = False | |
73 | ||
76 | 74 | def on_started(self): |
77 | 75 | if self.mailer is not None: |
78 | try: | |
79 | self.throttle_params = yield self.store.get_throttle_params_by_room( | |
80 | self.pusher_id | |
81 | ) | |
82 | yield self._process() | |
83 | except Exception: | |
84 | logger.exception("Error starting email pusher") | |
76 | self._start_processing() | |
85 | 77 | |
86 | 78 | def on_stop(self): |
87 | 79 | if self.timed_call: |
91 | 83 | pass |
92 | 84 | self.timed_call = None |
93 | 85 | |
94 | @defer.inlineCallbacks | |
95 | 86 | def on_new_notifications(self, min_stream_ordering, max_stream_ordering): |
96 | 87 | self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) |
97 | yield self._process() | |
88 | self._start_processing() | |
98 | 89 | |
99 | 90 | def on_new_receipts(self, min_stream_id, max_stream_id): |
100 | 91 | # We could wake up and cancel the timer but there tend to be quite a |
101 | 92 | # lot of read receipts so it's probably less work to just let the |
102 | 93 | # timer fire |
103 | return defer.succeed(None) | |
104 | ||
105 | @defer.inlineCallbacks | |
94 | pass | |
95 | ||
106 | 96 | def on_timer(self): |
107 | 97 | self.timed_call = None |
108 | yield self._process() | |
98 | self._start_processing() | |
99 | ||
100 | def _start_processing(self): | |
101 | if self._is_processing: | |
102 | return | |
103 | ||
104 | run_as_background_process("emailpush.process", self._process) | |
109 | 105 | |
110 | 106 | @defer.inlineCallbacks |
111 | 107 | def _process(self): |
112 | if self.processing: | |
113 | return | |
114 | ||
115 | with LoggingContext("emailpush._process"): | |
116 | with Measure(self.clock, "emailpush._process"): | |
108 | # we should never get here if we are already processing | |
109 | assert not self._is_processing | |
110 | ||
111 | try: | |
112 | self._is_processing = True | |
113 | ||
114 | if self.throttle_params is None: | |
115 | # this is our first loop: load up the throttle params | |
116 | self.throttle_params = yield self.store.get_throttle_params_by_room( | |
117 | self.pusher_id | |
118 | ) | |
119 | ||
120 | # if the max ordering changes while we're running _unsafe_process, | |
121 | # call it again, and so on until we've caught up. | |
122 | while True: | |
123 | starting_max_ordering = self.max_stream_ordering | |
117 | 124 | try: |
118 | self.processing = True | |
119 | # if the max ordering changes while we're running _unsafe_process, | |
120 | # call it again, and so on until we've caught up. | |
121 | while True: | |
122 | starting_max_ordering = self.max_stream_ordering | |
123 | try: | |
124 | yield self._unsafe_process() | |
125 | except Exception: | |
126 | logger.exception("Exception processing notifs") | |
127 | if self.max_stream_ordering == starting_max_ordering: | |
128 | break | |
129 | finally: | |
130 | self.processing = False | |
125 | yield self._unsafe_process() | |
126 | except Exception: | |
127 | logger.exception("Exception processing notifs") | |
128 | if self.max_stream_ordering == starting_max_ordering: | |
129 | break | |
130 | finally: | |
131 | self._is_processing = False | |
131 | 132 | |
132 | 133 | @defer.inlineCallbacks |
133 | 134 | def _unsafe_process(self): |
21 | 21 | from twisted.internet import defer |
22 | 22 | from twisted.internet.error import AlreadyCalled, AlreadyCancelled |
23 | 23 | |
24 | from synapse.metrics.background_process_metrics import run_as_background_process | |
24 | 25 | from synapse.push import PusherConfigException |
25 | from synapse.util.logcontext import LoggingContext | |
26 | from synapse.util.metrics import Measure | |
27 | 26 | |
28 | 27 | from . import push_rule_evaluator, push_tools |
29 | 28 | |
60 | 59 | self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC |
61 | 60 | self.failing_since = pusherdict['failing_since'] |
62 | 61 | self.timed_call = None |
63 | self.processing = False | |
62 | self._is_processing = False | |
64 | 63 | |
65 | 64 | # This is the highest stream ordering we know it's safe to process. |
66 | 65 | # When new events arrive, we'll be given a window of new events: we |
91 | 90 | self.data_minus_url.update(self.data) |
92 | 91 | del self.data_minus_url['url'] |
93 | 92 | |
94 | @defer.inlineCallbacks | |
95 | 93 | def on_started(self): |
96 | try: | |
97 | yield self._process() | |
98 | except Exception: | |
99 | logger.exception("Error starting http pusher") | |
100 | ||
101 | @defer.inlineCallbacks | |
94 | self._start_processing() | |
95 | ||
102 | 96 | def on_new_notifications(self, min_stream_ordering, max_stream_ordering): |
103 | 97 | self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) |
104 | yield self._process() | |
105 | ||
106 | @defer.inlineCallbacks | |
98 | self._start_processing() | |
99 | ||
107 | 100 | def on_new_receipts(self, min_stream_id, max_stream_id): |
108 | 101 | # Note that the min here shouldn't be relied upon to be accurate. |
109 | 102 | |
110 | 103 | # We could check the receipts are actually m.read receipts here, |
111 | 104 | # but currently that's the only type of receipt anyway... |
112 | with LoggingContext("push.on_new_receipts"): | |
113 | with Measure(self.clock, "push.on_new_receipts"): | |
114 | badge = yield push_tools.get_badge_count( | |
115 | self.hs.get_datastore(), self.user_id | |
116 | ) | |
117 | yield self._send_badge(badge) | |
118 | ||
119 | @defer.inlineCallbacks | |
105 | run_as_background_process("http_pusher.on_new_receipts", self._update_badge) | |
106 | ||
107 | @defer.inlineCallbacks | |
108 | def _update_badge(self): | |
109 | badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) | |
110 | yield self._send_badge(badge) | |
111 | ||
120 | 112 | def on_timer(self): |
121 | yield self._process() | |
113 | self._start_processing() | |
122 | 114 | |
123 | 115 | def on_stop(self): |
124 | 116 | if self.timed_call: |
128 | 120 | pass |
129 | 121 | self.timed_call = None |
130 | 122 | |
123 | def _start_processing(self): | |
124 | if self._is_processing: | |
125 | return | |
126 | ||
127 | run_as_background_process("httppush.process", self._process) | |
128 | ||
131 | 129 | @defer.inlineCallbacks |
132 | 130 | def _process(self): |
133 | if self.processing: | |
134 | return | |
135 | ||
136 | with LoggingContext("push._process"): | |
137 | with Measure(self.clock, "push._process"): | |
131 | # we should never get here if we are already processing | |
132 | assert not self._is_processing | |
133 | ||
134 | try: | |
135 | self._is_processing = True | |
136 | # if the max ordering changes while we're running _unsafe_process, | |
137 | # call it again, and so on until we've caught up. | |
138 | while True: | |
139 | starting_max_ordering = self.max_stream_ordering | |
138 | 140 | try: |
139 | self.processing = True | |
140 | # if the max ordering changes while we're running _unsafe_process, | |
141 | # call it again, and so on until we've caught up. | |
142 | while True: | |
143 | starting_max_ordering = self.max_stream_ordering | |
144 | try: | |
145 | yield self._unsafe_process() | |
146 | except Exception: | |
147 | logger.exception("Exception processing notifs") | |
148 | if self.max_stream_ordering == starting_max_ordering: | |
149 | break | |
150 | finally: | |
151 | self.processing = False | |
141 | yield self._unsafe_process() | |
142 | except Exception: | |
143 | logger.exception("Exception processing notifs") | |
144 | if self.max_stream_ordering == starting_max_ordering: | |
145 | break | |
146 | finally: | |
147 | self._is_processing = False | |
152 | 148 | |
153 | 149 | @defer.inlineCallbacks |
154 | 150 | def _unsafe_process(self): |
525 | 525 | Returns: |
526 | 526 | (notif_template_html, notif_template_text) |
527 | 527 | """ |
528 | logger.info("loading jinja2") | |
529 | ||
530 | if config.email_template_dir: | |
531 | loader = jinja2.FileSystemLoader(config.email_template_dir) | |
532 | else: | |
533 | loader = jinja2.PackageLoader('synapse', 'res/templates') | |
528 | logger.info("loading email templates from '%s'", config.email_template_dir) | |
529 | loader = jinja2.FileSystemLoader(config.email_template_dir) | |
534 | 530 | env = jinja2.Environment(loader=loader) |
535 | 531 | env.filters["format_ts"] = format_ts_filter |
536 | 532 | env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config) |
19 | 19 | |
20 | 20 | from synapse.metrics.background_process_metrics import run_as_background_process |
21 | 21 | from synapse.push.pusher import PusherFactory |
22 | from synapse.util.logcontext import make_deferred_yieldable, run_in_background | |
23 | 22 | |
24 | 23 | logger = logging.getLogger(__name__) |
25 | 24 | |
26 | 25 | |
27 | 26 | class PusherPool: |
27 | """ | |
28 | The pusher pool. This is responsible for dispatching notifications of new events to | |
29 | the http and email pushers. | |
30 | ||
31 | It provides three methods which are designed to be called by the rest of the | |
32 | application: `start`, `on_new_notifications`, and `on_new_receipts`: each of these | |
33 | delegates to each of the relevant pushers. | |
34 | ||
35 | Note that it is expected that each pusher will have its own 'processing' loop which | |
36 | will send out the notifications in the background, rather than blocking until the | |
37 | notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and | |
38 | Pusher.on_new_receipts are not expected to return deferreds. | |
39 | """ | |
28 | 40 | def __init__(self, _hs): |
29 | 41 | self.hs = _hs |
30 | 42 | self.pusher_factory = PusherFactory(_hs) |
31 | self.start_pushers = _hs.config.start_pushers | |
43 | self._should_start_pushers = _hs.config.start_pushers | |
32 | 44 | self.store = self.hs.get_datastore() |
33 | 45 | self.clock = self.hs.get_clock() |
34 | 46 | self.pushers = {} |
35 | 47 | |
36 | @defer.inlineCallbacks | |
37 | 48 | def start(self): |
38 | pushers = yield self.store.get_all_pushers() | |
39 | self._start_pushers(pushers) | |
49 | """Starts the pushers off in a background process. | |
50 | """ | |
51 | if not self._should_start_pushers: | |
52 | logger.info("Not starting pushers because they are disabled in the config") | |
53 | return | |
54 | run_as_background_process("start_pushers", self._start_pushers) | |
40 | 55 | |
41 | 56 | @defer.inlineCallbacks |
42 | 57 | def add_pusher(self, user_id, access_token, kind, app_id, |
85 | 100 | last_stream_ordering=last_stream_ordering, |
86 | 101 | profile_tag=profile_tag, |
87 | 102 | ) |
88 | yield self._refresh_pusher(app_id, pushkey, user_id) | |
103 | yield self.start_pusher_by_id(app_id, pushkey, user_id) | |
89 | 104 | |
90 | 105 | @defer.inlineCallbacks |
91 | 106 | def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, |
122 | 137 | p['app_id'], p['pushkey'], p['user_name'], |
123 | 138 | ) |
124 | 139 | |
140 | @defer.inlineCallbacks | |
125 | 141 | def on_new_notifications(self, min_stream_id, max_stream_id): |
126 | run_as_background_process( | |
127 | "on_new_notifications", | |
128 | self._on_new_notifications, min_stream_id, max_stream_id, | |
129 | ) | |
130 | ||
131 | @defer.inlineCallbacks | |
132 | def _on_new_notifications(self, min_stream_id, max_stream_id): | |
133 | 142 | try: |
134 | 143 | users_affected = yield self.store.get_push_action_users_in_range( |
135 | 144 | min_stream_id, max_stream_id |
136 | 145 | ) |
137 | 146 | |
138 | deferreds = [] | |
139 | ||
140 | 147 | for u in users_affected: |
141 | 148 | if u in self.pushers: |
142 | 149 | for p in self.pushers[u].values(): |
143 | deferreds.append( | |
144 | run_in_background( | |
145 | p.on_new_notifications, | |
146 | min_stream_id, max_stream_id, | |
147 | ) | |
148 | ) | |
149 | ||
150 | yield make_deferred_yieldable( | |
151 | defer.gatherResults(deferreds, consumeErrors=True), | |
152 | ) | |
150 | p.on_new_notifications(min_stream_id, max_stream_id) | |
151 | ||
153 | 152 | except Exception: |
154 | 153 | logger.exception("Exception in pusher on_new_notifications") |
155 | 154 | |
155 | @defer.inlineCallbacks | |
156 | 156 | def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): |
157 | run_as_background_process( | |
158 | "on_new_receipts", | |
159 | self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids, | |
160 | ) | |
161 | ||
162 | @defer.inlineCallbacks | |
163 | def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): | |
164 | 157 | try: |
165 | 158 | # Need to subtract 1 from the minimum because the lower bound here |
166 | 159 | # is not inclusive |
170 | 163 | # This returns a tuple, user_id is at index 3 |
171 | 164 | users_affected = set([r[3] for r in updated_receipts]) |
172 | 165 | |
173 | deferreds = [] | |
174 | ||
175 | 166 | for u in users_affected: |
176 | 167 | if u in self.pushers: |
177 | 168 | for p in self.pushers[u].values(): |
178 | deferreds.append( | |
179 | run_in_background( | |
180 | p.on_new_receipts, | |
181 | min_stream_id, max_stream_id, | |
182 | ) | |
183 | ) | |
184 | ||
185 | yield make_deferred_yieldable( | |
186 | defer.gatherResults(deferreds, consumeErrors=True), | |
187 | ) | |
169 | p.on_new_receipts(min_stream_id, max_stream_id) | |
170 | ||
188 | 171 | except Exception: |
189 | 172 | logger.exception("Exception in pusher on_new_receipts") |
190 | 173 | |
191 | 174 | @defer.inlineCallbacks |
192 | def _refresh_pusher(self, app_id, pushkey, user_id): | |
175 | def start_pusher_by_id(self, app_id, pushkey, user_id): | |
176 | """Look up the details for the given pusher, and start it""" | |
177 | if not self._should_start_pushers: | |
178 | return | |
179 | ||
193 | 180 | resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( |
194 | 181 | app_id, pushkey |
195 | 182 | ) |
200 | 187 | p = r |
201 | 188 | |
202 | 189 | if p: |
203 | ||
204 | self._start_pushers([p]) | |
205 | ||
206 | def _start_pushers(self, pushers): | |
207 | if not self.start_pushers: | |
208 | logger.info("Not starting pushers because they are disabled in the config") | |
209 | return | |
190 | self._start_pusher(p) | |
191 | ||
192 | @defer.inlineCallbacks | |
193 | def _start_pushers(self): | |
194 | """Start all the pushers | |
195 | ||
196 | Returns: | |
197 | Deferred | |
198 | """ | |
199 | pushers = yield self.store.get_all_pushers() | |
210 | 200 | logger.info("Starting %d pushers", len(pushers)) |
211 | 201 | for pusherdict in pushers: |
212 | try: | |
213 | p = self.pusher_factory.create_pusher(pusherdict) | |
214 | except Exception: | |
215 | logger.exception("Couldn't start a pusher: caught Exception") | |
216 | continue | |
217 | if p: | |
218 | appid_pushkey = "%s:%s" % ( | |
219 | pusherdict['app_id'], | |
220 | pusherdict['pushkey'], | |
221 | ) | |
222 | byuser = self.pushers.setdefault(pusherdict['user_name'], {}) | |
223 | ||
224 | if appid_pushkey in byuser: | |
225 | byuser[appid_pushkey].on_stop() | |
226 | byuser[appid_pushkey] = p | |
227 | run_in_background(p.on_started) | |
228 | ||
202 | self._start_pusher(pusherdict) | |
229 | 203 | logger.info("Started pushers") |
204 | ||
205 | def _start_pusher(self, pusherdict): | |
206 | """Start the given pusher | |
207 | ||
208 | Args: | |
209 | pusherdict (dict): | |
210 | ||
211 | Returns: | |
212 | None | |
213 | """ | |
214 | try: | |
215 | p = self.pusher_factory.create_pusher(pusherdict) | |
216 | except Exception: | |
217 | logger.exception("Couldn't start a pusher: caught Exception") | |
218 | return | |
219 | ||
220 | if not p: | |
221 | return | |
222 | ||
223 | appid_pushkey = "%s:%s" % ( | |
224 | pusherdict['app_id'], | |
225 | pusherdict['pushkey'], | |
226 | ) | |
227 | byuser = self.pushers.setdefault(pusherdict['user_name'], {}) | |
228 | ||
229 | if appid_pushkey in byuser: | |
230 | byuser[appid_pushkey].on_stop() | |
231 | byuser[appid_pushkey] = p | |
232 | p.on_started() | |
230 | 233 | |
231 | 234 | @defer.inlineCallbacks |
232 | 235 | def remove_pusher(self, app_id, pushkey, user_id): |
52 | 52 | "pillow>=3.1.2": ["PIL"], |
53 | 53 | "pydenticon>=0.2": ["pydenticon"], |
54 | 54 | "sortedcontainers>=1.4.4": ["sortedcontainers"], |
55 | "psutil>=2.0.0": ["psutil>=2.0.0"], | |
55 | 56 | "pysaml2>=3.0.0": ["saml2"], |
56 | 57 | "pymacaroons-pynacl>=0.9.3": ["pymacaroons"], |
57 | "msgpack-python>=0.3.0": ["msgpack"], | |
58 | "msgpack-python>=0.4.2": ["msgpack"], | |
58 | 59 | "phonenumbers>=8.2.0": ["phonenumbers"], |
59 | 60 | "six>=1.10": ["six"], |
60 | 61 | |
77 | 78 | }, |
78 | 79 | "matrix-synapse-ldap3": { |
79 | 80 | "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], |
80 | }, | |
81 | "psutil": { | |
82 | "psutil>=2.0.0": ["psutil>=2.0.0"], | |
83 | 81 | }, |
84 | 82 | "postgres": { |
85 | 83 | "psycopg2>=2.6": ["psycopg2"] |
73 | 73 | if room is None: |
74 | 74 | raise SynapseError(400, "Room does not exist") |
75 | 75 | |
76 | dir_handler = self.handlers.directory_handler | |
76 | requester = yield self.auth.get_user_by_req(request) | |
77 | 77 | |
78 | try: | |
79 | # try to auth as a user | |
80 | requester = yield self.auth.get_user_by_req(request) | |
81 | try: | |
82 | user_id = requester.user.to_string() | |
83 | yield dir_handler.create_association( | |
84 | user_id, room_alias, room_id, servers | |
85 | ) | |
86 | yield dir_handler.send_room_alias_update_event( | |
87 | requester, | |
88 | user_id, | |
89 | room_id | |
90 | ) | |
91 | except SynapseError as e: | |
92 | raise e | |
93 | except Exception: | |
94 | logger.exception("Failed to create association") | |
95 | raise | |
96 | except AuthError: | |
97 | # try to auth as an application service | |
98 | service = yield self.auth.get_appservice_by_req(request) | |
99 | yield dir_handler.create_appservice_association( | |
100 | service, room_alias, room_id, servers | |
101 | ) | |
102 | logger.info( | |
103 | "Application service at %s created alias %s pointing to %s", | |
104 | service.url, | |
105 | room_alias.to_string(), | |
106 | room_id | |
107 | ) | |
78 | yield self.handlers.directory_handler.create_association( | |
79 | requester, room_alias, room_id, servers | |
80 | ) | |
108 | 81 | |
109 | 82 | defer.returnValue((200, {})) |
110 | 83 | |
134 | 107 | room_alias = RoomAlias.from_string(room_alias) |
135 | 108 | |
136 | 109 | yield dir_handler.delete_association( |
137 | requester, user.to_string(), room_alias | |
110 | requester, room_alias | |
138 | 111 | ) |
139 | 112 | |
140 | 113 | logger.info( |
32 | 32 | parse_json_object_from_request, |
33 | 33 | parse_string, |
34 | 34 | ) |
35 | from synapse.storage.state import StateFilter | |
35 | 36 | from synapse.streams.config import PaginationConfig |
36 | 37 | from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID |
37 | 38 | |
408 | 409 | room_id=room_id, |
409 | 410 | user_id=requester.user.to_string(), |
410 | 411 | at_token=at_token, |
411 | types=[(EventTypes.Member, None)], | |
412 | state_filter=StateFilter.from_types([(EventTypes.Member, None)]), | |
412 | 413 | ) |
413 | 414 | |
414 | 415 | chunk = [] |
98 | 98 | cannot be handled in the normal flow (with requests to the same endpoint). |
99 | 99 | Current use is for web fallback auth. |
100 | 100 | """ |
101 | PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web") | |
101 | PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") | |
102 | 102 | |
103 | 103 | def __init__(self, hs): |
104 | 104 | super(AuthRestServlet, self).__init__() |
24 | 24 | |
25 | 25 | import twisted.internet.error |
26 | 26 | import twisted.web.http |
27 | from twisted.internet import defer, threads | |
27 | from twisted.internet import defer | |
28 | 28 | from twisted.web.resource import Resource |
29 | 29 | |
30 | 30 | from synapse.api.errors import ( |
35 | 35 | ) |
36 | 36 | from synapse.http.matrixfederationclient import MatrixFederationHttpClient |
37 | 37 | from synapse.metrics.background_process_metrics import run_as_background_process |
38 | from synapse.util import logcontext | |
38 | 39 | from synapse.util.async_helpers import Linearizer |
39 | from synapse.util.logcontext import make_deferred_yieldable | |
40 | 40 | from synapse.util.retryutils import NotRetryingDestination |
41 | 41 | from synapse.util.stringutils import is_ascii, random_string |
42 | 42 | |
491 | 491 | )) |
492 | 492 | |
493 | 493 | thumbnailer = Thumbnailer(input_path) |
494 | t_byte_source = yield make_deferred_yieldable(threads.deferToThread( | |
494 | t_byte_source = yield logcontext.defer_to_thread( | |
495 | self.hs.get_reactor(), | |
495 | 496 | self._generate_thumbnail, |
496 | 497 | thumbnailer, t_width, t_height, t_method, t_type |
497 | )) | |
498 | ) | |
498 | 499 | |
499 | 500 | if t_byte_source: |
500 | 501 | try: |
533 | 534 | )) |
534 | 535 | |
535 | 536 | thumbnailer = Thumbnailer(input_path) |
536 | t_byte_source = yield make_deferred_yieldable(threads.deferToThread( | |
537 | t_byte_source = yield logcontext.defer_to_thread( | |
538 | self.hs.get_reactor(), | |
537 | 539 | self._generate_thumbnail, |
538 | 540 | thumbnailer, t_width, t_height, t_method, t_type |
539 | )) | |
541 | ) | |
540 | 542 | |
541 | 543 | if t_byte_source: |
542 | 544 | try: |
619 | 621 | for (t_width, t_height, t_type), t_method in iteritems(thumbnails): |
620 | 622 | # Generate the thumbnail |
621 | 623 | if t_method == "crop": |
622 | t_byte_source = yield make_deferred_yieldable(threads.deferToThread( | |
624 | t_byte_source = yield logcontext.defer_to_thread( | |
625 | self.hs.get_reactor(), | |
623 | 626 | thumbnailer.crop, |
624 | 627 | t_width, t_height, t_type, |
625 | )) | |
628 | ) | |
626 | 629 | elif t_method == "scale": |
627 | t_byte_source = yield make_deferred_yieldable(threads.deferToThread( | |
630 | t_byte_source = yield logcontext.defer_to_thread( | |
631 | self.hs.get_reactor(), | |
628 | 632 | thumbnailer.scale, |
629 | 633 | t_width, t_height, t_type, |
630 | )) | |
634 | ) | |
631 | 635 | else: |
632 | 636 | logger.error("Unrecognized method: %r", t_method) |
633 | 637 | continue |
20 | 20 | |
21 | 21 | import six |
22 | 22 | |
23 | from twisted.internet import defer, threads | |
23 | from twisted.internet import defer | |
24 | 24 | from twisted.protocols.basic import FileSender |
25 | 25 | |
26 | from synapse.util import logcontext | |
26 | 27 | from synapse.util.file_consumer import BackgroundFileConsumer |
27 | 28 | from synapse.util.logcontext import make_deferred_yieldable |
28 | 29 | |
63 | 64 | |
64 | 65 | with self.store_into_file(file_info) as (f, fname, finish_cb): |
65 | 66 | # Write to the main repository |
66 | yield make_deferred_yieldable(threads.deferToThread( | |
67 | yield logcontext.defer_to_thread( | |
68 | self.hs.get_reactor(), | |
67 | 69 | _write_file_synchronously, source, f, |
68 | )) | |
70 | ) | |
69 | 71 | yield finish_cb() |
70 | 72 | |
71 | 73 | defer.returnValue(fname) |
595 | 595 | # to be returned. |
596 | 596 | elements = iter([tree]) |
597 | 597 | while True: |
598 | el = next(elements) | |
598 | el = next(elements, None) | |
599 | if el is None: | |
600 | return | |
601 | ||
599 | 602 | if isinstance(el, string_types): |
600 | 603 | yield el |
601 | elif el is not None and el.tag not in tags_to_ignore: | |
604 | elif el.tag not in tags_to_ignore: | |
602 | 605 | # el.text is the text before the first child, so we can immediately |
603 | 606 | # return it if the text exists. |
604 | 607 | if el.text: |
670 | 673 | # This splits the paragraph into words, but keeping the |
671 | 674 | # (preceeding) whitespace intact so we can easily concat |
672 | 675 | # words back together. |
673 | for match in re.finditer("\s*\S+", description): | |
676 | for match in re.finditer(r"\s*\S+", description): | |
674 | 677 | word = match.group() |
675 | 678 | |
676 | 679 | # Keep adding words while the total length is less than |
16 | 16 | import os |
17 | 17 | import shutil |
18 | 18 | |
19 | from twisted.internet import defer, threads | |
19 | from twisted.internet import defer | |
20 | 20 | |
21 | 21 | from synapse.config._base import Config |
22 | from synapse.util import logcontext | |
22 | 23 | from synapse.util.logcontext import run_in_background |
23 | 24 | |
24 | 25 | from .media_storage import FileResponder |
119 | 120 | if not os.path.exists(dirname): |
120 | 121 | os.makedirs(dirname) |
121 | 122 | |
122 | return threads.deferToThread( | |
123 | return logcontext.defer_to_thread( | |
124 | self.hs.get_reactor(), | |
123 | 125 | shutil.copyfile, primary_fname, backup_fname, |
124 | 126 | ) |
125 | 127 |
206 | 206 | logger.info("Setting up.") |
207 | 207 | with self.get_db_conn() as conn: |
208 | 208 | self.datastore = self.DATASTORE_CLASS(conn, self) |
209 | conn.commit() | |
209 | 210 | logger.info("Finished setting up.") |
210 | 211 | |
211 | 212 | def get_reactor(self): |
18 | 18 | |
19 | 19 | from six import iteritems, itervalues |
20 | 20 | |
21 | import attr | |
21 | 22 | from frozendict import frozendict |
22 | 23 | |
23 | 24 | from twisted.internet import defer |
24 | 25 | |
25 | 26 | from synapse.api.constants import EventTypes, RoomVersions |
26 | 27 | from synapse.events.snapshot import EventContext |
27 | from synapse.state import v1 | |
28 | from synapse.state import v1, v2 | |
28 | 29 | from synapse.util.async_helpers import Linearizer |
29 | 30 | from synapse.util.caches import get_cache_factor_for |
30 | 31 | from synapse.util.caches.expiringcache import ExpiringCache |
371 | 372 | |
372 | 373 | result = yield self._state_resolution_handler.resolve_state_groups( |
373 | 374 | room_id, room_version, state_groups_ids, None, |
374 | self._state_map_factory, | |
375 | state_res_store=StateResolutionStore(self.store), | |
375 | 376 | ) |
376 | 377 | defer.returnValue(result) |
377 | ||
378 | def _state_map_factory(self, ev_ids): | |
379 | return self.store.get_events( | |
380 | ev_ids, get_prev_content=False, check_redacted=False, | |
381 | ) | |
382 | 378 | |
383 | 379 | @defer.inlineCallbacks |
384 | 380 | def resolve_events(self, room_version, state_sets, event): |
397 | 393 | } |
398 | 394 | |
399 | 395 | with Measure(self.clock, "state._resolve_events"): |
400 | new_state = yield resolve_events_with_factory( | |
396 | new_state = yield resolve_events_with_store( | |
401 | 397 | room_version, state_set_ids, |
402 | 398 | event_map=state_map, |
403 | state_map_factory=self._state_map_factory | |
399 | state_res_store=StateResolutionStore(self.store), | |
404 | 400 | ) |
405 | 401 | |
406 | 402 | new_state = { |
435 | 431 | @defer.inlineCallbacks |
436 | 432 | @log_function |
437 | 433 | def resolve_state_groups( |
438 | self, room_id, room_version, state_groups_ids, event_map, state_map_factory, | |
434 | self, room_id, room_version, state_groups_ids, event_map, state_res_store, | |
439 | 435 | ): |
440 | 436 | """Resolves conflicts between a set of state groups |
441 | 437 | |
453 | 449 | a dict from event_id to event, for any events that we happen to |
454 | 450 | have in flight (eg, those currently being persisted). This will be |
455 | 451 | used as a starting point fof finding the state we need; any missing |
456 | events will be requested via state_map_factory. | |
457 | ||
458 | If None, all events will be fetched via state_map_factory. | |
452 | events will be requested via state_res_store. | |
453 | ||
454 | If None, all events will be fetched via state_res_store. | |
455 | ||
456 | state_res_store (StateResolutionStore) | |
459 | 457 | |
460 | 458 | Returns: |
461 | 459 | Deferred[_StateCacheEntry]: resolved state |
479 | 477 | |
480 | 478 | # start by assuming we won't have any conflicted state, and build up the new |
481 | 479 | # state map by iterating through the state groups. If we discover a conflict, |
482 | # we give up and instead use `resolve_events_with_factory`. | |
480 | # we give up and instead use `resolve_events_with_store`. | |
483 | 481 | # |
484 | 482 | # XXX: is this actually worthwhile, or should we just let |
485 | # resolve_events_with_factory do it? | |
483 | # resolve_events_with_store do it? | |
486 | 484 | new_state = {} |
487 | 485 | conflicted_state = False |
488 | 486 | for st in itervalues(state_groups_ids): |
497 | 495 | if conflicted_state: |
498 | 496 | logger.info("Resolving conflicted state for %r", room_id) |
499 | 497 | with Measure(self.clock, "state._resolve_events"): |
500 | new_state = yield resolve_events_with_factory( | |
498 | new_state = yield resolve_events_with_store( | |
501 | 499 | room_version, |
502 | 500 | list(itervalues(state_groups_ids)), |
503 | 501 | event_map=event_map, |
504 | state_map_factory=state_map_factory, | |
502 | state_res_store=state_res_store, | |
505 | 503 | ) |
506 | 504 | |
507 | 505 | # if the new state matches any of the input state groups, we can |
582 | 580 | ) |
583 | 581 | |
584 | 582 | |
585 | def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): | |
583 | def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): | |
586 | 584 | """ |
587 | 585 | Args: |
588 | 586 | room_version(str): Version of the room |
598 | 596 | |
599 | 597 | If None, all events will be fetched via state_map_factory. |
600 | 598 | |
601 | state_map_factory(func): will be called | |
602 | with a list of event_ids that are needed, and should return with | |
603 | a Deferred of dict of event_id to event. | |
599 | state_res_store (StateResolutionStore) | |
604 | 600 | |
605 | 601 | Returns |
606 | 602 | Deferred[dict[(str, str), str]]: |
607 | 603 | a map from (type, state_key) to event_id. |
608 | 604 | """ |
609 | if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): | |
610 | return v1.resolve_events_with_factory( | |
611 | state_sets, event_map, state_map_factory, | |
605 | if room_version == RoomVersions.V1: | |
606 | return v1.resolve_events_with_store( | |
607 | state_sets, event_map, state_res_store.get_events, | |
608 | ) | |
609 | elif room_version == RoomVersions.VDH_TEST: | |
610 | return v2.resolve_events_with_store( | |
611 | state_sets, event_map, state_res_store, | |
612 | 612 | ) |
613 | 613 | else: |
614 | 614 | # This should only happen if we added a version but forgot to add it to |
616 | 616 | raise Exception( |
617 | 617 | "No state resolution algorithm defined for version %r" % (room_version,) |
618 | 618 | ) |
619 | ||
620 | ||
621 | @attr.s | |
622 | class StateResolutionStore(object): | |
623 | """Interface that allows state resolution algorithms to access the database | |
624 | in well defined way. | |
625 | ||
626 | Args: | |
627 | store (DataStore) | |
628 | """ | |
629 | ||
630 | store = attr.ib() | |
631 | ||
632 | def get_events(self, event_ids, allow_rejected=False): | |
633 | """Get events from the database | |
634 | ||
635 | Args: | |
636 | event_ids (list): The event_ids of the events to fetch | |
637 | allow_rejected (bool): If True return rejected events. | |
638 | ||
639 | Returns: | |
640 | Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. | |
641 | """ | |
642 | ||
643 | return self.store.get_events( | |
644 | event_ids, | |
645 | check_redacted=False, | |
646 | get_prev_content=False, | |
647 | allow_rejected=allow_rejected, | |
648 | ) | |
649 | ||
650 | def get_auth_chain(self, event_ids): | |
651 | """Gets the full auth chain for a set of events (including rejected | |
652 | events). | |
653 | ||
654 | Includes the given event IDs in the result. | |
655 | ||
656 | Note that: | |
657 | 1. All events must be state events. | |
658 | 2. For v1 rooms this may not have the full auth chain in the | |
659 | presence of rejected events | |
660 | ||
661 | Args: | |
662 | event_ids (list): The event IDs of the events to fetch the auth | |
663 | chain for. Must be state events. | |
664 | ||
665 | Returns: | |
666 | Deferred[list[str]]: List of event IDs of the auth chain. | |
667 | """ | |
668 | ||
669 | return self.store.get_auth_chain_ids(event_ids, include_given=True) |
30 | 30 | |
31 | 31 | |
32 | 32 | @defer.inlineCallbacks |
33 | def resolve_events_with_factory(state_sets, event_map, state_map_factory): | |
33 | def resolve_events_with_store(state_sets, event_map, state_map_factory): | |
34 | 34 | """ |
35 | 35 | Args: |
36 | 36 | state_sets(list): List of dicts of (type, state_key) -> event_id, |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2018 New Vector Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import heapq | |
16 | import itertools | |
17 | import logging | |
18 | ||
19 | from six import iteritems, itervalues | |
20 | ||
21 | from twisted.internet import defer | |
22 | ||
23 | from synapse import event_auth | |
24 | from synapse.api.constants import EventTypes | |
25 | from synapse.api.errors import AuthError | |
26 | ||
27 | logger = logging.getLogger(__name__) | |
28 | ||
29 | ||
30 | @defer.inlineCallbacks | |
31 | def resolve_events_with_store(state_sets, event_map, state_res_store): | |
32 | """Resolves the state using the v2 state resolution algorithm | |
33 | ||
34 | Args: | |
35 | state_sets(list): List of dicts of (type, state_key) -> event_id, | |
36 | which are the different state groups to resolve. | |
37 | ||
38 | event_map(dict[str,FrozenEvent]|None): | |
39 | a dict from event_id to event, for any events that we happen to | |
40 | have in flight (eg, those currently being persisted). This will be | |
41 | used as a starting point fof finding the state we need; any missing | |
42 | events will be requested via state_res_store. | |
43 | ||
44 | If None, all events will be fetched via state_res_store. | |
45 | ||
46 | state_res_store (StateResolutionStore) | |
47 | ||
48 | Returns | |
49 | Deferred[dict[(str, str), str]]: | |
50 | a map from (type, state_key) to event_id. | |
51 | """ | |
52 | ||
53 | logger.debug("Computing conflicted state") | |
54 | ||
55 | # First split up the un/conflicted state | |
56 | unconflicted_state, conflicted_state = _seperate(state_sets) | |
57 | ||
58 | if not conflicted_state: | |
59 | defer.returnValue(unconflicted_state) | |
60 | ||
61 | logger.debug("%d conflicted state entries", len(conflicted_state)) | |
62 | logger.debug("Calculating auth chain difference") | |
63 | ||
64 | # Also fetch all auth events that appear in only some of the state sets' | |
65 | # auth chains. | |
66 | auth_diff = yield _get_auth_chain_difference( | |
67 | state_sets, event_map, state_res_store, | |
68 | ) | |
69 | ||
70 | full_conflicted_set = set(itertools.chain( | |
71 | itertools.chain.from_iterable(itervalues(conflicted_state)), | |
72 | auth_diff, | |
73 | )) | |
74 | ||
75 | events = yield state_res_store.get_events([ | |
76 | eid for eid in full_conflicted_set | |
77 | if eid not in event_map | |
78 | ], allow_rejected=True) | |
79 | event_map.update(events) | |
80 | ||
81 | full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) | |
82 | ||
83 | logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) | |
84 | ||
85 | # Get and sort all the power events (kicks/bans/etc) | |
86 | power_events = ( | |
87 | eid for eid in full_conflicted_set | |
88 | if _is_power_event(event_map[eid]) | |
89 | ) | |
90 | ||
91 | sorted_power_events = yield _reverse_topological_power_sort( | |
92 | power_events, | |
93 | event_map, | |
94 | state_res_store, | |
95 | full_conflicted_set, | |
96 | ) | |
97 | ||
98 | logger.debug("sorted %d power events", len(sorted_power_events)) | |
99 | ||
100 | # Now sequentially auth each one | |
101 | resolved_state = yield _iterative_auth_checks( | |
102 | sorted_power_events, unconflicted_state, event_map, | |
103 | state_res_store, | |
104 | ) | |
105 | ||
106 | logger.debug("resolved power events") | |
107 | ||
108 | # OK, so we've now resolved the power events. Now sort the remaining | |
109 | # events using the mainline of the resolved power level. | |
110 | ||
111 | leftover_events = [ | |
112 | ev_id | |
113 | for ev_id in full_conflicted_set | |
114 | if ev_id not in sorted_power_events | |
115 | ] | |
116 | ||
117 | logger.debug("sorting %d remaining events", len(leftover_events)) | |
118 | ||
119 | pl = resolved_state.get((EventTypes.PowerLevels, ""), None) | |
120 | leftover_events = yield _mainline_sort( | |
121 | leftover_events, pl, event_map, state_res_store, | |
122 | ) | |
123 | ||
124 | logger.debug("resolving remaining events") | |
125 | ||
126 | resolved_state = yield _iterative_auth_checks( | |
127 | leftover_events, resolved_state, event_map, | |
128 | state_res_store, | |
129 | ) | |
130 | ||
131 | logger.debug("resolved") | |
132 | ||
133 | # We make sure that unconflicted state always still applies. | |
134 | resolved_state.update(unconflicted_state) | |
135 | ||
136 | logger.debug("done") | |
137 | ||
138 | defer.returnValue(resolved_state) | |
139 | ||
140 | ||
141 | @defer.inlineCallbacks | |
142 | def _get_power_level_for_sender(event_id, event_map, state_res_store): | |
143 | """Return the power level of the sender of the given event according to | |
144 | their auth events. | |
145 | ||
146 | Args: | |
147 | event_id (str) | |
148 | event_map (dict[str,FrozenEvent]) | |
149 | state_res_store (StateResolutionStore) | |
150 | ||
151 | Returns: | |
152 | Deferred[int] | |
153 | """ | |
154 | event = yield _get_event(event_id, event_map, state_res_store) | |
155 | ||
156 | pl = None | |
157 | for aid, _ in event.auth_events: | |
158 | aev = yield _get_event(aid, event_map, state_res_store) | |
159 | if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): | |
160 | pl = aev | |
161 | break | |
162 | ||
163 | if pl is None: | |
164 | # Couldn't find power level. Check if they're the creator of the room | |
165 | for aid, _ in event.auth_events: | |
166 | aev = yield _get_event(aid, event_map, state_res_store) | |
167 | if (aev.type, aev.state_key) == (EventTypes.Create, ""): | |
168 | if aev.content.get("creator") == event.sender: | |
169 | defer.returnValue(100) | |
170 | break | |
171 | defer.returnValue(0) | |
172 | ||
173 | level = pl.content.get("users", {}).get(event.sender) | |
174 | if level is None: | |
175 | level = pl.content.get("users_default", 0) | |
176 | ||
177 | if level is None: | |
178 | defer.returnValue(0) | |
179 | else: | |
180 | defer.returnValue(int(level)) | |
181 | ||
182 | ||
183 | @defer.inlineCallbacks | |
184 | def _get_auth_chain_difference(state_sets, event_map, state_res_store): | |
185 | """Compare the auth chains of each state set and return the set of events | |
186 | that only appear in some but not all of the auth chains. | |
187 | ||
188 | Args: | |
189 | state_sets (list) | |
190 | event_map (dict[str,FrozenEvent]) | |
191 | state_res_store (StateResolutionStore) | |
192 | ||
193 | Returns: | |
194 | Deferred[set[str]]: Set of event IDs | |
195 | """ | |
196 | common = set(itervalues(state_sets[0])).intersection( | |
197 | *(itervalues(s) for s in state_sets[1:]) | |
198 | ) | |
199 | ||
200 | auth_sets = [] | |
201 | for state_set in state_sets: | |
202 | auth_ids = set( | |
203 | eid | |
204 | for key, eid in iteritems(state_set) | |
205 | if (key[0] in ( | |
206 | EventTypes.Member, | |
207 | EventTypes.ThirdPartyInvite, | |
208 | ) or key in ( | |
209 | (EventTypes.PowerLevels, ''), | |
210 | (EventTypes.Create, ''), | |
211 | (EventTypes.JoinRules, ''), | |
212 | )) and eid not in common | |
213 | ) | |
214 | ||
215 | auth_chain = yield state_res_store.get_auth_chain(auth_ids) | |
216 | auth_ids.update(auth_chain) | |
217 | ||
218 | auth_sets.append(auth_ids) | |
219 | ||
220 | intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) | |
221 | union = set().union(*auth_sets) | |
222 | ||
223 | defer.returnValue(union - intersection) | |
224 | ||
225 | ||
226 | def _seperate(state_sets): | |
227 | """Return the unconflicted and conflicted state. This is different than in | |
228 | the original algorithm, as this defines a key to be conflicted if one of | |
229 | the state sets doesn't have that key. | |
230 | ||
231 | Args: | |
232 | state_sets (list) | |
233 | ||
234 | Returns: | |
235 | tuple[dict, dict]: A tuple of unconflicted and conflicted state. The | |
236 | conflicted state dict is a map from type/state_key to set of event IDs | |
237 | """ | |
238 | unconflicted_state = {} | |
239 | conflicted_state = {} | |
240 | ||
241 | for key in set(itertools.chain.from_iterable(state_sets)): | |
242 | event_ids = set(state_set.get(key) for state_set in state_sets) | |
243 | if len(event_ids) == 1: | |
244 | unconflicted_state[key] = event_ids.pop() | |
245 | else: | |
246 | event_ids.discard(None) | |
247 | conflicted_state[key] = event_ids | |
248 | ||
249 | return unconflicted_state, conflicted_state | |
250 | ||
251 | ||
252 | def _is_power_event(event): | |
253 | """Return whether or not the event is a "power event", as defined by the | |
254 | v2 state resolution algorithm | |
255 | ||
256 | Args: | |
257 | event (FrozenEvent) | |
258 | ||
259 | Returns: | |
260 | boolean | |
261 | """ | |
262 | if (event.type, event.state_key) in ( | |
263 | (EventTypes.PowerLevels, ""), | |
264 | (EventTypes.JoinRules, ""), | |
265 | (EventTypes.Create, ""), | |
266 | ): | |
267 | return True | |
268 | ||
269 | if event.type == EventTypes.Member: | |
270 | if event.membership in ('leave', 'ban'): | |
271 | return event.sender != event.state_key | |
272 | ||
273 | return False | |
274 | ||
275 | ||
276 | @defer.inlineCallbacks | |
277 | def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, | |
278 | state_res_store, auth_diff): | |
279 | """Helper function for _reverse_topological_power_sort that add the event | |
280 | and its auth chain (that is in the auth diff) to the graph | |
281 | ||
282 | Args: | |
283 | graph (dict[str, set[str]]): A map from event ID to the events auth | |
284 | event IDs | |
285 | event_id (str): Event to add to the graph | |
286 | event_map (dict[str,FrozenEvent]) | |
287 | state_res_store (StateResolutionStore) | |
288 | auth_diff (set[str]): Set of event IDs that are in the auth difference. | |
289 | """ | |
290 | ||
291 | state = [event_id] | |
292 | while state: | |
293 | eid = state.pop() | |
294 | graph.setdefault(eid, set()) | |
295 | ||
296 | event = yield _get_event(eid, event_map, state_res_store) | |
297 | for aid, _ in event.auth_events: | |
298 | if aid in auth_diff: | |
299 | if aid not in graph: | |
300 | state.append(aid) | |
301 | ||
302 | graph.setdefault(eid, set()).add(aid) | |
303 | ||
304 | ||
305 | @defer.inlineCallbacks | |
306 | def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): | |
307 | """Returns a list of the event_ids sorted by reverse topological ordering, | |
308 | and then by power level and origin_server_ts | |
309 | ||
310 | Args: | |
311 | event_ids (list[str]): The events to sort | |
312 | event_map (dict[str,FrozenEvent]) | |
313 | state_res_store (StateResolutionStore) | |
314 | auth_diff (set[str]): Set of event IDs that are in the auth difference. | |
315 | ||
316 | Returns: | |
317 | Deferred[list[str]]: The sorted list | |
318 | """ | |
319 | ||
320 | graph = {} | |
321 | for event_id in event_ids: | |
322 | yield _add_event_and_auth_chain_to_graph( | |
323 | graph, event_id, event_map, state_res_store, auth_diff, | |
324 | ) | |
325 | ||
326 | event_to_pl = {} | |
327 | for event_id in graph: | |
328 | pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) | |
329 | event_to_pl[event_id] = pl | |
330 | ||
331 | def _get_power_order(event_id): | |
332 | ev = event_map[event_id] | |
333 | pl = event_to_pl[event_id] | |
334 | ||
335 | return -pl, ev.origin_server_ts, event_id | |
336 | ||
337 | # Note: graph is modified during the sort | |
338 | it = lexicographical_topological_sort( | |
339 | graph, | |
340 | key=_get_power_order, | |
341 | ) | |
342 | sorted_events = list(it) | |
343 | ||
344 | defer.returnValue(sorted_events) | |
345 | ||
346 | ||
347 | @defer.inlineCallbacks | |
348 | def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store): | |
349 | """Sequentially apply auth checks to each event in given list, updating the | |
350 | state as it goes along. | |
351 | ||
352 | Args: | |
353 | event_ids (list[str]): Ordered list of events to apply auth checks to | |
354 | base_state (dict[tuple[str, str], str]): The set of state to start with | |
355 | event_map (dict[str,FrozenEvent]) | |
356 | state_res_store (StateResolutionStore) | |
357 | ||
358 | Returns: | |
359 | Deferred[dict[tuple[str, str], str]]: Returns the final updated state | |
360 | """ | |
361 | resolved_state = base_state.copy() | |
362 | ||
363 | for event_id in event_ids: | |
364 | event = event_map[event_id] | |
365 | ||
366 | auth_events = {} | |
367 | for aid, _ in event.auth_events: | |
368 | ev = yield _get_event(aid, event_map, state_res_store) | |
369 | ||
370 | if ev.rejected_reason is None: | |
371 | auth_events[(ev.type, ev.state_key)] = ev | |
372 | ||
373 | for key in event_auth.auth_types_for_event(event): | |
374 | if key in resolved_state: | |
375 | ev_id = resolved_state[key] | |
376 | ev = yield _get_event(ev_id, event_map, state_res_store) | |
377 | ||
378 | if ev.rejected_reason is None: | |
379 | auth_events[key] = event_map[ev_id] | |
380 | ||
381 | try: | |
382 | event_auth.check( | |
383 | event, auth_events, | |
384 | do_sig_check=False, | |
385 | do_size_check=False | |
386 | ) | |
387 | ||
388 | resolved_state[(event.type, event.state_key)] = event_id | |
389 | except AuthError: | |
390 | pass | |
391 | ||
392 | defer.returnValue(resolved_state) | |
393 | ||
394 | ||
395 | @defer.inlineCallbacks | |
396 | def _mainline_sort(event_ids, resolved_power_event_id, event_map, | |
397 | state_res_store): | |
398 | """Returns a sorted list of event_ids sorted by mainline ordering based on | |
399 | the given event resolved_power_event_id | |
400 | ||
401 | Args: | |
402 | event_ids (list[str]): Events to sort | |
403 | resolved_power_event_id (str): The final resolved power level event ID | |
404 | event_map (dict[str,FrozenEvent]) | |
405 | state_res_store (StateResolutionStore) | |
406 | ||
407 | Returns: | |
408 | Deferred[list[str]]: The sorted list | |
409 | """ | |
410 | mainline = [] | |
411 | pl = resolved_power_event_id | |
412 | while pl: | |
413 | mainline.append(pl) | |
414 | pl_ev = yield _get_event(pl, event_map, state_res_store) | |
415 | auth_events = pl_ev.auth_events | |
416 | pl = None | |
417 | for aid, _ in auth_events: | |
418 | ev = yield _get_event(aid, event_map, state_res_store) | |
419 | if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): | |
420 | pl = aid | |
421 | break | |
422 | ||
423 | mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} | |
424 | ||
425 | event_ids = list(event_ids) | |
426 | ||
427 | order_map = {} | |
428 | for ev_id in event_ids: | |
429 | depth = yield _get_mainline_depth_for_event( | |
430 | event_map[ev_id], mainline_map, | |
431 | event_map, state_res_store, | |
432 | ) | |
433 | order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) | |
434 | ||
435 | event_ids.sort(key=lambda ev_id: order_map[ev_id]) | |
436 | ||
437 | defer.returnValue(event_ids) | |
438 | ||
439 | ||
440 | @defer.inlineCallbacks | |
441 | def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): | |
442 | """Get the mainline depths for the given event based on the mainline map | |
443 | ||
444 | Args: | |
445 | event (FrozenEvent) | |
446 | mainline_map (dict[str, int]): Map from event_id to mainline depth for | |
447 | events in the mainline. | |
448 | event_map (dict[str,FrozenEvent]) | |
449 | state_res_store (StateResolutionStore) | |
450 | ||
451 | Returns: | |
452 | Deferred[int] | |
453 | """ | |
454 | ||
455 | # We do an iterative search, replacing `event with the power level in its | |
456 | # auth events (if any) | |
457 | while event: | |
458 | depth = mainline_map.get(event.event_id) | |
459 | if depth is not None: | |
460 | defer.returnValue(depth) | |
461 | ||
462 | auth_events = event.auth_events | |
463 | event = None | |
464 | ||
465 | for aid, _ in auth_events: | |
466 | aev = yield _get_event(aid, event_map, state_res_store) | |
467 | if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): | |
468 | event = aev | |
469 | break | |
470 | ||
471 | # Didn't find a power level auth event, so we just return 0 | |
472 | defer.returnValue(0) | |
473 | ||
474 | ||
475 | @defer.inlineCallbacks | |
476 | def _get_event(event_id, event_map, state_res_store): | |
477 | """Helper function to look up event in event_map, falling back to looking | |
478 | it up in the store | |
479 | ||
480 | Args: | |
481 | event_id (str) | |
482 | event_map (dict[str,FrozenEvent]) | |
483 | state_res_store (StateResolutionStore) | |
484 | ||
485 | Returns: | |
486 | Deferred[FrozenEvent] | |
487 | """ | |
488 | if event_id not in event_map: | |
489 | events = yield state_res_store.get_events([event_id], allow_rejected=True) | |
490 | event_map.update(events) | |
491 | defer.returnValue(event_map[event_id]) | |
492 | ||
493 | ||
494 | def lexicographical_topological_sort(graph, key): | |
495 | """Performs a lexicographic reverse topological sort on the graph. | |
496 | ||
497 | This returns a reverse topological sort (i.e. if node A references B then B | |
498 | appears before A in the sort), with ties broken lexicographically based on | |
499 | return value of the `key` function. | |
500 | ||
501 | NOTE: `graph` is modified during the sort. | |
502 | ||
503 | Args: | |
504 | graph (dict[str, set[str]]): A representation of the graph where each | |
505 | node is a key in the dict and its value are the nodes edges. | |
506 | key (func): A function that takes a node and returns a value that is | |
507 | comparable and used to order nodes | |
508 | ||
509 | Yields: | |
510 | str: The next node in the topological sort | |
511 | """ | |
512 | ||
513 | # Note, this is basically Kahn's algorithm except we look at nodes with no | |
514 | # outgoing edges, c.f. | |
515 | # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm | |
516 | outdegree_map = graph | |
517 | reverse_graph = {} | |
518 | ||
519 | # Lists of nodes with zero out degree. Is actually a tuple of | |
520 | # `(key(node), node)` so that sorting does the right thing | |
521 | zero_outdegree = [] | |
522 | ||
523 | for node, edges in iteritems(graph): | |
524 | if len(edges) == 0: | |
525 | zero_outdegree.append((key(node), node)) | |
526 | ||
527 | reverse_graph.setdefault(node, set()) | |
528 | for edge in edges: | |
529 | reverse_graph.setdefault(edge, set()).add(node) | |
530 | ||
531 | # heapq is a built in implementation of a sorted queue. | |
532 | heapq.heapify(zero_outdegree) | |
533 | ||
534 | while zero_outdegree: | |
535 | _, node = heapq.heappop(zero_outdegree) | |
536 | ||
537 | for parent in reverse_graph[node]: | |
538 | out = outdegree_map[parent] | |
539 | out.discard(node) | |
540 | if len(out) == 0: | |
541 | heapq.heappush(zero_outdegree, (key(parent), parent)) | |
542 | ||
543 | yield node |
17 | 17 | import time |
18 | 18 | |
19 | 19 | from six import PY2, iteritems, iterkeys, itervalues |
20 | from six.moves import intern, range | |
20 | from six.moves import builtins, intern, range | |
21 | 21 | |
22 | 22 | from canonicaljson import json |
23 | 23 | from prometheus_client import Histogram |
1232 | 1232 | |
1233 | 1233 | # psycopg2 on Python 2 returns buffer objects, which we need to cast to |
1234 | 1234 | # bytes to decode |
1235 | if PY2 and isinstance(db_content, buffer): | |
1235 | if PY2 and isinstance(db_content, builtins.buffer): | |
1236 | 1236 | db_content = bytes(db_content) |
1237 | 1237 | |
1238 | 1238 | # Decode it to a Unicode string before feeding it to json.loads, so we |
89 | 89 | class DirectoryStore(DirectoryWorkerStore): |
90 | 90 | @defer.inlineCallbacks |
91 | 91 | def create_room_alias_association(self, room_alias, room_id, servers, creator=None): |
92 | """ Creates an associatin between a room alias and room_id/servers | |
92 | """ Creates an association between a room alias and room_id/servers | |
93 | 93 | |
94 | 94 | Args: |
95 | 95 | room_alias (RoomAlias) |
33 | 33 | from synapse.events import EventBase # noqa: F401 |
34 | 34 | from synapse.events.snapshot import EventContext # noqa: F401 |
35 | 35 | from synapse.metrics.background_process_metrics import run_as_background_process |
36 | from synapse.state import StateResolutionStore | |
36 | 37 | from synapse.storage.background_updates import BackgroundUpdateStore |
37 | 38 | from synapse.storage.event_federation import EventFederationStore |
38 | 39 | from synapse.storage.events_worker import EventsWorkerStore |
730 | 731 | |
731 | 732 | # Ok, we need to defer to the state handler to resolve our state sets. |
732 | 733 | |
733 | def get_events(ev_ids): | |
734 | return self.get_events( | |
735 | ev_ids, get_prev_content=False, check_redacted=False, | |
736 | ) | |
737 | ||
738 | 734 | state_groups = { |
739 | 735 | sg: state_groups_map[sg] for sg in new_state_groups |
740 | 736 | } |
744 | 740 | |
745 | 741 | logger.debug("calling resolve_state_groups from preserve_events") |
746 | 742 | res = yield self._state_resolution_handler.resolve_state_groups( |
747 | room_id, room_version, state_groups, events_map, get_events | |
743 | room_id, room_version, state_groups, events_map, | |
744 | state_res_store=StateResolutionStore(self) | |
748 | 745 | ) |
749 | 746 | |
750 | 747 | defer.returnValue((res.state, None)) |
852 | 849 | |
853 | 850 | # Insert into event_to_state_groups. |
854 | 851 | self._store_event_state_mappings_txn(txn, events_and_contexts) |
852 | ||
853 | # We want to store event_auth mappings for rejected events, as they're | |
854 | # used in state res v2. | |
855 | # This is only necessary if the rejected event appears in an accepted | |
856 | # event's auth chain, but its easier for now just to store them (and | |
857 | # it doesn't take much storage compared to storing the entire event | |
858 | # anyway). | |
859 | self._simple_insert_many_txn( | |
860 | txn, | |
861 | table="event_auth", | |
862 | values=[ | |
863 | { | |
864 | "event_id": event.event_id, | |
865 | "room_id": event.room_id, | |
866 | "auth_id": auth_id, | |
867 | } | |
868 | for event, _ in events_and_contexts | |
869 | for auth_id, _ in event.auth_events | |
870 | if event.is_state() | |
871 | ], | |
872 | ) | |
855 | 873 | |
856 | 874 | # _store_rejected_events_txn filters out any events which were |
857 | 875 | # rejected, and returns the filtered list. |
1328 | 1346 | txn, event.room_id, event.redacts |
1329 | 1347 | ) |
1330 | 1348 | |
1331 | self._simple_insert_many_txn( | |
1332 | txn, | |
1333 | table="event_auth", | |
1334 | values=[ | |
1335 | { | |
1336 | "event_id": event.event_id, | |
1337 | "room_id": event.room_id, | |
1338 | "auth_id": auth_id, | |
1339 | } | |
1340 | for event, _ in events_and_contexts | |
1341 | for auth_id, _ in event.auth_events | |
1342 | if event.is_state() | |
1343 | ], | |
1344 | ) | |
1345 | ||
1346 | 1349 | # Update the event_forward_extremities, event_backward_extremities and |
1347 | 1350 | # event_edges tables. |
1348 | 1351 | self._handle_mult_prev_events( |
2085 | 2088 | for sg in remaining_state_groups: |
2086 | 2089 | logger.info("[purge] de-delta-ing remaining state group %s", sg) |
2087 | 2090 | curr_state = self._get_state_groups_from_groups_txn( |
2088 | txn, [sg], types=None | |
2091 | txn, [sg], | |
2089 | 2092 | ) |
2090 | 2093 | curr_state = curr_state[sg] |
2091 | 2094 |
31 | 31 | # py2 sqlite has buffer hardcoded as only binary type, so we must use it, |
32 | 32 | # despite being deprecated and removed in favor of memoryview |
33 | 33 | if six.PY2: |
34 | db_binary_type = buffer | |
34 | db_binary_type = six.moves.builtins.buffer | |
35 | 35 | else: |
36 | 36 | db_binary_type = memoryview |
37 | 37 |
32 | 32 | self._clock = hs.get_clock() |
33 | 33 | self.hs = hs |
34 | 34 | self.reserved_users = () |
35 | ||
36 | @defer.inlineCallbacks | |
37 | def initialise_reserved_users(self, threepids): | |
38 | store = self.hs.get_datastore() | |
35 | # Do not add more reserved users than the total allowable number | |
36 | self._initialise_reserved_users( | |
37 | dbconn.cursor(), | |
38 | hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], | |
39 | ) | |
40 | ||
41 | def _initialise_reserved_users(self, txn, threepids): | |
42 | """Ensures that reserved threepids are accounted for in the MAU table, should | |
43 | be called on start up. | |
44 | ||
45 | Args: | |
46 | txn (cursor): | |
47 | threepids (list[dict]): List of threepid dicts to reserve | |
48 | """ | |
39 | 49 | reserved_user_list = [] |
40 | 50 | |
41 | # Do not add more reserved users than the total allowable number | |
42 | for tp in threepids[:self.hs.config.max_mau_value]: | |
43 | user_id = yield store.get_user_id_by_threepid( | |
51 | for tp in threepids: | |
52 | user_id = self.get_user_id_by_threepid_txn( | |
53 | txn, | |
44 | 54 | tp["medium"], tp["address"] |
45 | 55 | ) |
46 | 56 | if user_id: |
47 | yield self.upsert_monthly_active_user(user_id) | |
57 | self.upsert_monthly_active_user_txn(txn, user_id) | |
48 | 58 | reserved_user_list.append(user_id) |
49 | 59 | else: |
50 | 60 | logger.warning( |
54 | 64 | |
55 | 65 | @defer.inlineCallbacks |
56 | 66 | def reap_monthly_active_users(self): |
57 | """ | |
58 | Cleans out monthly active user table to ensure that no stale | |
67 | """Cleans out monthly active user table to ensure that no stale | |
59 | 68 | entries exist. |
60 | 69 | |
61 | 70 | Returns: |
164 | 173 | |
165 | 174 | @defer.inlineCallbacks |
166 | 175 | def upsert_monthly_active_user(self, user_id): |
167 | """ | |
168 | Updates or inserts monthly active user member | |
169 | Arguments: | |
170 | user_id (str): user to add/update | |
171 | Deferred[bool]: True if a new entry was created, False if an | |
172 | existing one was updated. | |
176 | """Updates or inserts the user into the monthly active user table, which | |
177 | is used to track the current MAU usage of the server | |
178 | ||
179 | Args: | |
180 | user_id (str): user to add/update | |
181 | """ | |
182 | is_insert = yield self.runInteraction( | |
183 | "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, | |
184 | user_id | |
185 | ) | |
186 | ||
187 | if is_insert: | |
188 | self.user_last_seen_monthly_active.invalidate((user_id,)) | |
189 | self.get_monthly_active_count.invalidate(()) | |
190 | ||
191 | def upsert_monthly_active_user_txn(self, txn, user_id): | |
192 | """Updates or inserts monthly active user member | |
193 | ||
194 | Note that, after calling this method, it will generally be necessary | |
195 | to invalidate the caches on user_last_seen_monthly_active and | |
196 | get_monthly_active_count. We can't do that here, because we are running | |
197 | in a database thread rather than the main thread, and we can't call | |
198 | txn.call_after because txn may not be a LoggingTransaction. | |
199 | ||
200 | Args: | |
201 | txn (cursor): | |
202 | user_id (str): user to add/update | |
203 | ||
204 | Returns: | |
205 | bool: True if a new entry was created, False if an | |
206 | existing one was updated. | |
173 | 207 | """ |
174 | 208 | # Am consciously deciding to lock the table on the basis that is ought |
175 | 209 | # never be a big table and alternative approaches (batching multiple |
176 | 210 | # upserts into a single txn) introduced a lot of extra complexity. |
177 | 211 | # See https://github.com/matrix-org/synapse/issues/3854 for more |
178 | is_insert = yield self._simple_upsert( | |
179 | desc="upsert_monthly_active_user", | |
212 | is_insert = self._simple_upsert_txn( | |
213 | txn, | |
180 | 214 | table="monthly_active_users", |
181 | 215 | keyvalues={ |
182 | 216 | "user_id": user_id, |
185 | 219 | "timestamp": int(self._clock.time_msec()), |
186 | 220 | }, |
187 | 221 | ) |
188 | if is_insert: | |
189 | self.user_last_seen_monthly_active.invalidate((user_id,)) | |
190 | self.get_monthly_active_count.invalidate(()) | |
222 | ||
223 | return is_insert | |
191 | 224 | |
192 | 225 | @cached(num_args=1) |
193 | 226 | def user_last_seen_monthly_active(self, user_id): |
28 | 28 | logger = logging.getLogger(__name__) |
29 | 29 | |
30 | 30 | if six.PY2: |
31 | db_binary_type = buffer | |
31 | db_binary_type = six.moves.builtins.buffer | |
32 | 32 | else: |
33 | 33 | db_binary_type = memoryview |
34 | 34 |
473 | 473 | |
474 | 474 | @defer.inlineCallbacks |
475 | 475 | def get_user_id_by_threepid(self, medium, address): |
476 | ret = yield self._simple_select_one( | |
476 | """Returns user id from threepid | |
477 | ||
478 | Args: | |
479 | medium (str): threepid medium e.g. email | |
480 | address (str): threepid address e.g. me@example.com | |
481 | ||
482 | Returns: | |
483 | Deferred[str|None]: user id or None if no user id/threepid mapping exists | |
484 | """ | |
485 | user_id = yield self.runInteraction( | |
486 | "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, | |
487 | medium, address | |
488 | ) | |
489 | defer.returnValue(user_id) | |
490 | ||
491 | def get_user_id_by_threepid_txn(self, txn, medium, address): | |
492 | """Returns user id from threepid | |
493 | ||
494 | Args: | |
495 | txn (cursor): | |
496 | medium (str): threepid medium e.g. email | |
497 | address (str): threepid address e.g. me@example.com | |
498 | ||
499 | Returns: | |
500 | str|None: user id or None if no user id/threepid mapping exists | |
501 | """ | |
502 | ret = self._simple_select_one_txn( | |
503 | txn, | |
477 | 504 | "user_threepids", |
478 | 505 | { |
479 | 506 | "medium": medium, |
480 | 507 | "address": address |
481 | 508 | }, |
482 | ['user_id'], True, 'get_user_id_by_threepid' | |
509 | ['user_id'], True | |
483 | 510 | ) |
484 | 511 | if ret: |
485 | defer.returnValue(ret['user_id']) | |
486 | defer.returnValue(None) | |
512 | return ret['user_id'] | |
513 | return None | |
487 | 514 | |
488 | 515 | def user_delete_threepid(self, user_id, medium, address): |
489 | 516 | return self._simple_delete( |
566 | 593 | def _find_next_generated_user_id(txn): |
567 | 594 | txn.execute("SELECT name FROM users") |
568 | 595 | |
569 | regex = re.compile("^@(\d+):") | |
596 | regex = re.compile(r"^@(\d+):") | |
570 | 597 | |
571 | 598 | found = set() |
572 | 599 |
26 | 26 | # py2 sqlite has buffer hardcoded as only binary type, so we must use it, |
27 | 27 | # despite being deprecated and removed in favor of memoryview |
28 | 28 | if six.PY2: |
29 | db_binary_type = buffer | |
29 | db_binary_type = six.moves.builtins.buffer | |
30 | 30 | else: |
31 | 31 | db_binary_type = memoryview |
32 | 32 |
18 | 18 | from six import iteritems, itervalues |
19 | 19 | from six.moves import range |
20 | 20 | |
21 | import attr | |
22 | ||
21 | 23 | from twisted.internet import defer |
22 | 24 | |
23 | 25 | from synapse.api.constants import EventTypes |
45 | 47 | |
46 | 48 | def __len__(self): |
47 | 49 | return len(self.delta_ids) if self.delta_ids else 0 |
50 | ||
51 | ||
52 | @attr.s(slots=True) | |
53 | class StateFilter(object): | |
54 | """A filter used when querying for state. | |
55 | ||
56 | Attributes: | |
57 | types (dict[str, set[str]|None]): Map from type to set of state keys (or | |
58 | None). This specifies which state_keys for the given type to fetch | |
59 | from the DB. If None then all events with that type are fetched. If | |
60 | the set is empty then no events with that type are fetched. | |
61 | include_others (bool): Whether to fetch events with types that do not | |
62 | appear in `types`. | |
63 | """ | |
64 | ||
65 | types = attr.ib() | |
66 | include_others = attr.ib(default=False) | |
67 | ||
68 | def __attrs_post_init__(self): | |
69 | # If `include_others` is set we canonicalise the filter by removing | |
70 | # wildcards from the types dictionary | |
71 | if self.include_others: | |
72 | self.types = { | |
73 | k: v for k, v in iteritems(self.types) | |
74 | if v is not None | |
75 | } | |
76 | ||
77 | @staticmethod | |
78 | def all(): | |
79 | """Creates a filter that fetches everything. | |
80 | ||
81 | Returns: | |
82 | StateFilter | |
83 | """ | |
84 | return StateFilter(types={}, include_others=True) | |
85 | ||
86 | @staticmethod | |
87 | def none(): | |
88 | """Creates a filter that fetches nothing. | |
89 | ||
90 | Returns: | |
91 | StateFilter | |
92 | """ | |
93 | return StateFilter(types={}, include_others=False) | |
94 | ||
95 | @staticmethod | |
96 | def from_types(types): | |
97 | """Creates a filter that only fetches the given types | |
98 | ||
99 | Args: | |
100 | types (Iterable[tuple[str, str|None]]): A list of type and state | |
101 | keys to fetch. A state_key of None fetches everything for | |
102 | that type | |
103 | ||
104 | Returns: | |
105 | StateFilter | |
106 | """ | |
107 | type_dict = {} | |
108 | for typ, s in types: | |
109 | if typ in type_dict: | |
110 | if type_dict[typ] is None: | |
111 | continue | |
112 | ||
113 | if s is None: | |
114 | type_dict[typ] = None | |
115 | continue | |
116 | ||
117 | type_dict.setdefault(typ, set()).add(s) | |
118 | ||
119 | return StateFilter(types=type_dict) | |
120 | ||
121 | @staticmethod | |
122 | def from_lazy_load_member_list(members): | |
123 | """Creates a filter that returns all non-member events, plus the member | |
124 | events for the given users | |
125 | ||
126 | Args: | |
127 | members (iterable[str]): Set of user IDs | |
128 | ||
129 | Returns: | |
130 | StateFilter | |
131 | """ | |
132 | return StateFilter( | |
133 | types={EventTypes.Member: set(members)}, | |
134 | include_others=True, | |
135 | ) | |
136 | ||
137 | def return_expanded(self): | |
138 | """Creates a new StateFilter where type wild cards have been removed | |
139 | (except for memberships). The returned filter is a superset of the | |
140 | current one, i.e. anything that passes the current filter will pass | |
141 | the returned filter. | |
142 | ||
143 | This helps the caching as the DictionaryCache knows if it has *all* the | |
144 | state, but does not know if it has all of the keys of a particular type, | |
145 | which makes wildcard lookups expensive unless we have a complete cache. | |
146 | Hence, if we are doing a wildcard lookup, populate the cache fully so | |
147 | that we can do an efficient lookup next time. | |
148 | ||
149 | Note that since we have two caches, one for membership events and one for | |
150 | other events, we can be a bit more clever than simply returning | |
151 | `StateFilter.all()` if `has_wildcards()` is True. | |
152 | ||
153 | We return a StateFilter where: | |
154 | 1. the list of membership events to return is the same | |
155 | 2. if there is a wildcard that matches non-member events we | |
156 | return all non-member events | |
157 | ||
158 | Returns: | |
159 | StateFilter | |
160 | """ | |
161 | ||
162 | if self.is_full(): | |
163 | # If we're going to return everything then there's nothing to do | |
164 | return self | |
165 | ||
166 | if not self.has_wildcards(): | |
167 | # If there are no wild cards, there's nothing to do | |
168 | return self | |
169 | ||
170 | if EventTypes.Member in self.types: | |
171 | get_all_members = self.types[EventTypes.Member] is None | |
172 | else: | |
173 | get_all_members = self.include_others | |
174 | ||
175 | has_non_member_wildcard = self.include_others or any( | |
176 | state_keys is None | |
177 | for t, state_keys in iteritems(self.types) | |
178 | if t != EventTypes.Member | |
179 | ) | |
180 | ||
181 | if not has_non_member_wildcard: | |
182 | # If there are no non-member wild cards we can just return ourselves | |
183 | return self | |
184 | ||
185 | if get_all_members: | |
186 | # We want to return everything. | |
187 | return StateFilter.all() | |
188 | else: | |
189 | # We want to return all non-members, but only particular | |
190 | # memberships | |
191 | return StateFilter( | |
192 | types={EventTypes.Member: self.types[EventTypes.Member]}, | |
193 | include_others=True, | |
194 | ) | |
195 | ||
196 | def make_sql_filter_clause(self): | |
197 | """Converts the filter to an SQL clause. | |
198 | ||
199 | For example: | |
200 | ||
201 | f = StateFilter.from_types([("m.room.create", "")]) | |
202 | clause, args = f.make_sql_filter_clause() | |
203 | clause == "(type = ? AND state_key = ?)" | |
204 | args == ['m.room.create', ''] | |
205 | ||
206 | ||
207 | Returns: | |
208 | tuple[str, list]: The SQL string (may be empty) and arguments. An | |
209 | empty SQL string is returned when the filter matches everything | |
210 | (i.e. is "full"). | |
211 | """ | |
212 | ||
213 | where_clause = "" | |
214 | where_args = [] | |
215 | ||
216 | if self.is_full(): | |
217 | return where_clause, where_args | |
218 | ||
219 | if not self.include_others and not self.types: | |
220 | # i.e. this is an empty filter, so we need to return a clause that | |
221 | # will match nothing | |
222 | return "1 = 2", [] | |
223 | ||
224 | # First we build up a lost of clauses for each type/state_key combo | |
225 | clauses = [] | |
226 | for etype, state_keys in iteritems(self.types): | |
227 | if state_keys is None: | |
228 | clauses.append("(type = ?)") | |
229 | where_args.append(etype) | |
230 | continue | |
231 | ||
232 | for state_key in state_keys: | |
233 | clauses.append("(type = ? AND state_key = ?)") | |
234 | where_args.extend((etype, state_key)) | |
235 | ||
236 | # This will match anything that appears in `self.types` | |
237 | where_clause = " OR ".join(clauses) | |
238 | ||
239 | # If we want to include stuff that's not in the types dict then we add | |
240 | # a `OR type NOT IN (...)` clause to the end. | |
241 | if self.include_others: | |
242 | if where_clause: | |
243 | where_clause += " OR " | |
244 | ||
245 | where_clause += "type NOT IN (%s)" % ( | |
246 | ",".join(["?"] * len(self.types)), | |
247 | ) | |
248 | where_args.extend(self.types) | |
249 | ||
250 | return where_clause, where_args | |
251 | ||
252 | def max_entries_returned(self): | |
253 | """Returns the maximum number of entries this filter will return if | |
254 | known, otherwise returns None. | |
255 | ||
256 | For example a simple state filter asking for `("m.room.create", "")` | |
257 | will return 1, whereas the default state filter will return None. | |
258 | ||
259 | This is used to bail out early if the right number of entries have been | |
260 | fetched. | |
261 | """ | |
262 | if self.has_wildcards(): | |
263 | return None | |
264 | ||
265 | return len(self.concrete_types()) | |
266 | ||
267 | def filter_state(self, state_dict): | |
268 | """Returns the state filtered with by this StateFilter | |
269 | ||
270 | Args: | |
271 | state (dict[tuple[str, str], Any]): The state map to filter | |
272 | ||
273 | Returns: | |
274 | dict[tuple[str, str], Any]: The filtered state map | |
275 | """ | |
276 | if self.is_full(): | |
277 | return dict(state_dict) | |
278 | ||
279 | filtered_state = {} | |
280 | for k, v in iteritems(state_dict): | |
281 | typ, state_key = k | |
282 | if typ in self.types: | |
283 | state_keys = self.types[typ] | |
284 | if state_keys is None or state_key in state_keys: | |
285 | filtered_state[k] = v | |
286 | elif self.include_others: | |
287 | filtered_state[k] = v | |
288 | ||
289 | return filtered_state | |
290 | ||
291 | def is_full(self): | |
292 | """Whether this filter fetches everything or not | |
293 | ||
294 | Returns: | |
295 | bool | |
296 | """ | |
297 | return self.include_others and not self.types | |
298 | ||
299 | def has_wildcards(self): | |
300 | """Whether the filter includes wildcards or is attempting to fetch | |
301 | specific state. | |
302 | ||
303 | Returns: | |
304 | bool | |
305 | """ | |
306 | ||
307 | return ( | |
308 | self.include_others | |
309 | or any( | |
310 | state_keys is None | |
311 | for state_keys in itervalues(self.types) | |
312 | ) | |
313 | ) | |
314 | ||
315 | def concrete_types(self): | |
316 | """Returns a list of concrete type/state_keys (i.e. not None) that | |
317 | will be fetched. This will be a complete list if `has_wildcards` | |
318 | returns False, but otherwise will be a subset (or even empty). | |
319 | ||
320 | Returns: | |
321 | list[tuple[str,str]] | |
322 | """ | |
323 | return [ | |
324 | (t, s) | |
325 | for t, state_keys in iteritems(self.types) | |
326 | if state_keys is not None | |
327 | for s in state_keys | |
328 | ] | |
329 | ||
330 | def get_member_split(self): | |
331 | """Return the filter split into two: one which assumes it's exclusively | |
332 | matching against member state, and one which assumes it's matching | |
333 | against non member state. | |
334 | ||
335 | This is useful due to the returned filters giving correct results for | |
336 | `is_full()`, `has_wildcards()`, etc, when operating against maps that | |
337 | either exclusively contain member events or only contain non-member | |
338 | events. (Which is the case when dealing with the member vs non-member | |
339 | state caches). | |
340 | ||
341 | Returns: | |
342 | tuple[StateFilter, StateFilter]: The member and non member filters | |
343 | """ | |
344 | ||
345 | if EventTypes.Member in self.types: | |
346 | state_keys = self.types[EventTypes.Member] | |
347 | if state_keys is None: | |
348 | member_filter = StateFilter.all() | |
349 | else: | |
350 | member_filter = StateFilter({EventTypes.Member: state_keys}) | |
351 | elif self.include_others: | |
352 | member_filter = StateFilter.all() | |
353 | else: | |
354 | member_filter = StateFilter.none() | |
355 | ||
356 | non_member_filter = StateFilter( | |
357 | types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, | |
358 | include_others=self.include_others, | |
359 | ) | |
360 | ||
361 | return member_filter, non_member_filter | |
48 | 362 | |
49 | 363 | |
50 | 364 | # this inherits from EventsWorkerStore because it calls self.get_events |
151 | 465 | ) |
152 | 466 | |
153 | 467 | # FIXME: how should this be cached? |
154 | def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): | |
468 | def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): | |
155 | 469 | """Get the current state event of a given type for a room based on the |
156 | 470 | current_state_events table. This may not be as up-to-date as the result |
157 | 471 | of doing a fresh state resolution as per state_handler.get_current_state |
472 | ||
158 | 473 | Args: |
159 | 474 | room_id (str) |
160 | types (list[(Str, (Str|None))]): List of (type, state_key) tuples | |
161 | which are used to filter the state fetched. `state_key` may be | |
162 | None, which matches any `state_key` | |
163 | filtered_types (list[Str]|None): List of types to apply the above filter to. | |
164 | Returns: | |
165 | deferred: dict of (type, state_key) -> event | |
166 | """ | |
167 | ||
168 | include_other_types = False if filtered_types is None else True | |
475 | state_filter (StateFilter): The state filter used to fetch state | |
476 | from the database. | |
477 | ||
478 | Returns: | |
479 | Deferred[dict[tuple[str, str], str]]: Map from type/state_key to | |
480 | event ID. | |
481 | """ | |
169 | 482 | |
170 | 483 | def _get_filtered_current_state_ids_txn(txn): |
171 | 484 | results = {} |
172 | sql = """SELECT type, state_key, event_id FROM current_state_events | |
173 | WHERE room_id = ? %s""" | |
174 | # Turns out that postgres doesn't like doing a list of OR's and | |
175 | # is about 1000x slower, so we just issue a query for each specific | |
176 | # type seperately. | |
177 | if types: | |
178 | clause_to_args = [ | |
179 | ( | |
180 | "AND type = ? AND state_key = ?", | |
181 | (etype, state_key) | |
182 | ) if state_key is not None else ( | |
183 | "AND type = ?", | |
184 | (etype,) | |
185 | ) | |
186 | for etype, state_key in types | |
187 | ] | |
188 | ||
189 | if include_other_types: | |
190 | unique_types = set(filtered_types) | |
191 | clause_to_args.append( | |
192 | ( | |
193 | "AND type <> ? " * len(unique_types), | |
194 | list(unique_types) | |
195 | ) | |
196 | ) | |
197 | else: | |
198 | # If types is None we fetch all the state, and so just use an | |
199 | # empty where clause with no extra args. | |
200 | clause_to_args = [("", [])] | |
201 | for where_clause, where_args in clause_to_args: | |
202 | args = [room_id] | |
203 | args.extend(where_args) | |
204 | txn.execute(sql % (where_clause,), args) | |
205 | for row in txn: | |
206 | typ, state_key, event_id = row | |
207 | key = (intern_string(typ), intern_string(state_key)) | |
208 | results[key] = event_id | |
485 | sql = """ | |
486 | SELECT type, state_key, event_id FROM current_state_events | |
487 | WHERE room_id = ? | |
488 | """ | |
489 | ||
490 | where_clause, where_args = state_filter.make_sql_filter_clause() | |
491 | ||
492 | if where_clause: | |
493 | sql += " AND (%s)" % (where_clause,) | |
494 | ||
495 | args = [room_id] | |
496 | args.extend(where_args) | |
497 | txn.execute(sql, args) | |
498 | for row in txn: | |
499 | typ, state_key, event_id = row | |
500 | key = (intern_string(typ), intern_string(state_key)) | |
501 | results[key] = event_id | |
502 | ||
209 | 503 | return results |
210 | 504 | |
211 | 505 | return self.runInteraction( |
321 | 615 | }) |
322 | 616 | |
323 | 617 | @defer.inlineCallbacks |
324 | def _get_state_groups_from_groups(self, groups, types, members=None): | |
618 | def _get_state_groups_from_groups(self, groups, state_filter): | |
325 | 619 | """Returns the state groups for a given set of groups, filtering on |
326 | 620 | types of state events. |
327 | 621 | |
328 | 622 | Args: |
329 | 623 | groups(list[int]): list of state group IDs to query |
330 | types (Iterable[str, str|None]|None): list of 2-tuples of the form | |
331 | (`type`, `state_key`), where a `state_key` of `None` matches all | |
332 | state_keys for the `type`. If None, all types are returned. | |
333 | members (bool|None): If not None, then, in addition to any filtering | |
334 | implied by types, the results are also filtered to only include | |
335 | member events (if True), or to exclude member events (if False) | |
336 | ||
337 | Returns: | |
624 | state_filter (StateFilter): The state filter used to fetch state | |
625 | from the database. | |
338 | 626 | Returns: |
339 | 627 | Deferred[dict[int, dict[tuple[str, str], str]]]: |
340 | 628 | dict of state_group_id -> (dict of (type, state_key) -> event id) |
345 | 633 | for chunk in chunks: |
346 | 634 | res = yield self.runInteraction( |
347 | 635 | "_get_state_groups_from_groups", |
348 | self._get_state_groups_from_groups_txn, chunk, types, members, | |
636 | self._get_state_groups_from_groups_txn, chunk, state_filter, | |
349 | 637 | ) |
350 | 638 | results.update(res) |
351 | 639 | |
352 | 640 | defer.returnValue(results) |
353 | 641 | |
354 | 642 | def _get_state_groups_from_groups_txn( |
355 | self, txn, groups, types=None, members=None, | |
643 | self, txn, groups, state_filter=StateFilter.all(), | |
356 | 644 | ): |
357 | 645 | results = {group: {} for group in groups} |
358 | 646 | |
359 | if types is not None: | |
360 | types = list(set(types)) # deduplicate types list | |
647 | where_clause, where_args = state_filter.make_sql_filter_clause() | |
648 | ||
649 | # Unless the filter clause is empty, we're going to append it after an | |
650 | # existing where clause | |
651 | if where_clause: | |
652 | where_clause = " AND (%s)" % (where_clause,) | |
361 | 653 | |
362 | 654 | if isinstance(self.database_engine, PostgresEngine): |
363 | 655 | # Temporarily disable sequential scans in this transaction. This is |
373 | 665 | # group for the given type, state_key. |
374 | 666 | # This may return multiple rows per (type, state_key), but last_value |
375 | 667 | # should be the same. |
376 | sql = (""" | |
668 | sql = """ | |
377 | 669 | WITH RECURSIVE state(state_group) AS ( |
378 | 670 | VALUES(?::bigint) |
379 | 671 | UNION ALL |
380 | 672 | SELECT prev_state_group FROM state_group_edges e, state s |
381 | 673 | WHERE s.state_group = e.state_group |
382 | 674 | ) |
383 | SELECT type, state_key, last_value(event_id) OVER ( | |
675 | SELECT DISTINCT type, state_key, last_value(event_id) OVER ( | |
384 | 676 | PARTITION BY type, state_key ORDER BY state_group ASC |
385 | 677 | ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING |
386 | 678 | ) AS event_id FROM state_groups_state |
387 | 679 | WHERE state_group IN ( |
388 | 680 | SELECT state_group FROM state |
389 | 681 | ) |
390 | %s | |
391 | """) | |
392 | ||
393 | if members is True: | |
394 | sql += " AND type = '%s'" % (EventTypes.Member,) | |
395 | elif members is False: | |
396 | sql += " AND type <> '%s'" % (EventTypes.Member,) | |
397 | ||
398 | # Turns out that postgres doesn't like doing a list of OR's and | |
399 | # is about 1000x slower, so we just issue a query for each specific | |
400 | # type seperately. | |
401 | if types is not None: | |
402 | clause_to_args = [ | |
403 | ( | |
404 | "AND type = ? AND state_key = ?", | |
405 | (etype, state_key) | |
406 | ) if state_key is not None else ( | |
407 | "AND type = ?", | |
408 | (etype,) | |
409 | ) | |
410 | for etype, state_key in types | |
411 | ] | |
412 | else: | |
413 | # If types is None we fetch all the state, and so just use an | |
414 | # empty where clause with no extra args. | |
415 | clause_to_args = [("", [])] | |
416 | ||
417 | for where_clause, where_args in clause_to_args: | |
418 | for group in groups: | |
419 | args = [group] | |
420 | args.extend(where_args) | |
421 | ||
422 | txn.execute(sql % (where_clause,), args) | |
423 | for row in txn: | |
424 | typ, state_key, event_id = row | |
425 | key = (typ, state_key) | |
426 | results[group][key] = event_id | |
682 | """ | |
683 | ||
684 | for group in groups: | |
685 | args = [group] | |
686 | args.extend(where_args) | |
687 | ||
688 | txn.execute(sql + where_clause, args) | |
689 | for row in txn: | |
690 | typ, state_key, event_id = row | |
691 | key = (typ, state_key) | |
692 | results[group][key] = event_id | |
427 | 693 | else: |
428 | where_args = [] | |
429 | where_clauses = [] | |
430 | wildcard_types = False | |
431 | if types is not None: | |
432 | for typ in types: | |
433 | if typ[1] is None: | |
434 | where_clauses.append("(type = ?)") | |
435 | where_args.append(typ[0]) | |
436 | wildcard_types = True | |
437 | else: | |
438 | where_clauses.append("(type = ? AND state_key = ?)") | |
439 | where_args.extend([typ[0], typ[1]]) | |
440 | ||
441 | where_clause = "AND (%s)" % (" OR ".join(where_clauses)) | |
442 | else: | |
443 | where_clause = "" | |
444 | ||
445 | if members is True: | |
446 | where_clause += " AND type = '%s'" % EventTypes.Member | |
447 | elif members is False: | |
448 | where_clause += " AND type <> '%s'" % EventTypes.Member | |
694 | max_entries_returned = state_filter.max_entries_returned() | |
449 | 695 | |
450 | 696 | # We don't use WITH RECURSIVE on sqlite3 as there are distributions |
451 | 697 | # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) |
459 | 705 | # without the right indices (which we can't add until |
460 | 706 | # after we finish deduping state, which requires this func) |
461 | 707 | args = [next_group] |
462 | if types: | |
463 | args.extend(where_args) | |
708 | args.extend(where_args) | |
464 | 709 | |
465 | 710 | txn.execute( |
466 | 711 | "SELECT type, state_key, event_id FROM state_groups_state" |
467 | " WHERE state_group = ? %s" % (where_clause,), | |
712 | " WHERE state_group = ? " + where_clause, | |
468 | 713 | args |
469 | 714 | ) |
470 | 715 | results[group].update( |
480 | 725 | # wildcards (i.e. Nones) in which case we have to do an exhaustive |
481 | 726 | # search |
482 | 727 | if ( |
483 | types is not None and | |
484 | not wildcard_types and | |
485 | len(results[group]) == len(types) | |
728 | max_entries_returned is not None and | |
729 | len(results[group]) == max_entries_returned | |
486 | 730 | ): |
487 | 731 | break |
488 | 732 | |
497 | 741 | return results |
498 | 742 | |
499 | 743 | @defer.inlineCallbacks |
500 | def get_state_for_events(self, event_ids, types, filtered_types=None): | |
744 | def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): | |
501 | 745 | """Given a list of event_ids and type tuples, return a list of state |
502 | dicts for each event. The state dicts will only have the type/state_keys | |
503 | that are in the `types` list. | |
746 | dicts for each event. | |
504 | 747 | |
505 | 748 | Args: |
506 | 749 | event_ids (list[string]) |
507 | types (list[(str, str|None)]|None): List of (type, state_key) tuples | |
508 | which are used to filter the state fetched. If `state_key` is None, | |
509 | all events are returned of the given type. | |
510 | May be None, which matches any key. | |
511 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
512 | list of event types. Other types of events are returned unfiltered. | |
513 | If None, `types` filtering is applied to all events. | |
750 | state_filter (StateFilter): The state filter used to fetch state | |
751 | from the database. | |
514 | 752 | |
515 | 753 | Returns: |
516 | 754 | deferred: A dict of (event_id) -> (type, state_key) -> [state_events] |
520 | 758 | ) |
521 | 759 | |
522 | 760 | groups = set(itervalues(event_to_groups)) |
523 | group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) | |
761 | group_to_state = yield self._get_state_for_groups(groups, state_filter) | |
524 | 762 | |
525 | 763 | state_event_map = yield self.get_events( |
526 | 764 | [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], |
539 | 777 | defer.returnValue({event: event_to_state[event] for event in event_ids}) |
540 | 778 | |
541 | 779 | @defer.inlineCallbacks |
542 | def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): | |
780 | def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): | |
543 | 781 | """ |
544 | 782 | Get the state dicts corresponding to a list of events, containing the event_ids |
545 | 783 | of the state events (as opposed to the events themselves) |
546 | 784 | |
547 | 785 | Args: |
548 | 786 | event_ids(list(str)): events whose state should be returned |
549 | types(list[(str, str|None)]|None): List of (type, state_key) tuples | |
550 | which are used to filter the state fetched. If `state_key` is None, | |
551 | all events are returned of the given type. | |
552 | May be None, which matches any key. | |
553 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
554 | list of event types. Other types of events are returned unfiltered. | |
555 | If None, `types` filtering is applied to all events. | |
787 | state_filter (StateFilter): The state filter used to fetch state | |
788 | from the database. | |
556 | 789 | |
557 | 790 | Returns: |
558 | 791 | A deferred dict from event_id -> (type, state_key) -> event_id |
562 | 795 | ) |
563 | 796 | |
564 | 797 | groups = set(itervalues(event_to_groups)) |
565 | group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) | |
798 | group_to_state = yield self._get_state_for_groups(groups, state_filter) | |
566 | 799 | |
567 | 800 | event_to_state = { |
568 | 801 | event_id: group_to_state[group] |
572 | 805 | defer.returnValue({event: event_to_state[event] for event in event_ids}) |
573 | 806 | |
574 | 807 | @defer.inlineCallbacks |
575 | def get_state_for_event(self, event_id, types=None, filtered_types=None): | |
808 | def get_state_for_event(self, event_id, state_filter=StateFilter.all()): | |
576 | 809 | """ |
577 | 810 | Get the state dict corresponding to a particular event |
578 | 811 | |
579 | 812 | Args: |
580 | 813 | event_id(str): event whose state should be returned |
581 | types(list[(str, str|None)]|None): List of (type, state_key) tuples | |
582 | which are used to filter the state fetched. If `state_key` is None, | |
583 | all events are returned of the given type. | |
584 | May be None, which matches any key. | |
585 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
586 | list of event types. Other types of events are returned unfiltered. | |
587 | If None, `types` filtering is applied to all events. | |
814 | state_filter (StateFilter): The state filter used to fetch state | |
815 | from the database. | |
588 | 816 | |
589 | 817 | Returns: |
590 | 818 | A deferred dict from (type, state_key) -> state_event |
591 | 819 | """ |
592 | state_map = yield self.get_state_for_events([event_id], types, filtered_types) | |
820 | state_map = yield self.get_state_for_events([event_id], state_filter) | |
593 | 821 | defer.returnValue(state_map[event_id]) |
594 | 822 | |
595 | 823 | @defer.inlineCallbacks |
596 | def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): | |
824 | def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): | |
597 | 825 | """ |
598 | 826 | Get the state dict corresponding to a particular event |
599 | 827 | |
600 | 828 | Args: |
601 | 829 | event_id(str): event whose state should be returned |
602 | types(list[(str, str|None)]|None): List of (type, state_key) tuples | |
603 | which are used to filter the state fetched. If `state_key` is None, | |
604 | all events are returned of the given type. | |
605 | May be None, which matches any key. | |
606 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
607 | list of event types. Other types of events are returned unfiltered. | |
608 | If None, `types` filtering is applied to all events. | |
830 | state_filter (StateFilter): The state filter used to fetch state | |
831 | from the database. | |
609 | 832 | |
610 | 833 | Returns: |
611 | 834 | A deferred dict from (type, state_key) -> state_event |
612 | 835 | """ |
613 | state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) | |
836 | state_map = yield self.get_state_ids_for_events([event_id], state_filter) | |
614 | 837 | defer.returnValue(state_map[event_id]) |
615 | 838 | |
616 | 839 | @cached(max_entries=50000) |
641 | 864 | |
642 | 865 | defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) |
643 | 866 | |
644 | def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): | |
867 | def _get_state_for_group_using_cache(self, cache, group, state_filter): | |
645 | 868 | """Checks if group is in cache. See `_get_state_for_groups` |
646 | 869 | |
647 | 870 | Args: |
648 | 871 | cache(DictionaryCache): the state group cache to use |
649 | 872 | group(int): The state group to lookup |
650 | types(list[str, str|None]): List of 2-tuples of the form | |
651 | (`type`, `state_key`), where a `state_key` of `None` matches all | |
652 | state_keys for the `type`. | |
653 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
654 | list of event types. Other types of events are returned unfiltered. | |
655 | If None, `types` filtering is applied to all events. | |
873 | state_filter (StateFilter): The state filter used to fetch state | |
874 | from the database. | |
656 | 875 | |
657 | 876 | Returns 2-tuple (`state_dict`, `got_all`). |
658 | 877 | `got_all` is a bool indicating if we successfully retrieved all |
661 | 880 | """ |
662 | 881 | is_all, known_absent, state_dict_ids = cache.get(group) |
663 | 882 | |
664 | type_to_key = {} | |
883 | if is_all or state_filter.is_full(): | |
884 | # Either we have everything or want everything, either way | |
885 | # `is_all` tells us whether we've gotten everything. | |
886 | return state_filter.filter_state(state_dict_ids), is_all | |
665 | 887 | |
666 | 888 | # tracks whether any of our requested types are missing from the cache |
667 | 889 | missing_types = False |
668 | 890 | |
669 | for typ, state_key in types: | |
670 | key = (typ, state_key) | |
671 | ||
672 | if ( | |
673 | state_key is None or | |
674 | (filtered_types is not None and typ not in filtered_types) | |
675 | ): | |
676 | type_to_key[typ] = None | |
677 | # we mark the type as missing from the cache because | |
678 | # when the cache was populated it might have been done with a | |
679 | # restricted set of state_keys, so the wildcard will not work | |
680 | # and the cache may be incomplete. | |
681 | missing_types = True | |
682 | else: | |
683 | if type_to_key.get(typ, object()) is not None: | |
684 | type_to_key.setdefault(typ, set()).add(state_key) | |
685 | ||
891 | if state_filter.has_wildcards(): | |
892 | # We don't know if we fetched all the state keys for the types in | |
893 | # the filter that are wildcards, so we have to assume that we may | |
894 | # have missed some. | |
895 | missing_types = True | |
896 | else: | |
897 | # There aren't any wild cards, so `concrete_types()` returns the | |
898 | # complete list of event types we're wanting. | |
899 | for key in state_filter.concrete_types(): | |
686 | 900 | if key not in state_dict_ids and key not in known_absent: |
687 | 901 | missing_types = True |
688 | ||
689 | sentinel = object() | |
690 | ||
691 | def include(typ, state_key): | |
692 | valid_state_keys = type_to_key.get(typ, sentinel) | |
693 | if valid_state_keys is sentinel: | |
694 | return filtered_types is not None and typ not in filtered_types | |
695 | if valid_state_keys is None: | |
696 | return True | |
697 | if state_key in valid_state_keys: | |
698 | return True | |
699 | return False | |
700 | ||
701 | got_all = is_all | |
702 | if not got_all: | |
703 | # the cache is incomplete. We may still have got all the results we need, if | |
704 | # we don't have any wildcards in the match list. | |
705 | if not missing_types and filtered_types is None: | |
706 | got_all = True | |
707 | ||
708 | return { | |
709 | k: v for k, v in iteritems(state_dict_ids) | |
710 | if include(k[0], k[1]) | |
711 | }, got_all | |
712 | ||
713 | def _get_all_state_from_cache(self, cache, group): | |
714 | """Checks if group is in cache. See `_get_state_for_groups` | |
715 | ||
716 | Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool | |
717 | indicating if we successfully retrieved all requests state from the | |
718 | cache, if False we need to query the DB for the missing state. | |
719 | ||
720 | Args: | |
721 | cache(DictionaryCache): the state group cache to use | |
722 | group: The state group to lookup | |
723 | """ | |
724 | is_all, _, state_dict_ids = cache.get(group) | |
725 | ||
726 | return state_dict_ids, is_all | |
902 | break | |
903 | ||
904 | return state_filter.filter_state(state_dict_ids), not missing_types | |
727 | 905 | |
728 | 906 | @defer.inlineCallbacks |
729 | def _get_state_for_groups(self, groups, types=None, filtered_types=None): | |
907 | def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): | |
730 | 908 | """Gets the state at each of a list of state groups, optionally |
731 | 909 | filtering by type/state_key |
732 | 910 | |
733 | 911 | Args: |
734 | 912 | groups (iterable[int]): list of state groups for which we want |
735 | 913 | to get the state. |
736 | types (None|iterable[(str, None|str)]): | |
737 | indicates the state type/keys required. If None, the whole | |
738 | state is fetched and returned. | |
739 | ||
740 | Otherwise, each entry should be a `(type, state_key)` tuple to | |
741 | include in the response. A `state_key` of None is a wildcard | |
742 | meaning that we require all state with that type. | |
743 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
744 | list of event types. Other types of events are returned unfiltered. | |
745 | If None, `types` filtering is applied to all events. | |
746 | ||
914 | state_filter (StateFilter): The state filter used to fetch state | |
915 | from the database. | |
747 | 916 | Returns: |
748 | 917 | Deferred[dict[int, dict[tuple[str, str], str]]]: |
749 | 918 | dict of state_group_id -> (dict of (type, state_key) -> event id) |
750 | 919 | """ |
751 | if types is not None: | |
752 | non_member_types = [t for t in types if t[0] != EventTypes.Member] | |
753 | ||
754 | if filtered_types is not None and EventTypes.Member not in filtered_types: | |
755 | # we want all of the membership events | |
756 | member_types = None | |
757 | else: | |
758 | member_types = [t for t in types if t[0] == EventTypes.Member] | |
759 | ||
760 | else: | |
761 | non_member_types = None | |
762 | member_types = None | |
763 | ||
764 | non_member_state = yield self._get_state_for_groups_using_cache( | |
765 | groups, self._state_group_cache, non_member_types, filtered_types, | |
766 | ) | |
767 | # XXX: we could skip this entirely if member_types is [] | |
768 | member_state = yield self._get_state_for_groups_using_cache( | |
769 | # we set filtered_types=None as member_state only ever contain members. | |
770 | groups, self._state_group_members_cache, member_types, None, | |
771 | ) | |
772 | ||
773 | state = non_member_state | |
920 | ||
921 | member_filter, non_member_filter = state_filter.get_member_split() | |
922 | ||
923 | # Now we look them up in the member and non-member caches | |
924 | non_member_state, incomplete_groups_nm, = ( | |
925 | yield self._get_state_for_groups_using_cache( | |
926 | groups, self._state_group_cache, | |
927 | state_filter=non_member_filter, | |
928 | ) | |
929 | ) | |
930 | ||
931 | member_state, incomplete_groups_m, = ( | |
932 | yield self._get_state_for_groups_using_cache( | |
933 | groups, self._state_group_members_cache, | |
934 | state_filter=member_filter, | |
935 | ) | |
936 | ) | |
937 | ||
938 | state = dict(non_member_state) | |
774 | 939 | for group in groups: |
775 | 940 | state[group].update(member_state[group]) |
776 | 941 | |
942 | # Now fetch any missing groups from the database | |
943 | ||
944 | incomplete_groups = incomplete_groups_m | incomplete_groups_nm | |
945 | ||
946 | if not incomplete_groups: | |
947 | defer.returnValue(state) | |
948 | ||
949 | cache_sequence_nm = self._state_group_cache.sequence | |
950 | cache_sequence_m = self._state_group_members_cache.sequence | |
951 | ||
952 | # Help the cache hit ratio by expanding the filter a bit | |
953 | db_state_filter = state_filter.return_expanded() | |
954 | ||
955 | group_to_state_dict = yield self._get_state_groups_from_groups( | |
956 | list(incomplete_groups), | |
957 | state_filter=db_state_filter, | |
958 | ) | |
959 | ||
960 | # Now lets update the caches | |
961 | self._insert_into_cache( | |
962 | group_to_state_dict, | |
963 | db_state_filter, | |
964 | cache_seq_num_members=cache_sequence_m, | |
965 | cache_seq_num_non_members=cache_sequence_nm, | |
966 | ) | |
967 | ||
968 | # And finally update the result dict, by filtering out any extra | |
969 | # stuff we pulled out of the database. | |
970 | for group, group_state_dict in iteritems(group_to_state_dict): | |
971 | # We just replace any existing entries, as we will have loaded | |
972 | # everything we need from the database anyway. | |
973 | state[group] = state_filter.filter_state(group_state_dict) | |
974 | ||
777 | 975 | defer.returnValue(state) |
778 | 976 | |
779 | @defer.inlineCallbacks | |
780 | 977 | def _get_state_for_groups_using_cache( |
781 | self, groups, cache, types=None, filtered_types=None | |
978 | self, groups, cache, state_filter, | |
782 | 979 | ): |
783 | 980 | """Gets the state at each of a list of state groups, optionally |
784 | 981 | filtering by type/state_key, querying from a specific cache. |
789 | 986 | cache (DictionaryCache): the cache of group ids to state dicts which |
790 | 987 | we will pass through - either the normal state cache or the specific |
791 | 988 | members state cache. |
792 | types (None|iterable[(str, None|str)]): | |
793 | indicates the state type/keys required. If None, the whole | |
794 | state is fetched and returned. | |
795 | ||
796 | Otherwise, each entry should be a `(type, state_key)` tuple to | |
797 | include in the response. A `state_key` of None is a wildcard | |
798 | meaning that we require all state with that type. | |
799 | filtered_types(list[str]|None): Only apply filtering via `types` to this | |
800 | list of event types. Other types of events are returned unfiltered. | |
801 | If None, `types` filtering is applied to all events. | |
802 | ||
803 | Returns: | |
804 | Deferred[dict[int, dict[tuple[str, str], str]]]: | |
805 | dict of state_group_id -> (dict of (type, state_key) -> event id) | |
806 | """ | |
807 | if types: | |
808 | types = frozenset(types) | |
989 | state_filter (StateFilter): The state filter used to fetch state | |
990 | from the database. | |
991 | ||
992 | Returns: | |
993 | tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of | |
994 | dict of state_group_id -> (dict of (type, state_key) -> event id) | |
995 | of entries in the cache, and the state group ids either missing | |
996 | from the cache or incomplete. | |
997 | """ | |
809 | 998 | results = {} |
810 | missing_groups = [] | |
811 | if types is not None: | |
812 | for group in set(groups): | |
813 | state_dict_ids, got_all = self._get_some_state_from_cache( | |
814 | cache, group, types, filtered_types | |
815 | ) | |
816 | results[group] = state_dict_ids | |
817 | ||
818 | if not got_all: | |
819 | missing_groups.append(group) | |
999 | incomplete_groups = set() | |
1000 | for group in set(groups): | |
1001 | state_dict_ids, got_all = self._get_state_for_group_using_cache( | |
1002 | cache, group, state_filter | |
1003 | ) | |
1004 | results[group] = state_dict_ids | |
1005 | ||
1006 | if not got_all: | |
1007 | incomplete_groups.add(group) | |
1008 | ||
1009 | return results, incomplete_groups | |
1010 | ||
1011 | def _insert_into_cache(self, group_to_state_dict, state_filter, | |
1012 | cache_seq_num_members, cache_seq_num_non_members): | |
1013 | """Inserts results from querying the database into the relevant cache. | |
1014 | ||
1015 | Args: | |
1016 | group_to_state_dict (dict): The new entries pulled from database. | |
1017 | Map from state group to state dict | |
1018 | state_filter (StateFilter): The state filter used to fetch state | |
1019 | from the database. | |
1020 | cache_seq_num_members (int): Sequence number of member cache since | |
1021 | last lookup in cache | |
1022 | cache_seq_num_non_members (int): Sequence number of member cache since | |
1023 | last lookup in cache | |
1024 | """ | |
1025 | ||
1026 | # We need to work out which types we've fetched from the DB for the | |
1027 | # member vs non-member caches. This should be as accurate as possible, | |
1028 | # but can be an underestimate (e.g. when we have wild cards) | |
1029 | ||
1030 | member_filter, non_member_filter = state_filter.get_member_split() | |
1031 | if member_filter.is_full(): | |
1032 | # We fetched all member events | |
1033 | member_types = None | |
820 | 1034 | else: |
821 | for group in set(groups): | |
822 | state_dict_ids, got_all = self._get_all_state_from_cache( | |
823 | cache, group | |
824 | ) | |
825 | ||
826 | results[group] = state_dict_ids | |
827 | ||
828 | if not got_all: | |
829 | missing_groups.append(group) | |
830 | ||
831 | if missing_groups: | |
832 | # Okay, so we have some missing_types, let's fetch them. | |
833 | cache_seq_num = cache.sequence | |
834 | ||
835 | # the DictionaryCache knows if it has *all* the state, but | |
836 | # does not know if it has all of the keys of a particular type, | |
837 | # which makes wildcard lookups expensive unless we have a complete | |
838 | # cache. Hence, if we are doing a wildcard lookup, populate the | |
839 | # cache fully so that we can do an efficient lookup next time. | |
840 | ||
841 | if filtered_types or (types and any(k is None for (t, k) in types)): | |
842 | types_to_fetch = None | |
843 | else: | |
844 | types_to_fetch = types | |
845 | ||
846 | group_to_state_dict = yield self._get_state_groups_from_groups( | |
847 | missing_groups, types_to_fetch, cache == self._state_group_members_cache, | |
848 | ) | |
849 | ||
850 | for group, group_state_dict in iteritems(group_to_state_dict): | |
851 | state_dict = results[group] | |
852 | ||
853 | # update the result, filtering by `types`. | |
854 | if types: | |
855 | for k, v in iteritems(group_state_dict): | |
856 | (typ, _) = k | |
857 | if ( | |
858 | (k in types or (typ, None) in types) or | |
859 | (filtered_types and typ not in filtered_types) | |
860 | ): | |
861 | state_dict[k] = v | |
1035 | # `concrete_types()` will only return a subset when there are wild | |
1036 | # cards in the filter, but that's fine. | |
1037 | member_types = member_filter.concrete_types() | |
1038 | ||
1039 | if non_member_filter.is_full(): | |
1040 | # We fetched all non member events | |
1041 | non_member_types = None | |
1042 | else: | |
1043 | non_member_types = non_member_filter.concrete_types() | |
1044 | ||
1045 | for group, group_state_dict in iteritems(group_to_state_dict): | |
1046 | state_dict_members = {} | |
1047 | state_dict_non_members = {} | |
1048 | ||
1049 | for k, v in iteritems(group_state_dict): | |
1050 | if k[0] == EventTypes.Member: | |
1051 | state_dict_members[k] = v | |
862 | 1052 | else: |
863 | state_dict.update(group_state_dict) | |
864 | ||
865 | # update the cache with all the things we fetched from the | |
866 | # database. | |
867 | cache.update( | |
868 | cache_seq_num, | |
869 | key=group, | |
870 | value=group_state_dict, | |
871 | fetched_keys=types_to_fetch, | |
872 | ) | |
873 | ||
874 | defer.returnValue(results) | |
1053 | state_dict_non_members[k] = v | |
1054 | ||
1055 | self._state_group_members_cache.update( | |
1056 | cache_seq_num_members, | |
1057 | key=group, | |
1058 | value=state_dict_members, | |
1059 | fetched_keys=member_types, | |
1060 | ) | |
1061 | ||
1062 | self._state_group_cache.update( | |
1063 | cache_seq_num_non_members, | |
1064 | key=group, | |
1065 | value=state_dict_non_members, | |
1066 | fetched_keys=non_member_types, | |
1067 | ) | |
875 | 1068 | |
876 | 1069 | def store_state_group(self, event_id, room_id, prev_group, delta_ids, |
877 | 1070 | current_state_ids): |
1180 | 1373 | continue |
1181 | 1374 | |
1182 | 1375 | prev_state = self._get_state_groups_from_groups_txn( |
1183 | txn, [prev_group], types=None | |
1376 | txn, [prev_group], | |
1184 | 1377 | ) |
1185 | 1378 | prev_state = prev_state[prev_group] |
1186 | 1379 | |
1187 | 1380 | curr_state = self._get_state_groups_from_groups_txn( |
1188 | txn, [state_group], types=None | |
1381 | txn, [state_group], | |
1189 | 1382 | ) |
1190 | 1383 | curr_state = curr_state[state_group] |
1191 | 1384 |
29 | 29 | # py2 sqlite has buffer hardcoded as only binary type, so we must use it, |
30 | 30 | # despite being deprecated and removed in favor of memoryview |
31 | 31 | if six.PY2: |
32 | db_binary_type = buffer | |
32 | db_binary_type = six.moves.builtins.buffer | |
33 | 33 | else: |
34 | 34 | db_binary_type = memoryview |
35 | 35 |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import logging |
16 | import re | |
16 | 17 | from itertools import islice |
17 | 18 | |
18 | 19 | import attr |
137 | 138 | |
138 | 139 | if not consumeErrors: |
139 | 140 | return failure |
141 | ||
142 | ||
143 | def glob_to_regex(glob): | |
144 | """Converts a glob to a compiled regex object. | |
145 | ||
146 | The regex is anchored at the beginning and end of the string. | |
147 | ||
148 | Args: | |
149 | glob (str) | |
150 | ||
151 | Returns: | |
152 | re.RegexObject | |
153 | """ | |
154 | res = '' | |
155 | for c in glob: | |
156 | if c == '*': | |
157 | res = res + '.*' | |
158 | elif c == '?': | |
159 | res = res + '.' | |
160 | else: | |
161 | res = res + re.escape(c) | |
162 | ||
163 | # \A anchors at start of string, \Z at end of string | |
164 | return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import logging |
16 | ||
17 | from six import integer_types | |
16 | 18 | |
17 | 19 | from sortedcontainers import SortedDict |
18 | 20 | |
46 | 48 | def has_entity_changed(self, entity, stream_pos): |
47 | 49 | """Returns True if the entity may have been updated since stream_pos |
48 | 50 | """ |
49 | assert type(stream_pos) is int or type(stream_pos) is long | |
51 | assert type(stream_pos) in integer_types | |
50 | 52 | |
51 | 53 | if stream_pos < self._earliest_known_stream_pos: |
52 | 54 | self.metrics.inc_misses() |
24 | 24 | import logging |
25 | 25 | import threading |
26 | 26 | |
27 | from twisted.internet import defer | |
27 | from twisted.internet import defer, threads | |
28 | 28 | |
29 | 29 | logger = logging.getLogger(__name__) |
30 | 30 | |
561 | 561 | return result |
562 | 562 | |
563 | 563 | |
564 | # modules to ignore in `logcontext_tracer` | |
565 | _to_ignore = [ | |
566 | "synapse.util.logcontext", | |
567 | "synapse.http.server", | |
568 | "synapse.storage._base", | |
569 | "synapse.util.async_helpers", | |
570 | ] | |
571 | ||
572 | ||
573 | def logcontext_tracer(frame, event, arg): | |
574 | """A tracer that logs whenever a logcontext "unexpectedly" changes within | |
575 | a function. Probably inaccurate. | |
576 | ||
577 | Use by calling `sys.settrace(logcontext_tracer)` in the main thread. | |
578 | """ | |
579 | if event == 'call': | |
580 | name = frame.f_globals["__name__"] | |
581 | if name.startswith("synapse"): | |
582 | if name == "synapse.util.logcontext": | |
583 | if frame.f_code.co_name in ["__enter__", "__exit__"]: | |
584 | tracer = frame.f_back.f_trace | |
585 | if tracer: | |
586 | tracer.just_changed = True | |
587 | ||
588 | tracer = frame.f_trace | |
589 | if tracer: | |
590 | return tracer | |
591 | ||
592 | if not any(name.startswith(ig) for ig in _to_ignore): | |
593 | return LineTracer() | |
594 | ||
595 | ||
596 | class LineTracer(object): | |
597 | __slots__ = ["context", "just_changed"] | |
598 | ||
599 | def __init__(self): | |
600 | self.context = LoggingContext.current_context() | |
601 | self.just_changed = False | |
602 | ||
603 | def __call__(self, frame, event, arg): | |
604 | if event in 'line': | |
605 | if self.just_changed: | |
606 | self.context = LoggingContext.current_context() | |
607 | self.just_changed = False | |
608 | else: | |
609 | c = LoggingContext.current_context() | |
610 | if c != self.context: | |
611 | logger.info( | |
612 | "Context changed! %s -> %s, %s, %s", | |
613 | self.context, c, | |
614 | frame.f_code.co_filename, frame.f_lineno | |
615 | ) | |
616 | self.context = c | |
617 | ||
618 | return self | |
564 | def defer_to_thread(reactor, f, *args, **kwargs): | |
565 | """ | |
566 | Calls the function `f` using a thread from the reactor's default threadpool and | |
567 | returns the result as a Deferred. | |
568 | ||
569 | Creates a new logcontext for `f`, which is created as a child of the current | |
570 | logcontext (so its CPU usage metrics will get attributed to the current | |
571 | logcontext). `f` should preserve the logcontext it is given. | |
572 | ||
573 | The result deferred follows the Synapse logcontext rules: you should `yield` | |
574 | on it. | |
575 | ||
576 | Args: | |
577 | reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread | |
578 | the Deferred will be invoked, and whose threadpool we should use for the | |
579 | function. | |
580 | ||
581 | Normally this will be hs.get_reactor(). | |
582 | ||
583 | f (callable): The function to call. | |
584 | ||
585 | args: positional arguments to pass to f. | |
586 | ||
587 | kwargs: keyword arguments to pass to f. | |
588 | ||
589 | Returns: | |
590 | Deferred: A Deferred which fires a callback with the result of `f`, or an | |
591 | errback if `f` throws an exception. | |
592 | """ | |
593 | return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) | |
594 | ||
595 | ||
596 | def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): | |
597 | """ | |
598 | A wrapper for twisted.internet.threads.deferToThreadpool, which handles | |
599 | logcontexts correctly. | |
600 | ||
601 | Calls the function `f` using a thread from the given threadpool and returns | |
602 | the result as a Deferred. | |
603 | ||
604 | Creates a new logcontext for `f`, which is created as a child of the current | |
605 | logcontext (so its CPU usage metrics will get attributed to the current | |
606 | logcontext). `f` should preserve the logcontext it is given. | |
607 | ||
608 | The result deferred follows the Synapse logcontext rules: you should `yield` | |
609 | on it. | |
610 | ||
611 | Args: | |
612 | reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread | |
613 | the Deferred will be invoked. Normally this will be hs.get_reactor(). | |
614 | ||
615 | threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for | |
616 | running `f`. Normally this will be hs.get_reactor().getThreadPool(). | |
617 | ||
618 | f (callable): The function to call. | |
619 | ||
620 | args: positional arguments to pass to f. | |
621 | ||
622 | kwargs: keyword arguments to pass to f. | |
623 | ||
624 | Returns: | |
625 | Deferred: A Deferred which fires a callback with the result of `f`, or an | |
626 | errback if `f` throws an exception. | |
627 | """ | |
628 | logcontext = LoggingContext.current_context() | |
629 | ||
630 | def g(): | |
631 | with LoggingContext(parent_context=logcontext): | |
632 | return f(*args, **kwargs) | |
633 | ||
634 | return make_deferred_yieldable( | |
635 | threads.deferToThreadPool(reactor, threadpool, g) | |
636 | ) |
69 | 69 | Returns: |
70 | 70 | twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` |
71 | 71 | """ |
72 | if not isinstance(password, bytes): | |
73 | password = password.encode('ascii') | |
72 | 74 | |
73 | 75 | checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( |
74 | 76 | **{username: password} |
81 | 83 | ) |
82 | 84 | |
83 | 85 | factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) |
84 | factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY) | |
85 | factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY) | |
86 | factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) | |
87 | factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) | |
86 | 88 | |
87 | 89 | return factory |
22 | 22 | |
23 | 23 | from synapse.api.constants import EventTypes, Membership |
24 | 24 | from synapse.events.utils import prune_event |
25 | from synapse.storage.state import StateFilter | |
25 | 26 | from synapse.types import get_domain_from_id |
26 | 27 | |
27 | 28 | logger = logging.getLogger(__name__) |
71 | 72 | ) |
72 | 73 | event_id_to_state = yield store.get_state_for_events( |
73 | 74 | frozenset(e.event_id for e in events), |
74 | types=types, | |
75 | state_filter=StateFilter.from_types(types), | |
75 | 76 | ) |
76 | 77 | |
77 | 78 | ignore_dict_content = yield store.get_global_account_data_by_type_for_user( |
272 | 273 | # need to check membership (as we know the server is in the room). |
273 | 274 | event_to_state_ids = yield store.get_state_ids_for_events( |
274 | 275 | frozenset(e.event_id for e in events), |
275 | types=( | |
276 | (EventTypes.RoomHistoryVisibility, ""), | |
276 | state_filter=StateFilter.from_types( | |
277 | types=((EventTypes.RoomHistoryVisibility, ""),), | |
277 | 278 | ) |
278 | 279 | ) |
279 | 280 | |
313 | 314 | # of the history vis and membership state at those events. |
314 | 315 | event_to_state_ids = yield store.get_state_ids_for_events( |
315 | 316 | frozenset(e.event_id for e in events), |
316 | types=( | |
317 | (EventTypes.RoomHistoryVisibility, ""), | |
318 | (EventTypes.Member, None), | |
317 | state_filter=StateFilter.from_types( | |
318 | types=( | |
319 | (EventTypes.RoomHistoryVisibility, ""), | |
320 | (EventTypes.Member, None), | |
321 | ), | |
319 | 322 | ) |
320 | 323 | ) |
321 | 324 |
0 | 0 | #!/usr/bin/env python |
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | # Copyright 2014-2016 OpenMarket Ltd |
3 | # Copyright 2018 New Vector Ltd | |
3 | 4 | # |
4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 6 | # you may not use this file except in compliance with the License. |
47 | 48 | |
48 | 49 | |
49 | 50 | def write(message, colour=NORMAL, stream=sys.stdout): |
50 | if colour == NORMAL: | |
51 | # Lets check if we're writing to a TTY before colouring | |
52 | should_colour = False | |
53 | try: | |
54 | should_colour = stream.isatty() | |
55 | except AttributeError: | |
56 | # Just in case `isatty` isn't defined on everything. The python | |
57 | # docs are incredibly vague. | |
58 | pass | |
59 | ||
60 | if not should_colour: | |
51 | 61 | stream.write(message + "\n") |
52 | 62 | else: |
53 | 63 | stream.write(colour + message + NORMAL + "\n") |
65 | 75 | |
66 | 76 | try: |
67 | 77 | subprocess.check_call(args) |
68 | write("started synapse.app.homeserver(%r)" % | |
69 | (configfile,), colour=GREEN) | |
78 | write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) | |
70 | 79 | except subprocess.CalledProcessError as e: |
71 | 80 | write( |
72 | 81 | "error starting (exit code: %d); see above for logs" % e.returncode, |
75 | 84 | |
76 | 85 | |
77 | 86 | def start_worker(app, configfile, worker_configfile): |
78 | args = [ | |
79 | "python", "-B", | |
80 | "-m", app, | |
81 | "-c", configfile, | |
82 | "-c", worker_configfile | |
83 | ] | |
87 | args = [sys.executable, "-B", "-m", app, "-c", configfile, "-c", worker_configfile] | |
84 | 88 | |
85 | 89 | try: |
86 | 90 | subprocess.check_call(args) |
87 | 91 | write("started %s(%r)" % (app, worker_configfile), colour=GREEN) |
88 | 92 | except subprocess.CalledProcessError as e: |
89 | 93 | write( |
90 | "error starting %s(%r) (exit code: %d); see above for logs" % ( | |
91 | app, worker_configfile, e.returncode, | |
92 | ), | |
94 | "error starting %s(%r) (exit code: %d); see above for logs" | |
95 | % (app, worker_configfile, e.returncode), | |
93 | 96 | colour=RED, |
94 | 97 | ) |
95 | 98 | |
109 | 112 | abort("Cannot stop %s: Unknown error" % (app,)) |
110 | 113 | |
111 | 114 | |
112 | Worker = collections.namedtuple("Worker", [ | |
113 | "app", "configfile", "pidfile", "cache_factor", "cache_factors", | |
114 | ]) | |
115 | Worker = collections.namedtuple( | |
116 | "Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"] | |
117 | ) | |
115 | 118 | |
116 | 119 | |
117 | 120 | def main(): |
130 | 133 | help="the homeserver config file, defaults to homeserver.yaml", |
131 | 134 | ) |
132 | 135 | parser.add_argument( |
133 | "-w", "--worker", | |
134 | metavar="WORKERCONFIG", | |
135 | help="start or stop a single worker", | |
136 | ) | |
137 | parser.add_argument( | |
138 | "-a", "--all-processes", | |
136 | "-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker" | |
137 | ) | |
138 | parser.add_argument( | |
139 | "-a", | |
140 | "--all-processes", | |
139 | 141 | metavar="WORKERCONFIGDIR", |
140 | 142 | help="start or stop all the workers in the given directory" |
141 | " and the main synapse process", | |
143 | " and the main synapse process", | |
142 | 144 | ) |
143 | 145 | |
144 | 146 | options = parser.parse_args() |
145 | 147 | |
146 | 148 | if options.worker and options.all_processes: |
147 | write( | |
148 | 'Cannot use "--worker" with "--all-processes"', | |
149 | stream=sys.stderr | |
150 | ) | |
149 | write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr) | |
151 | 150 | sys.exit(1) |
152 | 151 | |
153 | 152 | configfile = options.configfile |
156 | 155 | write( |
157 | 156 | "No config file found\n" |
158 | 157 | "To generate a config file, run '%s -c %s --generate-config" |
159 | " --server-name=<server name>'\n" % ( | |
160 | " ".join(SYNAPSE), options.configfile | |
161 | ), | |
158 | " --server-name=<server name>'\n" % (" ".join(SYNAPSE), options.configfile), | |
162 | 159 | stream=sys.stderr, |
163 | 160 | ) |
164 | 161 | sys.exit(1) |
183 | 180 | worker_configfile = options.worker |
184 | 181 | if not os.path.exists(worker_configfile): |
185 | 182 | write( |
186 | "No worker config found at %r" % (worker_configfile,), | |
187 | stream=sys.stderr, | |
183 | "No worker config found at %r" % (worker_configfile,), stream=sys.stderr | |
188 | 184 | ) |
189 | 185 | sys.exit(1) |
190 | 186 | worker_configfiles.append(worker_configfile) |
200 | 196 | stream=sys.stderr, |
201 | 197 | ) |
202 | 198 | sys.exit(1) |
203 | worker_configfiles.extend(sorted(glob.glob( | |
204 | os.path.join(worker_configdir, "*.yaml") | |
205 | ))) | |
199 | worker_configfiles.extend( | |
200 | sorted(glob.glob(os.path.join(worker_configdir, "*.yaml"))) | |
201 | ) | |
206 | 202 | |
207 | 203 | workers = [] |
208 | 204 | for worker_configfile in worker_configfiles: |
212 | 208 | if worker_app == "synapse.app.homeserver": |
213 | 209 | # We need to special case all of this to pick up options that may |
214 | 210 | # be set in the main config file or in this worker config file. |
215 | worker_pidfile = ( | |
216 | worker_config.get("pid_file") | |
217 | or pidfile | |
218 | ) | |
219 | worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor | |
211 | worker_pidfile = worker_config.get("pid_file") or pidfile | |
212 | worker_cache_factor = ( | |
213 | worker_config.get("synctl_cache_factor") or cache_factor | |
214 | ) | |
220 | 215 | worker_cache_factors = ( |
221 | worker_config.get("synctl_cache_factors") | |
222 | or cache_factors | |
216 | worker_config.get("synctl_cache_factors") or cache_factors | |
223 | 217 | ) |
224 | 218 | daemonize = worker_config.get("daemonize") or config.get("daemonize") |
225 | 219 | assert daemonize, "Main process must have daemonize set to true" |
228 | 222 | for key in worker_config: |
229 | 223 | if key == "worker_app": # But we allow worker_app |
230 | 224 | continue |
231 | assert not key.startswith("worker_"), \ | |
232 | "Main process cannot use worker_* config" | |
225 | assert not key.startswith( | |
226 | "worker_" | |
227 | ), "Main process cannot use worker_* config" | |
233 | 228 | else: |
234 | 229 | worker_pidfile = worker_config["worker_pid_file"] |
235 | 230 | worker_daemonize = worker_config["worker_daemonize"] |
236 | 231 | assert worker_daemonize, "In config %r: expected '%s' to be True" % ( |
237 | worker_configfile, "worker_daemonize") | |
232 | worker_configfile, | |
233 | "worker_daemonize", | |
234 | ) | |
238 | 235 | worker_cache_factor = worker_config.get("synctl_cache_factor") |
239 | 236 | worker_cache_factors = worker_config.get("synctl_cache_factors", {}) |
240 | workers.append(Worker( | |
241 | worker_app, worker_configfile, worker_pidfile, worker_cache_factor, | |
242 | worker_cache_factors, | |
243 | )) | |
237 | workers.append( | |
238 | Worker( | |
239 | worker_app, | |
240 | worker_configfile, | |
241 | worker_pidfile, | |
242 | worker_cache_factor, | |
243 | worker_cache_factors, | |
244 | ) | |
245 | ) | |
244 | 246 | |
245 | 247 | action = options.action |
246 | 248 |
59 | 59 | invalid_filters = [ |
60 | 60 | {"boom": {}}, |
61 | 61 | {"account_data": "Hello World"}, |
62 | {"event_fields": ["\\foo"]}, | |
62 | {"event_fields": [r"\\foo"]}, | |
63 | 63 | {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, |
64 | 64 | {"event_format": "other"}, |
65 | 65 | {"room": {"not_rooms": ["#foo:pik-test"]}}, |
108 | 108 | "event_format": "client", |
109 | 109 | "event_fields": ["type", "content", "sender"], |
110 | 110 | }, |
111 | ||
112 | # a single backslash should be permitted (though it is debatable whether | |
113 | # it should be permitted before anything other than `.`, and what that | |
114 | # actually means) | |
115 | # | |
116 | # (note that event_fields is implemented in | |
117 | # synapse.events.utils.serialize_event, and so whether this actually works | |
118 | # is tested elsewhere. We just want to check that it is allowed through the | |
119 | # filter validation) | |
120 | {"event_fields": [r"foo\.bar"]}, | |
111 | 121 | ] |
112 | 122 | for filter in valid_filters: |
113 | 123 | try: |
66 | 66 | with open(log_config_file) as f: |
67 | 67 | config = f.read() |
68 | 68 | # find the 'filename' line |
69 | matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) | |
69 | matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M) | |
70 | 70 | self.assertEqual(1, len(matches)) |
71 | 71 | self.assertEqual(matches[0], expected) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2018 New Vector Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import yaml | |
16 | ||
17 | from synapse.config.room_directory import RoomDirectoryConfig | |
18 | ||
19 | from tests import unittest | |
20 | ||
21 | ||
22 | class RoomDirectoryConfigTestCase(unittest.TestCase): | |
23 | def test_alias_creation_acl(self): | |
24 | config = yaml.load(""" | |
25 | alias_creation_rules: | |
26 | - user_id: "*bob*" | |
27 | alias: "*" | |
28 | action: "deny" | |
29 | - user_id: "*" | |
30 | alias: "#unofficial_*" | |
31 | action: "allow" | |
32 | - user_id: "@foo*:example.com" | |
33 | alias: "*" | |
34 | action: "allow" | |
35 | - user_id: "@gah:example.com" | |
36 | alias: "#goo:example.com" | |
37 | action: "allow" | |
38 | """) | |
39 | ||
40 | rd_config = RoomDirectoryConfig() | |
41 | rd_config.read_config(config) | |
42 | ||
43 | self.assertFalse(rd_config.is_alias_creation_allowed( | |
44 | user_id="@bob:example.com", | |
45 | alias="#test:example.com", | |
46 | )) | |
47 | ||
48 | self.assertTrue(rd_config.is_alias_creation_allowed( | |
49 | user_id="@test:example.com", | |
50 | alias="#unofficial_st:example.com", | |
51 | )) | |
52 | ||
53 | self.assertTrue(rd_config.is_alias_creation_allowed( | |
54 | user_id="@foobar:example.com", | |
55 | alias="#test:example.com", | |
56 | )) | |
57 | ||
58 | self.assertTrue(rd_config.is_alias_creation_allowed( | |
59 | user_id="@gah:example.com", | |
60 | alias="#goo:example.com", | |
61 | )) | |
62 | ||
63 | self.assertFalse(rd_config.is_alias_creation_allowed( | |
64 | user_id="@test:example.com", | |
65 | alias="#test:example.com", | |
66 | )) |
155 | 155 | room_id="!foo:bar", |
156 | 156 | content={"key.with.dots": {}}, |
157 | 157 | ), |
158 | ["content.key\.with\.dots"], | |
158 | [r"content.key\.with\.dots"], | |
159 | 159 | ), |
160 | 160 | {"content": {"key.with.dots": {}}}, |
161 | 161 | ) |
171 | 171 | "nested.dot.key": {"leaf.key": 42, "not_me_either": 1}, |
172 | 172 | }, |
173 | 173 | ), |
174 | ["content.nested\.dot\.key.leaf\.key"], | |
174 | [r"content.nested\.dot\.key.leaf\.key"], | |
175 | 175 | ), |
176 | 176 | {"content": {"nested.dot.key": {"leaf.key": 42}}}, |
177 | 177 | ) |
17 | 17 | |
18 | 18 | from twisted.internet import defer |
19 | 19 | |
20 | from synapse.config.room_directory import RoomDirectoryConfig | |
20 | 21 | from synapse.handlers.directory import DirectoryHandler |
22 | from synapse.rest.client.v1 import directory, room | |
21 | 23 | from synapse.types import RoomAlias |
22 | 24 | |
23 | 25 | from tests import unittest |
101 | 103 | ) |
102 | 104 | |
103 | 105 | self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) |
106 | ||
107 | ||
108 | class TestCreateAliasACL(unittest.HomeserverTestCase): | |
109 | user_id = "@test:test" | |
110 | ||
111 | servlets = [directory.register_servlets, room.register_servlets] | |
112 | ||
113 | def prepare(self, hs, reactor, clock): | |
114 | # We cheekily override the config to add custom alias creation rules | |
115 | config = {} | |
116 | config["alias_creation_rules"] = [ | |
117 | { | |
118 | "user_id": "*", | |
119 | "alias": "#unofficial_*", | |
120 | "action": "allow", | |
121 | } | |
122 | ] | |
123 | ||
124 | rd_config = RoomDirectoryConfig() | |
125 | rd_config.read_config(config) | |
126 | ||
127 | self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed | |
128 | ||
129 | return hs | |
130 | ||
131 | def test_denied(self): | |
132 | room_id = self.helper.create_room_as(self.user_id) | |
133 | ||
134 | request, channel = self.make_request( | |
135 | "PUT", | |
136 | b"directory/room/%23test%3Atest", | |
137 | ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), | |
138 | ) | |
139 | self.render(request) | |
140 | self.assertEquals(403, channel.code, channel.result) | |
141 | ||
142 | def test_allowed(self): | |
143 | room_id = self.helper.create_room_as(self.user_id) | |
144 | ||
145 | request, channel = self.make_request( | |
146 | "PUT", | |
147 | b"directory/room/%23unofficial_test%3Atest", | |
148 | ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), | |
149 | ) | |
150 | self.render(request) | |
151 | self.assertEquals(200, channel.code, channel.result) |
18 | 18 | |
19 | 19 | from synapse.api.errors import ResourceLimitError |
20 | 20 | from synapse.handlers.register import RegistrationHandler |
21 | from synapse.types import UserID, create_requester | |
21 | from synapse.types import RoomAlias, UserID, create_requester | |
22 | 22 | |
23 | 23 | from tests.utils import setup_test_homeserver |
24 | 24 | |
40 | 40 | self.mock_captcha_client = Mock() |
41 | 41 | self.hs = yield setup_test_homeserver( |
42 | 42 | self.addCleanup, |
43 | handlers=None, | |
44 | http_client=None, | |
45 | 43 | expire_access_token=True, |
46 | profile_handler=Mock(), | |
47 | 44 | ) |
48 | 45 | self.macaroon_generator = Mock( |
49 | 46 | generate_access_token=Mock(return_value='secret') |
50 | 47 | ) |
51 | 48 | self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) |
52 | self.hs.handlers = RegistrationHandlers(self.hs) | |
53 | 49 | self.handler = self.hs.get_handlers().registration_handler |
54 | 50 | self.store = self.hs.get_datastore() |
55 | 51 | self.hs.config.max_mau_value = 50 |
56 | 52 | self.lots_of_users = 100 |
57 | 53 | self.small_number_of_users = 1 |
58 | 54 | |
55 | self.requester = create_requester("@requester:test") | |
56 | ||
59 | 57 | @defer.inlineCallbacks |
60 | 58 | def test_user_is_created_and_logged_in_if_doesnt_exist(self): |
61 | local_part = "someone" | |
62 | display_name = "someone" | |
63 | user_id = "@someone:test" | |
64 | requester = create_requester("@as:test") | |
59 | frank = UserID.from_string("@frank:test") | |
60 | user_id = frank.to_string() | |
61 | requester = create_requester(user_id) | |
65 | 62 | result_user_id, result_token = yield self.handler.get_or_create_user( |
66 | requester, local_part, display_name | |
63 | requester, frank.localpart, "Frankie" | |
67 | 64 | ) |
68 | 65 | self.assertEquals(result_user_id, user_id) |
69 | 66 | self.assertEquals(result_token, 'secret') |
77 | 74 | token="jkv;g498752-43gj['eamb!-5", |
78 | 75 | password_hash=None, |
79 | 76 | ) |
80 | local_part = "frank" | |
81 | display_name = "Frank" | |
82 | user_id = "@frank:test" | |
83 | requester = create_requester("@as:test") | |
77 | local_part = frank.localpart | |
78 | user_id = frank.to_string() | |
79 | requester = create_requester(user_id) | |
84 | 80 | result_user_id, result_token = yield self.handler.get_or_create_user( |
85 | requester, local_part, display_name | |
81 | requester, local_part, None | |
86 | 82 | ) |
87 | 83 | self.assertEquals(result_user_id, user_id) |
88 | 84 | self.assertEquals(result_token, 'secret') |
91 | 87 | def test_mau_limits_when_disabled(self): |
92 | 88 | self.hs.config.limit_usage_by_mau = False |
93 | 89 | # Ensure does not throw exception |
94 | yield self.handler.get_or_create_user("requester", 'a', "display_name") | |
90 | yield self.handler.get_or_create_user(self.requester, 'a', "display_name") | |
95 | 91 | |
96 | 92 | @defer.inlineCallbacks |
97 | 93 | def test_get_or_create_user_mau_not_blocked(self): |
100 | 96 | return_value=defer.succeed(self.hs.config.max_mau_value - 1) |
101 | 97 | ) |
102 | 98 | # Ensure does not throw exception |
103 | yield self.handler.get_or_create_user("@user:server", 'c', "User") | |
99 | yield self.handler.get_or_create_user(self.requester, 'c', "User") | |
104 | 100 | |
105 | 101 | @defer.inlineCallbacks |
106 | 102 | def test_get_or_create_user_mau_blocked(self): |
109 | 105 | return_value=defer.succeed(self.lots_of_users) |
110 | 106 | ) |
111 | 107 | with self.assertRaises(ResourceLimitError): |
112 | yield self.handler.get_or_create_user("requester", 'b', "display_name") | |
108 | yield self.handler.get_or_create_user(self.requester, 'b', "display_name") | |
113 | 109 | |
114 | 110 | self.store.get_monthly_active_count = Mock( |
115 | 111 | return_value=defer.succeed(self.hs.config.max_mau_value) |
116 | 112 | ) |
117 | 113 | with self.assertRaises(ResourceLimitError): |
118 | yield self.handler.get_or_create_user("requester", 'b', "display_name") | |
114 | yield self.handler.get_or_create_user(self.requester, 'b', "display_name") | |
119 | 115 | |
120 | 116 | @defer.inlineCallbacks |
121 | 117 | def test_register_mau_blocked(self): |
146 | 142 | ) |
147 | 143 | with self.assertRaises(ResourceLimitError): |
148 | 144 | yield self.handler.register_saml2(localpart="local_part") |
145 | ||
146 | @defer.inlineCallbacks | |
147 | def test_auto_create_auto_join_rooms(self): | |
148 | room_alias_str = "#room:test" | |
149 | self.hs.config.auto_join_rooms = [room_alias_str] | |
150 | res = yield self.handler.register(localpart='jeff') | |
151 | rooms = yield self.store.get_rooms_for_user(res[0]) | |
152 | ||
153 | directory_handler = self.hs.get_handlers().directory_handler | |
154 | room_alias = RoomAlias.from_string(room_alias_str) | |
155 | room_id = yield directory_handler.get_association(room_alias) | |
156 | ||
157 | self.assertTrue(room_id['room_id'] in rooms) | |
158 | self.assertEqual(len(rooms), 1) | |
159 | ||
160 | @defer.inlineCallbacks | |
161 | def test_auto_create_auto_join_rooms_with_no_rooms(self): | |
162 | self.hs.config.auto_join_rooms = [] | |
163 | frank = UserID.from_string("@frank:test") | |
164 | res = yield self.handler.register(frank.localpart) | |
165 | self.assertEqual(res[0], frank.to_string()) | |
166 | rooms = yield self.store.get_rooms_for_user(res[0]) | |
167 | self.assertEqual(len(rooms), 0) | |
168 | ||
169 | @defer.inlineCallbacks | |
170 | def test_auto_create_auto_join_where_room_is_another_domain(self): | |
171 | self.hs.config.auto_join_rooms = ["#room:another"] | |
172 | frank = UserID.from_string("@frank:test") | |
173 | res = yield self.handler.register(frank.localpart) | |
174 | self.assertEqual(res[0], frank.to_string()) | |
175 | rooms = yield self.store.get_rooms_for_user(res[0]) | |
176 | self.assertEqual(len(rooms), 0) | |
177 | ||
178 | @defer.inlineCallbacks | |
179 | def test_auto_create_auto_join_where_auto_create_is_false(self): | |
180 | self.hs.config.autocreate_auto_join_rooms = False | |
181 | room_alias_str = "#room:test" | |
182 | self.hs.config.auto_join_rooms = [room_alias_str] | |
183 | res = yield self.handler.register(localpart='jeff') | |
184 | rooms = yield self.store.get_rooms_for_user(res[0]) | |
185 | self.assertEqual(len(rooms), 0) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2018 New Vector Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from synapse.handlers.room_list import RoomListNextBatch | |
16 | ||
17 | import tests.unittest | |
18 | import tests.utils | |
19 | ||
20 | ||
21 | class RoomListTestCase(tests.unittest.TestCase): | |
22 | """ Tests RoomList's RoomListNextBatch. """ | |
23 | ||
24 | def setUp(self): | |
25 | pass | |
26 | ||
27 | def test_check_read_batch_tokens(self): | |
28 | batch_token = RoomListNextBatch( | |
29 | stream_ordering="abcdef", | |
30 | public_room_stream_id="123", | |
31 | current_limit=20, | |
32 | direction_is_forward=True, | |
33 | ).to_token() | |
34 | next_batch = RoomListNextBatch.from_token(batch_token) | |
35 | self.assertEquals(next_batch.stream_ordering, "abcdef") | |
36 | self.assertEquals(next_batch.public_room_stream_id, "123") | |
37 | self.assertEquals(next_batch.current_limit, 20) | |
38 | self.assertEquals(next_batch.direction_is_forward, True) |
22 | 22 | from twisted.internet import defer |
23 | 23 | |
24 | 24 | from synapse.api.constants import Membership |
25 | from synapse.rest.client.v1 import room | |
25 | from synapse.rest.client.v1 import admin, login, room | |
26 | 26 | |
27 | 27 | from tests import unittest |
28 | 28 | |
798 | 798 | self.assertEquals(token, channel.json_body['start']) |
799 | 799 | self.assertTrue("chunk" in channel.json_body) |
800 | 800 | self.assertTrue("end" in channel.json_body) |
801 | ||
802 | ||
803 | class RoomSearchTestCase(unittest.HomeserverTestCase): | |
804 | servlets = [ | |
805 | admin.register_servlets, | |
806 | room.register_servlets, | |
807 | login.register_servlets, | |
808 | ] | |
809 | user_id = True | |
810 | hijack_auth = False | |
811 | ||
812 | def prepare(self, reactor, clock, hs): | |
813 | ||
814 | # Register the user who does the searching | |
815 | self.user_id = self.register_user("user", "pass") | |
816 | self.access_token = self.login("user", "pass") | |
817 | ||
818 | # Register the user who sends the message | |
819 | self.other_user_id = self.register_user("otheruser", "pass") | |
820 | self.other_access_token = self.login("otheruser", "pass") | |
821 | ||
822 | # Create a room | |
823 | self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) | |
824 | ||
825 | # Invite the other person | |
826 | self.helper.invite( | |
827 | room=self.room, | |
828 | src=self.user_id, | |
829 | tok=self.access_token, | |
830 | targ=self.other_user_id, | |
831 | ) | |
832 | ||
833 | # The other user joins | |
834 | self.helper.join( | |
835 | room=self.room, user=self.other_user_id, tok=self.other_access_token | |
836 | ) | |
837 | ||
838 | def test_finds_message(self): | |
839 | """ | |
840 | The search functionality will search for content in messages if asked to | |
841 | do so. | |
842 | """ | |
843 | # The other user sends some messages | |
844 | self.helper.send(self.room, body="Hi!", tok=self.other_access_token) | |
845 | self.helper.send(self.room, body="There!", tok=self.other_access_token) | |
846 | ||
847 | request, channel = self.make_request( | |
848 | "POST", | |
849 | "/search?access_token=%s" % (self.access_token,), | |
850 | { | |
851 | "search_categories": { | |
852 | "room_events": {"keys": ["content.body"], "search_term": "Hi"} | |
853 | } | |
854 | }, | |
855 | ) | |
856 | self.render(request) | |
857 | ||
858 | # Check we get the results we expect -- one search result, of the sent | |
859 | # messages | |
860 | self.assertEqual(channel.code, 200) | |
861 | results = channel.json_body["search_categories"]["room_events"] | |
862 | self.assertEqual(results["count"], 1) | |
863 | self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") | |
864 | ||
865 | # No context was requested, so we should get none. | |
866 | self.assertEqual(results["results"][0]["context"], {}) | |
867 | ||
868 | def test_include_context(self): | |
869 | """ | |
870 | When event_context includes include_profile, profile information will be | |
871 | included in the search response. | |
872 | """ | |
873 | # The other user sends some messages | |
874 | self.helper.send(self.room, body="Hi!", tok=self.other_access_token) | |
875 | self.helper.send(self.room, body="There!", tok=self.other_access_token) | |
876 | ||
877 | request, channel = self.make_request( | |
878 | "POST", | |
879 | "/search?access_token=%s" % (self.access_token,), | |
880 | { | |
881 | "search_categories": { | |
882 | "room_events": { | |
883 | "keys": ["content.body"], | |
884 | "search_term": "Hi", | |
885 | "event_context": {"include_profile": True}, | |
886 | } | |
887 | } | |
888 | }, | |
889 | ) | |
890 | self.render(request) | |
891 | ||
892 | # Check we get the results we expect -- one search result, of the sent | |
893 | # messages | |
894 | self.assertEqual(channel.code, 200) | |
895 | results = channel.json_body["search_categories"]["room_events"] | |
896 | self.assertEqual(results["count"], 1) | |
897 | self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") | |
898 | ||
899 | # We should get context info, like the two users, and the display names. | |
900 | context = results["results"][0]["context"] | |
901 | self.assertEqual(len(context["profile_info"].keys()), 2) | |
902 | self.assertEqual( | |
903 | context["profile_info"][self.other_user_id]["displayname"], "otheruser" | |
904 | ) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2018 New Vector | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from mock import Mock | |
16 | ||
17 | from synapse._scripts.register_new_matrix_user import request_registration | |
18 | ||
19 | from tests.unittest import TestCase | |
20 | ||
21 | ||
22 | class RegisterTestCase(TestCase): | |
23 | def test_success(self): | |
24 | """ | |
25 | The script will fetch a nonce, and then generate a MAC with it, and then | |
26 | post that MAC. | |
27 | """ | |
28 | ||
29 | def get(url, verify=None): | |
30 | r = Mock() | |
31 | r.status_code = 200 | |
32 | r.json = lambda: {"nonce": "a"} | |
33 | return r | |
34 | ||
35 | def post(url, json=None, verify=None): | |
36 | # Make sure we are sent the correct info | |
37 | self.assertEqual(json["username"], "user") | |
38 | self.assertEqual(json["password"], "pass") | |
39 | self.assertEqual(json["nonce"], "a") | |
40 | # We want a 40-char hex MAC | |
41 | self.assertEqual(len(json["mac"]), 40) | |
42 | ||
43 | r = Mock() | |
44 | r.status_code = 200 | |
45 | return r | |
46 | ||
47 | requests = Mock() | |
48 | requests.get = get | |
49 | requests.post = post | |
50 | ||
51 | # The fake stdout will be written here | |
52 | out = [] | |
53 | err_code = [] | |
54 | ||
55 | request_registration( | |
56 | "user", | |
57 | "pass", | |
58 | "matrix.org", | |
59 | "shared", | |
60 | admin=False, | |
61 | requests=requests, | |
62 | _print=out.append, | |
63 | exit=err_code.append, | |
64 | ) | |
65 | ||
66 | # We should get the success message making sure everything is OK. | |
67 | self.assertIn("Success!", out) | |
68 | ||
69 | # sys.exit shouldn't have been called. | |
70 | self.assertEqual(err_code, []) | |
71 | ||
72 | def test_failure_nonce(self): | |
73 | """ | |
74 | If the script fails to fetch a nonce, it throws an error and quits. | |
75 | """ | |
76 | ||
77 | def get(url, verify=None): | |
78 | r = Mock() | |
79 | r.status_code = 404 | |
80 | r.reason = "Not Found" | |
81 | r.json = lambda: {"not": "error"} | |
82 | return r | |
83 | ||
84 | requests = Mock() | |
85 | requests.get = get | |
86 | ||
87 | # The fake stdout will be written here | |
88 | out = [] | |
89 | err_code = [] | |
90 | ||
91 | request_registration( | |
92 | "user", | |
93 | "pass", | |
94 | "matrix.org", | |
95 | "shared", | |
96 | admin=False, | |
97 | requests=requests, | |
98 | _print=out.append, | |
99 | exit=err_code.append, | |
100 | ) | |
101 | ||
102 | # Exit was called | |
103 | self.assertEqual(err_code, [1]) | |
104 | ||
105 | # We got an error message | |
106 | self.assertIn("ERROR! Received 404 Not Found", out) | |
107 | self.assertNotIn("Success!", out) | |
108 | ||
109 | def test_failure_post(self): | |
110 | """ | |
111 | The script will fetch a nonce, and then if the final POST fails, will | |
112 | report an error and quit. | |
113 | """ | |
114 | ||
115 | def get(url, verify=None): | |
116 | r = Mock() | |
117 | r.status_code = 200 | |
118 | r.json = lambda: {"nonce": "a"} | |
119 | return r | |
120 | ||
121 | def post(url, json=None, verify=None): | |
122 | # Make sure we are sent the correct info | |
123 | self.assertEqual(json["username"], "user") | |
124 | self.assertEqual(json["password"], "pass") | |
125 | self.assertEqual(json["nonce"], "a") | |
126 | # We want a 40-char hex MAC | |
127 | self.assertEqual(len(json["mac"]), 40) | |
128 | ||
129 | r = Mock() | |
130 | # Then 500 because we're jerks | |
131 | r.status_code = 500 | |
132 | r.reason = "Broken" | |
133 | return r | |
134 | ||
135 | requests = Mock() | |
136 | requests.get = get | |
137 | requests.post = post | |
138 | ||
139 | # The fake stdout will be written here | |
140 | out = [] | |
141 | err_code = [] | |
142 | ||
143 | request_registration( | |
144 | "user", | |
145 | "pass", | |
146 | "matrix.org", | |
147 | "shared", | |
148 | admin=False, | |
149 | requests=requests, | |
150 | _print=out.append, | |
151 | exit=err_code.append, | |
152 | ) | |
153 | ||
154 | # Exit was called | |
155 | self.assertEqual(err_code, [1]) | |
156 | ||
157 | # We got an error message | |
158 | self.assertIn("ERROR! Received 500 Broken", out) | |
159 | self.assertNotIn("Success!", out) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2018 New Vector Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import itertools | |
16 | ||
17 | from six.moves import zip | |
18 | ||
19 | import attr | |
20 | ||
21 | from synapse.api.constants import EventTypes, JoinRules, Membership | |
22 | from synapse.event_auth import auth_types_for_event | |
23 | from synapse.events import FrozenEvent | |
24 | from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store | |
25 | from synapse.types import EventID | |
26 | ||
27 | from tests import unittest | |
28 | ||
29 | ALICE = "@alice:example.com" | |
30 | BOB = "@bob:example.com" | |
31 | CHARLIE = "@charlie:example.com" | |
32 | EVELYN = "@evelyn:example.com" | |
33 | ZARA = "@zara:example.com" | |
34 | ||
35 | ROOM_ID = "!test:example.com" | |
36 | ||
37 | MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN} | |
38 | MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN} | |
39 | ||
40 | ||
41 | ORIGIN_SERVER_TS = 0 | |
42 | ||
43 | ||
44 | class FakeEvent(object): | |
45 | """A fake event we use as a convenience. | |
46 | ||
47 | NOTE: Again as a convenience we use "node_ids" rather than event_ids to | |
48 | refer to events. The event_id has node_id as localpart and example.com | |
49 | as domain. | |
50 | """ | |
51 | def __init__(self, id, sender, type, state_key, content): | |
52 | self.node_id = id | |
53 | self.event_id = EventID(id, "example.com").to_string() | |
54 | self.sender = sender | |
55 | self.type = type | |
56 | self.state_key = state_key | |
57 | self.content = content | |
58 | ||
59 | def to_event(self, auth_events, prev_events): | |
60 | """Given the auth_events and prev_events, convert to a Frozen Event | |
61 | ||
62 | Args: | |
63 | auth_events (list[str]): list of event_ids | |
64 | prev_events (list[str]): list of event_ids | |
65 | ||
66 | Returns: | |
67 | FrozenEvent | |
68 | """ | |
69 | global ORIGIN_SERVER_TS | |
70 | ||
71 | ts = ORIGIN_SERVER_TS | |
72 | ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1 | |
73 | ||
74 | event_dict = { | |
75 | "auth_events": [(a, {}) for a in auth_events], | |
76 | "prev_events": [(p, {}) for p in prev_events], | |
77 | "event_id": self.node_id, | |
78 | "sender": self.sender, | |
79 | "type": self.type, | |
80 | "content": self.content, | |
81 | "origin_server_ts": ts, | |
82 | "room_id": ROOM_ID, | |
83 | } | |
84 | ||
85 | if self.state_key is not None: | |
86 | event_dict["state_key"] = self.state_key | |
87 | ||
88 | return FrozenEvent(event_dict) | |
89 | ||
90 | ||
91 | # All graphs start with this set of events | |
92 | INITIAL_EVENTS = [ | |
93 | FakeEvent( | |
94 | id="CREATE", | |
95 | sender=ALICE, | |
96 | type=EventTypes.Create, | |
97 | state_key="", | |
98 | content={"creator": ALICE}, | |
99 | ), | |
100 | FakeEvent( | |
101 | id="IMA", | |
102 | sender=ALICE, | |
103 | type=EventTypes.Member, | |
104 | state_key=ALICE, | |
105 | content=MEMBERSHIP_CONTENT_JOIN, | |
106 | ), | |
107 | FakeEvent( | |
108 | id="IPOWER", | |
109 | sender=ALICE, | |
110 | type=EventTypes.PowerLevels, | |
111 | state_key="", | |
112 | content={"users": {ALICE: 100}}, | |
113 | ), | |
114 | FakeEvent( | |
115 | id="IJR", | |
116 | sender=ALICE, | |
117 | type=EventTypes.JoinRules, | |
118 | state_key="", | |
119 | content={"join_rule": JoinRules.PUBLIC}, | |
120 | ), | |
121 | FakeEvent( | |
122 | id="IMB", | |
123 | sender=BOB, | |
124 | type=EventTypes.Member, | |
125 | state_key=BOB, | |
126 | content=MEMBERSHIP_CONTENT_JOIN, | |
127 | ), | |
128 | FakeEvent( | |
129 | id="IMC", | |
130 | sender=CHARLIE, | |
131 | type=EventTypes.Member, | |
132 | state_key=CHARLIE, | |
133 | content=MEMBERSHIP_CONTENT_JOIN, | |
134 | ), | |
135 | FakeEvent( | |
136 | id="IMZ", | |
137 | sender=ZARA, | |
138 | type=EventTypes.Member, | |
139 | state_key=ZARA, | |
140 | content=MEMBERSHIP_CONTENT_JOIN, | |
141 | ), | |
142 | FakeEvent( | |
143 | id="START", | |
144 | sender=ZARA, | |
145 | type=EventTypes.Message, | |
146 | state_key=None, | |
147 | content={}, | |
148 | ), | |
149 | FakeEvent( | |
150 | id="END", | |
151 | sender=ZARA, | |
152 | type=EventTypes.Message, | |
153 | state_key=None, | |
154 | content={}, | |
155 | ), | |
156 | ] | |
157 | ||
158 | INITIAL_EDGES = [ | |
159 | "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", | |
160 | ] | |
161 | ||
162 | ||
163 | class StateTestCase(unittest.TestCase): | |
164 | def test_ban_vs_pl(self): | |
165 | events = [ | |
166 | FakeEvent( | |
167 | id="PA", | |
168 | sender=ALICE, | |
169 | type=EventTypes.PowerLevels, | |
170 | state_key="", | |
171 | content={ | |
172 | "users": { | |
173 | ALICE: 100, | |
174 | BOB: 50, | |
175 | } | |
176 | }, | |
177 | ), | |
178 | FakeEvent( | |
179 | id="MA", | |
180 | sender=ALICE, | |
181 | type=EventTypes.Member, | |
182 | state_key=ALICE, | |
183 | content={"membership": Membership.JOIN}, | |
184 | ), | |
185 | FakeEvent( | |
186 | id="MB", | |
187 | sender=ALICE, | |
188 | type=EventTypes.Member, | |
189 | state_key=BOB, | |
190 | content={"membership": Membership.BAN}, | |
191 | ), | |
192 | FakeEvent( | |
193 | id="PB", | |
194 | sender=BOB, | |
195 | type=EventTypes.PowerLevels, | |
196 | state_key='', | |
197 | content={ | |
198 | "users": { | |
199 | ALICE: 100, | |
200 | BOB: 50, | |
201 | }, | |
202 | }, | |
203 | ), | |
204 | ] | |
205 | ||
206 | edges = [ | |
207 | ["END", "MB", "MA", "PA", "START"], | |
208 | ["END", "PB", "PA"], | |
209 | ] | |
210 | ||
211 | expected_state_ids = ["PA", "MA", "MB"] | |
212 | ||
213 | self.do_check(events, edges, expected_state_ids) | |
214 | ||
215 | def test_join_rule_evasion(self): | |
216 | events = [ | |
217 | FakeEvent( | |
218 | id="JR", | |
219 | sender=ALICE, | |
220 | type=EventTypes.JoinRules, | |
221 | state_key="", | |
222 | content={"join_rules": JoinRules.PRIVATE}, | |
223 | ), | |
224 | FakeEvent( | |
225 | id="ME", | |
226 | sender=EVELYN, | |
227 | type=EventTypes.Member, | |
228 | state_key=EVELYN, | |
229 | content={"membership": Membership.JOIN}, | |
230 | ), | |
231 | ] | |
232 | ||
233 | edges = [ | |
234 | ["END", "JR", "START"], | |
235 | ["END", "ME", "START"], | |
236 | ] | |
237 | ||
238 | expected_state_ids = ["JR"] | |
239 | ||
240 | self.do_check(events, edges, expected_state_ids) | |
241 | ||
242 | def test_offtopic_pl(self): | |
243 | events = [ | |
244 | FakeEvent( | |
245 | id="PA", | |
246 | sender=ALICE, | |
247 | type=EventTypes.PowerLevels, | |
248 | state_key="", | |
249 | content={ | |
250 | "users": { | |
251 | ALICE: 100, | |
252 | BOB: 50, | |
253 | } | |
254 | }, | |
255 | ), | |
256 | FakeEvent( | |
257 | id="PB", | |
258 | sender=BOB, | |
259 | type=EventTypes.PowerLevels, | |
260 | state_key='', | |
261 | content={ | |
262 | "users": { | |
263 | ALICE: 100, | |
264 | BOB: 50, | |
265 | CHARLIE: 50, | |
266 | }, | |
267 | }, | |
268 | ), | |
269 | FakeEvent( | |
270 | id="PC", | |
271 | sender=CHARLIE, | |
272 | type=EventTypes.PowerLevels, | |
273 | state_key='', | |
274 | content={ | |
275 | "users": { | |
276 | ALICE: 100, | |
277 | BOB: 50, | |
278 | CHARLIE: 0, | |
279 | }, | |
280 | }, | |
281 | ), | |
282 | ] | |
283 | ||
284 | edges = [ | |
285 | ["END", "PC", "PB", "PA", "START"], | |
286 | ["END", "PA"], | |
287 | ] | |
288 | ||
289 | expected_state_ids = ["PC"] | |
290 | ||
291 | self.do_check(events, edges, expected_state_ids) | |
292 | ||
293 | def test_topic_basic(self): | |
294 | events = [ | |
295 | FakeEvent( | |
296 | id="T1", | |
297 | sender=ALICE, | |
298 | type=EventTypes.Topic, | |
299 | state_key="", | |
300 | content={}, | |
301 | ), | |
302 | FakeEvent( | |
303 | id="PA1", | |
304 | sender=ALICE, | |
305 | type=EventTypes.PowerLevels, | |
306 | state_key='', | |
307 | content={ | |
308 | "users": { | |
309 | ALICE: 100, | |
310 | BOB: 50, | |
311 | }, | |
312 | }, | |
313 | ), | |
314 | FakeEvent( | |
315 | id="T2", | |
316 | sender=ALICE, | |
317 | type=EventTypes.Topic, | |
318 | state_key="", | |
319 | content={}, | |
320 | ), | |
321 | FakeEvent( | |
322 | id="PA2", | |
323 | sender=ALICE, | |
324 | type=EventTypes.PowerLevels, | |
325 | state_key='', | |
326 | content={ | |
327 | "users": { | |
328 | ALICE: 100, | |
329 | BOB: 0, | |
330 | }, | |
331 | }, | |
332 | ), | |
333 | FakeEvent( | |
334 | id="PB", | |
335 | sender=BOB, | |
336 | type=EventTypes.PowerLevels, | |
337 | state_key='', | |
338 | content={ | |
339 | "users": { | |
340 | ALICE: 100, | |
341 | BOB: 50, | |
342 | }, | |
343 | }, | |
344 | ), | |
345 | FakeEvent( | |
346 | id="T3", | |
347 | sender=BOB, | |
348 | type=EventTypes.Topic, | |
349 | state_key="", | |
350 | content={}, | |
351 | ), | |
352 | ] | |
353 | ||
354 | edges = [ | |
355 | ["END", "PA2", "T2", "PA1", "T1", "START"], | |
356 | ["END", "T3", "PB", "PA1"], | |
357 | ] | |
358 | ||
359 | expected_state_ids = ["PA2", "T2"] | |
360 | ||
361 | self.do_check(events, edges, expected_state_ids) | |
362 | ||
363 | def test_topic_reset(self): | |
364 | events = [ | |
365 | FakeEvent( | |
366 | id="T1", | |
367 | sender=ALICE, | |
368 | type=EventTypes.Topic, | |
369 | state_key="", | |
370 | content={}, | |
371 | ), | |
372 | FakeEvent( | |
373 | id="PA", | |
374 | sender=ALICE, | |
375 | type=EventTypes.PowerLevels, | |
376 | state_key='', | |
377 | content={ | |
378 | "users": { | |
379 | ALICE: 100, | |
380 | BOB: 50, | |
381 | }, | |
382 | }, | |
383 | ), | |
384 | FakeEvent( | |
385 | id="T2", | |
386 | sender=BOB, | |
387 | type=EventTypes.Topic, | |
388 | state_key="", | |
389 | content={}, | |
390 | ), | |
391 | FakeEvent( | |
392 | id="MB", | |
393 | sender=ALICE, | |
394 | type=EventTypes.Member, | |
395 | state_key=BOB, | |
396 | content={"membership": Membership.BAN}, | |
397 | ), | |
398 | ] | |
399 | ||
400 | edges = [ | |
401 | ["END", "MB", "T2", "PA", "T1", "START"], | |
402 | ["END", "T1"], | |
403 | ] | |
404 | ||
405 | expected_state_ids = ["T1", "MB", "PA"] | |
406 | ||
407 | self.do_check(events, edges, expected_state_ids) | |
408 | ||
409 | def test_topic(self): | |
410 | events = [ | |
411 | FakeEvent( | |
412 | id="T1", | |
413 | sender=ALICE, | |
414 | type=EventTypes.Topic, | |
415 | state_key="", | |
416 | content={}, | |
417 | ), | |
418 | FakeEvent( | |
419 | id="PA1", | |
420 | sender=ALICE, | |
421 | type=EventTypes.PowerLevels, | |
422 | state_key='', | |
423 | content={ | |
424 | "users": { | |
425 | ALICE: 100, | |
426 | BOB: 50, | |
427 | }, | |
428 | }, | |
429 | ), | |
430 | FakeEvent( | |
431 | id="T2", | |
432 | sender=ALICE, | |
433 | type=EventTypes.Topic, | |
434 | state_key="", | |
435 | content={}, | |
436 | ), | |
437 | FakeEvent( | |
438 | id="PA2", | |
439 | sender=ALICE, | |
440 | type=EventTypes.PowerLevels, | |
441 | state_key='', | |
442 | content={ | |
443 | "users": { | |
444 | ALICE: 100, | |
445 | BOB: 0, | |
446 | }, | |
447 | }, | |
448 | ), | |
449 | FakeEvent( | |
450 | id="PB", | |
451 | sender=BOB, | |
452 | type=EventTypes.PowerLevels, | |
453 | state_key='', | |
454 | content={ | |
455 | "users": { | |
456 | ALICE: 100, | |
457 | BOB: 50, | |
458 | }, | |
459 | }, | |
460 | ), | |
461 | FakeEvent( | |
462 | id="T3", | |
463 | sender=BOB, | |
464 | type=EventTypes.Topic, | |
465 | state_key="", | |
466 | content={}, | |
467 | ), | |
468 | FakeEvent( | |
469 | id="MZ1", | |
470 | sender=ZARA, | |
471 | type=EventTypes.Message, | |
472 | state_key=None, | |
473 | content={}, | |
474 | ), | |
475 | FakeEvent( | |
476 | id="T4", | |
477 | sender=ALICE, | |
478 | type=EventTypes.Topic, | |
479 | state_key="", | |
480 | content={}, | |
481 | ), | |
482 | ] | |
483 | ||
484 | edges = [ | |
485 | ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], | |
486 | ["END", "MZ1", "T3", "PB", "PA1"], | |
487 | ] | |
488 | ||
489 | expected_state_ids = ["T4", "PA2"] | |
490 | ||
491 | self.do_check(events, edges, expected_state_ids) | |
492 | ||
493 | def do_check(self, events, edges, expected_state_ids): | |
494 | """Take a list of events and edges and calculate the state of the | |
495 | graph at END, and asserts it matches `expected_state_ids` | |
496 | ||
497 | Args: | |
498 | events (list[FakeEvent]) | |
499 | edges (list[list[str]]): A list of chains of event edges, e.g. | |
500 | `[[A, B, C]]` are edges A->B and B->C. | |
501 | expected_state_ids (list[str]): The expected state at END, (excluding | |
502 | the keys that haven't changed since START). | |
503 | """ | |
504 | # We want to sort the events into topological order for processing. | |
505 | graph = {} | |
506 | ||
507 | # node_id -> FakeEvent | |
508 | fake_event_map = {} | |
509 | ||
510 | for ev in itertools.chain(INITIAL_EVENTS, events): | |
511 | graph[ev.node_id] = set() | |
512 | fake_event_map[ev.node_id] = ev | |
513 | ||
514 | for a, b in pairwise(INITIAL_EDGES): | |
515 | graph[a].add(b) | |
516 | ||
517 | for edge_list in edges: | |
518 | for a, b in pairwise(edge_list): | |
519 | graph[a].add(b) | |
520 | ||
521 | # event_id -> FrozenEvent | |
522 | event_map = {} | |
523 | # node_id -> state | |
524 | state_at_event = {} | |
525 | ||
526 | # We copy the map as the sort consumes the graph | |
527 | graph_copy = {k: set(v) for k, v in graph.items()} | |
528 | ||
529 | for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e): | |
530 | fake_event = fake_event_map[node_id] | |
531 | event_id = fake_event.event_id | |
532 | ||
533 | prev_events = list(graph[node_id]) | |
534 | ||
535 | if len(prev_events) == 0: | |
536 | state_before = {} | |
537 | elif len(prev_events) == 1: | |
538 | state_before = dict(state_at_event[prev_events[0]]) | |
539 | else: | |
540 | state_d = resolve_events_with_store( | |
541 | [state_at_event[n] for n in prev_events], | |
542 | event_map=event_map, | |
543 | state_res_store=TestStateResolutionStore(event_map), | |
544 | ) | |
545 | ||
546 | self.assertTrue(state_d.called) | |
547 | state_before = state_d.result | |
548 | ||
549 | state_after = dict(state_before) | |
550 | if fake_event.state_key is not None: | |
551 | state_after[(fake_event.type, fake_event.state_key)] = event_id | |
552 | ||
553 | auth_types = set(auth_types_for_event(fake_event)) | |
554 | ||
555 | auth_events = [] | |
556 | for key in auth_types: | |
557 | if key in state_before: | |
558 | auth_events.append(state_before[key]) | |
559 | ||
560 | event = fake_event.to_event(auth_events, prev_events) | |
561 | ||
562 | state_at_event[node_id] = state_after | |
563 | event_map[event_id] = event | |
564 | ||
565 | expected_state = {} | |
566 | for node_id in expected_state_ids: | |
567 | # expected_state_ids are node IDs rather than event IDs, | |
568 | # so we have to convert | |
569 | event_id = EventID(node_id, "example.com").to_string() | |
570 | event = event_map[event_id] | |
571 | ||
572 | key = (event.type, event.state_key) | |
573 | ||
574 | expected_state[key] = event_id | |
575 | ||
576 | start_state = state_at_event["START"] | |
577 | end_state = { | |
578 | key: value | |
579 | for key, value in state_at_event["END"].items() | |
580 | if key in expected_state or start_state.get(key) != value | |
581 | } | |
582 | ||
583 | self.assertEqual(expected_state, end_state) | |
584 | ||
585 | ||
586 | class LexicographicalTestCase(unittest.TestCase): | |
587 | def test_simple(self): | |
588 | graph = { | |
589 | "l": {"o"}, | |
590 | "m": {"n", "o"}, | |
591 | "n": {"o"}, | |
592 | "o": set(), | |
593 | "p": {"o"}, | |
594 | } | |
595 | ||
596 | res = list(lexicographical_topological_sort(graph, key=lambda x: x)) | |
597 | ||
598 | self.assertEqual(["o", "l", "n", "m", "p"], res) | |
599 | ||
600 | ||
601 | def pairwise(iterable): | |
602 | "s -> (s0,s1), (s1,s2), (s2, s3), ..." | |
603 | a, b = itertools.tee(iterable) | |
604 | next(b, None) | |
605 | return zip(a, b) | |
606 | ||
607 | ||
608 | @attr.s | |
609 | class TestStateResolutionStore(object): | |
610 | event_map = attr.ib() | |
611 | ||
612 | def get_events(self, event_ids, allow_rejected=False): | |
613 | """Get events from the database | |
614 | ||
615 | Args: | |
616 | event_ids (list): The event_ids of the events to fetch | |
617 | allow_rejected (bool): If True return rejected events. | |
618 | ||
619 | Returns: | |
620 | Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. | |
621 | """ | |
622 | ||
623 | return { | |
624 | eid: self.event_map[eid] | |
625 | for eid in event_ids | |
626 | if eid in self.event_map | |
627 | } | |
628 | ||
629 | def get_auth_chain(self, event_ids): | |
630 | """Gets the full auth chain for a set of events (including rejected | |
631 | events). | |
632 | ||
633 | Includes the given event IDs in the result. | |
634 | ||
635 | Note that: | |
636 | 1. All events must be state events. | |
637 | 2. For v1 rooms this may not have the full auth chain in the | |
638 | presence of rejected events | |
639 | ||
640 | Args: | |
641 | event_ids (list): The event IDs of the events to fetch the auth | |
642 | chain for. Must be state events. | |
643 | ||
644 | Returns: | |
645 | Deferred[list[str]]: List of event IDs of the auth chain. | |
646 | """ | |
647 | ||
648 | # Simple DFS for auth chain | |
649 | result = set() | |
650 | stack = list(event_ids) | |
651 | while stack: | |
652 | event_id = stack.pop() | |
653 | if event_id in result: | |
654 | continue | |
655 | ||
656 | result.add(event_id) | |
657 | ||
658 | event = self.event_map[event_id] | |
659 | for aid, _ in event.auth_events: | |
660 | stack.append(aid) | |
661 | ||
662 | return list(result) |
51 | 51 | now = int(self.hs.get_clock().time_msec()) |
52 | 52 | self.store.user_add_threepid(user1, "email", user1_email, now, now) |
53 | 53 | self.store.user_add_threepid(user2, "email", user2_email, now, now) |
54 | self.store.initialise_reserved_users(threepids) | |
54 | ||
55 | self.store.runInteraction( | |
56 | "initialise", self.store._initialise_reserved_users, threepids | |
57 | ) | |
55 | 58 | self.pump() |
56 | 59 | |
57 | 60 | active_count = self.store.get_monthly_active_count() |
198 | 201 | {'medium': 'email', 'address': user2_email}, |
199 | 202 | ] |
200 | 203 | self.hs.config.mau_limits_reserved_threepids = threepids |
201 | self.store.initialise_reserved_users(threepids) | |
204 | self.store.runInteraction( | |
205 | "initialise", self.store._initialise_reserved_users, threepids | |
206 | ) | |
207 | ||
202 | 208 | self.pump() |
203 | 209 | count = self.store.get_registered_reserved_users_count() |
204 | 210 | self.assertEquals(self.get_success(count), 0) |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | 19 | from synapse.api.constants import EventTypes, Membership |
20 | from synapse.storage.state import StateFilter | |
20 | 21 | from synapse.types import RoomID, UserID |
21 | 22 | |
22 | 23 | import tests.unittest |
147 | 148 | |
148 | 149 | # check we get the full state as of the final event |
149 | 150 | state = yield self.store.get_state_for_event( |
150 | e5.event_id, None, filtered_types=None | |
151 | e5.event_id, | |
151 | 152 | ) |
152 | 153 | |
153 | 154 | self.assertIsNotNone(e4) |
165 | 166 | |
166 | 167 | # check we can filter to the m.room.name event (with a '' state key) |
167 | 168 | state = yield self.store.get_state_for_event( |
168 | e5.event_id, [(EventTypes.Name, '')], filtered_types=None | |
169 | e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) | |
169 | 170 | ) |
170 | 171 | |
171 | 172 | self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) |
172 | 173 | |
173 | 174 | # check we can filter to the m.room.name event (with a wildcard None state key) |
174 | 175 | state = yield self.store.get_state_for_event( |
175 | e5.event_id, [(EventTypes.Name, None)], filtered_types=None | |
176 | e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) | |
176 | 177 | ) |
177 | 178 | |
178 | 179 | self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) |
179 | 180 | |
180 | 181 | # check we can grab the m.room.member events (with a wildcard None state key) |
181 | 182 | state = yield self.store.get_state_for_event( |
182 | e5.event_id, [(EventTypes.Member, None)], filtered_types=None | |
183 | e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) | |
183 | 184 | ) |
184 | 185 | |
185 | 186 | self.assertStateMapEqual( |
186 | 187 | {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state |
187 | 188 | ) |
188 | 189 | |
189 | # check we can use filtered_types to grab a specific room member | |
190 | # without filtering out the other event types | |
190 | # check we can grab a specific room member without filtering out the | |
191 | # other event types | |
191 | 192 | state = yield self.store.get_state_for_event( |
192 | 193 | e5.event_id, |
193 | [(EventTypes.Member, self.u_alice.to_string())], | |
194 | filtered_types=[EventTypes.Member], | |
194 | state_filter=StateFilter( | |
195 | types={EventTypes.Member: {self.u_alice.to_string()}}, | |
196 | include_others=True, | |
197 | ) | |
195 | 198 | ) |
196 | 199 | |
197 | 200 | self.assertStateMapEqual( |
203 | 206 | state, |
204 | 207 | ) |
205 | 208 | |
206 | # check that types=[], filtered_types=[EventTypes.Member] | |
207 | # doesn't return all members | |
208 | state = yield self.store.get_state_for_event( | |
209 | e5.event_id, [], filtered_types=[EventTypes.Member] | |
209 | # check that we can grab everything except members | |
210 | state = yield self.store.get_state_for_event( | |
211 | e5.event_id, state_filter=StateFilter( | |
212 | types={EventTypes.Member: set()}, | |
213 | include_others=True, | |
214 | ), | |
210 | 215 | ) |
211 | 216 | |
212 | 217 | self.assertStateMapEqual( |
214 | 219 | ) |
215 | 220 | |
216 | 221 | ####################################################### |
217 | # _get_some_state_from_cache tests against a full cache | |
222 | # _get_state_for_group_using_cache tests against a full cache | |
218 | 223 | ####################################################### |
219 | 224 | |
220 | 225 | room_id = self.room.to_string() |
221 | 226 | group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) |
222 | 227 | group = list(group_ids.keys())[0] |
223 | 228 | |
224 | # test _get_some_state_from_cache correctly filters out members with types=[] | |
225 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
226 | self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] | |
229 | # test _get_state_for_group_using_cache correctly filters out members | |
230 | # with types=[] | |
231 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
232 | self.store._state_group_cache, group, | |
233 | state_filter=StateFilter( | |
234 | types={EventTypes.Member: set()}, | |
235 | include_others=True, | |
236 | ), | |
227 | 237 | ) |
228 | 238 | |
229 | 239 | self.assertEqual(is_all, True) |
235 | 245 | state_dict, |
236 | 246 | ) |
237 | 247 | |
238 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
239 | self.store._state_group_members_cache, | |
240 | group, | |
241 | [], | |
242 | filtered_types=[EventTypes.Member], | |
248 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
249 | self.store._state_group_members_cache, | |
250 | group, | |
251 | state_filter=StateFilter( | |
252 | types={EventTypes.Member: set()}, | |
253 | include_others=True, | |
254 | ), | |
243 | 255 | ) |
244 | 256 | |
245 | 257 | self.assertEqual(is_all, True) |
246 | 258 | self.assertDictEqual({}, state_dict) |
247 | 259 | |
248 | # test _get_some_state_from_cache correctly filters in members with wildcard types | |
249 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
260 | # test _get_state_for_group_using_cache correctly filters in members | |
261 | # with wildcard types | |
262 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
250 | 263 | self.store._state_group_cache, |
251 | 264 | group, |
252 | [(EventTypes.Member, None)], | |
253 | filtered_types=[EventTypes.Member], | |
265 | state_filter=StateFilter( | |
266 | types={EventTypes.Member: None}, | |
267 | include_others=True, | |
268 | ), | |
254 | 269 | ) |
255 | 270 | |
256 | 271 | self.assertEqual(is_all, True) |
262 | 277 | state_dict, |
263 | 278 | ) |
264 | 279 | |
265 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
266 | self.store._state_group_members_cache, | |
267 | group, | |
268 | [(EventTypes.Member, None)], | |
269 | filtered_types=[EventTypes.Member], | |
280 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
281 | self.store._state_group_members_cache, | |
282 | group, | |
283 | state_filter=StateFilter( | |
284 | types={EventTypes.Member: None}, | |
285 | include_others=True, | |
286 | ), | |
270 | 287 | ) |
271 | 288 | |
272 | 289 | self.assertEqual(is_all, True) |
279 | 296 | state_dict, |
280 | 297 | ) |
281 | 298 | |
282 | # test _get_some_state_from_cache correctly filters in members with specific types | |
283 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
299 | # test _get_state_for_group_using_cache correctly filters in members | |
300 | # with specific types | |
301 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
284 | 302 | self.store._state_group_cache, |
285 | 303 | group, |
286 | [(EventTypes.Member, e5.state_key)], | |
287 | filtered_types=[EventTypes.Member], | |
304 | state_filter=StateFilter( | |
305 | types={EventTypes.Member: {e5.state_key}}, | |
306 | include_others=True, | |
307 | ), | |
288 | 308 | ) |
289 | 309 | |
290 | 310 | self.assertEqual(is_all, True) |
296 | 316 | state_dict, |
297 | 317 | ) |
298 | 318 | |
299 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
300 | self.store._state_group_members_cache, | |
301 | group, | |
302 | [(EventTypes.Member, e5.state_key)], | |
303 | filtered_types=[EventTypes.Member], | |
319 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
320 | self.store._state_group_members_cache, | |
321 | group, | |
322 | state_filter=StateFilter( | |
323 | types={EventTypes.Member: {e5.state_key}}, | |
324 | include_others=True, | |
325 | ), | |
304 | 326 | ) |
305 | 327 | |
306 | 328 | self.assertEqual(is_all, True) |
307 | 329 | self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) |
308 | 330 | |
309 | # test _get_some_state_from_cache correctly filters in members with specific types | |
310 | # and no filtered_types | |
311 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
312 | self.store._state_group_members_cache, | |
313 | group, | |
314 | [(EventTypes.Member, e5.state_key)], | |
315 | filtered_types=None, | |
331 | # test _get_state_for_group_using_cache correctly filters in members | |
332 | # with specific types | |
333 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
334 | self.store._state_group_members_cache, | |
335 | group, | |
336 | state_filter=StateFilter( | |
337 | types={EventTypes.Member: {e5.state_key}}, | |
338 | include_others=False, | |
339 | ), | |
316 | 340 | ) |
317 | 341 | |
318 | 342 | self.assertEqual(is_all, True) |
356 | 380 | ############################################ |
357 | 381 | # test that things work with a partial cache |
358 | 382 | |
359 | # test _get_some_state_from_cache correctly filters out members with types=[] | |
383 | # test _get_state_for_group_using_cache correctly filters out members | |
384 | # with types=[] | |
360 | 385 | room_id = self.room.to_string() |
361 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
362 | self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] | |
386 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
387 | self.store._state_group_cache, group, | |
388 | state_filter=StateFilter( | |
389 | types={EventTypes.Member: set()}, | |
390 | include_others=True, | |
391 | ), | |
363 | 392 | ) |
364 | 393 | |
365 | 394 | self.assertEqual(is_all, False) |
366 | 395 | self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) |
367 | 396 | |
368 | 397 | room_id = self.room.to_string() |
369 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
370 | self.store._state_group_members_cache, | |
371 | group, | |
372 | [], | |
373 | filtered_types=[EventTypes.Member], | |
398 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
399 | self.store._state_group_members_cache, | |
400 | group, | |
401 | state_filter=StateFilter( | |
402 | types={EventTypes.Member: set()}, | |
403 | include_others=True, | |
404 | ), | |
374 | 405 | ) |
375 | 406 | |
376 | 407 | self.assertEqual(is_all, True) |
377 | 408 | self.assertDictEqual({}, state_dict) |
378 | 409 | |
379 | # test _get_some_state_from_cache correctly filters in members wildcard types | |
380 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
410 | # test _get_state_for_group_using_cache correctly filters in members | |
411 | # wildcard types | |
412 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
381 | 413 | self.store._state_group_cache, |
382 | 414 | group, |
383 | [(EventTypes.Member, None)], | |
384 | filtered_types=[EventTypes.Member], | |
415 | state_filter=StateFilter( | |
416 | types={EventTypes.Member: None}, | |
417 | include_others=True, | |
418 | ), | |
385 | 419 | ) |
386 | 420 | |
387 | 421 | self.assertEqual(is_all, False) |
388 | 422 | self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) |
389 | 423 | |
390 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
391 | self.store._state_group_members_cache, | |
392 | group, | |
393 | [(EventTypes.Member, None)], | |
394 | filtered_types=[EventTypes.Member], | |
424 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
425 | self.store._state_group_members_cache, | |
426 | group, | |
427 | state_filter=StateFilter( | |
428 | types={EventTypes.Member: None}, | |
429 | include_others=True, | |
430 | ), | |
395 | 431 | ) |
396 | 432 | |
397 | 433 | self.assertEqual(is_all, True) |
403 | 439 | state_dict, |
404 | 440 | ) |
405 | 441 | |
406 | # test _get_some_state_from_cache correctly filters in members with specific types | |
407 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
442 | # test _get_state_for_group_using_cache correctly filters in members | |
443 | # with specific types | |
444 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
408 | 445 | self.store._state_group_cache, |
409 | 446 | group, |
410 | [(EventTypes.Member, e5.state_key)], | |
411 | filtered_types=[EventTypes.Member], | |
447 | state_filter=StateFilter( | |
448 | types={EventTypes.Member: {e5.state_key}}, | |
449 | include_others=True, | |
450 | ), | |
412 | 451 | ) |
413 | 452 | |
414 | 453 | self.assertEqual(is_all, False) |
415 | 454 | self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) |
416 | 455 | |
417 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
418 | self.store._state_group_members_cache, | |
419 | group, | |
420 | [(EventTypes.Member, e5.state_key)], | |
421 | filtered_types=[EventTypes.Member], | |
456 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
457 | self.store._state_group_members_cache, | |
458 | group, | |
459 | state_filter=StateFilter( | |
460 | types={EventTypes.Member: {e5.state_key}}, | |
461 | include_others=True, | |
462 | ), | |
422 | 463 | ) |
423 | 464 | |
424 | 465 | self.assertEqual(is_all, True) |
425 | 466 | self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) |
426 | 467 | |
427 | # test _get_some_state_from_cache correctly filters in members with specific types | |
428 | # and no filtered_types | |
429 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
468 | # test _get_state_for_group_using_cache correctly filters in members | |
469 | # with specific types | |
470 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
430 | 471 | self.store._state_group_cache, |
431 | 472 | group, |
432 | [(EventTypes.Member, e5.state_key)], | |
433 | filtered_types=None, | |
473 | state_filter=StateFilter( | |
474 | types={EventTypes.Member: {e5.state_key}}, | |
475 | include_others=False, | |
476 | ), | |
434 | 477 | ) |
435 | 478 | |
436 | 479 | self.assertEqual(is_all, False) |
437 | 480 | self.assertDictEqual({}, state_dict) |
438 | 481 | |
439 | (state_dict, is_all) = yield self.store._get_some_state_from_cache( | |
440 | self.store._state_group_members_cache, | |
441 | group, | |
442 | [(EventTypes.Member, e5.state_key)], | |
443 | filtered_types=None, | |
482 | (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( | |
483 | self.store._state_group_members_cache, | |
484 | group, | |
485 | state_filter=StateFilter( | |
486 | types={EventTypes.Member: {e5.state_key}}, | |
487 | include_others=False, | |
488 | ), | |
444 | 489 | ) |
445 | 490 | |
446 | 491 | self.assertEqual(is_all, True) |
123 | 123 | config.user_consent_server_notice_content = None |
124 | 124 | config.block_events_without_consent_error = None |
125 | 125 | config.media_storage_providers = [] |
126 | config.autocreate_auto_join_rooms = True | |
126 | 127 | config.auto_join_rooms = [] |
127 | 128 | config.limit_usage_by_mau = False |
128 | 129 | config.hs_disabled = False |
2 | 2 | |
3 | 3 | [base] |
4 | 4 | deps = |
5 | coverage | |
6 | 5 | Twisted>=17.1 |
7 | 6 | mock |
8 | 7 | python-subunit |
25 | 24 | |
26 | 25 | commands = |
27 | 26 | /usr/bin/find "{toxinidir}" -name '*.pyc' -delete |
28 | coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \ | |
29 | "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} | |
30 | {env:DUMP_COVERAGE_COMMAND:coverage report -m} | |
27 | "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} | |
31 | 28 | |
32 | 29 | [testenv:py27] |
33 | 30 | |
107 | 104 | |
108 | 105 | [testenv:pep8] |
109 | 106 | skip_install = True |
110 | basepython = python2.7 | |
107 | basepython = python3.6 | |
111 | 108 | deps = |
112 | 109 | flake8 |
113 | commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}" | |
110 | commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" | |
114 | 111 | |
115 | 112 | [testenv:check_isort] |
116 | 113 | skip_install = True |